Learning Causal Models Online

arxiv 10-9-2020

Motivation

Predictive models – learned from observational data not covering the complete data distribution – can rely on spurious correlations in the data for making predictions. These correlations make the models brittle and hinder generalization. One solution for achieving strong generalization is to incorporate causal structures in the models; such structures constrain learning by ignoring correlations that contradict them. However, learning these structures is a hard problem in itself. Moreover, it’s not clear how to incorporate the machinery of causality with online continual learning. In this work, we take an indirect approach to discovering causal models. Instead of searching for the true causal model directly, we propose an online algorithm that continually detects and removes spurious features. Our algorithm works on the idea that the correlation of a spurious feature with a target is not constant over-time. As a result, the weight associated with that feature is constantly changing. We show that by continually removing such features, our method converges to solutions that have strong generalization. Moreover, our method combined with random search can also discover non-spurious features from raw sensory data. Finally, our work highlights that the information present in the temporal structure of the problem – destroyed by shuffling the data – is essential for detecting spurious features online.

Problem Setup

Learning to make predictions in a Markov decision process (MDP): (S,A,r,p)(\mathcal{S},\mathcal{A}, r, p)

  • set of states S

  • set of actions A

  • reward function r

  • Transition model

In a prediction problem, the agent has to learn a functionfθ(st,at)f_{\theta}(s_t', a_t)to predict a target yty_tusing parameters θ. As the agent transitions to the new state st+1s_{t+1}, it receives the ground truth label y^t\hat{y}_t from the environment and accumulates regret given by L(yt,y^t)\mathcal{L}(y_t,\hat{y}_t). The agent can use it to update its estimate of fθf_{\theta} .

Objective

Consider n features x=f1,f2,...,fnx = f_1, f_2, ..., f_nthat can be linearly combined using parameters w1,w2,...,wnw_1, w_2,...,w_n to predict a target y. Moreover, assume that all features are binary (0 or 1). Given these features, our goal is to identify and remove the spurious ones.

We define a feature fif_ito have a spurious correlation with the target y if the expected value of target given fif_iis not constant in temporally distant parts of the MDP i.e. E[yfi=1]\mathbb{E}[y|f_i=1] slowly changes as the agent interacts with the world.

Detecting Spurious Features

For a linear prediction problem, detecting if the i-th feature is spurious is equivalent to tracking the stability of the wiw_i across time. i.e. if the online learner is always learning using the most recent data with the following update:

the weight corresponding to the features with a constant expected value, E[yfi=1]=c\mathbb{E}[y|f_i=1]=c, would converge to a fixed magnitude. Whereas if E[yfi=1]\mathbb{E}[y|f_i=1] is changing, wiw_iwould track this change by changing its magnitude overtime. This implies that weights that are constantly changing in a stationary prediction problem encode spurious correlations. We can approximate the change in the weight, wiw_i, overtime by approximating its variance online. Our hypothesis is that spurious features would have weights that have high variance.

Learning Causal Models Online

Reference

Last updated

Was this helpful?