📒
PaperNotes
  • PAPER NOTES
  • Meta-Learning with Implicit Gradient
  • DARTS: Differentiable Architecture Search
  • Meta-Learning of Neural Architectures for Few-Shot Learning
  • Towards Fast Adaptation of Neural Architectures with Meta Learning
  • Editable Neural Networks
  • ANIL (Almost No Inner Loop)
  • Meta-Learning Representation for Continual Learning
  • Learning to learn by gradient descent by gradient descent
  • Modular Meta-Learning with Shrinkage
  • NADS: Neural Architecture Distribution Search for Uncertainty Awareness
  • Modular Meta Learning
  • Sep
    • Incremental Few Shot Learning with Attention Attractor Network
    • Learning Steady-States of Iterative Algorithms over Graphs
      • Experiments
    • Learning combinatorial optimization algorithms over graphs
    • Meta-Learning with Shared Amortized Variational Inference
    • Concept Learners for Generalizable Few-Shot Learning
    • Progressive Graph Learning for Open-Set Domain Adaptation
    • Probabilistic Neural Architecture Search
    • Large-Scale Long-Tailed Recognition in an Open World
    • Learning to stop while learning to predict
    • Adaptive Risk Minimization: A Meta-Learning Approach for Tackling Group Shift
    • Learning to Generalize: Meta-Learning for Domain Generalization
  • Oct
    • Meta-Learning Acquisition Functions for Transfer Learning in Bayesian Optimization
    • Network Architecture Search for Domain Adaptation
    • Continuous Meta Learning without tasks
    • Learning Causal Models Online
    • Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
    • Conditional Neural Progress (CNPs)
    • Reviving and Improving Recurrent Back-Propagation
    • Meta-Q-Learning
    • Learning Self-train for semi-supervised few shot classification
    • Watch, Try, Learn: Meta-Learning from Demonstrations and Rewards
  • Nov
    • Neural Process
    • Adversarially Robust Few-Shot Learning: A Meta-Learning Approach
    • Learning to Adapt to Evolving Domains
  • Tutorials
    • Relax constraints to continuous
    • MAML, FO-MAML, Reptile
    • Gradient Descent
      • Steepest Gradient Descent
      • Conjugate Gradient Descent
  • KL, Entropy, MLE, ELBO
  • Coding tricks
    • Python
    • Pytorch
  • ml
    • kmeans
Powered by GitBook
On this page
  • 8/18/2020
  • Motivation
  • Contributions
  • Few-shot supervised learning and MAML (Bi-level optimization)
  • Proximal regularization in the inner level
  • New Bi-level optimization problem after proximal regularization
  • Implicit MAML
  • Meta-Gradient Computation
  • In theory
  • In practice
  • Some Questions:

Was this helpful?

Meta-Learning with Implicit Gradient

NIPS 2019

PreviousPAPER NOTESNextDARTS: Differentiable Architecture Search

Last updated 1 year ago

Was this helpful?

8/18/2020

Motivation

There are some limitations in MAML. The meta-learning process requires higher-order derivatives, imposes a non-trivial computational and memory burden, and can suffer from vanishing gradients. These limitations make it harder to scale optimization-based meta learning methods to tasks involving medium or large datasets, or those that require many inner-loop optimization steps.

Contributions

The development of the implicit MAML (iMAML) algorithm, an approach for optimization-based meta-learning with deep neural networks that removes the need for differentiating through the optimization path.

The algorithm aims to learn a set of parameters such that an optimization algorithm that is initialized at and regularized to this parameter vector leads to good generalization for a variety of learning tasks.

Few-shot supervised learning and MAML (Bi-level optimization)

Notations:

Proximal regularization in the inner level

Two challenges in MAML in cases like ill-conditioned optimization landscapes and medium-shot learning, we may want to take many gradient steps:

A more explicitly regularized algorithm is considered:

New Bi-level optimization problem after proximal regularization

Simplify the notation:

Total and Partial Derivatives:

Implicit MAML

