📒
PaperNotes
  • PAPER NOTES
  • Meta-Learning with Implicit Gradient
  • DARTS: Differentiable Architecture Search
  • Meta-Learning of Neural Architectures for Few-Shot Learning
  • Towards Fast Adaptation of Neural Architectures with Meta Learning
  • Editable Neural Networks
  • ANIL (Almost No Inner Loop)
  • Meta-Learning Representation for Continual Learning
  • Learning to learn by gradient descent by gradient descent
  • Modular Meta-Learning with Shrinkage
  • NADS: Neural Architecture Distribution Search for Uncertainty Awareness
  • Modular Meta Learning
  • Sep
    • Incremental Few Shot Learning with Attention Attractor Network
    • Learning Steady-States of Iterative Algorithms over Graphs
      • Experiments
    • Learning combinatorial optimization algorithms over graphs
    • Meta-Learning with Shared Amortized Variational Inference
    • Concept Learners for Generalizable Few-Shot Learning
    • Progressive Graph Learning for Open-Set Domain Adaptation
    • Probabilistic Neural Architecture Search
    • Large-Scale Long-Tailed Recognition in an Open World
    • Learning to stop while learning to predict
    • Adaptive Risk Minimization: A Meta-Learning Approach for Tackling Group Shift
    • Learning to Generalize: Meta-Learning for Domain Generalization
  • Oct
    • Meta-Learning Acquisition Functions for Transfer Learning in Bayesian Optimization
    • Network Architecture Search for Domain Adaptation
    • Continuous Meta Learning without tasks
    • Learning Causal Models Online
    • Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
    • Conditional Neural Progress (CNPs)
    • Reviving and Improving Recurrent Back-Propagation
    • Meta-Q-Learning
    • Learning Self-train for semi-supervised few shot classification
    • Watch, Try, Learn: Meta-Learning from Demonstrations and Rewards
  • Nov
    • Neural Process
    • Adversarially Robust Few-Shot Learning: A Meta-Learning Approach
    • Learning to Adapt to Evolving Domains
  • Tutorials
    • Relax constraints to continuous
    • MAML, FO-MAML, Reptile
    • Gradient Descent
      • Steepest Gradient Descent
      • Conjugate Gradient Descent
  • KL, Entropy, MLE, ELBO
  • Coding tricks
    • Python
    • Pytorch
  • ml
    • kmeans
Powered by GitBook
On this page
  • Motivation
  • Gradient-based Meta-Learning
  • MAML
  • Reptile
  • iMAML
  • Modular Bayesian Meta-Learning
  • Hierarchical Bayesian Model
  • Meta-Learning as Parameter Estimation
  • Approximation Method 1: Parameter estimation via the predictive likelihood
  • 1) explicit computation.
  • 2) implicit computation.
  • Approximation Method 2: Estimate via MAP approximation
  • Results:
  • Module discovery of Image classification
  • Module discovery of Text-to-speech
  • Conclusion
  • Reference

Was this helpful?

Modular Meta-Learning with Shrinkage

8-24-2020

PreviousLearning to learn by gradient descent by gradient descentNextNADS: Neural Architecture Distribution Search for Uncertainty Awareness

Last updated 4 years ago

Was this helpful?

Motivation

The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:

  • multi-speaker text-to-speech synthesis

So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.

Existing meta-learning approaches either do not scale to long adaptation or else rely on handcrafted task-specific architectures.

A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.

works well in few shot text-to-speech domain.

MAML, iMAML, Reptile are special cases of the proposed method.

