Modular Meta-Learning with Shrinkage

8-24-2020

Motivation

The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:

  • multi-speaker text-to-speech synthesis

So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.

Existing meta-learning approaches either do not scale to long adaptation or else rely on handcrafted task-specific architectures.

A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.

works well in few shot text-to-speech domain.

MAML, iMAML, Reptile are special cases of the proposed method.

Gradient-based Meta-Learning

  • Task tt Tt,\mathcal{T}_t, is associated with a finite dataset Dt={xt,n}n=1Nt\mathcal{D}_{t}=\left\{\mathbf{x}_{t, n}\right\}|_{n=1}^{N_t}

  • Task tt Tt:\mathcal{T}_t: Dttrain,Dtval\mathcal{D}_{t}^{train}, \mathcal{D}_{t}^{val}

  • meta parameters ϕRD\boldsymbol{\phi} \in \mathbb{R}^{D}

  • Task-specific parameters θtRD\boldsymbol{\theta}_{t} \in \mathbb{R}^{D}

  • loss function (Dt;θt)\ell\left(\mathcal{D}_{t} ; \boldsymbol{\theta}_{t}\right)

Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:

  • MAML

  • iMAML

  • Reptile

  1. TASKADAPT: task adaptation (inner loop)

  2. The meta-updateΔt \Delta_tspecifies the contribution of task tt to the meta parameters. (outer loop)

MAML

  1. task adaptation: minimizing the training loss ttrain(θt)=(Dttrain;θt)\ell_t^{train}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{train}; \boldsymbol{\theta}_t) by the gradient descent w.r.t. task parameters

  2. meta parameters update: by gradient descent on the validation loss tval(θt)=(Dtval;θ),\ell_t^{val}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{val}; \boldsymbol{\theta}), resulting in the meta update (gradient) for task tt : ΔtMAML=ϕtval(θt(ϕ))\Delta_t^{\text{MAML}} = \nabla_{\phi}\ell_t^{\text{val}}(\boldsymbol{\theta}_t(\phi))

This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.

Reptile

Reptile optimizes θt\theta_t on the entire dataset Dt\mathcal{D}_t , and move ϕ\phi towards the adapted task parameters, yielding ΔtReptile=ϕθt\Delta_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}_t

iMAML

iMAML introduces an L2 regularizer λ2θtϕ2\frac{\lambda}{2}||\boldsymbol{\theta}_t-\phi||^2 to training loss, and optimizes the task parameters on the regularized training loss.

Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+1λθt2ttrain(θt))1θttval(θt)\Delta_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}\right)\right)^{-1} \nabla_{\boldsymbol{\theta}_{t}} \ell_{t}^{\mathrm{val}}\left(\boldsymbol{\theta}_{t}\right)

Modular Bayesian Meta-Learning

  • Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.

    • inefficient and prone to overfitting

  • Split the network parameters into two groups:

    • a group varying across tasks

    • a group that is shared across tasks

In this paper: (task independent modules)

We assume that the network parameters can be partitioned into MM disjoint modules, in general.

θt=(θt,1,θt,2,,θt,m,,θt,M)\boldsymbol{\theta}_t = (\boldsymbol{\theta}_{t,1}, \boldsymbol{\theta}_{t,2}, \dots, \boldsymbol{\theta}_{t,m}, \dots, \boldsymbol{\theta}_{t,M})

where θt,m\boldsymbol{\theta}_{t,m} : parameters in module mm for task tt .

Fig.1 Modular meta-learning

How to define modules? Modules can correspond to:

  • Layers: θt,m\boldsymbol{\theta}_{t,m} could be the weights of the m-thm\text{-th} layer of a NN for task tt [This paper treats each layer as a module]

  • Receptive fields

  • the encoder and decoder in an auto-encoder

  • the heads in a multi-task learning model

  • any other grouping of interest

Hierarchical Bayesian Model

Bayesian Shrinkage Graphical Model

(with a factored probability density):