Goal: to solve the bi-level meta-learning problem in Eq (4) using an iterative gradient based algorithm of the form:

Specifically:

Note: many available iterative gradient based algorithm and other optimization methods could be used here:

Meta-Gradient Computation

In theory

According to the stationary point conditions, we have:

which is an implicit equation.

When the derivative exists:

Implicit Jacobian

In practice

Two issues of theory solution in practice:

  • Explicitly forming and inverting the matrix Eq.6 for computing the Jacobian may be intractable for large deep learning network.

1), we consider an approximate solution to the inner optimization problem, that can be obtained with iterative optimization algorithms like gradient descent. Red

2), we will perform a partial or approximate matrix inversion. Green

Some Questions:

Use Figure 1 to explain the differences between MAML, first-order MAML, and implicit MAML. Appendix A might be helpful for this.

I will write a new answer for this question independently later including intuition and math details.

In Section 3.1, it talks about the high memory cost when the number of gradient steps is large for MAML. What is the memory cost with respect to the number of gradient steps?

Using iterative algorithm (Gradient Descent) for the optimization of inner loop has drawback: depending explicitly on the path of the optimization, which has to be fully stored in memory, quickly becoming intractable when the number of gradient steps needed is large.

According to Theorem 1: Algorithm 2 can be implemented using at most

Detailed proof will be given later.

Please provide a review of the conjugate gradient algorithm that is used in Section 3.1.

Another independent tutorial of conjugate gradient will be given.

According to definition 2:

In order to solve the following problem:

we have:

Potential possible iterative optimization solver here could be:

  • gradient descent

  • Nesterov's accelerated gradient descent

According to definition 2:

Reference:

θML∗\theta^*_{ML}θML∗​ optimal meta-learned parameters

MMM the number of tasks in meta-train, iii is the index of task iii

Ditr\mathcal{D}^{tr}_i Ditr​ support set, Ditest\mathcal{D}^{test}_i Ditest​ query set in task iii

L(ϕ,D)\mathcal{L}(\phi, \mathcal{D}) L(ϕ,D) loss function with parameter vector and dataset

ϕi=Alg(θ,Ditr)=θ−α∇θL(θ,Ditr)\phi_i = \mathcal{A}lg(\theta, \mathcal{D}^{tr}_{i})=\theta - \alpha\nabla_{\theta}\mathcal{L}(\theta,\mathcal{D}^{tr}_i)ϕi​=Alg(θ,Ditr​)=θ−α∇θ​L(θ,Ditr​) : one (or multiple) steps of gradient descent initialized at θ\thetaθ. [inner-level of MAML]

we need to store and differentiate through the long optimization path of Alg\mathcal{A}lgAlg, imposing a considerable computation and memory burden

The dependence of the model-parameters {ϕi\phi_iϕi​} on the meta-parameters ( θ\thetaθ ) shrinks and vanishes as the number of gradient steps in Alg\mathcal{A}lgAlg grows, making meta-learning difficult.

Li(ϕ):=L(ϕ,Ditest),L^i(ϕ):=L(ϕ,Ditr),Alg(θ):=Alg(θ,Ditr)\mathcal{L}_i(\phi):=\mathcal{L}(\phi,\mathcal{D}_i^{test}), \hat{\mathcal{L}}_i(\phi):=\mathcal{L}(\phi,\mathcal{D}_i^{tr}), \mathcal{A}lg(\theta):=\mathcal{A}lg(\theta,\mathcal{D}^{tr}_i)Li​(ϕ):=L(ϕ,Ditest​),L^i​(ϕ):=L(ϕ,Ditr​),Alg(θ):=Alg(θ,Ditr​)

( d\boldsymbol{d}d denotes the total derivative, ∇\nabla∇ denotes the partial derivative)