Gradient-based Meta-Learning

  • Task ttt Tt,\mathcal{T}_t, Tt​, is associated with a finite dataset Dt={xt,n}∣n=1Nt\mathcal{D}_{t}=\left\{\mathbf{x}_{t, n}\right\}|_{n=1}^{N_t}Dt​={xt,n​}∣n=1Nt​​

  • Task ttt Tt:\mathcal{T}_t:Tt​: Dttrain,Dtval\mathcal{D}_{t}^{train}, \mathcal{D}_{t}^{val}Dttrain​,Dtval​

  • meta parameters ϕ∈RD\boldsymbol{\phi} \in \mathbb{R}^{D}ϕ∈RD

  • Task-specific parameters θt∈RD\boldsymbol{\theta}_{t} \in \mathbb{R}^{D}θt​∈RD

  • loss function ℓ(Dt;θt)\ell\left(\mathcal{D}_{t} ; \boldsymbol{\theta}_{t}\right)ℓ(Dt​;θt​)

Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:

  • MAML

  • iMAML

  • Reptile

  1. TASKADAPT: task adaptation (inner loop)

  2. The meta-updateΔt \Delta_tΔt​specifies the contribution of task ttt to the meta parameters. (outer loop)

MAML

  1. task adaptation: minimizing the training loss ℓttrain(θt)=ℓ(Dttrain;θt)\ell_t^{train}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{train}; \boldsymbol{\theta}_t)ℓttrain​(θt​)=ℓ(Dttrain​;θt​) by the gradient descent w.r.t. task parameters

  2. meta parameters update: by gradient descent on the validation loss ℓtval(θt)=ℓ(Dtval;θ),\ell_t^{val}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{val}; \boldsymbol{\theta}), ℓtval​(θt​)=ℓ(Dtval​;θ), resulting in the meta update (gradient) for task ttt : ΔtMAML=∇ϕℓtval(θt(ϕ))\Delta_t^{\text{MAML}} = \nabla_{\phi}\ell_t^{\text{val}}(\boldsymbol{\theta}_t(\phi))ΔtMAML​=∇ϕ​ℓtval​(θt​(ϕ))

This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.

Reptile

Reptile optimizes θt\theta_tθt​ on the entire dataset Dt\mathcal{D}_tDt​ , and move ϕ\phiϕ towards the adapted task parameters, yielding ΔtReptile=ϕ−θt\Delta_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}_tΔtReptile​=ϕ−θt​

iMAML

iMAML introduces an L2 regularizer λ2∣∣θt−ϕ∣∣2\frac{\lambda}{2}||\boldsymbol{\theta}_t-\phi||^22λ​∣∣θt​−ϕ∣∣2 to training loss, and optimizes the task parameters on the regularized training loss.

Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+1λ∇θt2ℓttrain⁡(θt))−1∇θtℓtval(θt)\Delta_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}\right)\right)^{-1} \nabla_{\boldsymbol{\theta}_{t}} \ell_{t}^{\mathrm{val}}\left(\boldsymbol{\theta}_{t}\right)ΔtiMAML​=(I+λ1​∇θt​2​ℓttrain​(θt​))−1∇θt​​ℓtval​(θt​)

Modular Bayesian Meta-Learning

  • Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.

    • inefficient and prone to overfitting

  • Split the network parameters into two groups:

    • a group varying across tasks

    • a group that is shared across tasks

In this paper: (task independent modules)

We assume that the network parameters can be partitioned into MMM disjoint modules, in general.

θt=(θt,1,θt,2,…,θt,m,…,θt,M)\boldsymbol{\theta}_t = (\boldsymbol{\theta}_{t,1}, \boldsymbol{\theta}_{t,2}, \dots, \boldsymbol{\theta}_{t,m}, \dots, \boldsymbol{\theta}_{t,M})θt​=(θt,1​,θt,2​,…,θt,m​,…,θt,M​)

where θt,m\boldsymbol{\theta}_{t,m}θt,m​ : parameters in module mmm for task ttt .

A shrinkage parameter σm\sigma_mσm​ is associated with each module to control the deviation of task-dependent parameters θt,m\boldsymbol{\theta}_{t,m} θt,m​ from the central value ϕm\boldsymbol{\phi}_mϕm​ (meta-parameters)

How to define modules? Modules can correspond to:

  • Layers: θt,m\boldsymbol{\theta}_{t,m}θt,m​ could be the weights of the m-thm\text{-th}m-th layer of a NN for task ttt [This paper treats each layer as a module]

  • Receptive fields

  • the encoder and decoder in an auto-encoder

  • the heads in a multi-task learning model

  • any other grouping of interest

