Learning Causal Models Online
arxiv 10-9-2020
Last updated
arxiv 10-9-2020
Last updated
Predictive models – learned from observational data not covering the complete data distribution – can rely on spurious correlations in the data for making predictions. These correlations make the models brittle and hinder generalization. One solution for achieving strong generalization is to incorporate causal structures in the models; such structures constrain learning by ignoring correlations that contradict them. However, learning these structures is a hard problem in itself. Moreover, it’s not clear how to incorporate the machinery of causality with online continual learning. In this work, we take an indirect approach to discovering causal models. Instead of searching for the true causal model directly, we propose an online algorithm that continually detects and removes spurious features. Our algorithm works on the idea that the correlation of a spurious feature with a target is not constant over-time. As a result, the weight associated with that feature is constantly changing. We show that by continually removing such features, our method converges to solutions that have strong generalization. Moreover, our method combined with random search can also discover non-spurious features from raw sensory data. Finally, our work highlights that the information present in the temporal structure of the problem – destroyed by shuffling the data – is essential for detecting spurious features online.
Learning to make predictions in a Markov decision process (MDP):
set of states S
set of actions A
reward function r
Transition model
In a prediction problem, the agent has to learn a functionto predict a target using parameters θ. As the agent transitions to the new state , it receives the ground truth label from the environment and accumulates regret given by . The agent can use it to update its estimate of .
Objective
Consider n features that can be linearly combined using parameters to predict a target y. Moreover, assume that all features are binary (0 or 1). Given these features, our goal is to identify and remove the spurious ones.
We define a feature to have a spurious correlation with the target y if the expected value of target given is not constant in temporally distant parts of the MDP i.e. slowly changes as the agent interacts with the world.
Detecting Spurious Features
For a linear prediction problem, detecting if the i-th feature is spurious is equivalent to tracking the stability of the across time. i.e. if the online learner is always learning using the most recent data with the following update:
Learning Causal Models Online
the weight corresponding to the features with a constant expected value, , would converge to a fixed magnitude. Whereas if is changing, would track this change by changing its magnitude overtime. This implies that weights that are constantly changing in a stationary prediction problem encode spurious correlations. We can approximate the change in the weight, , overtime by approximating its variance online. Our hypothesis is that spurious features would have weights that have high variance.