p(θ1:T,Dσ2,ϕ)=t=1Tm=1MN(θt,mϕm,σm2I)t=1Tp(Dtθt)p\left(\boldsymbol{\theta}_{1:T}, \mathcal{D}|\boldsymbol{\sigma}^2, \boldsymbol{\phi} \right) = \prod_{t=1}^{T}\prod_{m=1}^{M}\mathcal{N}(\boldsymbol{\theta}_{t,m}|\phi_m, \sigma_m^2\boldsymbol{I})\prod_{t=1}^{T}p(\mathcal{D}_t|\boldsymbol{\theta}_t)
  • ϕ\boldsymbol{\phi} : shared meta parameters, the initialization of the NN parameters for each task θt\boldsymbol{\theta}_t

  • θt,m\boldsymbol{\theta}_{t,m} are conditionally independent of those of all other tasks given some "central" parameters. Namely: θt,mN(θt,mϕm,σm2I)\boldsymbol{\theta}_{t,m} \sim \mathcal{N}(\boldsymbol{\theta}_{t,m}|\phi_m, \sigma_m^2\boldsymbol{I}) with mean ϕm\phi_m and variance σm2\sigma_m^2 ( I\boldsymbol{I} is the identity matrix)

  • σ\boldsymbol{\sigma} : shrinkage parameters.

  • σm2\sigma_m^2 : the m-thm\text{-th} module scalar shrinkage parameter, which measures the degree to which θt,m\theta_{t,m} can deviate from ϕm\phi_m . If σm0\sigma_m\approx 0, then θt,mϕm\theta_{t,m} \approx \phi_m , when σm\sigma_m shrinks to 0, the parameters of module mm become task independent.

  • metaparameters:Φ=(σ2,ϕ){\color{blue}{meta-parameters}}: \mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)

For values of σm2\sigma_m^2 near zero, the difference between parameters θt,m\theta_{t,m} and mean ϕm\phi_m will be shrunk to zero and thus module mm will become task independent\text{\color{red}task independent}.

  • Thus, by learning σ2\boldsymbol{\sigma}^2 , we can discover which modules are task independent.

  • These independent modules can be re-used at meta-test time

    • reducing the computational burden of adaptation

    • and likely improving generalization

Meta-Learning as Parameter Estimation

Goal: estimate the parameters ϕ\boldsymbol{\phi} and σ2\boldsymbol{\sigma}^2 .

Standard solution: Maximize the marginal likelihood (intractable)

p(Dϕ,σ2)=p(Dθ)p(θϕ,σ2)dθp(\mathcal{D}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)=\int p(\mathcal{D}|\boldsymbol{\theta})p(\boldsymbol{\theta}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)\mathbf{d}\boldsymbol{\theta}

Approximation: Two principled alternative approaches for parameter estimation based on maximizing the predictive likelihood over validation subsets. (query set)

Approximation Method 1: Parameter estimation via the predictive likelihood

  • (Relation with MAML and iMAML)

Our goal is to minimize the average negative predictive log-likelihood over TT validation tasks:

the average predictive log-likelihood
  • there might be a typo here (missing a minus) in equation (2) in the paper

Justification:

Based on two assumptions:

  • the training and validation data (support and query set) is distributed i.i.d according to some distribution ν(Dttrain,Dtval)\nu(\mathcal{D}_t^{train}, \mathcal{D}_t^{val})

  • The law of large numbers.

as TT\rightarrow \infty :

PLL(σ2,ϕ)Eν(Dttrain )[KL(ν(Dtval Dttrain )p(DtvalDttrain,σ2,ϕ))]+H(ν(Dtval Dttrain ))\ell_{\mathrm{PLL}}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) \rightarrow \mathbb{E}_{\color{blue}\nu\left(\mathcal{D}_{t}^{\text {train }}\right)}\left[\mathrm{KL}\left(\nu\left(\mathcal{D}_{t}^{\text {val }} \mid \mathcal{D}_{t}^{\text {train }}\right) \| p\left(\mathcal{D}_{t}^{\text {val}} \mid \mathcal{D}_{t}^{\text {train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)\right)\right]+\mathrm{H}\left(\nu\left(\mathcal{D}_{t}^{\text {val }} \mid \mathcal{D}_{t}^{\text {train }}\right)\right)

  • ν(Dtval Dttrain)\nu\left(\mathcal{D}_{t}^{\text {val }}\right | \mathcal{D}_{t}^{\text {train}}) : true distribution

  • p(DtvalDttrain,σ2,ϕ)p\left(\mathcal{D}_{t}^{\text {val}} \mid \mathcal{D}_{t}^{\text{train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) : predictive distribution

Thus, minimizing PLL\ell_{PLL} w.r.t. meta parameters corresponds to selecting the predictive distribution that is closest approximately to the true predictive distribution on average.

Proof:

Note: ν(Dttrain)\color{blue}\nu(\mathcal{D}_t^{train}) ?

New Problem: the computation of PLL\ell_{PLL} is not feasible due to the intractable integral in equ.2,

So, we use a simple maximum a posteriori (MAP) approximation of the task parameters:

Where is this from?

(1)

For a specific individual task tt :

After having the MAP estimates, we can approximate PLL(σ2,ϕ)\ell_{PLL}(\boldsymbol{\sigma}^2, \boldsymbol{\phi}) as follows:

  • Individual task adaptation follows from equation (4)

  • Meta updating from minimizing equation (5)

So minimizing equation (5) is a bi-level optimization problem which requires solving equation (4).

How to solve this problem?

Compute the gradient of the approximate predictive log-likelihood tval(θt^)\ell_t^{val}(\hat{\boldsymbol{\theta}_t}) w.r.t. the meta parameters Φ=(σ2,ϕ)\mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right):

Φtval(θ^t(Φ))=θ^ttval(θ^t)Φθ^t(Φ)        (Eq.)\nabla_{\mathbf{\Phi}} \ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})\right)=\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) \space\space \space \space\space\space \space \space ({\color{red}Eq.*})
  • θ^ttval(θ^t)\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right) is straightforward.

  • Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) is not.

