# Modular Meta-Learning with Shrinkage

## Motivation

The ability to meta-learn large models with only a few task-specific components is important in many real-world problems:&#x20;

* multi-speaker text-to-speech synthesis

So updating only these task-specific modules then allows the model to be adapted to low-data tasks for as many steps as necessary without risking overfitting.

Existing meta-learning approaches either **do not scale to long adaptation** or else rely on handcrafted task-specific architectures.

A new meta-learning approach based on Bayesian shrinkage is proposed to automatically discover and learn both task-specific and general reusable modules.

works well in few shot text-to-speech domain.

MAML, iMAML, Reptile are special cases of the proposed method.

## Gradient-based Meta-Learning

* Task $$t$$ $$\mathcal{T}*t,$$ is associated with a finite dataset $$\mathcal{D}*{t}=\left{\mathbf{x}*{t, n}\right}|*{n=1}^{N\_t}$$&#x20;
* Task $$t$$ $$\mathcal{T}*t:$$ $$\mathcal{D}*{t}^{train}, \mathcal{D}\_{t}^{val}$$&#x20;
* meta parameters $$\boldsymbol{\phi} \in \mathbb{R}^{D}$$&#x20;
* Task-specific parameters $$\boldsymbol{\theta}\_{t} \in \mathbb{R}^{D}$$&#x20;
* loss function $$\ell\left(\mathcal{D}*{t} ; \boldsymbol{\theta}*{t}\right)$$&#x20;

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1ytVrabx-stcHcDXA%2F-MG2GQq565UtOb53Y8sR%2Falgo1.png?alt=media\&token=aad85e43-c329-405a-bce6-0ddce8a1eefc)

**Algorithm 1** is a structure of a typical meta-learning algorithm, which could be:

* MAML
* iMAML
* Reptile

1. TASKADAPT: task adaptation (**inner loop**)
2. The meta-update$$\Delta\_t$$specifies the contribution of task $$t$$ to the meta parameters. (**outer loop**)

### MAML

1. **task adaptation**: minimizing the training loss $$\ell\_t^{train}(\boldsymbol{\theta}\_t)=\ell(\mathcal{D}\_t^{train}; \boldsymbol{\theta}\_t)$$ by the gradient descent w\.r.t. task parameters
2. **meta parameters update**: by gradient descent on the validation loss $$\ell\_t^{val}(\boldsymbol{\theta}\_t)=\ell(\mathcal{D}*t^{val}; \boldsymbol{\theta}),$$ resulting in the meta update (gradient) for task $$t$$ : $$\Delta\_t^{\text{MAML}} = \nabla*{\phi}\ell\_t^{\text{val}}(\boldsymbol{\theta}\_t(\phi))$$

This approach treats the task parameters as a function of the meta parameters, and hence requires back-propagation through the entire L-step task adaptation process. When L is large, this becomes computationally prohibitive.

### Reptile

Reptile optimizes $$\theta\_t$$ on the entire dataset $$\mathcal{D}\_t$$ , and move $$\phi$$ towards the adapted task parameters, yielding $$\Delta\_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}\_t$$&#x20;

### iMAML

iMAML introduces an L2 regularizer $$\frac{\lambda}{2}||\boldsymbol{\theta}\_t-\phi||^2$$ to training loss, and optimizes the task parameters on the regularized training loss.

Provided that this task adaptation process converges to a stationary point, implicit differentiation enables the computation of meta gradient based only on the final solution of the adaptation process: $$\Delta\_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla\_{\boldsymbol{\theta}*{t}}^{2} \ell*{t}^{\operatorname{train}}\left(\boldsymbol{\theta}*{t}\right)\right)^{-1} \nabla*{\boldsymbol{\theta}*{t}} \ell*{t}^{\mathrm{val}}\left(\boldsymbol{\theta}\_{t}\right)$$&#x20;

## Modular Bayesian Meta-Learning

* Standard meta-learning: one base learner (Neural Network)'s parameters update at the same time.
  * inefficient and prone to overfitting
* Split the network parameters into two groups:&#x20;

  * a group varying across tasks
  * a group that is shared across tasks

**In this paper: (task independent modules)**

We assume that the network parameters can be partitioned into $$M$$ disjoint modules, in general.

$$\boldsymbol{\theta}*t = (\boldsymbol{\theta}*{t,1}, \boldsymbol{\theta}*{t,2}, \dots, \boldsymbol{\theta}*{t,m}, \dots, \boldsymbol{\theta}\_{t,M})$$&#x20;