Hierarchical Bayesian Model

(with a factored probability density):

p(θ1:T,D∣σ2,ϕ)=∏t=1T∏m=1MN(θt,m∣ϕm,σm2I)∏t=1Tp(Dt∣θt)p\left(\boldsymbol{\theta}_{1:T}, \mathcal{D}|\boldsymbol{\sigma}^2, \boldsymbol{\phi} \right) = \prod_{t=1}^{T}\prod_{m=1}^{M}\mathcal{N}(\boldsymbol{\theta}_{t,m}|\phi_m, \sigma_m^2\boldsymbol{I})\prod_{t=1}^{T}p(\mathcal{D}_t|\boldsymbol{\theta}_t)p(θ1:T​,D∣σ2,ϕ)=t=1∏T​m=1∏M​N(θt,m​∣ϕm​,σm2​I)t=1∏T​p(Dt​∣θt​)
  • ϕ\boldsymbol{\phi}ϕ : shared meta parameters, the initialization of the NN parameters for each task θt\boldsymbol{\theta}_tθt​

  • θt,m\boldsymbol{\theta}_{t,m}θt,m​ are conditionally independent of those of all other tasks given some "central" parameters. Namely: θt,m∼N(θt,m∣ϕm,σm2I)\boldsymbol{\theta}_{t,m} \sim \mathcal{N}(\boldsymbol{\theta}_{t,m}|\phi_m, \sigma_m^2\boldsymbol{I})θt,m​∼N(θt,m​∣ϕm​,σm2​I) with mean ϕm\phi_mϕm​ and variance σm2\sigma_m^2σm2​ ( I\boldsymbol{I}I is the identity matrix)

  • σ\boldsymbol{\sigma}σ : shrinkage parameters.

  • σm2\sigma_m^2σm2​ : the m-thm\text{-th}m-th module scalar shrinkage parameter, which measures the degree to which θt,m\theta_{t,m} θt,m​ can deviate from ϕm\phi_mϕm​ . If σm≈0\sigma_m\approx 0σm​≈0, then θt,m≈ϕm\theta_{t,m} \approx \phi_mθt,m​≈ϕm​ , when σm\sigma_mσm​ shrinks to 0, the parameters of module mmm become task independent.

  • meta−parameters:Φ=(σ2,ϕ){\color{blue}{meta-parameters}}: \mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)meta−parameters:Φ=(σ2,ϕ)

For values of σm2\sigma_m^2σm2​ near zero, the difference between parameters θt,m\theta_{t,m}θt,m​ and mean ϕm\phi_mϕm​ will be shrunk to zero and thus module mmm will become task independent\text{\color{red}task independent}task independent.

  • Thus, by learning σ2\boldsymbol{\sigma}^2σ2 , we can discover which modules are task independent.

  • These independent modules can be re-used at meta-test time

    • reducing the computational burden of adaptation

    • and likely improving generalization

Meta-Learning as Parameter Estimation

Goal: estimate the parameters ϕ\boldsymbol{\phi}ϕ and σ2\boldsymbol{\sigma}^2σ2 .

Standard solution: Maximize the marginal likelihood (intractable)

p(D∣ϕ,σ2)=∫p(D∣θ)p(θ∣ϕ,σ2)dθp(\mathcal{D}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)=\int p(\mathcal{D}|\boldsymbol{\theta})p(\boldsymbol{\theta}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)\mathbf{d}\boldsymbol{\theta}p(D∣ϕ,σ2)=∫p(D∣θ)p(θ∣ϕ,σ2)dθ

Approximation: Two principled alternative approaches for parameter estimation based on maximizing the predictive likelihood over validation subsets. (query set)

Approximation Method 1: Parameter estimation via the predictive likelihood

  • (Relation with MAML and iMAML)