Two methods for Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) :

1) Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) explicit computation.

(Relation with MAML)

(6)

If optimizing ttrain\ell_t^{train} requires only a small number of local gradient steps, we can compute the update for meta parameters ( ϕ,σ2\boldsymbol{\phi}, \boldsymbol{\sigma^2} ) with back-propagation through θt^\hat{\boldsymbol{\theta}_t} , yielding (6).

This update reduces to that of MAML if σm2\sigma^2_m\rightarrow \infty for all modules and is thus denoted as σ\sigma-MAML.

Why? (my explanation)

According to equation (4):

ttrain:=logp(Dttrainθt)logp(θtσ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)

p(θtσ2,ϕ)=N(θtσ2,ϕ)=1σ2πe(θtϕ)22σ2p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\mathcal{N}(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}) = \frac{1}{\sigma \sqrt{2\pi}}e^{-\frac{(\theta_t-\phi)^2}{2\sigma^2}}

2) Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) implicit computation.

(Relation with iMAML)

  • We are more interested in long adaptation horizons.

  • Implicit Function Theorem (compute the gradient of θt^\hat{\boldsymbol{\theta}_t} w.r.t ϕ\boldsymbol{\phi} and σ2\boldsymbol{\sigma^2})

Approximation Method 2: Estimate ϕ\phi via MAP approximation

  • Relation with Reptile

Integrating out the center parameter ϕ\phi , and considering ϕ\phi depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( θ1:T,ϕ\theta_{1:T},\phi ), i.e.

Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters and the central meta parameters:

(assume ϕ\phi is a flat prior, non-informative, no prior at all, the second term above is dropped)

By plugging the MAP of ϕ\phi (Equ: 9) into Equ (8), and scaling by 1/T1/T , we derive the approximate predictive log-likelihood as :

compare this to Equ.5

The meta update for ϕ\phi can be obtained by differentiating (9) w.r.t. ϕ\phi .

To derive the gradient of Eq. (42) with respect to σ2\sigma^2 , notice that when ϕ\phi is estimated as the MAP on the training subsets of all tasks, it becomes a function of σ2\sigma^2 .

For a specific individual task tt ,

σ2tval(θ^t(σ2))=θ^ttval(θ^t)σ2θ^t(σ2)        (Eq.)\nabla_{\mathbf{\sigma^2}} \ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2})\right)=\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right)\nabla_{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2}) \space\space \space \space\space\space \space \space ({\color{red}Eq.**})

we take an approximation and ignore the dependence of ϕ^\hat{\phi} on σ2\sigma^2 . Then ϕ\phi becomes a constant when computing σ2θ^t(σ2)\nabla_{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2}) , and the derivation in above for iMAML applies by replacing Φ\Phi with σ2\sigma^2 , giving the implicit gradient in Eq. (10).

Derivation:

Rewrite equ (4) as:

θ^t(σ2,ϕ)=argminθtttrain(θt,σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\underset{\boldsymbol{\theta}_{t}}{\operatorname{argmin}} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)

where ttrain:=logp(Dttrainθt)logp(θtσ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)

So, θ^t(σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) is the stationary point of function ttrain(θt,σ2,ϕ)\ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) .

Based on the implicit function theorem,

Φθ^t(Φ)=(θt,θt2ttrain)1θt,Φ2ttrain\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) = -\left(\nabla_{\boldsymbol{\theta}_{t}, \boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\mathrm{train}}\right)^{-1} \nabla_{\boldsymbol{\theta}_{t},\mathbf{\Phi}}^{2} \ell_{t}^{\mathrm{train}}

Plug the equation above to (Eq.)({\color{red}Eq.**}) , END.

Reptile is a special case of this method when σm2\sigma^2_m\rightarrow \infty and we choose a learning rate proportional to σm2\sigma_m^2 for ϕm\phi_m . We thus refer to it as σ-Reptile.

ΔtReptile=ϕθt\Delta_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}_t

Results:

Module discovery of Image classification

Module discovery of Text-to-speech

Conclusion

This work proposes a hierarchical Bayesian model for meta-learning that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.

Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.

As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.

Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task

Think of the connection among those papers (probabilistic graphical model explanation):

Reference

Last updated

Was this helpful?