26/05/2022 Technical Blog

On Learning Domain-Invariant Representations for Transfer Learning with Multiple Sources

  • 58 minutes
  • Trung Phung (Research Resident)

Share

1. Introduction

The central question of transfer learning from one domain to another domain is how could we ensure good performance of the model on the test domain, given the model having good performance on the training domain. Here, a domain is defined generally as a joint distribution on data-label space. Depending on the availability of the target domain during training, we encounter either multiple source domain adaptation (MSDA) or domain generalization (DG) problem.

Domain adaptation (DA) is a specific case of MSDA when we need to transfer from a single source domain to another target domain available during training. For the DA setting, domain-invariant (DI) representations are thoroughly studied [1]: what is it, how to learn this kind of representations [3], and the trade-off of enforcing learning DI representations [6].

On the other hand, due to the presence of multiple source domains and possibly unavailability of target domain, establishing theoretical foundation and characterizing DI representations for the MSDA and especially DG settings are significantly more challenging. Moreover, previous works [1, 5, 4, 7, 6] do not investigate the DI representations under a representation learning viewpoint. The common approach is connecting loss on target domain to divergence of data distribution, labelling function mismatch, and loss on source domain, without any mention of representation. On the other hand, we argue that the representation must be explicitly discussed in order to conclude anything about the DI representation.

2. Main Theory

2.1. Notations

Let \mathcal{X} be a data space, \mathbb{P} be a data distribution on this space, and p(x) be the corresponding density function We consider the multi-class classification problem with the label set \mathcal{Y}=\left[C\right], where C is the number of classes and \left[C\right]:=\{1,\ldots,C\}. Denote \mathcal{Y}_{\Delta}:=\left\{ \alpha\in\mathbb{R}^{C}:\left\Vert \alpha\right\Vert {}_{1}=1\,\wedge\,\alpha\geq\mathbf{0}\right\} as the C-simplex label space, let f:\mathcal{X}\mapsto\mathcal{Y}_{\Delta} be a probabilistic labeling function returning a C-tuple f\left(x\right)=\left[f\left(x,i\right)\right]_{i=1}^{C}, whose element f\left(x,i\right)=p\left(y=i\mid x\right) is the probability to assign a data sample x\sim\mathbb{P} to the class i (i.e., i\in\left\{ 1,...,C\right\}). Moreover, a domain is denoted compactly as pair of data distribution and labeling function \mathbb{D}:=\left(\mathbb{P},f\right). We note that given a data sample x\sim\mathbb{P}, its categorical label y\in\mathcal{Y} is sampled as y\sim Cat\left(f\left(x\right)\right) which a categorical distribution over f\left(x\right)\in\mathcal{Y}_{\Delta}.

Let l:\mathcal{Y}_{\Delta}\times\mathcal{Y}\mapsto\mathbb{R} be a loss function. The general loss of a classifier \hat{f}:\mathcal{X}\mapsto\mathcal{Y}_{\Delta} on a domain \mathbb{D}\equiv\left(\mathbb{P},f\right) is

    \[ \mathcal{L}\left(\hat{f},f,\mathbb{P}\right)=\mathcal{L}\left(\hat{f},\mathbb{D}\right):=\mathbb{E}_{x\sim\mathbb{P}}\left[\ell\left(\hat{f}(x),f(x)\right)\right]. \]

We inspect the multiple source setting in which we are given multiple source domains \{\mathbb{D}^{S,i}\}_{i=1}^{K} over the common data space \mathcal{X}, each of which consists of data distribution and its own labeling function \mathbb{D}^{S,i}:=\left(\mathbb{P}^{S,i},f^{S,i}\right). When combining the source domains, we obtain a mixture of multiple source distributions denoted as \mathbb{D}^{\pi}=\sum_{i=1}^{K}\pi_{i}\mathbb{D}^{S,i}, where the mixing coefficients  \pi=\left[\pi_{i}\right]_{i=1}^{K} can be conveniently set to \pi_{i}=\frac{N_{i}}{\sum_{j=1}^{K}N_{j}} with N_{i} being the training size of the i-th source domain. Similarly, the target domain is \mathbb{D}^{T}:=\left(\mathbb{P}^{T},f^{T}\right).

To this end, previous works generally upper bound the target loss using sum of source loss, a divergence term between source and target data distributions, and a constant involving labeling functions of source and target domains.

    \[ \mathcal{L}\left(\hat{f},\mathbb{D}^{T}\right)\leq\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{f},\mathbb{D}^{S,i}\right)+D\left(\sum_{i=1}^{K}\pi_{i}\mathbb{P}^{S,i},\mathbb{P}^{T}\right)+\lambda \]