dθLi(Algi(θ))=dAlgi(θ)dθ∇ϕLi(ϕ)∣ϕ=Algi(θ)=dAlgi(θ)dθ∇ϕLi(Algi(θ))\boldsymbol{d}_{\theta}\mathcal{L}_i(\mathcal{A}lg_i(\theta)) = \frac{\boldsymbol{d}\mathcal{A}lg_i(\theta)}{d\theta}\nabla_{\phi}\mathcal{L}_i(\phi)|_{\phi=\mathcal{A}lg_i(\theta)} = \frac{\boldsymbol{d}\mathcal{A}lg_i(\theta)}{d\theta}\nabla_{\phi}\mathcal{L}_i(\mathcal{A}lg_i(\theta))dθ​Li​(Algi​(θ))=dθdAlgi​(θ)​∇ϕ​Li​(ϕ)∣ϕ=Algi​(θ)​=dθdAlgi​(θ)​∇ϕ​Li​(Algi​(θ))
θ←θ−ηdθF(θ)\theta \leftarrow \theta -\eta d_{\theta}F(\theta)θ←θ−ηdθ​F(θ)
θ←θ−η1M∑i=1MdAlgi⋆(θ)dθ∇ϕLi(Algi⋆(θ))\boldsymbol{\theta} \leftarrow \boldsymbol{\theta}-\eta \frac{1}{M} \sum_{i=1}^{M} \frac{d \mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})}{d \boldsymbol{\theta}} \nabla_{\phi} \mathcal{L}_{i}\left(\mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})\right)θ←θ−ηM1​i=1∑M​dθdAlgi⋆​(θ)​∇ϕ​Li​(Algi⋆​(θ))

∇ϕLi(Algi⋆(θ))\nabla_{\phi} \mathcal{L}_{i}\left(\mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})\right)∇ϕ​Li​(Algi⋆​(θ)) can be easily obtained in practice via automatic differentiation

dAlgi⋆(θ)dθ\frac{d \mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})}{d \boldsymbol{\theta}} dθdAlgi⋆​(θ)​ presents the primary challenge. Algi⋆(θ)\mathcal{A}lg_i^{\star}(\theta) Algi⋆​(θ) is implicitly defined as an optimization problem in Equ.4.

Theoretically we can calculate the meta-gradient computation dAlgi⋆(θ)dθ\frac{d \mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})}{d \boldsymbol{\theta}} dθdAlgi⋆​(θ)​ exactly using the following lemma.

lemma 1: (Implicit Jacobian) Consider Algi⋆(θ)\mathcal{A}lg_i^{\star}(\theta) Algi⋆​(θ) as defined in Eq.4 for task Ti\mathcal{T}_iTi​. Let ϕi=Algi⋆(θ)\phi_i=\mathcal{A}lg_i^{\star}(\theta)ϕi​=Algi⋆​(θ) be the results of Algi⋆(θ)\mathcal{A}lg_i^{\star}(\theta)Algi⋆​(θ). If (I+1λ∇ϕ2Li^(ϕi))(I+\frac{1}{\lambda}\nabla^2_{\phi}\hat{\mathcal{L}_i}(\phi_i) )(I+λ1​∇ϕ2​Li​^​(ϕi​)) is invertible, then the derivative Jacobian is

dAlgi⋆(θ)dθ=(I+1λ∇ϕ2Li^(ϕi))−1\frac{d \mathcal{A}lg _{i}^{\star}(\boldsymbol{\theta})}{d \boldsymbol{\theta}} = (I+\frac{1}{\lambda}\nabla^2_{\phi}\hat{\mathcal{L}_i}(\phi_i) )^{-1}dθdAlgi⋆​(θ)​=(I+λ1​∇ϕ2​Li​^​(ϕi​))−1 (6)

Proof: We drop iii subscripts in the proof for convenience.

