salad.layers package

Submodules

salad.layers.association module

class salad.layers.association.Accuracy

Bases: torch.nn.modules.module.Module

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AssociationMatrix(verbose=False)

Bases: torch.nn.modules.module.Module

forward(xs, xt)

xs: (Ns, K, …) xt: (Nt, K, …)

class salad.layers.association.AssociativeLoss(walker_weight=1.0, visit_weight=1.0)

Bases: torch.nn.modules.module.Module

Association Loss for Domain Adaptation

Reference: Associative Domain Adaptation, Hausser et al. (2017)

forward(xs, xt, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AugmentationLoss(aug_loss_func=MSELoss(), use_rampup=True)

Bases: torch.nn.modules.module.Module

Augmentation Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt

forward()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.ClassBalanceLoss

Bases: torch.nn.modules.module.Module

Class Balance Loss from https://github.com/Britefury/self-ensemble-visual-domain-adapt

forward(tea_out)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.OTLoss

Bases: torch.nn.modules.module.Module

forward(xs, ys, xt, yt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.VisitLoss

Bases: torch.nn.modules.module.Module

forward(Pt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WalkerLoss

Bases: torch.nn.modules.module.Module

forward(Psts, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WassersteinLoss

Bases: torch.nn.modules.module.Module

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.layers.base module

class salad.layers.base.AccuracyScore

Bases: torch.nn.modules.module.Module

forward(y, t)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.KLDivWithLogits

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.MeanAccuracyScore

Bases: torch.nn.modules.module.Module

forward(y, t)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.WeightedCE(confidence_threshold=0.96837722)

Bases: torch.nn.modules.module.Module

Adapted from Self-Ensembling repository

forward(logits, logits_target)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

robust_binary_crossentropy(pred, tgt)

salad.layers.coral module

class salad.layers.coral.AffineInvariantDivergence

Bases: salad.layers.coral.CorrelationDistance

class salad.layers.coral.CoralLoss

Bases: salad.layers.coral.CorrelationDistance

Deep CORAL loss from paper: https://arxiv.org/pdf/1607.01719.pdf

class salad.layers.coral.CorrelationDistance(distance=<function euclid>)

Bases: torch.nn.modules.module.Module

forward(xs, xt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.JeffreyDivergence

Bases: salad.layers.coral.CorrelationDistance

Log Coral Loss

class salad.layers.coral.LogCoralLoss

Bases: torch.nn.modules.module.Module

forward(xs, xt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.SteinDivergence

Bases: salad.layers.coral.CorrelationDistance

Log Coral Loss

salad.layers.da module

class salad.layers.da.AutoAlign2d

Bases: torch.nn.modules.batchnorm.BatchNorm2d

forward()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.da.FeatureAwareNormalization

Bases: torch.nn.modules.module.Module

salad.layers.funcs module

salad.layers.funcs.concat(x, z)

Concat 4D tensor with expanded 2D tensor

salad.layers.mat module

Metrics and Divergences for Correlation Matrices

salad.layers.mat.affineinvariant(A, B)
salad.layers.mat.apply(C, func)
salad.layers.mat.cov(x, eps=1e-05)

Estimate the covariance matrix

salad.layers.mat.euclid(A, B)
salad.layers.mat.getdata(N, d, std)
salad.layers.mat.jeffrey(A, B)
salad.layers.mat.logeuclid(A, B)
salad.layers.mat.riemann(A, B)
salad.layers.mat.stable_logdet(A)

Compute the logarithm of the determinant of matrix in a numerically stable way

salad.layers.mat.stein(A, B)

salad.layers.vat module

class salad.layers.vat.ConditionalEntropy

Bases: torch.nn.modules.module.Module

estimates the conditional cross entropy of the input

$$
rac{1}{n} sum_i sum_c p(y_i = c | x_i) log p(y_i = c | x_i)

$$

By default, will assume that samples are across the first and class probabilities across the second dimension.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.vat.VATLoss(model, radius=1)

Bases: torch.nn.modules.module.Module

Virtual Adversarial Training Loss function

Reference: TODO

forward(x, p)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.layers.vat.normalize_perturbation(d)