📒
PaperNotes
  • PAPER NOTES
  • Meta-Learning with Implicit Gradient
  • DARTS: Differentiable Architecture Search
  • Meta-Learning of Neural Architectures for Few-Shot Learning
  • Towards Fast Adaptation of Neural Architectures with Meta Learning
  • Editable Neural Networks
  • ANIL (Almost No Inner Loop)
  • Meta-Learning Representation for Continual Learning
  • Learning to learn by gradient descent by gradient descent
  • Modular Meta-Learning with Shrinkage
  • NADS: Neural Architecture Distribution Search for Uncertainty Awareness
  • Modular Meta Learning
  • Sep
    • Incremental Few Shot Learning with Attention Attractor Network
    • Learning Steady-States of Iterative Algorithms over Graphs
      • Experiments
    • Learning combinatorial optimization algorithms over graphs
    • Meta-Learning with Shared Amortized Variational Inference
    • Concept Learners for Generalizable Few-Shot Learning
    • Progressive Graph Learning for Open-Set Domain Adaptation
    • Probabilistic Neural Architecture Search
    • Large-Scale Long-Tailed Recognition in an Open World
    • Learning to stop while learning to predict
    • Adaptive Risk Minimization: A Meta-Learning Approach for Tackling Group Shift
    • Learning to Generalize: Meta-Learning for Domain Generalization
  • Oct
    • Meta-Learning Acquisition Functions for Transfer Learning in Bayesian Optimization
    • Network Architecture Search for Domain Adaptation
    • Continuous Meta Learning without tasks
    • Learning Causal Models Online
    • Meta-Dataset: A Dataset of Datasets for Learning to Learn from Few Examples
    • Conditional Neural Progress (CNPs)
    • Reviving and Improving Recurrent Back-Propagation
    • Meta-Q-Learning
    • Learning Self-train for semi-supervised few shot classification
    • Watch, Try, Learn: Meta-Learning from Demonstrations and Rewards
  • Nov
    • Neural Process
    • Adversarially Robust Few-Shot Learning: A Meta-Learning Approach
    • Learning to Adapt to Evolving Domains
  • Tutorials
    • Relax constraints to continuous
    • MAML, FO-MAML, Reptile
    • Gradient Descent
      • Steepest Gradient Descent
      • Conjugate Gradient Descent
  • KL, Entropy, MLE, ELBO
  • Coding tricks
    • Python
    • Pytorch
  • ml
    • kmeans
Powered by GitBook
On this page
  • Motivation
  • Differentiable Architecture Search
  • Continuous relaxation
  • Approximation
  • Questions:
  • Reference:

Was this helpful?

DARTS: Differentiable Architecture Search

ICLR 2019 8/19/2020

PreviousMeta-Learning with Implicit GradientNextMeta-Learning of Neural Architectures for Few-Shot Learning

Last updated 1 year ago

Was this helpful?

Motivation

This paper addresses the scalability challenge of architecture search by formulating the task in a differentiable manner.

The best existing architecture search algorithms are computationally demanding despite their remark- able performance.

Conventional approaches over a discrete and non-differentiable search space (obtaining a state-of-the-art architecture for CIFAR-10 and ImageNet):

  • Evolution: 3150 GPU days

  • reinforcement learning: 2000 GPU days

Method in the paper:

  • based on the continuous relaxation of the architecture representation, allowing efficient search of the architecture using gradient descent.

  • a novel algorithm for differentiable network architecture search based on bilevel optimization, which is applicable to both convolutional and recurrent architectures.

  • remarkable efficiency improvement (reducing the cost of architecture discovery to a few GPU days)

Differentiable Architecture Search

Cell: A cell is a directed acyclic graph consisting of an ordered sequence of N nodes. For example, [0,1,2] in Fig.1

Node: each node x(i)x^{(i)}x(i) is a latent representation (e.g. a feature map in CNN)

Edge: each directed edge (i,j)(i,j)(i,j) is associated with some operation o(i,j)o^{(i,j)}o(i,j) that transforms x(i)x^{(i)}x(i)

