MAML(Bi-level optimization)
Problem setting
Meta-Train
Notations:
θML∗ optimal meta-learned parameters
ϕi task-specific parameters for task i
M the number of tasks in meta-train, i is the index of task i
Ditr support set, Ditest query set in task i
L(ϕ,D) loss function with parameter vector and dataset
ϕi=Alg(θ,Ditr)=θ−α∇θL(θ,Ditr) : one (or multiple) steps of gradient descent initialized at θ. [inner-level of MAML]
Meta-test
Gradient-based Meta-Learning
Task t Tt, is associated with a finite dataset Dt={xt,n}∣n=1Nt
Task t Tt: Dttrain,Dtval
meta parameters ϕ∈RD
Task-specific parameters θt∈RD
loss function ℓ(Dt;θt)
Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:
TASKADAPT: task adaptation (inner loop)
The meta-updateΔtspecifies the contribution of task t to the meta parameters. (outer loop)
MAML
task adaptation: minimizing the training loss ℓttrain(θt)=ℓ(Dttrain;θt) by the gradient descent w.r.t. task parameters
meta parameters update: by gradient descent on the validation loss ℓtval(θt)=ℓ(Dtval;θ), resulting in the meta update (gradient) for task t : ΔtMAML=∇ϕℓtval(θt(ϕ))
This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.
Reptile
Reptile optimizes θt on the entire dataset Dt , and move ϕ towards the adapted task parameters, yielding ΔtReptile=ϕ−θt
iMAML
iMAML introduces an L2 regularizer 2λ∣∣θt−ϕ∣∣2 to training loss, and optimizes the task parameters on the regularized training loss.
Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+λ1∇θt2ℓttrain(θt))−1∇θtℓtval(θt)
Derivative process
FO-MAML
Reptile
How iMAML generalizes them