where $$\boldsymbol{\theta}\_{t,m}$$ : parameters in module $$m$$ for task $$t$$ .

![Fig.1 Modular meta-learning](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1qm7XgOqIo5elumHg%2F-MG1szyb-BXAcx9vJiGT%2Ffig1.png?alt=media\&token=fd9f7791-eba2-4df7-a12e-37e126b18c6a)

{% hint style="warning" %}
**A shrinkage parameter** $$\sigma\_m$$ **is associated with each module to control the deviation of task-dependent parameters** $$\boldsymbol{\theta}\_{t,m}$$ **from the central value** $$\boldsymbol{\phi}\_m$$ **(meta-parameters)**
{% endhint %}

#### How to define modules? Modules can correspond to:

* **Layers**: $$\boldsymbol{\theta}*{t,m}$$ could be the weights of the $$m\text{-th}$$ layer of a NN for task $$t$$                                           \[***This paper treats each layer as a module**\_]
* Receptive fields
* the encoder and decoder in an auto-encoder
* the heads in a multi-task learning model
* any other grouping of interest

### Hierarchical Bayesian Model

![Bayesian Shrinkage Graphical Model](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MFxiSc-QSm_8oqxSMe7%2F-MG-IeuTpZ7R_H4QeKnZ%2FHierBaye.jpg?alt=media\&token=88fe13ca-63ae-4b0a-8785-14343c81cb47)

(with a factored probability density):

$$
p\left(\boldsymbol{\theta}*{1:T}, \mathcal{D}|\boldsymbol{\sigma}^2, \boldsymbol{\phi} \right) = \prod*{t=1}^{T}\prod\_{m=1}^{M}\mathcal{N}(\boldsymbol{\theta}*{t,m}|\phi\_m, \sigma\_m^2\boldsymbol{I})\prod*{t=1}^{T}p(\mathcal{D}\_t|\boldsymbol{\theta}\_t)
$$

* $$\boldsymbol{\phi}$$ : **shared meta parameters**, the initialization of the NN parameters for each task $$\boldsymbol{\theta}\_t$$&#x20;
* $$\boldsymbol{\theta}*{t,m}$$ are conditionally independent of those of all other tasks given some "central" parameters. Namely: $$\boldsymbol{\theta}*{t,m} \sim \mathcal{N}(\boldsymbol{\theta}\_{t,m}|\phi\_m, \sigma\_m^2\boldsymbol{I})$$ with mean $$\phi\_m$$ and variance $$\sigma\_m^2$$  ( $$\boldsymbol{I}$$ is the identity matrix)
* $$\boldsymbol{\sigma}$$ : shrinkage parameters.
* $$\sigma\_m^2$$ : the $$m\text{-th}$$ module scalar shrinkage parameter, which measures the degree to which $$\theta\_{t,m}$$ can deviate from $$\phi\_m$$ . If $$\sigma\_m\approx 0$$, then $$\theta\_{t,m} \approx \phi\_m$$ , when $$\sigma\_m$$ shrinks to 0, the parameters of module $$m$$ become task independent.
* $${\color{blue}{meta-parameters}}: \mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

For values of $$\sigma\_m^2$$ near zero, the difference between parameters $$\theta\_{t,m}$$ and mean $$\phi\_m$$ will be shrunk to zero and thus module $$m$$ will become $$\text{\color{red}task independent}$$.

* Thus, by learning $$\boldsymbol{\sigma}^2$$ , we can discover **which modules are task independent**.&#x20;
* These independent modules can be re-used **at meta-test time**
  * reducing the computational burden of adaptation&#x20;
  * and likely improving generalization

## &#x20;Meta-Learning as Parameter Estimation

#### Goal: estimate the parameters $$\boldsymbol{\phi}$$ and $$\boldsymbol{\sigma}^2$$ .

Standard solution: Maximize the marginal likelihood (**intractable**)

$$p(\mathcal{D}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)=\int p(\mathcal{D}|\boldsymbol{\theta})p(\boldsymbol{\theta}|\boldsymbol{\phi},\boldsymbol{\sigma}^2)\mathbf{d}\boldsymbol{\theta}$$&#x20;

**Approximation**: Two principled alternative approaches for parameter estimation based on **maximizing the predictive likelihood** ***over validation subsets***. (**query set**)

## Approximation Method 1: Parameter estimation via the predictive likelihood

* **(Relation with MAML and iMAML)**