From this formulation, learning DI representation is said to be encouraged via minimizing the divergence D\left(\sum_{i=1}^{K}\pi_{i}\mathbb{P}^{S,i},\mathbb{P}^{T}\right), but on the representation space, e.g., D\left(\sum_{i=1}^{K}\pi_{i}\mathbb{P}_{g}^{S,i},\mathbb{P}_{g}^{T}\right) \mathbb{P}_{g} denoting representation distribution). However, we argue that this formulation is not explicit and concrete for discussing DI representation learning.

In our setting, input is mapped into a latent space \mathcal{Z} by a feature map g:\mathcal{X}\mapsto\mathcal{Z}, and then a classifier \hat{h}:\mathcal{Z}\mapsto\mathcal{Y}_{\Delta} is trained based on the representations g\left(\mathcal{X}\right). Let f:\mathcal{X}\mapsto\mathcal{Y}_{\Delta} be the original labeling function. To facilitate the theory developed for latent space, we introduce representation distribution being the pushed-forward distribution \mathbb{P}_{g}:=g_{\#}\mathbb{P}, and the labeling function h:\mathcal{Z}\mapsto\mathcal{Y}_{\Delta} induced by g as h(z)=\frac{\int_{g^{-1}\left(z\right)}f(x)p(x)dx}{\int_{g^{-1}\left(z\right)}p(x)dx}. Going back to our multiple source setting, the source mixture becomes \mathbb{D}_{g}^{\pi}=\sum_{i}\pi_{i}\mathbb{D}_{g}^{S,i}, where each source domain is \mathbb{D}_{g}^{S,i}=\left(\mathbb{P}_{g}^{S,i},h^{S,i}\right), and the target domain is \mathbb{D}_{g}^{T}=\left(\mathbb{P}_{g}^{T},h^{T}\right).

Finally, in our theory development, we use Hellinger divergence between two distributions defined as D_{1 / 2}\left(\mathbb{P}^{1}, \mathbb{P}^{2}\right)=2 \int\left(\sqrt{p^{1}(x)}-\sqrt{p^{2}(x)}\right)^{2} d x, whose squared d_{1 / 2}=\sqrt{D_{1 / 2}} is a proper metric.

2.2. Two types of representation

Theorem 1. Consider a mixture of source domains \mathbb{D}^{\pi}=\sum_{i=1}^{K}\pi_{i}\mathbb{D}^{S,i} and the target domain \mathbb{D}^{T}. Let \ell be any loss function upper-bounded by a positive constant L. For any hypothesis \hat{f}:\mathcal{X}\mapsto\mathcal{Y}_{\Delta} where \hat{f}=\hat{h}\circ g with g:\mathcal{X}\mapsto\mathcal{Z} and \hat{h}:\mathcal{Z}\mapsto\mathcal{Y}_{\Delta}, the target loss on input space is upper bounded

(1)   \begin{equation*}</em> <em>\begin{aligned}\mathcal{L}\left(\hat{f},\mathbb{D}^{T}\right)\leq\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{f},\mathbb{D}^{S,i}\right)+L\max_{i\in[K]}\mathbb{E}_{\mathbb{P}^{S,i}}\left[\|\Delta p^{i}(y|x)\|_{1}\right]+L\sqrt{2}\,d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right)\end{aligned}</em> <em></em> <em>\end{equation*}

where \Delta p^{i}(y|x):=\left[\left|f^{T}(x,y)-f^{S,i}(x,y)\right|\right]_{y=1}^{C} is the absolute of single point label shift on input space between source domain \mathbb{D}^{S,i}, the target domain \mathbb{D}^{T}, and [K]:=\left\{ 1,2,...,K\right\}.

The bound in Equation 1 implies that the target loss in the input or latent space depends on three terms: (i) representation discrepancy: d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right), (ii) the label shift: \max_{i\in[K]}\mathbb{E}_{\mathbb{P}^{S,i}}\left[\|\Delta p^{i}(y|x)\|_{1}\right], and (iii) the general source loss: \sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{f},\mathbb{D}^{S,i}\right). To minimize the target loss in the left side, we need to minimize the three aforementioned terms. First, the \emph{label shift} term is a natural characteristics of domains, hence almost impossible to tackle. Secondly, the representation discrepancy term can be explicitly tackled for the MSDA setting, while almost impossible for the DG setting. Finally, the general source loss term is convenient to tackle, where its minimization results in a feature extractor g and a classifier \hat{h}.

