salad.layers packageΒΆ
SubmodulesΒΆ
salad.layers.association moduleΒΆ
-
class
salad.layers.association.AccuracyΒΆ Bases:
torch.nn.modules.module.Module-
forward(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.AssociationMatrix(verbose=False)ΒΆ Bases:
torch.nn.modules.module.Module-
forward(xs, xt)ΒΆ xs: (Ns, K, β¦) xt: (Nt, K, β¦)
-
-
class
salad.layers.association.AssociativeLoss(walker_weight=1.0, visit_weight=1.0)ΒΆ Bases:
torch.nn.modules.module.ModuleAssociation Loss for Domain Adaptation
Reference: Associative Domain Adaptation, Hausser et al. (2017)
-
forward(xs, xt, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.AugmentationLoss(aug_loss_func=MSELoss(), use_rampup=True)ΒΆ Bases:
torch.nn.modules.module.ModuleAugmentation Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt
-
forward()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.ClassBalanceLossΒΆ Bases:
torch.nn.modules.module.ModuleClass Balance Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt
-
forward(tea_out)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.OTLossΒΆ Bases:
torch.nn.modules.module.Module-
forward(xs, ys, xt, yt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.VisitLossΒΆ Bases:
torch.nn.modules.module.Module-
forward(Pt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.WalkerLossΒΆ Bases:
torch.nn.modules.module.Module-
forward(Psts, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.association.WassersteinLossΒΆ Bases:
torch.nn.modules.module.Module-
forward(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
salad.layers.base moduleΒΆ
-
class
salad.layers.base.AccuracyScoreΒΆ Bases:
torch.nn.modules.module.Module-
forward(y, t)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.KLDivWithLogitsΒΆ Bases:
torch.nn.modules.module.Module-
forward(x, y)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.MeanAccuracyScoreΒΆ Bases:
torch.nn.modules.module.Module-
forward(y, t)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.base.WeightedCE(confidence_threshold=0.96837722)ΒΆ Bases:
torch.nn.modules.module.ModuleAdapted from Self-Ensembling repository
-
forward(logits, logits_target)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
robust_binary_crossentropy(pred, tgt)ΒΆ
-
salad.layers.coral moduleΒΆ
-
class
salad.layers.coral.AffineInvariantDivergenceΒΆ
-
class
salad.layers.coral.CoralLossΒΆ Bases:
salad.layers.coral.CorrelationDistanceDeep CORAL loss from paper: https://arxiv.org/pdf/1607.01719.pdf
-
class
salad.layers.coral.CorrelationDistance(distance=<function euclid>)ΒΆ Bases:
torch.nn.modules.module.Module-
forward(xs, xt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.coral.JeffreyDivergenceΒΆ Bases:
salad.layers.coral.CorrelationDistanceLog Coral Loss
-
class
salad.layers.coral.LogCoralLossΒΆ Bases:
torch.nn.modules.module.Module-
forward(xs, xt)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.coral.SteinDivergenceΒΆ Bases:
salad.layers.coral.CorrelationDistanceLog Coral Loss
salad.layers.da moduleΒΆ
-
class
salad.layers.da.AutoAlign2dΒΆ Bases:
torch.nn.modules.batchnorm.BatchNorm2d-
forward()ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
class
salad.layers.da.FeatureAwareNormalizationΒΆ Bases:
torch.nn.modules.module.Module
salad.layers.funcs moduleΒΆ
-
salad.layers.funcs.concat(x, z)ΒΆ Concat 4D tensor with expanded 2D tensor
salad.layers.mat moduleΒΆ
Metrics and Divergences for Correlation Matrices
-
salad.layers.mat.affineinvariant(A, B)ΒΆ
-
salad.layers.mat.apply(C, func)ΒΆ
-
salad.layers.mat.cov(x, eps=1e-05)ΒΆ Estimate the covariance matrix
-
salad.layers.mat.euclid(A, B)ΒΆ
-
salad.layers.mat.getdata(N, d, std)ΒΆ
-
salad.layers.mat.jeffrey(A, B)ΒΆ
-
salad.layers.mat.logeuclid(A, B)ΒΆ
-
salad.layers.mat.riemann(A, B)ΒΆ
-
salad.layers.mat.stable_logdet(A)ΒΆ Compute the logarithm of the determinant of matrix in a numerically stable way
-
salad.layers.mat.stein(A, B)ΒΆ
salad.layers.vat moduleΒΆ
-
class
salad.layers.vat.ConditionalEntropyΒΆ Bases:
torch.nn.modules.module.Moduleestimates the conditional cross entropy of the input
$$- rac{1}{n} sum_i sum_c p(y_i = c | x_i) log p(y_i = c | x_i)
$$
By default, will assume that samples are across the first and class probabilities across the second dimension.
-
forward(input)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
class
salad.layers.vat.VATLoss(model, radius=1)ΒΆ Bases:
torch.nn.modules.module.ModuleVirtual Adversarial Training Loss function
Reference: TODO
-
forward(x, p)ΒΆ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Moduleinstance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
salad.layers.vat.normalize_perturbation(d)ΒΆ