ϕ\phiϕ is the minimizer of G(ϕ′,θ)G(\phi',\theta)G(ϕ′,θ) , namely:

ϕ=Alg⋆(θ):=argmin⁡ϕ′∈ΦG(ϕ′,θ), where ϕ=G(ϕ′,θ)=L^(ϕ′)+λ2∥ϕ′−θ∥2\phi = \mathcal{A}lg^{\star}(\boldsymbol{\theta}):=\underset{\boldsymbol{\phi}^{\prime} \in \Phi}{\operatorname{argmin}} G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right), \text { where } \phi = G\left(\boldsymbol{\phi}^{\prime}, \boldsymbol{\theta}\right)=\hat{\mathcal{L}}\left(\boldsymbol{\phi}^{\prime}\right)+\frac{\lambda}{2}\left\|\boldsymbol{\phi}^{\prime}-\boldsymbol{\theta}\right\|^{2}ϕ=Alg⋆(θ):=ϕ′∈Φargmin​G(ϕ′,θ), where ϕ=G(ϕ′,θ)=L^(ϕ′)+2λ​​ϕ′−θ​2

∇ϕ′G(ϕ′,θ)∣ϕ′=ϕ=0  ⟹  ∇ϕ′(L^(ϕ′)+λ2∥ϕ′−θ∥2)∣ϕ′=ϕ=∇L^(ϕ)+λ(ϕ−θ)  ⟹  ϕ=θ−1λ∇L^(θ)\nabla_{\phi'}G(\phi',\theta)|_{\phi'=\phi}=0 \\ \implies \nabla_{\phi'}(\hat{\mathcal{L}}(\phi')+\frac{\lambda}{2}\left\|\phi'-\theta \right\|^2 )|_{\phi'=\phi}=\nabla\hat{\mathcal{L}}(\phi)+\lambda(\phi-\theta) \\ \implies \phi=\theta-\frac{1}{\lambda}\nabla\hat{\mathcal{L}}(\theta)∇ϕ′​G(ϕ′,θ)∣ϕ′=ϕ​=0⟹∇ϕ′​(L^(ϕ′)+2λ​∥ϕ′−θ∥2)∣ϕ′=ϕ​=∇L^(ϕ)+λ(ϕ−θ)⟹ϕ=θ−λ1​∇L^(θ)

dϕdθ=I−1λ∇2(L^(θ))dϕdθ  ⟹  (I+1λ∇2L^(θ))dϕdθ=I  ⟹  dϕdθ=(I+1λ∇2L^(θ))−1\frac{d\phi}{d\theta}=I-\frac{1}{\lambda}\nabla^2(\hat{\mathcal{L}}(\theta)) \frac{d\phi}{d\theta}\\ \implies (I+\frac{1}{\lambda}\nabla_2\hat{\mathcal{L}}(\theta)) \frac{d\phi}{d\theta}=I \\ \implies \frac{d\phi}{d\theta} = (I+\frac{1}{\lambda}\nabla_2\hat{\mathcal{L}}(\theta))^{-1}dθdϕ​=I−λ1​∇2(L^(θ))dθdϕ​⟹(I+λ1​∇2​L^(θ))dθdϕ​=I⟹dθdϕ​=(I+λ1​∇2​L^(θ))−1

The meta-gradients require computation of Algi⋆(θ),\mathcal{A}lg_{i}^{\star}(\theta), Algi⋆​(θ), which is the exact solution to the inner optimization problem. Only approximation could be obtained in practice.

O~(κlog⁡(poly(κ,D,B,L,ρ,μ,λ)ϵ))\tilde{O}\left(\sqrt{\kappa} \log \left(\frac{poly(\kappa, D, B, L, \rho, \mu, \lambda)}{\epsilon}\right)\right)O~(κ​log(ϵpoly(κ,D,B,L,ρ,μ,λ)​)) gradient computations of L^i(⋅)\hat{\mathcal{L}}_i(\cdot)L^i​(⋅)

2⋅Mem(∇L^i)2\cdot Mem(\nabla\hat{\mathcal{L}}_i)2⋅Mem(∇L^i​) memory.

Please explain why gig_igi​ can be obtained as an approximate solution to Problem (7).

∥gi−(I+1λ∇ϕ2L^i(ϕi))−1∇ϕLi(ϕi)∥≤δ′\left\|\boldsymbol{g}_{i}-\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)\right\| \leq \delta^{\prime}​gi​−(I+λ1​∇ϕ2​L^i​(ϕi​))−1∇ϕ​Li​(ϕi​)​≤δ′

gi\boldsymbol{g}_{i}gi​ is an approximation of the meta-gradient for task iii.

min⁡wf(w)=min⁡w12w⊤(I+1λ∇ϕ2L^i(ϕi))w−w⊤∇ϕLi(ϕi)\min _{\boldsymbol{w}} f(\boldsymbol{w}) = \min _{\boldsymbol{w}} \frac{1}{2} \boldsymbol{w}^{\top}\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\boldsymbol{\mathcal { L }}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right) \boldsymbol{w}-\boldsymbol{w}^{\top} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)minw​f(w)=minw​21​w⊤(I+λ1​∇ϕ2​L^i​(ϕi​))w−w⊤∇ϕ​Li​(ϕi​) (7)