Our goal is to minimize the average negative predictive log-likelihood over TTT validation tasks:

  • there might be a typo here (missing a minus) in equation (2) in the paper

Justification:

Based on two assumptions:

  • the training and validation data (support and query set) is distributed i.i.d according to some distribution ν(Dttrain,Dtval)\nu(\mathcal{D}_t^{train}, \mathcal{D}_t^{val})ν(Dttrain​,Dtval​)

  • The law of large numbers.

as T→∞T\rightarrow \inftyT→∞ :

ℓPLL(σ2,ϕ)→Eν(Dttrain )[KL(ν(Dtval ∣Dttrain )∥p(Dtval∣Dttrain,σ2,ϕ))]+H(ν(Dtval ∣Dttrain ))\ell_{\mathrm{PLL}}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) \rightarrow \mathbb{E}_{\color{blue}\nu\left(\mathcal{D}_{t}^{\text {train }}\right)}\left[\mathrm{KL}\left(\nu\left(\mathcal{D}_{t}^{\text {val }} \mid \mathcal{D}_{t}^{\text {train }}\right) \| p\left(\mathcal{D}_{t}^{\text {val}} \mid \mathcal{D}_{t}^{\text {train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)\right)\right]+\mathrm{H}\left(\nu\left(\mathcal{D}_{t}^{\text {val }} \mid \mathcal{D}_{t}^{\text {train }}\right)\right) ℓPLL​(σ2,ϕ)→Eν(Dttrain ​)​[KL(ν(Dtval ​∣Dttrain ​)∥p(Dtval​∣Dttrain​,σ2,ϕ))]+H(ν(Dtval ​∣Dttrain ​))

  • ν(Dtval ∣Dttrain)\nu\left(\mathcal{D}_{t}^{\text {val }}\right | \mathcal{D}_{t}^{\text {train}})ν(Dtval ​​Dttrain​) : true distribution

  • p(Dtval∣Dttrain,σ2,ϕ)p\left(\mathcal{D}_{t}^{\text {val}} \mid \mathcal{D}_{t}^{\text{train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)p(Dtval​∣Dttrain​,σ2,ϕ) : predictive distribution

Thus, minimizing ℓPLL\ell_{PLL}ℓPLL​ w.r.t. meta parameters corresponds to selecting the predictive distribution that is closest approximately to the true predictive distribution on average.

Proof:

Note: ν(Dttrain)\color{blue}\nu(\mathcal{D}_t^{train})ν(Dttrain​) ?

New Problem: the computation of ℓPLL\ell_{PLL}ℓPLL​ is not feasible due to the intractable integral in equ.2,

So, we use a simple maximum a posteriori (MAP) approximation of the task parameters:

Where is this from?

For a specific individual task ttt :

\begin{align} \ell_t^{train} :=& -\log p\left(\boldsymbol{\theta}_{t}, \mathcal{D}|\boldsymbol{\sigma}^2, \boldsymbol{\phi} \right) \\ = & -\log p(\boldsymbol{\theta}_t|\phi, \sigma^2)-\log p(\mathcal{D}_t|\boldsymbol{\theta}_t) \end{align}

After having the MAP estimates, we can approximate ℓPLL(σ2,ϕ)\ell_{PLL}(\boldsymbol{\sigma}^2, \boldsymbol{\phi})ℓPLL​(σ2,ϕ) as follows:

  • Individual task adaptation follows from equation (4)

  • Meta updating from minimizing equation (5)

So minimizing equation (5) is a bi-level optimization problem which requires solving equation (4).

How to solve this problem?

Compute the gradient of the approximate predictive log-likelihood ℓtval(θt^)\ell_t^{val}(\hat{\boldsymbol{\theta}_t})ℓtval​(θt​^​) w.r.t. the meta parameters Φ=(σ2,ϕ)\mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)Φ=(σ2,ϕ):

∇Φℓtval(θ^t(Φ))=∇θ^tℓtval(θ^t)∇Φθ^t(Φ)        (Eq.∗)\nabla_{\mathbf{\Phi}} \ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})\right)=\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) \space\space \space \space\space\space \space \space ({\color{red}Eq.*})∇Φ​ℓtval​(θ^t​(Φ))=∇θ^t​​ℓtval​(θ^t​)∇Φ​θ^t​(Φ)        (Eq.∗)
  • ∇θ^tℓtval(θ^t)\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right)∇θ^t​​ℓtval​(θ^t​) is straightforward.

  • ∇Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})∇Φ​θ^t​(Φ) is not.