Note: candidate operations between two edges:

  • convolution

  • max pooling

  • zero operation: a lack of connection between two nodes

Assumption: the cell has two input nodes and one single output node.

  • input nodes:

    • For convolutional cells: the input nodes are defined as the cell outputs in the previous two layers.

      • example: cell [0,1,2], 0,1:input nodes, 2: output node

    • For recurrent cells: input nodes are the input at the current step and the state carried from the previous step.

  • output:

    • the output of the cell is obtained by applying a reduction operation (e.g., concatenation) to all the intermediate nodes.

Continuous relaxation

To make the search space continuous, DARTS relaxes the categorical choice of a particular operation as a softmax over all the operations and the task of architecture search is reduced to learn a set of mixing probabilities α={α(i,j)}.\alpha = \{\alpha^{(i,j)}\}.α={α(i,j)}.

oˉ(i,j)(x)=∑o∈Oexp⁡(αo(i,j))∑o′∈Oexp⁡(αo′(i,j))o(x)\bar{o}^{(i, j)}(x)=\sum_{o \in \mathcal{O}} \frac{\exp \left(\alpha_{o}^{(i, j)}\right)}{\sum_{o^{\prime} \in \mathcal{O}} \exp \left(\alpha_{o^{\prime}}^{(i, j)}\right)} o(x)oˉ(i,j)(x)=o∈O∑​∑o′∈O​exp(αo′(i,j)​)exp(αo(i,j)​)​o(x)

where αo(i,j)\alpha_{o}^{(i,j)} αo(i,j)​ is a vector of dimension ∣O∣|\mathcal{O}|∣O∣, containing weights between nodes iii and jjj over different operations.

The bilevel optimization exists as we want to optimize both the network weights www and the architecture representation α\alphaα:

min⁡αLval(w∗(α),α) s.t. w∗(α)=argmin⁡wLtrain⁡(w,α)\begin{array}{ll} \min _{\alpha} & \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \\ \text { s.t. } & w^{*}(\alpha)=\operatorname{argmin}_{w} \mathcal{L}_{\operatorname{train}}(w, \alpha) \end{array}minα​ s.t. ​Lval​(w∗(α),α)w∗(α)=argminw​Ltrain​(w,α)​

Approximation

Due to the expensive inner optimization, evaluating the architecture gradient exactly can be prohibitive. A simple approximation scheme was proposed:

∇αLval(w∗(α),α)≈∇αLval(w−ξ∇wLtrain(w,α),α)=∇αLval(w′,α)−dw′dα∇w′Lval(w′,α)=∇αLval(w′,α)−ξ∇α,w2Ltrain(w,α)∇w′Lval(w′,α)where w′=w−ξ∇wLtrain(w,α)\begin{aligned} & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w^{*}(\alpha), \alpha\right) \\ \approx & \nabla_{\alpha} \mathcal{L}_{v a l}\left(w-\xi \nabla_{w} \mathcal{L}_{train}(w, \alpha), \alpha\right) \\ =& \nabla_{\alpha} \mathcal{L}_{val}(w',\alpha) - \frac{d w'}{d \alpha} \nabla_{w'}\mathcal{L}_{val}(w',\alpha) \\ =& \nabla_{\alpha} \mathcal{L}_{val}(w',\alpha) - \xi \nabla_{\alpha,w}^2\mathcal{\mathcal{L}_{train}}(w,\alpha)\nabla_{w'}\mathcal{L}_{val}(w',\alpha) \end{aligned} \text{where } w'=w-\xi \nabla_{w}\mathcal{L}_{train}(w,\alpha)≈==​∇α​Lval​(w∗(α),α)∇α​Lval​(w−ξ∇w​Ltrain​(w,α),α)∇α​Lval​(w′,α)−dαdw′​∇w′​Lval​(w′,α)∇α​Lval​(w′,α)−ξ∇α,w2​Ltrain​(w,α)∇w′​Lval​(w′,α)​where w′=w−ξ∇w​Ltrain​(w,α)

Fortunately, the complexity can be substantially reduced using the finite difference approximation:

  • ϵ\epsilonϵ is a small scalar

  • w±=w±ϵ∇w′Lval(w′,α)w^{\pm}=w\pm \epsilon \nabla_{w'}\mathcal{L}_{val}(w',\alpha)w±=w±ϵ∇w′​Lval​(w′,α)

∇α,w2Ltrain(w,α)∇w′Lval(w′,α)≈∇αLtrain(w+,α)−∇αLtrain(w−,α)w+−w−∇w′Lval(w′,α)=∇αLtrain(w+,α)−∇αLtrain(w−,α)2ϵ∇w′Lval(w′,α)∇w′Lval(w′,α)=∇αLtrain(w+,α)−∇αLtrain(w−,α)2ϵ\begin{aligned} & \nabla_{\alpha, w}^{2} \mathcal{L}_{t r a i n}(w, \alpha) \nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right) \\ \approx & \frac{\nabla_{\alpha} \mathcal{L}_{train}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{train}\left(w^{-}, \alpha\right)}{w^{+}-w^{-}}\nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\\ = & \frac{\nabla_{\alpha} \mathcal{L}_{train}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{train}\left(w^{-}, \alpha\right)}{2\epsilon \nabla_{w'}\mathcal{L}_{val}(w',\alpha)}\nabla_{w^{\prime}} \mathcal{L}_{v a l}\left(w^{\prime}, \alpha\right)\\ = & \frac{\nabla_{\alpha} \mathcal{L}_{train}\left(w^{+}, \alpha\right)-\nabla_{\alpha} \mathcal{L}_{train}\left(w^{-}, \alpha\right)}{2 \epsilon} \end{aligned}≈==​∇α,w2​Ltrain​(w,α)∇w′​Lval​(w′,α)w+−w−∇α​Ltrain​(w+,α)−∇α​Ltrain​(w−,α)​∇w′​Lval​(w′,α)2ϵ∇w′​Lval​(w′,α)∇α​Ltrain​(w+,α)−∇α​Ltrain​(w−,α)​∇w′​Lval​(w′,α)2ϵ∇α​Ltrain​(w+,α)−∇α​Ltrain​(w−,α)​​

Questions:

"A cell is a directed acyclic graph consisting of an ordered sequence of N nodes"

How to decide the ordered sequence in practice?

The nodes are ordered randomly at first.

How to decide the intermediate nodes?

(I am not sure about the definition of intermediate node)

one cell has 2 input nodes, and 1 output node. So how to define the intermediate node? No intermediate node in each cell. It might be defined based on the whole architecture.

Can we decompose ResNet into such cells, each of which has two input nodes, and one output node?

Yes.

This is the original ResNet block.

This is the cell figure.

"The output of the cell is obtained by applying a reduction operation (e.g. concatenation) to all the intermediate nodes."

Use an example to illustrate this.

Not very confident about this. What is the output of one cell? Each cell has three nodes, each of which has output. So the output of the cell might be the concatenation of output from these three nodes.

After watching the video again, c_{k-1}, c_{k-2}, c_{k} were called nodes in one cell.

I am confused this. If c_{k-1}, c_{k-2}, c_{k} are nodes, what are 0,1,2,3?

How are Equ. (7) and (8) derived?

Please see the above Approximation section.

To form each node in the discrete architecture, we retain the top-k strongest operations (from distinct nodes) among all non-zero candidate operations collected from all the previous nodes.

is kkk a hyperparameter?

Yes. This is a hyperparameter.

what is zero operations?

a lack of connection between two nodes

Please explain why batch normalization is effective for this?

will finish this after discussion with Xujiang and Tianhao.

Reference:

https://openreview.net/forum?id=S1eYHoC5FX
https://lilianweng.github.io/lil-log/2020/08/06/neural-architecture-search.html
https://github.com/quark0/darts
https://towardsdatascience.com/investigating-differentiable-neural-architecture-search-for-scientific-datasets-62899be8714e
http://torch.ch/blog/2016/02/04/resnets.html
GOAL: find the optimal operation on each edge