Contrary to previous works in DA and MSDA [1, 5, 7, 2] that consider both losses and data discrepancy on data space, our bound connects losses on data space to discrepancy on representation space. Therefore, our theory provides a natural way to analyse representation learning, especially feature alignment in deep learning practice. Note that although DANN [3] explains their feature alignment method using theory developed by Ben-david et al. [1], it is not rigorous. In particular, while application of the theory to representation space yield a representation discrepancy term, the loss terms are also on that feature space, and hence minimizing these losses is not the learning goal. Finally, our setting is much more general, which extends to multilabel, stochastic labeling setting, and any bounded loss function.

As the final piece, the representation discrepancy can be broken down as L\sqrt{2}\,d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right)\leq\sum_{i=1}^{K}\sum_{j=1}^{K}\frac{L\sqrt{2\pi_{j}}} {K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right)+\sum_{i=1}^{K}\sum_{j=1}^{K}\,\frac{L\sqrt{2\pi_{j}}}{K}d_{1/2}\left(\mathbb{P}_{g}^{S,i},\mathbb{P}_{g}^{S,j}\right). To this end, we could define two types of DI representation.

Definition 2. i) (General Domain-Invariant Representation) A feature map g^{*}\in\mathcal{G} is said to be a DG general domain-invariant (DI) feature map if g^{*} is the solution of the optimization problem (OP): \min_{g\in\mathcal{G}}\min_{\hat{h}\in\mathcal{H}}\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{h},\mathbb{D}_{g}^{S,i}\right). Moreover, the latent representations z=g^{*}\left(x\right) induced by g^{*} is called general DI representations for the DG setting.

ii) (Compressed Domain-Invariant Representation) A feature map g^{*}\in\mathcal{G} is a DG compressed DI representations for source domains \{\mathbb{D}^{S,i}\}_{i=1}^{K} if g^{*} is the solution of the optimization problem (OP): \min_{g\in\mathcal{G}}\text{\ensuremath{\min}}_{\hat{h}\in\mathcal{H}}\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{h},\mathbb{D}_{g}^{S,i}\right) which satisfies \mathbb{P}_{g^{*}}^{S,1}=\mathbb{P}_{g^{*}}^{S,2}=\ldots=\mathbb{P}_{g^{*}}^{S,K}
(i.e., the pushed forward distributions of all source domains are identical). The latent representations z=g^{*}\left(x\right) is then called compressed DI representations for the DG setting.

2.3. Trade-off bound

Similar to the theoretical finding in Zhao et al. 2019 [6] developed for DA, we theoretically find that compression does come with a cost for MSDA and DG. We investigate the representation trade-off, typically how compressed DI representation affects classification loss. Specifically, we consider a data processing chain \mathcal{X}\stackrel{g}{\longmapsto}\mathcal{Z}\stackrel{\hat{h}}{\longmapsto}\mathcal{Y}_{\Delta}, where \mathcal{X} is the common data space, \mathcal{Z} is the latent space induced by a feature extractor g, and \hat{h} is a hypothesis on top of the latent space. We define \mathbb{P}_{\mathcal{Y}}^{\pi} and \mathbb{P}_{\mathcal{Y}}^{T} as two distribution over \mathcal{Y} in which to draw y\sim\mathbb{P}_{\mathcal{Y}}^{\pi}, we sample k\sim Cat\left(\pi\right), x\sim\mathbb{P}^{S,k}, and y\sim f^{S,k}\left(x\right), while similar to draw y\sim\mathbb{P}_{\mathcal{Y}}^{T}. Our theoretical bounds developed regarding the trade-off of learning DI representations are relevant to d_{1/2}\left(\mathbb{P}_{\mathcal{Y}}^{\pi},\mathbb{P}_{\mathcal{Y}}^{T}\right).

