The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:
multi-speaker text-to-speech synthesis
So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.
Existing meta-learning approaches either do not scale to long adaptation or else rely on handcrafted task-specific architectures.
A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.
works well in few shot text-to-speech domain.
MAML, iMAML, Reptile are special cases of the proposed method.
Gradient-based Meta-Learning
Task tTt, is associated with a finite dataset Dt={xt,n}∣n=1Nt
Task tTt:Dttrain,Dtval
meta parameters ϕ∈RD
Task-specific parameters θt∈RD
loss function ℓ(Dt;θt)
Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:
MAML
iMAML
Reptile
TASKADAPT: task adaptation (inner loop)
The meta-updateΔtspecifies the contribution of task t to the meta parameters. (outer loop)
MAML
task adaptation: minimizing the training loss ℓttrain(θt)=ℓ(Dttrain;θt) by the gradient descent w.r.t. task parameters
meta parameters update: by gradient descent on the validation loss ℓtval(θt)=ℓ(Dtval;θ), resulting in the meta update (gradient) for task t : ΔtMAML=∇ϕℓtval(θt(ϕ))
This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.
Reptile
Reptile optimizes θt on the entire dataset Dt , and move ϕ towards the adapted task parameters, yielding ΔtReptile=ϕ−θt
iMAML
iMAML introduces an L2 regularizer 2λ∣∣θt−ϕ∣∣2 to training loss, and optimizes the task parameters on the regularized training loss.
Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+λ1∇θt2ℓttrain(θt))−1∇θtℓtval(θt)
Modular Bayesian Meta-Learning
Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.
inefficient and prone to overfitting
Split the network parameters into two groups:
a group varying across tasks
a group that is shared across tasks
In this paper: (task independent modules)
We assume that the network parameters can be partitioned into M disjoint modules, in general.
θt=(θt,1,θt,2,…,θt,m,…,θt,M)
where θt,m : parameters in module m for task t .
A shrinkage parameter σm is associated with each module to control the deviation of task-dependent parameters θt,m from the central value ϕm (meta-parameters)
How to define modules? Modules can correspond to:
Layers: θt,m could be the weights of the m-th layer of a NN for task t [This paper treats each layer as a module]
ϕ : shared meta parameters, the initialization of the NN parameters for each task θt
θt,m are conditionally independent of those of all other tasks given some "central" parameters. Namely: θt,m∼N(θt,m∣ϕm,σm2I) with mean ϕm and variance σm2 ( I is the identity matrix)
σ : shrinkage parameters.
σm2 : the m-th module scalar shrinkage parameter, which measures the degree to which θt,m can deviate from ϕm . If σm≈0, then θt,m≈ϕm , when σm shrinks to 0, the parameters of module m become task independent.
meta−parameters:Φ=(σ2,ϕ)
For values of σm2 near zero, the difference between parameters θt,m and mean ϕm will be shrunk to zero and thus module m will become task independent.
Thus, by learning σ2 , we can discover which modules are task independent.
These independent modules can be re-used at meta-test time
reducing the computational burden of adaptation
and likely improving generalization
Meta-Learning as Parameter Estimation
Goal: estimate the parameters ϕ and σ2 .
Standard solution: Maximize the marginal likelihood (intractable)
p(D∣ϕ,σ2)=∫p(D∣θ)p(θ∣ϕ,σ2)dθ
Approximation: Two principled alternative approaches for parameter estimation based on maximizing the predictive likelihoodover validation subsets. (query set)
Approximation Method 1: Parameter estimation via the predictive likelihood
(Relation with MAML and iMAML)
Our goal is to minimize the average negative predictive log-likelihood over T validation tasks:
there might be a typo here (missing a minus) in equation (2) in the paper
Justification:
Based on two assumptions:
the training and validation data (support and query set) is distributed i.i.d according to some distribution ν(Dttrain,Dtval)
Thus, minimizing ℓPLL w.r.t. meta parameters corresponds to selecting the predictive distribution that is closest approximately to the true predictive distribution on average.
If optimizing ℓttrain requires only a small number of local gradient steps, we can compute the update for meta parameters ( ϕ,σ2 ) with back-propagation through θt^ , yielding (6).
This update reduces to that of MAML if σm2→∞ for all modules and is thus denoted as σ-MAML.
Why? (my explanation)
According to equation (4):
ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
p(θt∣σ2,ϕ)=N(θt∣σ2,ϕ)=σ2π1e−2σ2(θt−ϕ)2
2) ∇Φθ^t(Φ) implicit computation.
(Relation with iMAML)
We are more interested in long adaptation horizons.
Implicit Function Theorem (compute the gradient of θt^ w.r.t ϕ and σ2)
Meta update for ϕis equivalent to that of iMAML when σm is constant for all m.
Why?
According to equation (4):
ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
Also because all modules share a constant variance, σm=σ ,
we can expand the log-prior term for task parameters θt in eqution (4), and plug in the normal prior assumption as follows:
Meta-update of iMAML: ΔtiMAML=(I+λ1∇θt2ℓttrain(θt))−1∇θtℓtval(θt)
if we define the regularization scale λ=σ21, and plug in the definition ℓttrain:=−logp(Dttrain∣θt) to the meta-update of iMAML, the finial equation is exactly the same with equation (41).
Approximation Method 2: Estimate ϕ via MAP approximation
Relation with Reptile
Integrating out the center parameter ϕ , and considering ϕ depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( θ1:T,ϕ ), i.e.
Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters and the central meta parameters:
(assume ϕ is a flat prior, non-informative, no prior at all, the second term above is dropped)
By plugging the MAP of ϕ (Equ: 9) into Equ (8), and scaling by 1/T , we derive the approximate predictive log-likelihood as :
compare this to Equ.5
The meta update for ϕ can be obtained by differentiating (9) w.r.t. ϕ .
To derive the gradient of Eq. (42) with respect to σ2, notice that when ϕ is estimated as the MAP on the training subsets of all tasks, it becomes a function of σ2 .
we take an approximation and ignore the dependence of ϕ^ on σ2 . Then ϕ becomes a constantwhen computing ∇σ2θ^t(σ2) , and the derivation in above for iMAML applies by replacing Φ with σ2 , giving the implicit gradient in Eq. (10).
Derivation:
Rewrite equ (4) as:
θ^t(σ2,ϕ)=θtargminℓttrain(θt,σ2,ϕ)
where ℓttrain:=−logp(Dttrain∣θt)−logp(θt∣σ2,ϕ)
So, θ^t(σ2,ϕ) is the stationary point of function ℓttrain(θt,σ2,ϕ) .
This work proposes a hierarchical Bayesian model for meta-learning that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.
Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.
As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.
Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task
Think of the connection among those papers (probabilistic graphical model explanation):