Two methods for ∇Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})∇Φ​θ^t​(Φ) :

1) ∇Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})∇Φ​θ^t​(Φ) explicit computation.

(Relation with MAML)

If optimizing ℓttrain\ell_t^{train}ℓttrain​ requires only a small number of local gradient steps, we can compute the update for meta parameters ( ϕ,σ2\boldsymbol{\phi}, \boldsymbol{\sigma^2}ϕ,σ2 ) with back-propagation through θt^\hat{\boldsymbol{\theta}_t}θt​^​ , yielding (6).

This update reduces to that of MAML if σm2→∞\sigma^2_m\rightarrow \infty σm2​→∞ for all modules and is thus denoted as σ\sigmaσ-MAML.

Why? (my explanation)

According to equation (4):

ℓttrain⁡:=−log⁡p(Dttrain⁡∣θt)−log⁡p(θt∣σ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​:=−logp(Dttrain​∣θt​)−logp(θt​∣σ2,ϕ)

p(θt∣σ2,ϕ)=N(θt∣σ2,ϕ)=1σ2πe−(θt−ϕ)22σ2p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\mathcal{N}(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}) = \frac{1}{\sigma \sqrt{2\pi}}e^{-\frac{(\theta_t-\phi)^2}{2\sigma^2}}p(θt​∣σ2,ϕ)=N(θt​∣σ2,ϕ)=σ2π​1​e−2σ2(θt​−ϕ)2​

2) ∇Φθ^t(Φ)\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi})∇Φ​θ^t​(Φ) implicit computation.

(Relation with iMAML)

  • We are more interested in long adaptation horizons.

  • Implicit Function Theorem (compute the gradient of θt^\hat{\boldsymbol{\theta}_t}θt​^​ w.r.t ϕ\boldsymbol{\phi}ϕ and σ2\boldsymbol{\sigma^2}σ2)

  • Φ=(σ2,ϕ)\mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)Φ=(σ2,ϕ) full vector of meta-parameters

  • Hab=∇a,b2ℓttrain \mathbf{H}_{a b}=\nabla_{a, b}^{2} \ell_{t}^{\text {train }}Hab​=∇a,b2​ℓttrain ​ , namely, Hθtθt=∇θt2ℓttrain \mathbf{H}_{\theta_t\theta_t}=\nabla_{\theta_t}^{2} \ell_{t}^{\text {train }}Hθt​θt​​=∇θt​2​ℓttrain ​ , HθtΦ=∇θt,Φ2ℓttrain \mathbf{H}_{\theta_t\Phi}=\nabla_{\theta_t,\Phi}^{2} \ell_{t}^{\text {train }}Hθt​Φ​=∇θt​,Φ2​ℓttrain ​

  • Derivatives are evaluated at the stationary point θt=θ^t(σ2,ϕ)\boldsymbol{\theta}_{t}=\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)θt​=θ^t​(σ2,ϕ)

Derivation:

Rewrite equ (4) as:

θ^t(σ2,ϕ)=argmin⁡θtℓttrain⁡(θt,σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\underset{\boldsymbol{\theta}_{t}}{\operatorname{argmin}} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)θ^t​(σ2,ϕ)=θt​argmin​ℓttrain​(θt​,σ2,ϕ)

