Learning Steady-States of Iterative Algorithms over Graphs

ICML 2018 9/15/2020

Motivation

Many graph analytics problems can be solved via iterative algorithms according to the graph structure, and the solutions of the algorithms are often characterized by a set of steady-state conditions.

  • PageRank: score of a node

  • Mean field inference: the posterior distribution of a variable

Instead of designing algorithms for each individual graph problem, we take a different perspective:

Can we design a learning framework for a diverse range of graph problems that learns the algorithm over large graphs achieving the steady-state solutions efficiently and effectively?

How to represent the meta learner for such algorithm and how to carry out the learning of these algorithms?

Iterative algorithm over graphs

For a graph: G=(V,E)\mathcal{G}=(\mathcal{V}, \mathcal{E}) , with node set V\mathcal{V} and edge set E\mathcal{E} , many iterative algorithms over graphs can be formulated into:

hv(t+1)T({hu(t)}uN(v)),t1, and hv(0) constant ,vV      (1)\begin{array}{l} h_{v}^{(t+1)} \leftarrow \mathcal{T}\left(\left\{h_{u}^{(t)}\right\}_{u \in \mathcal{N}(v)}\right), \forall t \geqslant 1, \text { and } \\ h_{v}^{(0)} \leftarrow \text { constant }, \forall v \in \mathcal{V} \end{array} \space\space\space\space\space\space (1)

until the steady-state conditions are met:

hv=T({hu}uN(v)),vV      (2)h_{v}^{*}=\mathcal{T}\left(\left\{h_{u}^{*}\right\}_{u \in \mathcal{N}(v)}\right), \forall v \in \mathcal{V} \space\space\space\space\space\space (2)
  • node:v,uVnode: v,u\in \mathcal{V}

  • the set of neighbor nodes of v:N(v)\text{the set of neighbor nodes of } v: \mathcal{N}(v)

  • some operator :T()\text{some operator }: \mathcal{T}(\cdot)

  • Intermediate representation of node v:hv\text{Intermediate representation of node }v: h_v

  • Intermediate representation of node v at step t:hv(t)\text{Intermediate representation of node }v \text{ at step } t: h_v^{(t)}

  • Final (converged/steady) intermediate representation of node v:hv\text{Final (converged/steady) intermediate representation of node }v : h_v^{*}

Specifically:

1) Graph component detection problem

Goal: find all nodes within the same connected component as source node sVs\in \mathcal{V}

How: iteratively propagate the label at node ss to other nodes:

yv(t+1)=maxuN(v)yu(t),ys(0)=1,yv(0)=0,vVy_{v}^{(t+1)}=\max _{u \in \mathcal{N}(v)} y_{u}^{(t)}, y_{s}^{(0)}=1, y_{v}^{(0)}=0, \forall v \in \mathcal{V}
  • ys:y_s: label of node ss

  • ys(t):y_s^{(t)}: label of node ss at step tt

  • at initial step t=0t=0, the label at node ss are set to 1 (infected), 0 for all other nodes.

Steady state: nodes in the same connected component as ss are infected, (labelled as 1)

yv=maxuN(v)yuy_v^{*}= \max_{u\in \mathcal{N}(v)}y_u^{*}

2) PageRank scores for node importance

Goal: estimate the importance of each node in a graph

How: update score of each node iteratively

rv(t+1)(1λ)V+λN(v)uN(v)ru(t),vVr_{v}^{(t+1)} \leftarrow \frac{(1-\lambda)}{|\mathcal{V}|}+\frac{\lambda}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} r_{u}^{(t)}, \forall v \in \mathcal{V}
  • rv(t)r_v^{(t)} : score of node vv at step tt

  • Initialization: rv(0)=0,vVr_v^{(0)} = 0, \forall v \in \mathcal{V}

Steady state: rv=(1λ)V+λN(v)uN(v)rur_{v}^{*} = \frac{(1-\lambda)}{|\mathcal{V}|}+\frac{\lambda}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} r_{u}^{*}

3) Mean field inference in graphical model

Goal: approximate the marginal distributions of a set of variables xvx_v in a graph model defined on G\mathcal{G}

In graphical model, we know:

p({xv}vV)vVϕ(xv)(u,v)Eϕ(xu,xv)p\left(\left\{x_{v}\right\}_{v \in \mathcal{V}}\right) \propto \prod_{v \in \mathcal{V}} \phi\left(x_{v}\right) \prod_{(u, v) \in \mathcal{E}} \phi\left(x_{u}, x_{v}\right)
  • p({xv}vV)p\left(\left\{x_{v}\right\}_{v \in \mathcal{V}}\right) : true marginal distribution

  • ϕ(xv) and ϕ(xu,xv)\phi(x_v) \text{ and } \phi(x_u,x_v) are node and edge potential respectively