Our goal is to minimize the average negative predictive log-likelihood over $$T$$ validation tasks:

![the average predictive log-likelihood](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1Hb8jmK3-WedzXGQk%2F-MG1R2XrCsDI4vC6jaHL%2Fequ2.jpg?alt=media\&token=bad6e283-4329-471b-8dd3-c4f633a45129)

* *there might be a typo here (missing a minus) in equation (2) in the paper*

{% hint style="info" %}
**Justification**:

Based on **two assumptions**:

* the training and validation data (support and query set) is distributed i.i.d according to some distribution $$\nu(\mathcal{D}\_t^{train}, \mathcal{D}\_t^{val})$$&#x20;
* The law of large numbers.&#x20;

as $$T\rightarrow \infty$$ :

$$\ell\_{\mathrm{PLL}}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right) \rightarrow \mathbb{E}*{\color{blue}\nu\left(\mathcal{D}*{t}^{\text {train }}\right)}\left\[\mathrm{KL}\left(\nu\left(\mathcal{D}*{t}^{\text {val }} \mid \mathcal{D}*{t}^{\text {train }}\right) | p\left(\mathcal{D}*{t}^{\text {val}} \mid \mathcal{D}*{t}^{\text {train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)\right)\right]+\mathrm{H}\left(\nu\left(\mathcal{D}*{t}^{\text {val }} \mid \mathcal{D}*{t}^{\text {train }}\right)\right)$$&#x20;

* $$\nu\left(\mathcal{D}*{t}^{\text {val }}\right | \mathcal{D}*{t}^{\text {train}})$$ : true distribution
* $$p\left(\mathcal{D}*{t}^{\text {val}} \mid \mathcal{D}*{t}^{\text{train}}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$ : predictive distribution

Thus, **minimizing** $$\ell\_{PLL}$$ **w\.r.t. meta parameters** corresponds to selecting the predictive distribution that is closest approximately to the true predictive distribution on average.

**Proof**:

* <https://wiseodd.github.io/techblog/2017/01/26/kl-mle/>
* <https://glassboxmedicine.com/2019/12/07/connections-log-likelihood-cross-entropy-kl-divergence-logistic-regression-and-neural-networks/>
* <https://www.jessicayung.com/maximum-likelihood-as-minimising-kl-divergence/>
* <https://www.hongliangjie.com/2012/07/12/maximum-likelihood-as-minimize-kl-divergence/>

~~**Note**~~: $$\color{blue}\nu(\mathcal{D}\_t^{train})$$ ?
{% endhint %}

**New Problem**: the computation of $$\ell\_{PLL}$$ is not feasible due to the intractable integral in equ.2,&#x20;

So, we use a simple **maximum a posteriori (MAP) approximation of the task parameters:**

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1mcPL_riFtWoHGs_e%2F-MG1nZzulb2QQV-zy6xl%2Feq4.png?alt=media\&token=faf251e8-d0c9-4d44-9700-6b88b77e3656)

{% hint style="info" %}
**Where is this from?**

&#x20;<img src="https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1nmOB5O5WqEgAechj%2F-MG1o-PQk44zJ1zTsqxc%2Feq111.jpg?alt=media&#x26;token=2705766a-ab38-4c28-8d33-57d8bc27c8c5" alt="" data-size="original">  (1)

For a specific individual task $$t$$ :

$$\begin{align}  \ell\_t^{train} :=& -\log p\left(\boldsymbol{\theta}\_{t}, \mathcal{D}|\boldsymbol{\sigma}^2, \boldsymbol{\phi} \right) \  = & -\log p(\boldsymbol{\theta}\_t|\phi, \sigma^2)-\log p(\mathcal{D}\_t|\boldsymbol{\theta}\_t)    \end{align}$$&#x20;
{% endhint %}

After having the MAP estimates, we can approximate $$\ell\_{PLL}(\boldsymbol{\sigma}^2, \boldsymbol{\phi})$$ as follows:

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1qm7XgOqIo5elumHg%2F-MG1s2iMkljzO0CBDchU%2Feq5.png?alt=media\&token=56cb05b1-74ad-4989-83e8-7ace9778d3f3)

* **Individual task adaptation** follows from equation (4)
* **Meta updating** from minimizing equation (5)

#### So minimizing equation (5) is a bi-level optimization problem which requires solving equation (4).

#### How to solve this problem?

Compute the gradient of the approximate predictive log-likelihood $$\ell\_t^{val}(\hat{\boldsymbol{\theta}\_t})$$ w\.r.t. the meta parameters $$\mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$:

$$
\nabla\_{\mathbf{\Phi}} \ell\_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}*{t}(\mathbf{\Phi})\right)=\nabla*{\hat{\boldsymbol{\theta}}*{t}}\ell*{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}*{t}\right)\nabla*{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\Phi})\
\space\space \space \space\space\space \space \space   ({\color{red}Eq.\*})
$$

