Modular Meta-Learning with Shrinkage
8-24-2020
Last updated
Was this helpful?
8-24-2020
Last updated
Was this helpful?
The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:
multi-speaker text-to-speech synthesis
So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.
Existing meta-learning approaches either do not scale to long adaptation or else rely on handcrafted task-specific architectures.
A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.
works well in few shot text-to-speech domain.
MAML, iMAML, Reptile are special cases of the proposed method.
Task is associated with a finite dataset
Task
meta parameters
Task-specific parameters
loss function
Algorithm 1 is a structure of a typical meta-learning algorithm, which could be:
MAML
iMAML
Reptile
TASKADAPT: task adaptation (inner loop)
The meta-updatespecifies the contribution of task to the meta parameters. (outer loop)
task adaptation: minimizing the training loss by the gradient descent w.r.t. task parameters
meta parameters update: by gradient descent on the validation loss resulting in the meta update (gradient) for task :
This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.
Reptile optimizes on the entire dataset , and move towards the adapted task parameters, yielding
iMAML introduces an L2 regularizer to training loss, and optimizes the task parameters on the regularized training loss.
Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process:
Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.
inefficient and prone to overfitting
Split the network parameters into two groups:
a group varying across tasks
a group that is shared across tasks
In this paper: (task independent modules)
We assume that the network parameters can be partitioned into disjoint modules, in general.
where : parameters in module for task .
A shrinkage parameter is associated with each module to control the deviation of task-dependent parameters from the central value (meta-parameters)
Layers: could be the weights of the layer of a NN for task [This paper treats each layer as a module]
Receptive fields
the encoder and decoder in an auto-encoder
the heads in a multi-task learning model
any other grouping of interest
(with a factored probability density):
: shared meta parameters, the initialization of the NN parameters for each task
are conditionally independent of those of all other tasks given some "central" parameters. Namely: with mean and variance ( is the identity matrix)
: shrinkage parameters.
: the module scalar shrinkage parameter, which measures the degree to which can deviate from . If , then , when shrinks to 0, the parameters of module become task independent.
For values of near zero, the difference between parameters and mean will be shrunk to zero and thus module will become .
Thus, by learning , we can discover which modules are task independent.
These independent modules can be re-used at meta-test time
reducing the computational burden of adaptation
and likely improving generalization
Standard solution: Maximize the marginal likelihood (intractable)
Approximation: Two principled alternative approaches for parameter estimation based on maximizing the predictive likelihood over validation subsets. (query set)
(Relation with MAML and iMAML)
Our goal is to minimize the average negative predictive log-likelihood over validation tasks:
there might be a typo here (missing a minus) in equation (2) in the paper
New Problem: the computation of is not feasible due to the intractable integral in equ.2,
So, we use a simple maximum a posteriori (MAP) approximation of the task parameters:
After having the MAP estimates, we can approximate as follows:
Individual task adaptation follows from equation (4)
Meta updating from minimizing equation (5)
Compute the gradient of the approximate predictive log-likelihood w.r.t. the meta parameters :
is straightforward.
is not.
Two methods for :
We are more interested in long adaptation horizons.
Implicit Function Theorem (compute the gradient of w.r.t and )
full vector of meta-parameters
, namely, ,
Derivatives are evaluated at the stationary point
Derivation:
Rewrite equ (4) as:
where
So, is the stationary point of function .
Based on the implicit function theorem,
Plug the equation above to , END.
Meta update for is equivalent to that of iMAML when is constant for all .
Why?
According to equation (4):
Also because all modules share a constant variance, ,
we can expand the log-prior term for task parameters in eqution (4), and plug in the normal prior assumption as follows:
Meta-update of iMAML:
if we define the regularization scale and plug in the definition to the meta-update of iMAML, the finial equation is exactly the same with equation (41).
Relation with Reptile
Integrating out the center parameter , and considering depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( ), i.e.
Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters and the central meta parameters:
(assume is a flat prior, non-informative, no prior at all, the second term above is dropped)
By plugging the MAP of (Equ: 9) into Equ (8), and scaling by , we derive the approximate predictive log-likelihood as :
The meta update for can be obtained by differentiating (9) w.r.t. .
To derive the gradient of Eq. (42) with respect to , notice that when is estimated as the MAP on the training subsets of all tasks, it becomes a function of .
(9-layer network (4 conv layers, 4 batch-norm layers, and a linear output layer))
This work proposes a hierarchical Bayesian model for meta-learning that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.
Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.
As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.
Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task
Think of the connection among those papers (probabilistic graphical model explanation):
(1)
(6)
(7)
Similar to ANIL paper ()
Meta-Learning Probabilistic Inference For Prediction ()
Probabilistic model-agnostic meta-learning ()
Amortized Bayesian Meta-Learning ()
Bayesian TAML ()