How: The marginal approximation can be obtained in an iterative fashion by the following mean field update:

q(t+1)(xv)ϕ(xv)uN(v)exp(uq(t)(xu)logϕ(xu,xv)du)\begin{aligned} q^{(t+1)}\left(x_{v}\right) \leftarrow \phi\left(x_{v}\right) \prod_{u \in \mathcal{N}(v)} & \exp \left(\int_{u}q^{(t)}\left(x_{u}\right) \log \phi\left(x_{u}, x_{v}\right) \mathrm{d} u\right) \end{aligned}
  • q(t+1)(xv)q^{(t+1)}\left(x_{v}\right) : marginal approximation of a set of variables xvx_v at step t+1t+1

Steady State: q(xv)=ϕ(xv)uN(v)exp(uq(xu)logϕ(xu,xv)du)\begin{aligned} q^{*}\left(x_{v}\right) = \phi\left(x_{v}\right) \prod_{u \in \mathcal{N}(v)} & \exp \left(\int_{u}q^{*}\left(x_{u}\right) \log \phi\left(x_{u}, x_{v}\right) \mathrm{d} u\right) \end{aligned}

4) Compute long range graph convolution features (node classification)

Goal: extract long range features from graph and use that features to figure to capture the relation between graph topology and external labels

How: One possible parametrization of graph convolution features hvh_v can be updated from zeros initialization as:

hv(t+1)σ(W1xv+W2uN(v)hu(t))h_{v}^{(t+1)} \leftarrow \sigma\left(W_{1} x_{v}+W_{2} \sum_{u \in \mathcal{N}(v)} h_{u}^{(t)}\right)
  • hv(t+1):graph convolution features for node v at step t+1h_{v}^{(t+1)} : \text{graph convolution features for node } v \text{ at step } t+1

  • σ:a nonlinear element-wise operation\sigma : \text{a nonlinear element-wise operation}

  • W1,W2:parameters of the operatorW_1,W_2: \text{parameters of the operator}

Steady State: hvσ(W1xv+W2uN(v)hu)h_{v}^{*} \leftarrow \sigma\left(W_{1} x_{v}+W_{2} \sum_{u \in \mathcal{N}(v)} h_{u}^{*}\right)

After that: the label for each node will be determined by the steady state feature hvh_v^* by a labeling function: f(hv)f(h_v^*)

The Algorithm Learning Problem: framework of algorithm design

Assumption: we have collected the output of an iterative algorithm T\mathcal{T} over a single large graph.

Training dataset (input of the proposed algorithm) consists of:

  • input graph G=(V,E)\mathcal{G}=(\mathcal{V},\mathcal{E})

  • output of the iterative algorithm for a subset of nodes, V(y)V\mathcal{V}^{(y)} \subseteq\mathcal{V} , (Note: labeled nodes) from the graph

D={fv:=f(hv)hv=T[{hu}uN(v)],vV(y)}     (3)\mathcal{D}=\left\{f_v^{*}:=f(h_v^{*})|h_v^{*}=\mathcal{T[\{h_u^{*}\}_{u\in \mathcal{N}(v)}], v\in \mathcal{V}}^{(y)}\right\} \space\space\space\space\space (3)

  • hvh_v^{*} : is the quantity in the iterative algorithm satisfying the steady-state conditions

  • f()f(\cdot) : an additional labeling function taking input the steady-state quantity, produces the final label for each node

  • fv:f_v^{*}: ground truth of node vv

Given the above D\mathcal{D} ,

Goal: to learn a parameterized algorithm AΘ\mathcal{A}_{\Theta} , such that the output of the algorithm AΘ\mathcal{A}_{\Theta} can mimic the output of the original algorithm T\mathcal{T} .

Namely:

The output of AΘ\mathcal{A}_{\Theta} is: AΘ[G]={fv^}vV(y)\mathcal{A}_{\Theta}[\mathcal{G}]=\{\hat{f_v}\}_{v\in \mathcal{V}^{(y)}} , which are close to fvf_v^{*} according to some loss function.

The algorithm learning problem for AΘ\mathcal{A}_{\Theta} can be formulated into the following optimization problem:

minΘvV(y)(fv,f^v)       (4) s.t. {f^v}vV(y)=AΘ[G]       (5) \min_{\Theta} \sum_{v \in \mathcal{V}^{(y)}} \ell\left(f_{v}^{*}, \widehat{f}_{v}\right) \space\space\space\space\space\space\space (4)\\ \text { s.t. }\left\{\widehat{f}_{v}\right\}_{v \in \mathcal{V}^{(y)}}=\mathcal{A}_{\Theta}[\mathcal{G}] \space\space\space\space\space\space\space (5)
  • (fv,f^v)\ell\left(f_{v}^{*}, \widehat{f}_{v}\right) : loss function

Design goal: respect steady-state conditions and learn fast.

Core:

  • a steady-state operator TΘ\mathcal{T}_{\Theta}between vector embedding representations of nodes

  • a link function gg mapping the embedding to the algorithm output.

  • Note: namely AΘ:TΘ and g\mathcal{A}_{\Theta}: \mathcal{T}_{\Theta} \text{ and } g

 output :{f^v:=g(h^v)}vV         (6) s.t. h^v=TΘ[{h^u}uN(v)]         (7)\begin{aligned} \text { output }: &\left\{\widehat{f}_{v}:=g(\hat{h}_{v})\right\}_{v \in \mathcal{V}} \space\space\space\space\space\space\space\space\space (6)\\ \text { s.t. } & \widehat{h}_{v}=\mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right] \space\space\space\space\space\space\space\space\space (7) \end{aligned}
  • initialization: h^vconstant for all vV\widehat{h}_v \leftarrow \text{constant for all } v\in \mathcal{V}

  • update using equation (7)

Operator TΘ\mathcal{T}_{\Theta} : a two-layer NN

  • General nolinear function class

  • The operator enforces the steady-state condition of node embeddings based on 1-hop

    local neighborhood information.

  • Due to the variety of graph structures, this function should be able to handle different

    number of inputs (i.e., different number of neighbor nodes)

h^v=TΘ[{h^u}uN(v)]=W1σ(W2[xv,uN(v)[h^u,xu]])        (9)\widehat{h}_v = \mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right]=W_{1} \sigma\left(W_{2}\left[x_{v}, \sum_{u \in \mathcal{N}(v)}\left[\widehat{h}_{u}, x_{u}\right]\right]\right) \space\space\space\space\space\space\space\space (9)
  • σ()\sigma(\cdot) : element-wise activation function: Sigmoid, ReLU

  • W1,W2:weight matrices of NN.W_1, W_2 : \text{weight matrices of NN} . W1:first layer, W2:2nd layerW_1: \text{first layer, } W_2: \text{2nd layer}

  • xvx_v : the optional feature representation of nodes

  • General nolinear function class

  • input: node embeddings

  • predict: the corresponding algorithm outputs (like label of node)

g(h^v)=σ(V2ReLU(V1h^v))      (10)g\left(\widehat{h}_{v}\right)=\sigma\left(V_{2}^{\top} \operatorname{ReLU}\left(V_{1}^{\top} \widehat{h}_{v}\right)\right) \space\space\space\space\space\space (10)
  • h^v\widehat{h}_{v} : node embeddings

  • V1,V2:parameters of g.V_1, V_2 : \text{parameters of } g. V1:first layer, V2:2nd layerV_1: \text{first layer, } V_2: \text{2nd layer}

  • σ:task-specific activation function\sigma:\text{task-specific activation function}

    • linear regression: identity function σ(x)=x\sigma(x)=x

    • multi-class classification: σ()\sigma(\cdot) is softmax (output a probabilistic simplex)

The overall optimization problem

min{Wi,Vi}i=12L({Wi,Vi}i=12):=1VyvV(y)(fv,g(h^v)) s.t. h^v=TΘ[{h^u}uN(v)],vV      (11)\begin{array}{c} \min _{\left\{W_{i}, V_{i}\right\}_{i=1}^{2}} \mathcal{L}\left(\left\{W_{i}, V_{i}\right\}_{i=1}^{2}\right):=\frac{1}{\left|\mathcal{V}^{y}\right|} \sum_{v \in \mathcal{V}^{(y)}} \ell(f_{v}^{*}, g(\hat{h}_{v})) \\ \text { s.t. } \widehat{h}_{v}=\mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right], \forall v \in \mathcal{V} \end{array} \space\space\space\space\space\space (11)
  • W1,W2:parameters of TΘW_1,W_2: \text{parameters of } \mathcal{T}_{\Theta}

  • V1,V2:parameters of gV_1,V_2: \text{parameters of } g

  • my understanding: semi-surpervised learning