* $$\nabla\_{\hat{\boldsymbol{\theta}}*{t}}\ell*{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}\_{t}\right)$$ is straightforwar&#x64;**.**
* $$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\Phi})$$ is not.

**Two methods for** $$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\Phi})$$ **:**

### 1) $$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\Phi})$$ explicit computation.

#### (Relation with MAML)

{% hint style="info" %} <img src="https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1uS2OA-VgI7IhFb8g%2F-MG1v0YdMR0sNP5lOYmQ%2Feq6.png?alt=media&#x26;token=9d0b7b64-da34-4ed5-8451-86c891ece92c" alt="" data-size="original"> (6)

If optimizing $$\ell\_t^{train}$$ requires **only a small number of local gradient steps**, we can compute the update for meta parameters ( $$\boldsymbol{\phi}, \boldsymbol{\sigma^2}$$ ) with back-propagation through $$\hat{\boldsymbol{\theta}\_t}$$ , yielding (6).

This update reduces to that of MAML if $$\sigma^2\_m\rightarrow \infty$$ for all modules and is thus denoted as    $$\sigma$$-MAML.

**Why?**  (my explanation)

According to equation (4):

$$\ell\_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}*{t}^{\operatorname{train}} \mid \boldsymbol{\theta}*{t}\right)-\log p\left(\boldsymbol{\theta}\_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

$$p\left(\boldsymbol{\theta}*{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\mathcal{N}(\boldsymbol{\theta}*{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}) = \frac{1}{\sigma \sqrt{2\pi}}e^{-\frac{(\theta\_t-\phi)^2}{2\sigma^2}}$$&#x20;
{% endhint %}

### 2)  $$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\Phi})$$ implicit computation.

#### (Relation with iMAML)

* We are more interested in **long adaptation horizons.**&#x20;
* **Implicit Function Theorem  (compute the gradient of** $$\hat{\boldsymbol{\theta}\_t}$$ w\.r.t $$\boldsymbol{\phi}$$ and $$\boldsymbol{\sigma^2}$$**)**

{% hint style="warning" %} <img src="https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MG1uS2OA-VgI7IhFb8g%2F-MG1wWTSE7UpZA72UX9P%2Feq7.png?alt=media&#x26;token=e8d089bb-1b91-4dd9-95bf-52459204e8aa" alt="" data-size="original"> (7)

* $$\mathbf{\Phi}=\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$  full vector of meta-parameters
* $$\mathbf{H}*{a b}=\nabla*{a, b}^{2} \ell\_{t}^{\text {train }}$$ ,  namely, $$\mathbf{H}*{\theta\_t\theta\_t}=\nabla*{\theta\_t}^{2} \ell\_{t}^{\text {train }}$$ , $$\mathbf{H}*{\theta\_t\Phi}=\nabla*{\theta\_t,\Phi}^{2} \ell\_{t}^{\text {train }}$$&#x20;
* Derivatives are evaluated at the stationary point $$\boldsymbol{\theta}*{t}=\hat{\boldsymbol{\theta}}*{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

**Derivation**:

Rewrite equ (4) as:

&#x20;                        $$\hat{\boldsymbol{\theta}}*{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\underset{\boldsymbol{\theta}*{t}}{\operatorname{argmin}} \ell\_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}\_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

where $$\ell\_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}*{t}^{\operatorname{train}} \mid \boldsymbol{\theta}*{t}\right)-\log p\left(\boldsymbol{\theta}\_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

So, $$\hat{\boldsymbol{\theta}}*{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$ is the stationary point of function $$\ell*{t}^{\operatorname{train}}\left(\boldsymbol{\theta}\_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$ .

Based on the implicit function theorem,

$$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}*{t}(\mathbf{\Phi}) = -\left(\nabla*{\boldsymbol{\theta}*{t}, \boldsymbol{\theta}*{t}}^{2}  \ell\_{t}^{\mathrm{train}}\right)^{-1} \nabla\_{\boldsymbol{\theta}*{t},\mathbf{\Phi}}^{2} \ell*{t}^{\mathrm{train}}$$&#x20;

Plug the equation above to $$({\color{red}Eq.\*})$$ , END.

**Meta update for** $$\boldsymbol{\phi}$$**is equivalent to that of iMAML when** $$\sigma\_m$$ **is constant for all** $$m$$**.**&#x20;

**Why?**

According to equation (4):

$$\ell\_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}*{t}^{\operatorname{train}} \mid \boldsymbol{\theta}*{t}\right)-\log p\left(\boldsymbol{\theta}\_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

Also because all modules share a constant variance, $$\sigma\_m = \sigma$$ ,

we can expand the log-prior term for task parameters $$\theta\_t$$ in eqution (4), and plug in the normal prior assumption as follows:
{% endhint %}

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHAJltIp06cF3xJz76N%2F-MHAXD6fM5yA6EiRMmgN%2Fb2.png?alt=media\&token=59df61bd-9ad4-4342-a414-dfb2c628cda1)

{% hint style="warning" %}
Meta-update of iMAML: $$\Delta\_{t}^{\mathrm{iMAML}}=\left(\mathbf{I}+\frac{1}{\lambda} \nabla\_{\boldsymbol{\theta}*{t}}^{2} \ell*{t}^{\operatorname{train}}\left(\boldsymbol{\theta}*{t}\right)\right)^{-1} \nabla*{\boldsymbol{\theta}*{t}} \ell*{t}^{\mathrm{val}}\left(\boldsymbol{\theta}\_{t}\right)$$&#x20;

if we define the regularization scale $$\lambda=\frac{1}{\sigma^2},$$ and plug in the definition $$\ell^{train}\_t:=-\log p(\mathcal{D}^{train}\_t|\theta\_t)$$ to the meta-update of iMAML, the finial equation is exactly the same with equation (41).
{% endhint %}

## Approximation Method 2: Estimate $$\phi$$ via MAP approximation

* **Relation with Reptile**

Integrating out the center parameter $$\phi$$ , and considering $$\phi$$ depends on all training data, we can rewrite the predictive likelihood in terms of the joint posterior over ( $$\theta\_{1:T},\phi$$ ), i.e.

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHAJltIp06cF3xJz76N%2F-MHAeGUoRf39UMy4VMv5%2Feq8.png?alt=media\&token=407d3891-3334-4ac4-ac62-9c95477c2b6b)

Similarly, this is still intractable, we make use of MAP approximation, to approximate both the task parameters **and the central meta parameters**:

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHAJltIp06cF3xJz76N%2F-MHAem_v3uCr3rKGlpSd%2Feq9.png?alt=media\&token=158d3bc4-44cf-431e-800f-b8368ffca750)

