MAML, FO-MAML, Reptile

gradient optimization based meta-learning algorithms

MAML(Bi-level optimization)

Problem setting

Meta-Train

The goal of meta-learning is to learn meta-parameters that produce good task specific parameters after adaptation.

Notations:

  • θML\theta^*_{ML} optimal meta-learned parameters

  • ϕi\phi_i task-specific parameters for task ii

  • MM the number of tasks in meta-train, ii is the index of task ii

  • Ditr\mathcal{D}^{tr}_i support set, Ditest\mathcal{D}^{test}_i query set in task ii

  • L(ϕ,D)\mathcal{L}(\phi, \mathcal{D}) loss function with parameter vector and dataset

  • ϕi=Alg(θ,Ditr)=θαθL(θ,Ditr)\phi_i = \mathcal{A}lg(\theta, \mathcal{D}^{tr}_{i})=\theta - \alpha\nabla_{\theta}\mathcal{L}(\theta,\mathcal{D}^{tr}_i) : one (or multiple) steps of gradient descent initialized at θ\theta. [inner-level of MAML]

Meta-test

Gradient-based Meta-Learning

  • Task tt Tt,\mathcal{T}_t, is associated with a finite dataset Dt={xt,n}n=1Nt\mathcal{D}_{t}=\left\{\mathbf{x}_{t, n}\right\}|_{n=1}^{N_t}

  • Task tt Tt:\mathcal{T}_t: Dttrain,Dtval\mathcal{D}_{t}^{train}, \mathcal{D}_{t}^{val}

  • meta parameters ϕRD\boldsymbol{\phi} \in \mathbb{R}^{D}

  • Task-specific parameters θtRD\boldsymbol{\theta}_{t} \in \mathbb{R}^{D}

  • loss function (Dt;θt)\ell\left(\mathcal{D}_{t} ; \boldsymbol{\theta}_{t}\right)

Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:

  • MAML

  • iMAML

  • Reptile

  1. TASKADAPT: task adaptation (inner loop)

  2. The meta-updateΔt \Delta_tspecifies the contribution of task tt to the meta parameters. (outer loop)

MAML

  1. task adaptation: minimizing the training loss ttrain(θt)=(Dttrain;θt)\ell_t^{train}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{train}; \boldsymbol{\theta}_t) by the gradient descent w.r.t. task parameters

  2. meta parameters update: by gradient descent on the validation loss tval(θt)=(Dtval;θ),\ell_t^{val}(\boldsymbol{\theta}_t)=\ell(\mathcal{D}_t^{val}; \boldsymbol{\theta}), resulting in the meta update (gradient) for task tt : ΔtMAML=ϕtval(θt(ϕ))\Delta_t^{\text{MAML}} = \nabla_{\phi}\ell_t^{\text{val}}(\boldsymbol{\theta}_t(\phi))

This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.

Reptile

Reptile optimizes θt\theta_t on the entire dataset Dt\mathcal{D}_t , and move ϕ\phi towards the adapted task parameters, yielding ΔtReptile=ϕθt\Delta_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}_t

iMAML

iMAML introduces an L2 regularizer λ2θtϕ2\frac{\lambda}{2}||\boldsymbol{\theta}_t-\phi||^2 to training loss, and optimizes the task parameters on the regularized training loss.

Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: ΔtiMAML=(I+1λθt2ttrain(θt))1θttval(θt)\Delta_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\theta}_{t}}^{2} \ell_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}_{t}\right)\right)^{-1} \nabla_{\boldsymbol{\theta}_{t}} \ell_{t}^{\mathrm{val}}\left(\boldsymbol{\theta}_{t}\right)

Derivative process

FO-MAML

Reptile

How iMAML generalizes them

Last updated

Was this helpful?