## Submodules¶

class salad.layers.association.Accuracy

Bases: torch.nn.modules.module.Module

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AssociationMatrix(verbose=False)

Bases: torch.nn.modules.module.Module

forward(xs, xt)

xs: (Ns, K, …) xt: (Nt, K, …)

class salad.layers.association.AssociativeLoss(walker_weight=1.0, visit_weight=1.0)

Bases: torch.nn.modules.module.Module

Reference: Associative Domain Adaptation, Hausser et al. (2017)

forward(xs, xt, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.AugmentationLoss(aug_loss_func=MSELoss(), use_rampup=True)

Bases: torch.nn.modules.module.Module

forward()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.ClassBalanceLoss

Bases: torch.nn.modules.module.Module

forward(tea_out)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.OTLoss

Bases: torch.nn.modules.module.Module

forward(xs, ys, xt, yt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.VisitLoss

Bases: torch.nn.modules.module.Module

forward(Pt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WalkerLoss

Bases: torch.nn.modules.module.Module

forward(Psts, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.association.WassersteinLoss

Bases: torch.nn.modules.module.Module

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.AccuracyScore

Bases: torch.nn.modules.module.Module

forward(y, t)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.KLDivWithLogits

Bases: torch.nn.modules.module.Module

forward(x, y)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.MeanAccuracyScore

Bases: torch.nn.modules.module.Module

forward(y, t)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.base.WeightedCE(confidence_threshold=0.96837722)

Bases: torch.nn.modules.module.Module

forward(logits, logits_target)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

robust_binary_crossentropy(pred, tgt)

class salad.layers.coral.AffineInvariantDivergence
class salad.layers.coral.CoralLoss

Deep CORAL loss from paper: https://arxiv.org/pdf/1607.01719.pdf

class salad.layers.coral.CorrelationDistance(distance=<function euclid>)

Bases: torch.nn.modules.module.Module

forward(xs, xt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.JeffreyDivergence

Log Coral Loss

class salad.layers.coral.LogCoralLoss

Bases: torch.nn.modules.module.Module

forward(xs, xt)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.coral.SteinDivergence

Log Coral Loss

class salad.layers.da.AutoAlign2d

Bases: torch.nn.modules.batchnorm.BatchNorm2d

forward()

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.da.FeatureAwareNormalization

Bases: torch.nn.modules.module.Module

salad.layers.funcs.concat(x, z)

Concat 4D tensor with expanded 2D tensor

Metrics and Divergences for Correlation Matrices

salad.layers.mat.affineinvariant(A, B)
salad.layers.mat.apply(C, func)
salad.layers.mat.cov(x, eps=1e-05)

Estimate the covariance matrix

salad.layers.mat.euclid(A, B)
salad.layers.mat.getdata(N, d, std)
salad.layers.mat.jeffrey(A, B)
salad.layers.mat.logeuclid(A, B)
salad.layers.mat.riemann(A, B)
salad.layers.mat.stable_logdet(A)

Compute the logarithm of the determinant of matrix in a numerically stable way

salad.layers.mat.stein(A, B)

class salad.layers.vat.ConditionalEntropy

Bases: torch.nn.modules.module.Module

estimates the conditional cross entropy of the input

$$rac{1}{n} sum_i sum_c p(y_i = c | x_i) log p(y_i = c | x_i)$$

By default, will assume that samples are across the first and class probabilities across the second dimension.

forward(input)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

class salad.layers.vat.VATLoss(model, radius=1)

Bases: torch.nn.modules.module.Module

Reference: TODO

forward(x, p)

Defines the computation performed at every call.

Should be overridden by all subclasses.

Note

Although the recipe for forward pass needs to be defined within this function, one should call the Module instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.

salad.layers.vat.normalize_perturbation(d)