(assume $$\phi$$ is a flat prior, non-informative, no prior at all, the second term above is dropped)

By plugging the MAP of $$\phi$$ (Equ: 9) into Equ (8), and scaling by $$1/T$$ , we derive the approximate predictive log-likelihood as :

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHAJltIp06cF3xJz76N%2F-MHAgqoxGFroXM6Q4ij6%2Feq42.png?alt=media\&token=8810b8fb-581a-4508-a9e1-b057ab459247)

{% hint style="info" %}
compare this to Equ.5
{% endhint %}

The meta update for $$\phi$$ can be obtained by differentiating (9) w\.r.t. $$\phi$$ .

To derive the gradient of Eq. (42) with respect to $$\sigma^2$$ , notice that when $$\phi$$ is estimated as the MAP on the training subsets of all tasks, it becomes a function of $$\sigma^2$$ .

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHAJltIp06cF3xJz76N%2F-MHAprjDzAENtMsJaEaA%2Feq10.png?alt=media\&token=19960acb-461b-4c51-8af8-13454baa941a)

{% hint style="info" %}
For a specific individual task $$t$$ ,&#x20;

$$\nabla\_{\mathbf{\sigma^2}} \ell\_{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}*{t}(\mathbf{\sigma^2})\right)=\nabla*{\hat{\boldsymbol{\theta}}*{t}}\ell*{t}^{\mathrm{val}}\left(\hat{\boldsymbol{\theta}}*{t}\right)\nabla*{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\sigma^2})   \space\space \space \space\space\space \space \space   ({\color{red}Eq.\*\*})$$&#x20;