Theorem 3. (Trade-off bound) [thm:trade_off_LowerBound] Consider a feature extractor g and a hypothesis \hat{h}, the Hellinger distance between two label marginal distributions \mathbb{P}_{\mathcal{Y}}^{\pi} and \mathbb{P}_{\mathcal{Y}}^{T} can be upper-bounded as:

  1. d_{1/2}\left(\mathbb{P}_{\mathcal{Y}}^{\pi},\mathbb{P}_{\mathcal{Y}}^{T}\right)\leq\left[\sum_{k=1}^{K}\pi_{k}\mathcal{L}\left(\hat{h}\circ g,f^{S,k},\mathbb{P}^{S,k}\right)\right]^{1/2}+d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right)+\mathcal{L}\left(\hat{h}\circ g,f^{T},\mathbb{P}^{T}\right)^{1/2}
  2. d_{1/2}\left(\mathbb{P}_{\mathcal{Y}}^{\pi},\mathbb{P}_{\mathcal{Y}}^{T}\right)\leq\left[\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{h}\circ g,f^{S,i},\mathbb{P}^{S,i}\right)\right]^{1/2}+\sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi_{j}}}{K}d_{1/2}\left(\mathbb{P}_{g}^{S,i},\mathbb{P}_{g}^{S,j}\right)+\sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi_{j}}}{K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right)+\mathcal{L}\left(\hat{h}\circ g,f^{T},\mathbb{P}^{T}\right)^{1/2}.

Here we note that the general loss \mathcal{L} is defined based on the Hellinger loss \ell which is define as \ell(\hat{f}(x),f(x))=D_{1/2}(\hat{f}(x),f(x))=2\sum_{i=1}^{C}\left(\sqrt{\hat{f}\left(x,i\right)}-\sqrt{f\left(x,i\right)}\right)^{2} (more discussion can be found in Appendix C).

Remark. Compared to the trade-off bound in the work of Zhao et al. 2019, our context is more general, concerning MSDA and DG problems with multiple source domains and multi-class probabilistic labeling functions, rather than single source DA with binary-class and deterministic setting.

Moreover, the Hellinger distance is more universal, in the sense that it does not depend on the choice of classifier family \mathcal{H} and loss function \ell as in the case of \mathcal{H}-divergence in Theorem [1].

We base on the first inequality of Theorem [3] to analyze the trade-off of learning general DI representations. The first term on the left hand side is the source mixture’s loss, which is controllable and tends to be small when enforcing learning general DI representations. With that in mind, if \textit{two label marginal distributions }\mathbb{P}_{\mathcal{Y}}^{\pi} and \mathbb{P}_{\mathcal{Y}}^{T} are distant (i.e., d_{1/2}\left(\mathbb{P}_{\mathcal{Y}}^{\pi},\mathbb{P}_{\mathcal{Y}}^{T}\right) is high), the sum d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right)+\mathcal{L}\left(\hat{h}\circ g,f^{T},\mathbb{P}^{T}\right)^{1/2} tends to be high. This leads to 2 possibilities. The first scenario is when the representation discrepancy d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right) has small value, e.g., it is minimized in MSDA setting, or it happens to be small by pure chance in DG setting. In this case, the lower bound of target loss \mathcal{L}\left(\hat{h}\circ g,f^{T},\mathbb{P}^{T}\right) is high, possibly hurting model’s generalization ability. On the other hand, if the discrepancy d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{\pi}\right) is large for some reasons, the lower bound of target loss will be small, but its upper-bound is higher, as indicated in Theorem [1].

Based on the second inequality of Theorem [3], we observe that if two label marginal distributions \mathbb{P}_{\mathcal{Y}}^{\pi} and \mathbb{P}_{\mathcal{Y}}^{T} are distant while enforcing learning
compressed DI representations (i.e., both source loss and source-source feature discrepancy \left[\sum_{i=1}^{K}\pi_{i}\mathcal{L}\left(\hat{h}\circ g,f^{S,i},\mathbb{P}^{S,i}\right)\right]^{1/2}+\sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi_{j}}}{K}d_{1/2}\left(\mathbb{P}_{g}^{S,i},\mathbb{P}_{g}^{S,j}\right)
are low), the sum \sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi_{j}}}{K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right)+\mathcal{L}\left(h\circ g,f^{T},\mathbb{P}^{T}\right)^{1/2} is high. For the MSDA setting, the discrepancy \sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi}_{j}}{K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right) is trained to get smaller, meaning that the lower bound of target loss \mathcal{L}\left(\hat{h}\circ g,f^{T},\mathbb{P}^{T}\right) is high, hurting the target performance. Similarly, for the DG setting, if the trained feature extractor g occasionally reduces \sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi}_{j}}{K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right) for some unseen target domain, it certainly increases the target loss \mathcal{L}\left(h\circ g,f^{T},\mathbb{P}^{T}\right). In contrast, if for some target domains, the discrepancy \sum_{i=1}^{K}\sum_{j=1}^{K}\frac{\sqrt{\pi}_{j}}{K}d_{1/2}\left(\mathbb{P}_{g}^{T},\mathbb{P}_{g}^{S,i}\right) is high by some reasons, by linking to our upper bound, the target general loss has a high upper-bound, hence is possibly high.

