Modular Meta-Learning with Shrinkage
8-24-2020
Motivation
The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:
multi-speaker text-to-speech synthesis
So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.
Existing meta-learning approaches either do not scale to long adaptation or else rely on handcrafted task-specific architectures.
A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.
works well in few shot text-to-speech domain.
MAML, iMAML, Reptile are special cases of the proposed method.
Gradient-based Meta-Learning
Task t Tt, is associated with a finite dataset Dt={xt,n}∣n=1Nt
Task t Tt: Dttrain,Dtval
meta parameters ϕ∈RD
Task-specific parameters θt∈RD
loss function ℓ(Dt;θt)

Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:
MAML
iMAML
Reptile
TASKADAPT: task adaptation (inner loop)
The meta-updateΔtspecifies the contribution of task t to the meta parameters. (outer loop)
MAML
task adaptation: minimizing the training loss ℓttrain(θt)=ℓ(Dttrain;θt) by the gradient descent w.r.t. task parameters
meta parameters update: by gradient descent on the validation loss ℓtval(θt)=ℓ(Dtval;θ), resulting in the meta update (gradient) for task t : ΔtMAML=∇ϕℓtval(θt(ϕ))
This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.
Reptile
Reptile optimizes θt on the entire dataset Dt , and move ϕ towards the adapted task parameters, yielding ΔtReptile=ϕ−θt
iMAML
iMAML introduces an L2 regularizer 2λ∣∣θt−ϕ∣∣2 to training loss, and optimizes the task parameters on the regularized training loss.
Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+λ1∇θt2ℓttrain(θt))−1∇θtℓtval(θt)
Modular Bayesian Meta-Learning
Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.
inefficient and prone to overfitting
Split the network parameters into two groups:
a group varying across tasks
a group that is shared across tasks
In this paper: (task independent modules)
We assume that the network parameters can be partitioned into M disjoint modules, in general.
θt=(θt,1,θt,2,…,θt,m,…,θt,M)
where θt,m : parameters in module m for task t .

A shrinkage parameter σm is associated with each module to control the deviation of task-dependent parameters θt,m from the central value ϕm (meta-parameters)
How to define modules? Modules can correspond to:
Layers: θt,m could be the weights of the m-th layer of a NN for task t [This paper treats each layer as a module]
Receptive fields
the encoder and decoder in an auto-encoder
the heads in a multi-task learning model
any other grouping of interest
Hierarchical Bayesian Model

(with a factored probability density):
ϕ : shared meta parameters, the initialization of the NN parameters for each task θt
θt,m are conditionally independent of those of all other tasks given some "central" parameters. Namely: θt,m∼N(θt,m∣ϕm,σm2I) with mean ϕm and variance σm2 ( I is the identity matrix)
σ : shrinkage parameters.
σm2 : the m-th module scalar shrinkage parameter, which measures the degree to which θt,m can deviate from ϕm . If σm≈0, then θt,m≈ϕm , when σm shrinks to 0, the parameters of module m become task independent.
meta−parameters:Φ=(σ2,ϕ)
For values of σm2 near zero, the difference between parameters θt,m and mean ϕm will be shrunk to zero and thus module m will become task independent.
Thus, by learning σ2 , we can discover which modules are task independent.
These independent modules can be re-used at meta-test time
reducing the computational burden of adaptation
and likely improving generalization
Meta-Learning as Parameter Estimation
Goal: estimate the parameters ϕ and σ2 .
Standard solution: Maximize the marginal likelihood (intractable)
p(D∣ϕ,σ2)=∫p(D∣θ)p(θ∣ϕ,σ2)dθ
Approximation: Two principled alternative approaches for parameter estimation based on maximizing the predictive likelihood over validation subsets. (query set)
Approximation Method 1: Parameter estimation via the predictive likelihood
(Relation with MAML and iMAML)
Our goal is to minimize the average negative predictive log-likelihood over T validation tasks:

there might be a typo here (missing a minus) in equation (2) in the paper
Justification:
Based on two assumptions:
the training and validation data (support and query set) is distributed i.i.d according to some distribution ν(Dttrain,Dtval)
The law of large numbers.
as T→∞ :
ℓPLL(σ2,ϕ)→Eν(Dttrain )[KL(ν(Dtval ∣Dttrain )∥p(Dtval∣Dttrain,σ2,ϕ))]+H(ν(Dtval ∣Dttrain ))
ν(Dtval Dttrain) : true distribution
p(Dtval∣Dttrain,σ2,ϕ) : predictive distribution
Thus, minimizing ℓPLL w.r.t. meta parameters corresponds to selecting the predictive distribution that is closest approximately to the true predictive distribution on average.
Proof:
Note: ν(Dttrain) ?
New Problem: the computation of ℓPLL is not feasible due to the intractable integral in equ.2,
So, we use a simple maximum a posteriori (MAP) approximation of the task parameters:

Where is this from?
(1)
For a specific individual task t :
After having the MAP estimates, we can approximate ℓPLL(σ2,ϕ) as follows:

Individual task adaptation follows from equation (4)
Meta updating from minimizing equation (5)
So minimizing equation (5) is a bi-level optimization problem which requires solving equation (4).
How to solve this problem?
Compute the gradient of the approximate predictive log-likelihood ℓtval(θt^) w.r.t. the meta parameters Φ=(σ2,ϕ):
∇θ^tℓtval(θ^t) is straightforward.
∇Φθ^t(Φ) is not.
Two methods for ∇Φθ^t(Φ) :
1) ∇Φθ^t(Φ) explicit computation.
(Relation with MAML)
(6)
If optimizing ℓttrain requires only a small number of local gradient steps, we can compute the update for meta parameters ( ϕ,σ2 ) with back-propagation through θt^ , yielding (6).
This update reduces to that of MAML if σm2→∞ for all modules and is thus denoted as σ-MAML.
Why? (my explanation)
According to equation (4):
ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
p(θt∣σ2,ϕ)=N(θt∣σ2,ϕ)=σ2π1e−2σ2(θt−ϕ)2
2) ∇Φθ^t(Φ) implicit computation.
(Relation with iMAML)
We are more interested in long adaptation horizons.
Implicit Function Theorem (compute the gradient of θt^ w.r.t ϕ and σ2)
(7)
Φ=(σ2,ϕ) full vector of meta-parameters
Hab=∇a,b2ℓttrain , namely, Hθtθt=∇θt2ℓttrain , HθtΦ=∇θt,Φ2ℓttrain
Derivatives are evaluated at the stationary point θt=θ^t(σ2,ϕ)
Derivation:
Rewrite equ (4) as:
θ^t(σ2,ϕ)=θtargminℓttrain(θt,σ2,ϕ)
where ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
So, θ^t(σ2,ϕ) is the stationary point of function ℓttrain(θt,σ2,ϕ) .
Based on the implicit function theorem,
∇Φθ^t(Φ)=−(∇θt,θt2ℓttrain)−1∇θt,Φ2ℓttrain
Plug the equation above to (Eq.∗) , END.
Meta update for ϕis equivalent to that of iMAML when σm is constant for all m.
Why?
According to equation (4):
ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
Also because all modules share a constant variance, σm=σ ,
we can expand the log-prior term for task parameters θt in eqution (4), and plug in the normal prior assumption as follows:

Meta-update of iMAML: ΔtiMAML=(I+λ1∇θt2ℓttrain(θt))−1∇θtℓtval(θt)
if we define the regularization scale λ=σ21, and plug in the definition ℓttrain:=−logp(Dttrain∣θt) to the meta-update of iMAML, the finial equation is exactly the same with equation (41).
Approximation Method 2: Estimate ϕ via MAP approximation
Relation with Reptile
Integrating out the center parameter ϕ , and considering ϕ depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( θ1:T,ϕ ), i.e.

Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters and the central meta parameters:

(assume ϕ is a flat prior, non-informative, no prior at all, the second term above is dropped)
By plugging the MAP of ϕ (Equ: 9) into Equ (8), and scaling by 1/T , we derive the approximate predictive log-likelihood as :

compare this to Equ.5
The meta update for ϕ can be obtained by differentiating (9) w.r.t. ϕ .
To derive the gradient of Eq. (42) with respect to σ2 , notice that when ϕ is estimated as the MAP on the training subsets of all tasks, it becomes a function of σ2 .

For a specific individual task t ,
∇σ2ℓtval(θ^t(σ2))=∇θ^tℓtval(θ^t)∇σ2θ^t(σ2) (Eq.∗∗)
we take an approximation and ignore the dependence of ϕ^ on σ2 . Then ϕ becomes a constant when computing ∇σ2θ^t(σ2) , and the derivation in above for iMAML applies by replacing Φ with σ2 , giving the implicit gradient in Eq. (10).
Derivation:
Rewrite equ (4) as:
θ^t(σ2,ϕ)=θtargminℓttrain(θt,σ2,ϕ)
where ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
So, θ^t(σ2,ϕ) is the stationary point of function ℓttrain(θt,σ2,ϕ) .
Based on the implicit function theorem,
∇Φθ^t(Φ)=−(∇θt,θt2ℓttrain)−1∇θt,Φ2ℓttrain
Plug the equation above to (Eq.∗∗) , END.
Reptile is a special case of this method when σm2→∞ and we choose a learning rate proportional to σm2 for ϕm . We thus refer to it as σ-Reptile.
ΔtReptile=ϕ−θt
Results:
Module discovery of Image classification
(9-layer network (4 conv layers, 4 batch-norm layers, and a linear output layer))
Similar to ANIL paper (https://openreview.net/forum?id=rkgMkCEtPB)

Module discovery of Text-to-speech

Conclusion
This work proposes a hierarchical Bayesian model for meta-learning that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.
Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.
As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.
Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task
Think of the connection among those papers (probabilistic graphical model explanation):
Meta-Learning Probabilistic Inference For Prediction (https://openreview.net/pdf?id=HkxStoC5F7)
Probabilistic model-agnostic meta-learning (https://arxiv.org/abs/1806.02817)
Amortized Bayesian Meta-Learning (https://openreview.net/pdf?id=rkgpy3C5tX)
Bayesian TAML (https://openreview.net/forum?id=rkeZIJBYvr)
Reference
Last updated