we take an approximation and ignore the dependence of $$\hat{\phi}$$ on $$\sigma^2$$ . Then $$\phi$$ becomes a constant when computing $$\nabla\_{\mathbf{\sigma^2}} \hat{\boldsymbol{\theta}}\_{t}(\mathbf{\sigma^2})$$ , and the derivation in above for iMAML applies by replacing $$\Phi$$ with $$\sigma^2$$ , giving the implicit gradient in Eq. (10).

**Derivation**:

Rewrite equ (4) as:

&#x20;                        $$\hat{\boldsymbol{\theta}}*{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)=\underset{\boldsymbol{\theta}*{t}}{\operatorname{argmin}} \ell\_{t}^{\operatorname{train}}\left(\boldsymbol{\theta}\_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

where $$\ell\_{t}^{\operatorname{train}}:=-\log p\left(\mathcal{D}*{t}^{\operatorname{train}} \mid \boldsymbol{\theta}*{t}\right)-\log p\left(\boldsymbol{\theta}\_{t} \mid \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$&#x20;

So, $$\hat{\boldsymbol{\theta}}*{t}\left(\boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$ is the stationary point of function $$\ell*{t}^{\operatorname{train}}\left(\boldsymbol{\theta}\_{t}, \boldsymbol{\sigma}^{2}, \boldsymbol{\phi}\right)$$ .

Based on the implicit function theorem,

$$\nabla\_{\mathbf{\Phi}} \hat{\boldsymbol{\theta}}*{t}(\mathbf{\Phi}) = -\left(\nabla*{\boldsymbol{\theta}*{t}, \boldsymbol{\theta}*{t}}^{2}  \ell\_{t}^{\mathrm{train}}\right)^{-1} \nabla\_{\boldsymbol{\theta}*{t},\mathbf{\Phi}}^{2} \ell*{t}^{\mathrm{train}}$$&#x20;

Plug the equation above to $$({\color{red}Eq.\*\*})$$ , END.

**Reptile is a special case of this method when** $$\sigma^2\_m\rightarrow \infty$$ **and we choose a learning rate proportional to** $$\sigma\_m^2$$ **for** $$\phi\_m$$ **. We thus refer to it as σ-Reptile.**

$$\Delta\_t^{\text{Reptile}}=\phi-\boldsymbol{\theta}\_t$$&#x20;
{% endhint %}

## Results:

### Module discovery of Image classification

* (9-layer network (4 conv layers, 4 batch-norm layers, and a linear output layer))
* Similar to ANIL paper (<https://openreview.net/forum?id=rkgMkCEtPB>)

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHArv_EMiOwpdUNf-AN%2F-MHAs6H-5uewd6097fMi%2Fres1.jpg?alt=media\&token=633f6633-0ce7-4a87-bda0-2ce508555a87)

### Module discovery of Text-to-speech

![](https://1687130946-files.gitbook.io/~/files/v0/b/gitbook-legacy-files/o/assets%2F-MEnQbUIupyAn8eMmrmG%2F-MHArv_EMiOwpdUNf-AN%2F-MHAtDyLvFfSbafH95eF%2Ffig2.png?alt=media\&token=6794f567-b611-4a8a-bbbd-4097d56b2cc1)

## Conclusion

This work proposes a **hierarchical Bayesian model for meta-learning** that places a shrinkage prior on each module to allow learning the extent to which each module should adapt, without a limitation on the adaptation horizon.&#x20;

Our formulation includes MAML, Reptile, and iMAML as special cases, empirically discovers a small set of task-specific modules in various domains, and shows promising improvement in a practical TTS application with low data and long task adaptation.

As a general modular meta-learning framework, it allows many interesting extensions, including incorporating alternative Bayesian inference algorithms, modular structure learning, and learn-to-optimize methods.

**Adaptation may not generalize to a task with a fundamentally different characteristic from the training distribution**. Applying our method to a new task without examining the task similarity runs a risk of transferring induced bias from meta-training to an out-of-sample task

Think of the connection among those papers (probabilistic graphical model explanation):

* Meta-Learning Probabilistic Inference For Prediction (<https://openreview.net/pdf?id=HkxStoC5F7>)
* Probabilistic model-agnostic meta-learning (<https://arxiv.org/abs/1806.02817>)
* Amortized Bayesian Meta-Learning (<https://openreview.net/pdf?id=rkgpy3C5tX>)
* Bayesian TAML (<https://openreview.net/forum?id=rkeZIJBYvr>)
*

## Reference

* <https://arxiv.org/pdf/1909.05557.pdf>