where ℓttrain⁡:=−log⁡p(Dttrain⁡∣θt)−log⁡p(θt∣σ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​:=−logp(Dttrain​∣θt​)−logp(θt​∣σ2,ϕ)

So, θ^t(σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)θ^t​(σ2,ϕ) is the stationary point of function ℓttrain⁡(θt,σ2,ϕ)\ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​(θt​,σ2,ϕ) .

Based on the implicit function theorem,

∇Φθ^t(Φ)=−(∇θt,θt2ℓttrain)−1∇θt,Φ2ℓttrain\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) = -\left(\nabla_{\boldsymbol{\theta}_{t}, \boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\mathrm{train}}\right)^{-1} \nabla_{\boldsymbol{\theta}_{t},\mathbf{\Phi}}^{2} \ell_{t}^{\mathrm{train}}∇Φ​θ^t​(Φ)=−(∇θt​,θt​2​ℓttrain​)−1∇θt​,Φ2​ℓttrain​

Plug the equation above to (Eq.∗)({\color{red}Eq.*})(Eq.∗) , END.

Meta update for ϕ\boldsymbol{\phi}ϕis equivalent to that of iMAML when σm\sigma_mσm​ is constant for all mmm.

Why?

According to equation (4):

ℓttrain⁡:=−log⁡p(Dttrain⁡∣θt)−log⁡p(θt∣σ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​:=−logp(Dttrain​∣θt​)−logp(θt​∣σ2,ϕ)

Also because all modules share a constant variance, σm=σ\sigma_m = \sigmaσm​=σ ,

we can expand the log-prior term for task parameters θt\theta_tθt​ in eqution (4), and plug in the normal prior assumption as follows:

Meta-update of iMAML: ΔtiMAML=(I+1λ∇θt2ℓttrain⁡(θt))−1∇θtℓtval(θt)\Delta_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}\right)\right)^{-1} \nabla_{\boldsymbol{\theta}_{t}} \ell_{t}^{\mathrm{val}}\left(\boldsymbol{\theta}_{t}\right)ΔtiMAML​=(I+λ1​∇θt​2​ℓttrain​(θt​))−1∇θt​​ℓtval​(θt​)

if we define the regularization scale λ=1σ2,\lambda=\frac{1}{\sigma^2}, λ=σ21​, and plug in the definition ℓttrain:=−log⁡p(Dttrain∣θt)\ell^{train}_t:=-\log p(\mathcal{D}^{train}_t|\theta_t)ℓttrain​:=−logp(Dttrain​∣θt​) to the meta-update of iMAML, the finial equation is exactly the same with equation (41).

Approximation Method 2: Estimate ϕ\phiϕ via MAP approximation

  • Relation with Reptile

Integrating out the center parameter ϕ\phiϕ , and considering ϕ\phiϕ depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( θ1:T,ϕ\theta_{1:T},\phiθ1:T​,ϕ ), i.e.

Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters and the central meta parameters:

(assume ϕ\phiϕ is a flat prior, non-informative, no prior at all, the second term above is dropped)

By plugging the MAP of ϕ\phiϕ (Equ: 9) into Equ (8), and scaling by 1/T1/T1/T , we derive the approximate predictive log-likelihood as :

compare this to Equ.5

The meta update for ϕ\phiϕ can be obtained by differentiating (9) w.r.t. ϕ\phiϕ .

To derive the gradient of Eq. (42) with respect to σ2\sigma^2σ2 , notice that when ϕ\phiϕ is estimated as the MAP on the training subsets of all tasks, it becomes a function of σ2\sigma^2σ2 .

For a specific individual task ttt ,

∇σ2ℓtval(θ^t(σ2))=∇θ^tℓtval(θ^t)∇σ2θ^t(σ2)        (Eq.∗∗)\nabla_{\mathbf{\sigma^2}} \ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2})\right)=\nabla_{\hat{\boldsymbol{\theta}}_{t}}\ell_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}_{t}\right)\nabla_{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2}) \space\space \space \space\space\space \space \space ({\color{red}Eq.**})∇σ2​ℓtval​(θ^t​(σ2))=∇θ^t​​ℓtval​(θ^t​)∇σ2​θ^t​(σ2)        (Eq.∗∗)