This trade-off between representation discrepancy and target loss suggests a sweet spot for just-right feature alignment. In that case, the target loss is most likely to be small.

Experiments

We wish to study the characteristics and trade-off between two kinds of DI representations when predicting on various target domains. Specifically, we apply adversarial learning similar to {[}Ganin et al. 2016{]}, in which a min-max game is played between domain discriminator \hat{h}^{d} trying to distinguish the source domain given representation, while the feature extractor (generator) g tries to fool the domain discriminator. Simultaneously, a classifier is used to classify label based on the
representation. Let \mathcal{L}_{gen} and \mathcal{L}_{disc} be the label classification and domain discrimination losses, the training objective becomes:

    \[ \min_{g}\left(\min_{\hat{h}}\mathcal{L}_{gen}+\lambda\max_{\hat{h}^{d}}\mathcal{L}_{disc}\right), \]

where the source compression strength \lambda>0 controls the compression extent of learned representation. More specifically, general DI representation is obtained with \lambda=0, while larger \lambda leads to more compressed DI representation. We did experiments on a CMNIST dataset and a real-world dataset PACS, whose details are in our official paper. Finally, our implementation is based on DomainBed repository.

Figure 1a shows the source validation and target accuracies when increasing \lambda (i.e., encouraging the source compression). It can be observed that both source validation accuracy and target accuracy have the same pattern: increasing when setting appropriate \lambda for just-right compressed DI representations and compromising when setting overly high values \lambda for overly-compressed DI representation. Figure 1b shows in detail the variation of the source validation accuracy for each specific \lambda. In practice, we should encourage learning two kinds of DI representations simultaneously by finding an appropriate trade-off to balance them for working out just-right compressed DI representations.

Figure 1: (CMNIST) (1a) Source validation accuracy and target accuracy for target domain w.r.t. compression strength. (1b) Validation accuracy over training step for different values of λ.

References

[1] Ben-David, S., Blitzer, J., Crammer, K., Kulesza, A., Pereira, F., and Vaughan, J. W. A theory of learning from different domains. Mach. Learn. 79, 1–2 (May 2010), 151–175. 1, 2.2

[2] Cortes, C., Mohri, M., and Medina, A. M. Adaptation based on generalized discrepancy. Journal of Machine Learning Research 20, 1 (2019), 1–30. 2.2

[3] Ganin, Y., Ustinova, E., Ajakan, H., Germain, P., Larochelle, H., Laviolette, F., March, M., and Lempitsky, V. Domain-adversarial training of neural networks. Journal of Machine Learning Research 17, 59 (2016), 1–35. 1, 2.2

[4] Hoffman, J., Mohri, M., and Zhang, N. Algorithms and theory for multiple-source adaptation. In Advances in Neural Information Processing Systems (2018), S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa- Bianchi, and R. Garnett, Eds., vol. 31, Curran Associates, Inc. 1

[5] Mansour, Y., Mohri, M., and Rostamizadeh, A. Domain adaptation with multiple sources. In Advances in Neural Information Processing Systems (2009), D. Koller, D. Schuurmans, Y. Bengio, and L. Bottou, Eds., vol. 21, Curran Associates, Inc. 1, 2.2

[6] Zhao, H., Combes, R. T. D., Zhang, K., and Gordon, G. On learning invariant representations for domain adaptation. In Proceedings of the 36th International Conference on Machine Learning (09–15 Jun 2019), K. Chaud- huri and R. Salakhutdinov, Eds., vol. 97 of Proceedings of Machine Learning Research, PMLR, pp. 7523–7532. 1, 2.3, 2.3

[7] Zhao, H., Zhang, S., Wu, G., Moura, J. M. F., Costeira, J. P., and Gordon, G. J. Adversarial multiple source domain adaptation. In Advances in Neural Information Processing Systems (2018), S. Bengio, H. Wallach, H. Larochelle, K. Grauman, N. Cesa-Bianchi, and R. Garnett, Eds., vol. 31, Curran Associates, Inc. 1, 2.2

Back to Research
  • 58 minutes
  • Trung Phung (Research Resident)

Share

Related post