df(w)dw=(I+1λ∇ϕ2Li^(ϕi))w−∇ϕLi(ϕi)=0  ⟹  w=(I+1λ∇ϕ2Li^(ϕi))−1∇ϕLi(ϕi)\frac{df(\boldsymbol{w})}{d\boldsymbol{w}} = \left(I+\frac{1}{\lambda}\nabla^2_{\phi}\hat{\mathcal{L}_i}(\phi_i)\right)\boldsymbol{w}-\nabla_{\phi}\mathcal{L}_i(\phi_i)=0 \\ \implies \boldsymbol{w} = \left(I+\frac{1}{\lambda}\nabla^2_{\phi}\hat{\mathcal{L}_i}(\phi_i)\right)^{-1}\nabla_{\phi}\mathcal{L}_i(\phi_i)dwdf(w)​=(I+λ1​∇ϕ2​Li​^​(ϕi​))w−∇ϕ​Li​(ϕi​)=0⟹w=(I+λ1​∇ϕ2​Li​^​(ϕi​))−1∇ϕ​Li​(ϕi​)

So gi\boldsymbol{g}_igi​ can be obtained as an approximate solution to optimization problem (7).

About Line 3 in Algorithm 2, what iterative optimization solver can be used and how many iterations are enough such that the error ≤δ\leq \delta≤δ .

If Nesterov's accelerated gradient descent algorithm is used to compute ϕ\phiϕ, the number of iterations could be: (Theorem 2)

2κlog⁡(8κD(B1ϵ+ρμ))2\sqrt{\kappa}\log\left(8\kappa D\left(\frac{B_1}{\epsilon} + \frac{\rho}{\mu}\right)\right)2κ​log(8κD(ϵB1​​+μρ​))

What is the relationships between gig_igi​ in Definition 2 and the derivative Jacobian?

∥gi−(I+1λ∇ϕ2L^i(ϕi))−1∇ϕLi(ϕi)∥≤δ′\left\|\boldsymbol{g}_{i}-\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1} \nabla_{\boldsymbol{\phi}} \mathcal{L}_{i}\left(\boldsymbol{\phi}_{i}\right)\right\| \leq \delta^{\prime}​gi​−(I+λ1​∇ϕ2​L^i​(ϕi​))−1∇ϕ​Li​(ϕi​)​≤δ′

gi\boldsymbol{g}_{i}gi​ is an approximation of the meta-gradient for task iii.

(I+1λ∇ϕ2L^i(ϕi))−1\left(\boldsymbol{I}+\frac{1}{\lambda} \nabla_{\boldsymbol{\phi}}^{2} \hat{\mathcal{L}}_{i}\left(\boldsymbol{\phi}_{i}\right)\right)^{-1}(I+λ1​∇ϕ2​L^i​(ϕi​))−1 is the derivative Jacobian.

https://www.inference.vc/notes-on-imaml-meta-learning-without-differentiating-through/
https://www.youtube.com/watch?v=u5BkO8XMS2I
https://papers.nips.cc/paper/8306-meta-learning-with-implicit-gradients
The goal of meta-learning is to learn meta-parameters that produce good task specific parameters after adaptation.
bi-level meta-learning problem (more general)