salad.solver.da package¶
Domain Adaptation solvers
Submodules¶
salad.solver.da.advdrop module¶
-
class
salad.solver.da.advdrop.
AdversarialDropoutLoss
(model, step=1)¶ Bases:
object
Loss Derivation for Adversarial Dropout Regularization
See also
salad.solver.AdversarialDropoutSolver
-
step1
(batch)¶
-
step2
(batch)¶
-
step3
(batch)¶
-
-
class
salad.solver.da.advdrop.
AdversarialDropoutSolver
(model, dataset, **kwargs)¶ Bases:
salad.solver.da.base.DABaseSolver
Implementation of “Adversarial Dropout Regularization”
Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.
Both network parts are jointly trained on the source domain using the standard cross entropy loss,
..math:
\min_{C, G} - \sum_k p^s_k \log y^s_k
The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.
..math:
\min_{C} - \sum_k p^s_k \log y^s_k + \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}
In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.
..math:
\min_{G} \frac{p^t_k - q^t_k}{2} \log \frac{p^t_k}{q^t_k}
References
[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018
-
class
salad.solver.da.advdrop.
SymmetricKL
¶ Bases:
torch.nn.modules.module.Module
-
forward
(x, y)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
-
salad.solver.da.advdrop.
pack
(*args)¶
-
salad.solver.da.advdrop.
unpack
(arg, n_tensors)¶
salad.solver.da.advdrop_refactor module¶
-
class
salad.solver.da.advdrop_refactor.
AdversarialDropoutLoss
(model, step)¶ Bases:
object
-
step1
(batch)¶
-
step2
(batch)¶
-
step3
(batch)¶
-
-
class
salad.solver.da.advdrop_refactor.
AdversarialDropoutSolver
(model, dataset, **kwargs)¶ Bases:
salad.solver.da.base.DABaseSolver
Implementation of “Adversarial Dropout Regularization”
Adversarial Dropout Regulariation [1] estimates uncertainties about the classification process by sampling different models using dropout. On the source domain, a standard cross entropy loss is employed. On the target domain, two predictions are sampled from the model.
Both network parts are jointly trained on the source domain using the standard cross entropy loss,
$$ min_{C, G} - sum_k p^s_k log y^s_k $$
The classifier part of the network is trained to maximize the symmetric KL distance between two predictions. This distance is one option for measuring uncertainty in a network. In other words, the classifier aims at maximizing uncertainty given two noisy estimates of the current feature vector.
$$ min_{C} - sum_k p^s_k log y^s_k + frac{p^t_k - q^t_k}{2} log frac{p^t_k}{q^t_k} $$
In contrast, the feature extrator aims at minimizing the uncertainty between two noisy samples given a fixed target classifier.
$$ min_{G} frac{p^t_k - q^t_k}{2} log frac{p^t_k}{q^t_k} $$
References
[1] Adversarial Dropout Regularization, Saito et al., ICLR 2018
-
class
salad.solver.da.advdrop_refactor.
SymmetricKL
¶ Bases:
torch.nn.modules.module.Module
-
forward
(x, y)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
salad.solver.da.association module¶
Associative Domain Adaptation
[Hausser et al., CVPR 2017](#)
-
class
salad.solver.da.association.
AssociationLoss
(model)¶ Bases:
object
Loss function for associative domain adaptation
Given a model, derive a function that computes arguments for the association loss.
-
class
salad.solver.da.association.
AssociativeSolver
(model, dataset, learningrate, walker_weight=1.0, visit_weight=0.1, *args, **kwargs)¶ Bases:
salad.solver.da.base.DABaseSolver
Implementation of “Associative Domain Adaptation”
Associative Domain Adaptation [1] leverages a random walk based on feature similarity as a distance between source and target feature correlations. The algorithm is based on two loss functions that are added to the standard cross entropy loss on the source domain.
Given features for source and target domain, a kernel function is used to measure similiarity between both domains. The original implementation uses the scalar product between feature vectors, scaled by an exponential,
\[K_{ij} = k(x^s_i, x^t_j) = \exp(\langle x^s_i, x^t_j \rangle)\]This kernel is then used to compute transition probabilities
\[p(x^t_j | x^s_i) = \frac{K_{ij}}{\sum_{l} K_{lj}}\]and
\[p(x^s_k | x^t_j) = \frac{K_{jk}}{\sum_{l} K_{kl}}\]to compute the roundtrip
\[p(x^s_k | x^s_i) = \sum_{j} p(x^s_k | x^t_j) p(x^t_j | x^s_i)\]It is then required that
- WalkerLoss The roundtrip ends at a sample with the same class label, i.e., $y^s_i = y^s_k$
- VisitLoss Each target sample is visited with a certain probability
As one possible modification, different kernel functions could be used to measure similarity between the domains. With this solver, it is advised to use large sample sizes for the target domain and ensure that a sufficient number of source samples is available for each batch.
TODO: Possibly in the solver class, implement a functionality to aggregate batches to avoid memory issues.
Parameters: - model (torch.nn.Module) – A pytorch model to be trained by association
- dataset (StackedDataset) – A dataset suitable for an unsupervised solver
- learningrate (int) – TODO
References
[1] Associative Domain Adaptation, Häusser et al., CVPR 2017, https://arxiv.org/abs/1708.00938
salad.solver.da.base module¶
Solver classes for domain adaptation experiments
-
class
salad.solver.da.base.
BaselineDASolver
(*args, **kwargs)¶ Bases:
salad.solver.da.base.DABaseSolver
A domain adaptation solver that actually does not run any adaptation algorithm
This is useful to establish baseline results for the case of no adaptation, for measurement of the domain shift between datasets.
-
class
salad.solver.da.base.
DABaseSolver
(*args, **kwargs)¶ Bases:
salad.solver.classification.BaseClassSolver
Base Class for Unsupervised Domain Adaptation Approaches
Unsupervised DA assumes the presence of a single source domain \(\mathcal S\) along with a target domain \(\mathcal T\) known at training time. Given a labeled sample of points drawn from \(\mathcal S\), \(\{x^s_i, y^s_i\}_{i}^{N_s}\), and an unlabeled sample drawn from \(\mathcal T\), \(\{x^t_i\}_{i}^{N_t}\), unsupervised adaptation aims at minimizing the
\[\min_\theta \mathcal{R}^l_{\mathcal S} (\theta) + \lambda \mathcal{R}^u_{\mathcal {S \times T}} (\theta),\]leveraging an unsupervised risk term \(\mathcal{R}^u_{\mathcal {S \times T}} (\theta)\) that depends on feature representations \(f_\theta(x^s,s)\) and \(f_\theta(x^t,t)\), classifier labels \(h_\theta(x^s,s), h_\theta(x^t,t)\) as well as source labels \(y^s\). The full model \(h = g \circ f\) is a composition of a feature extractor \(f\) and classifier \(g\), both of which can possibly depend on the domain label \(s\) or \(t\) for domain-specific computations.
Notes
This solver adds two accuracies with keys
acc_s
andacc_t
for the source and target domain, respectively. Make sure to include derivation of these accuracy in your loss computation.
-
class
salad.solver.da.base.
DABaselineLoss
(solver)¶ Bases:
object
-
class
salad.solver.da.base.
DATeacher
(model, teacher, dataset, *args, **kwargs)¶ Bases:
salad.solver.base.Solver
Base Class for Unsupervised Domain Adaptation Approaches using a teacher model
-
class
salad.solver.da.base.
DGBaseSolver
(*args, **kwargs)¶ Bases:
salad.solver.classification.BaseClassSolver
Base Class for Domain Generalization Approaches
Domain generalization assumes the presence of multiple source domains alongside a target domain unknown at training time. Following cite{Shankar2018}, this setting requires a dataset of training examples \(\{x_i, y_i, d_i\}_{i}^{N}\) with class and domain labels. Importantly, the domains present at training time should reflect the kind of variability that can be expected during inference. The ERM problem is then approached as
\[\min_\theta \sum_d \mathcal{R}^l_{\mathcal S_d} (\theta) = \sum_d \lambda_d \mathbb{E}_{x,y \sim \mathcal S_d }[\ell ( f_\theta(x), h_\theta(x), y, d) ].\]In contrast to the unsupervised setting, samples are now presented in a single batch comprised of inputs \(x\), labels \(y\) and domains \(d\). In a addition to a feature extractor \(f_\theta\) and classifier \(g_\theta\), models should also provide a feature extractor \(f^d_\theta\) to derive domain features along with a domain classifier \(g^d_\theta\), with possibly shared parameters.
In contrast to unsupervised DA, this training setting leverages information from multiple labeled source domains with the goal of generalizing well on data from a previously unseen domain during test time.
salad.solver.da.coral module¶
Losses for Correlatin Alignment
Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: https://arxiv.org/pdf/1607.01719.pdf
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
-
class
salad.solver.da.coral.
CentroidDistanceLossSolver
(model, dataset, *args, **kwargs)¶ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z and: https://arxiv.org/pdf/1705.08180.pdf
-
class
salad.solver.da.coral.
CentroidLoss
(model)¶ Bases:
object
-
class
salad.solver.da.coral.
CorrDistanceSolver
(model, dataset, *args, **kwargs)¶ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
-
class
salad.solver.da.coral.
CorrelationDistanceLoss
(model, n_steps_recompute=10, nullspace=False)¶ Bases:
object
-
class
salad.solver.da.coral.
CorrelationDistanceSolver
(model, dataset, *args, **kwargs)¶
-
class
salad.solver.da.coral.
DeepCoralSolver
(model, dataset, *args, **kwargs)¶ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Deep CORAL: Correlation Alignment for Deep Domain Adaptation Paper: [https://arxiv.org/pdf/1607.01719.pdf](https://arxiv.org/pdf/1607.01719.pdf)
Loss Functions:
\[\mathcal{L}(x^s, x^t) = \frac{1}{4d^2} \| C_s - C_t \|\]
-
class
salad.solver.da.coral.
DeepLogCoralSolver
(model, dataset, *args, **kwargs)¶ Bases:
salad.solver.da.coral.CorrelationDistanceSolver
Minimal Entropy Correlation Alignment for Unsupervised Domain Adaptation Paper: https://openreview.net/pdf?id=rJWechg0Z
salad.solver.da.crossgrad module¶
Cross Gradient Training
ICLR 2018
-
class
salad.solver.da.crossgrad.
CrossGradLoss
(solver)¶ Bases:
object
Cross Gradient Training
http://arxiv.org/abs/1804.10745
-
pertub
(x, loss, eps=1e-05)¶
-
-
class
salad.solver.da.crossgrad.
CrossGradSolver
(model, *args, **kwargs)¶ Bases:
salad.solver.da.base.DGBaseSolver
Cross Gradient Optimizer
A domain generalization solver based on Cross Gradient Training [1].
- ..math:
- p(y | x) = int_d p(y|x,d) p(d|x) dd
- ..math:
- x_d = x + eps Nabla_y L(y) \ x_y = x + eps Nabla_d L(d)
References
[1] (1, 2) Shankar et al., Generalizing Across Domains via Cross-Gradient Training, ICLR 2018
-
class
salad.solver.da.crossgrad.
Model
(n_classes, n_domains)¶ Bases:
torch.nn.modules.module.Module
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_domain
(x)¶
-
parameters_classifier
()¶
-
parameters_domain
()¶
-
-
class
salad.solver.da.crossgrad.
MultiDomainModule
(n_domains)¶ Bases:
torch.nn.modules.module.Module
-
forward
(x)¶ Defines the computation performed at every call.
Should be overridden by all subclasses.
Note
Although the recipe for forward pass needs to be defined within this function, one should call the
Module
instance afterwards instead of this since the former takes care of running the registered hooks while the latter silently ignores them.
-
forward_domain
(x)¶
-
parameters_classifier
()¶
-
parameters_domain
()¶
-
-
salad.solver.da.crossgrad.
concat
(x, z)¶ Concat 4D tensor with expanded 2D tensor
-
salad.solver.da.crossgrad.
conv2d
(m, n, k, act=True)¶
-
salad.solver.da.crossgrad.
features
(inp)¶
-
salad.solver.da.crossgrad.
get_dataset
(noisemodels, batch_size, shuffle=True, num_workers=0, which='train')¶
salad.solver.da.dann module¶
-
class
salad.solver.da.dann.
AdversarialLoss
(G, D, train_G=True)¶ Bases:
object
-
class
salad.solver.da.dann.
DANNSolver
(model, discriminator, dataset, learningrate, *args, **kwargs)¶ Bases:
salad.solver.da.base.DABaseSolver
Domain Adversarial Neural Networks Solver
This builds upon the normal classification solver that uses CrossEntropy or BinaryCrossEntropy for optimizing neural networks.
salad.solver.da.dirtt module¶
-
class
salad.solver.da.dirtt.
DIRTT
(model, teacher)¶ Bases:
object
-
class
salad.solver.da.dirtt.
DIRTTSolver
(model, teacher, dataset, *args, **kwargs)¶ Bases:
salad.solver.base.Solver
Virtual Adversarial Domain Adaptation
-
class
salad.solver.da.dirtt.
VADA
(G, D, train_G=True)¶
-
class
salad.solver.da.dirtt.
VADASolver
(model, discriminator, dataset, *args, **kwargs)¶ Bases:
salad.solver.da.dann.DANNSolver
Virtual Adversarial Domain Adaptation
salad.solver.da.dirtt_re module¶
Self Ensembling for Visual Domain Adaptation
-
class
salad.solver.da.dirtt_re.
DIRTT
(model, teacher)¶ Bases:
object
-
class
salad.solver.da.dirtt_re.
DIRTTSolver
(model, teacher, dataset, learningrate, *args, **kwargs)¶
salad.solver.da.djdot module¶
-
class
salad.solver.da.djdot.
DJDOTSolver
(model, dataset, *args, **kwargs)¶ Bases:
salad.solver.classification.BaseClassSolver
Deep Joint Optimal Transport solver
TODO
-
derive_losses
(batch)¶
-