we take an approximation and ignore the dependence of ϕ^\hat{\phi}ϕ^​ on σ2\sigma^2σ2 . Then ϕ\phiϕ becomes a constant when computing ∇σ2θ^t(σ2)\nabla_{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\sigma^2}) ∇σ2​θ^t​(σ2) , and the derivation in above for iMAML applies by replacing Φ\PhiΦ with σ2\sigma^2σ2 , giving the implicit gradient in Eq. (10).

Derivation:

Rewrite equ (4) as:

θ^t(σ2,ϕ)=argmin⁡θtℓttrain⁡(θt,σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\underset{\boldsymbol{\theta}_{t}}{\operatorname{argmin}} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)θ^t​(σ2,ϕ)=θt​argmin​ℓttrain​(θt​,σ2,ϕ)

where ℓttrain⁡:=−log⁡p(Dttrain⁡∣θt)−log⁡p(θt∣σ2,ϕ)\ell_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}_{t}^{\operatorname{train}} \mid \boldsymbol{\theta}_{t}\right)-\log p\left(\boldsymbol{\theta}_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​:=−logp(Dttrain​∣θt​)−logp(θt​∣σ2,ϕ)

So, θ^t(σ2,ϕ)\hat{\boldsymbol{\theta}}_{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)θ^t​(σ2,ϕ) is the stationary point of function ℓttrain⁡(θt,σ2,ϕ)\ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)ℓttrain​(θt​,σ2,ϕ) .

Based on the implicit function theorem,

∇Φθ^t(Φ)=−(∇θt,θt2ℓttrain)−1∇θt,Φ2ℓttrain\nabla_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}_{t}(\mathbf{\Phi}) = -\left(\nabla_{\boldsymbol{\theta}_{t}, \boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\mathrm{train}}\right)^{-1} \nabla_{\boldsymbol{\theta}_{t},\mathbf{\Phi}}^{2} \ell_{t}^{\mathrm{train}}∇Φ​θ^t​(Φ)=−(∇θt​,θt​2​ℓttrain​)−1∇θt​,Φ2​ℓttrain​

Plug the equation above to (Eq.∗∗)({\color{red}Eq.**})(Eq.∗∗) , END.

Reptile is a special case of this method when σm2→∞\sigma^2_m\rightarrow \inftyσm2​→∞ and we choose a learning rate proportional to σm2\sigma_m^2σm2​ for ϕm\phi_mϕm​ . We thus refer to it as σ-Reptile.

ΔtReptile=ϕ−θt\Delta_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}_tΔtReptile​=ϕ−θt​

Results:

Module discovery of Image classification

  • (9-layer network (4 conv layers, 4 batch-norm layers, and a linear output layer))

Module discovery of Text-to-speech

Conclusion

This work proposes a hierarchical Bayesian model for meta-learning that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.

Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.

As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.

Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task

Think of the connection among those papers (probabilistic graphical model explanation):

Reference

(1)

(6)

(7)

Similar to ANIL paper ()

Meta-Learning Probabilistic Inference For Prediction ()

Probabilistic model-agnostic meta-learning ()

Amortized Bayesian Meta-Learning ()

Bayesian TAML ()

https://wiseodd.github.io/techblog/2017/01/26/kl-mle/
https://glassboxmedicine.com/2019/12/07/connections-log-likelihood-cross-entropy-kl-divergence-logistic-regression-and-neural-networks/
https://www.jessicayung.com/maximum-likelihood-as-minimising-kl-divergence/
https://www.hongliangjie.com/2012/07/12/maximum-likelihood-as-minimize-kl-divergence/
https://openreview.net/forum?id=rkgMkCEtPB
https://openreview.net/pdf?id=HkxStoC5F7
https://arxiv.org/abs/1806.02817
https://openreview.net/pdf?id=rkgpy3C5tX
https://openreview.net/forum?id=rkeZIJBYvr
https://arxiv.org/pdf/1909.05557.pdf
Fig.1 Modular meta-learning
Bayesian Shrinkage Graphical Model
the average predictive log-likelihood