Adaptive Risk Minimization: A Meta-Learning Approach for Tackling Group Shift

9-28-2020

Motivation

A fundamental assumption of most machine learning algorithms is that the training and test data are drawn from the same underlying distribution.

However, this assumption is violated in almost all practical applications: machine learning systems are regularly tested on data that are structurally different from the training set, either due to temporal correlations, particular end users, or other factors.

structural shift or group shift

Imagine an image classification system: After training on a large database of past images, this system is deployed to specific end users. Each user takes photos with differing cameras, locations, and subjects, leading to shift in the input distribution.

This test scenario must be carefully considered when building machine learning systems for real world applications.

Overview

We aim to get both robustness and performance.

To do so, we introduce a new assumption to tackling group shift — that we get the unlabeled test data at once in a batch (only available at test time)

Different from in domain adaptation and transductive learning settings

  • we do not need to specify the test distribution at training time

  • we are not limited to a single test distribution

This assumption is reasonable in many standard supervised learning setups.

Key idea: Rather than trying to be robust, we train a meta-learning algorithm to adapt to group distribution shift with unlabeled data.

We acquire such models by learning to adapt to training batches sampled according to different sub-distributions (groups, tasks), which simulate structural distribution shifts that may occur at test time.

adaptive risk minimization (ARM), a formalization of this setting that lends itself to meta-learning methods.

A schematic of our problem setting and approach

NOTE: update the note based on the newer version of the paper.

Reference

Last updated