How to solve (11): Stochastic Steady-state Embedding(SSE)

An alternating algorithm: alternate between:

  • using most current model to find the embeddings and make prediction

  • using the gradient of the loss with respect to {W1,W2,V1,V2}\{W_1, W_2, V_1, V_2\} for update these parameters

Intuition:

  • RL (policy iteration): improve the policy minimizing the cost proportional to ff^{*} by updating the parameters TΘ and g\mathcal{T}_{\Theta} \text{ and } g

    • steady-state hv^\hat{h_v} for each node: "value function"

    • embedding operator TΘ\mathcal{T}_{\Theta} and classifier function gg : "policy"

  • K-means and EM (mine)

"Value" estimation: estimate steady-state hv^\hat{h_v}

limitation: it is prohibitive to solve the steady-state equation exactly in large-scale graph with millions of vertices since it requires visiting all the nodes in the graph.

Solution: stochastic fixed point iteration, the extra randomness on the constraints for sampling the constraints to tackle the groups of equations approximately.

In k-thk\text{-th} step, first sample a set of nodes V~={v1,v2,,vN}V\tilde{\mathcal{V}}=\{v_1,v_2,\dots,v_N\}\in \mathcal{V} from the entire node set rather of the labeled set. Update the new embedding by moving average:

h^vi(k)(1α)h^vi(k1)+αTΘ[{h^u(k1)}uN(vi)],viV~        (12)\widehat{h}_{v_{i}}^{(k)} \leftarrow(1-\alpha) \hat{h}_{v_{i}}^{(k-1)}+\alpha \mathcal{T}_{\Theta}\left[\{\widehat{h}_{u}^{(k-1)}\}_{u \in \mathcal{N}\left(v_{i}\right)}\right], \forall v_{i} \in \tilde{\mathcal{V}} \space\space\space\space\space\space\space\space (12)
  • α:0α1\alpha: 0 \leq \alpha\leq1

"Policy" improvement: update parameters of TΘ and g\mathcal{T}_{\Theta} \text{ and } g

At the k-thk\text{-th} step, once we have {h^v(k)}vV\{\widehat{h}_{v}^{(k)}\}_{v\in \mathcal{V}} satisfying the steady-state equation, we use vanilla stochastic gradient descent to update parameters {W1,W2,V1,V2}\{W_1, W_2, V_1, V_2\} :

LVi=E^[(fv,g(h^vk))g(h^vk)g(h^vk)Vi]LWi=E^[(fv,g(h^vk))h^vk.TΘWi]\begin{aligned} \frac{\partial \mathcal{L}}{\partial \color{red}V_{i}} &=\widehat{\mathbb{E}}\left[\frac{\partial \ell\left(f_{v}^{*}, g\left(\hat{h}_{v}^{k}\right)\right)}{\partial g\left(\hat{h}_{v}^{k}\right)} \frac{\partial g\left(\hat{h}_{v}^{k}\right)}{\partial \color{red}V_{i}}\right] \\ \frac{\partial \mathcal{L}}{\partial \color{red}W_{i}} &=\widehat{\mathbb{E}}\left[\frac{\partial \ell\left(f_{v}^{*}, g\left(\hat{h}_{v}^{k}\right)\right)}{\partial \widehat{h}_{v}^{k}} \frac{.\partial \mathcal{T}_{\Theta}}{\partial \color{red}W_{i}}\right] \end{aligned}
  • E^[]\widehat{\mathbb{E}}[\cdot] : the expectation is w.r.t. uniform distribution over labeled nodes V(y)\mathcal{V}^{(y)} .

  • nhn_h : the # of inner loops in "value" estimation

  • nfn_f : the # of inner loops in "policy" improvement

Complexity

  • Memory space

    • O(V):O(|\mathcal{V}|): The dominating part is the persistent node embedding matrix{h^v}vV\{\widehat{h}_v\}_{v\in \mathcal{V}}

    • O(TV):O(T|\mathcal{V}|): T-hopsT\text{-hops} for GNN family

  • Time: the computational cost in each iteration is just proportional to the number of edges in each mini-batch.

    • "policy" improvement: O(MEV)O(M\frac{|\mathcal{E}|}{|\mathcal{V}|})

    • "value" estimation: O(NEV)O(N\frac{|\mathcal{E}|}{|\mathcal{V}|})

Reference

Last updated