Motivation
Many graph analytics problems can be solved via iterative algorithms according to the graph structure, and the solutions of the algorithms are often characterized by a set of steady-state conditions.
PageRank: score of a node
Mean field inference: the posterior distribution of a variable
Instead of designing algorithms for each individual graph problem, we take a different perspective:
Can we design a learning framework for a diverse range of graph problems that learns the algorithm over large graphs achieving the steady-state solutions efficiently and effectively ?
How to represent the meta learner for such algorithm and how to carry out the learning of these algorithms?
Iterative algorithm over graphs
For a graph: G = ( V , E ) \mathcal{G}=(\mathcal{V}, \mathcal{E}) G = ( V , E ) , with node set V \mathcal{V} V and edge set E \mathcal{E} E , many iterative algorithms over graphs can be formulated into:
h v ( t + 1 ) ā T ( { h u ( t ) } u ā N ( v ) ) , ā t ⩾ 1 , Ā andĀ h v ( 0 ) ā Ā constantĀ , ā v ā V Ā Ā Ā Ā Ā Ā ( 1 ) \begin{array}{l}
h_{v}^{(t+1)} \leftarrow \mathcal{T}\left(\left\{h_{u}^{(t)}\right\}_{u \in \mathcal{N}(v)}\right), \forall t \geqslant 1, \text { and } \\
h_{v}^{(0)} \leftarrow \text { constant }, \forall v \in \mathcal{V}
\end{array} \space\space\space\space\space\space (1) h v ( t + 1 ) ā ā T ( { h u ( t ) ā } u ā N ( v ) ā ) , ā t ⩾ 1 , Ā andĀ h v ( 0 ) ā ā Ā constantĀ , ā v ā V ā Ā Ā Ā Ā Ā Ā ( 1 ) until the steady-state conditions are met:
h v ā = T ( { h u ā } u ā N ( v ) ) , ā v ā V Ā Ā Ā Ā Ā Ā ( 2 ) h_{v}^{*}=\mathcal{T}\left(\left\{h_{u}^{*}\right\}_{u \in \mathcal{N}(v)}\right), \forall v \in \mathcal{V} \space\space\space\space\space\space (2) h v ā ā = T ( { h u ā ā } u ā N ( v ) ā ) , ā v ā V Ā Ā Ā Ā Ā Ā ( 2 ) n o d e : v , u ā V node: v,u\in \mathcal{V} n o d e : v , u ā V
theĀ setĀ ofĀ neighborĀ nodesĀ ofĀ v : N ( v ) \text{the set of neighbor nodes of } v: \mathcal{N}(v) theĀ setĀ ofĀ neighborĀ nodesĀ ofĀ v : N ( v )
someĀ operatorĀ : T ( ā
) \text{some operator }: \mathcal{T}(\cdot) someĀ operatorĀ : T ( ā
)
IntermediateĀ representationĀ ofĀ nodeĀ v : h v \text{Intermediate representation of node }v: h_v IntermediateĀ representationĀ ofĀ nodeĀ v : h v ā
IntermediateĀ representationĀ ofĀ nodeĀ v Ā atĀ stepĀ t : h v ( t ) \text{Intermediate representation of node }v \text{ at step } t: h_v^{(t)} IntermediateĀ representationĀ ofĀ nodeĀ v Ā atĀ stepĀ t : h v ( t ) ā
FinalĀ (converged/steady)Ā intermediateĀ representationĀ ofĀ nodeĀ v : h v ā \text{Final (converged/steady) intermediate representation of node }v : h_v^{*} FinalĀ (converged/steady)Ā intermediateĀ representationĀ ofĀ nodeĀ v : h v ā ā
Specifically:
1) Graph component detection problem
Goal : find all nodes within the same connected component as source node s ā V s\in \mathcal{V} s ā V
How : iteratively propagate the label at node s s s to other nodes:
y v ( t + 1 ) = max ā” u ā N ( v ) y u ( t ) , y s ( 0 ) = 1 , y v ( 0 ) = 0 , ā v ā V y_{v}^{(t+1)}=\max _{u \in \mathcal{N}(v)} y_{u}^{(t)}, y_{s}^{(0)}=1, y_{v}^{(0)}=0, \forall v \in \mathcal{V} y v ( t + 1 ) ā = u ā N ( v ) max ā y u ( t ) ā , y s ( 0 ) ā = 1 , y v ( 0 ) ā = 0 , ā v ā V y s : y_s: y s ā : label of node s s s
y s ( t ) : y_s^{(t)}: y s ( t ) ā : label of node s s s at step t t t
at initial step t = 0 t=0 t = 0 , the label at node s s s are set to 1 (infected), 0 for all other nodes.
Steady state : nodes in the same connected component as s s s are infected, (labelled as 1)
y v ā = max ā” u ā N ( v ) y u ā y_v^{*}= \max_{u\in \mathcal{N}(v)}y_u^{*} y v ā ā = u ā N ( v ) max ā y u ā ā 2) PageRank scores for node importance
Goal : estimate the importance of each node in a graph
How : update score of each node iteratively
r v ( t + 1 ) ā ( 1 ā Ī» ) ⣠V ⣠+ Ī» ⣠N ( v ) ⣠ā u ā N ( v ) r u ( t ) , ā v ā V r_{v}^{(t+1)} \leftarrow \frac{(1-\lambda)}{|\mathcal{V}|}+\frac{\lambda}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} r_{u}^{(t)}, \forall v \in \mathcal{V} r v ( t + 1 ) ā ā ⣠V ⣠( 1 ā Ī» ) ā + ⣠N ( v ) ⣠λ ā u ā N ( v ) ā ā r u ( t ) ā , ā v ā V r v ( t ) r_v^{(t)} r v ( t ) ā : score of node v v v at step t t t
Initialization: r v ( 0 ) = 0 , ā v ā V r_v^{(0)} = 0, \forall v \in \mathcal{V} r v ( 0 ) ā = 0 , ā v ā V
Steady state : r v ā = ( 1 ā Ī» ) ⣠V ⣠+ Ī» ⣠N ( v ) ⣠ā u ā N ( v ) r u ā r_{v}^{*} = \frac{(1-\lambda)}{|\mathcal{V}|}+\frac{\lambda}{|\mathcal{N}(v)|} \sum_{u \in \mathcal{N}(v)} r_{u}^{*} r v ā ā = ⣠V ⣠( 1 ā Ī» ) ā + ⣠N ( v ) ⣠λ ā ā u ā N ( v ) ā r u ā ā
3) Mean field inference in graphical model
Goal : approximate the marginal distributions of a set of variables x v x_v x v ā in a graph model defined on G \mathcal{G} G
In graphical model, we know:
p ( { x v } v ā V ) ā ā v ā V Ļ ( x v ) ā ( u , v ) ā E Ļ ( x u , x v ) p\left(\left\{x_{v}\right\}_{v \in \mathcal{V}}\right) \propto \prod_{v \in \mathcal{V}} \phi\left(x_{v}\right) \prod_{(u, v) \in \mathcal{E}} \phi\left(x_{u}, x_{v}\right) p ( { x v ā } v ā V ā ) ā v ā V ā ā Ļ ( x v ā ) ( u , v ) ā E ā ā Ļ ( x u ā , x v ā ) p ( { x v } v ā V ) p\left(\left\{x_{v}\right\}_{v \in \mathcal{V}}\right) p ( { x v ā } v ā V ā ) : true marginal distribution
Ļ ( x v ) Ā andĀ Ļ ( x u , x v ) \phi(x_v) \text{ and } \phi(x_u,x_v) Ļ ( x v ā ) Ā andĀ Ļ ( x u ā , x v ā ) are node and edge potential respectively
How : The marginal approximation can be obtained in an iterative fashion by the following mean field update:
q ( t + 1 ) ( x v ) ā Ļ ( x v ) ā u ā N ( v ) exp ā” ( ā« u q ( t ) ( x u ) log ā” Ļ ( x u , x v ) d u ) \begin{aligned}
q^{(t+1)}\left(x_{v}\right) \leftarrow \phi\left(x_{v}\right) \prod_{u \in \mathcal{N}(v)}
& \exp \left(\int_{u}q^{(t)}\left(x_{u}\right) \log \phi\left(x_{u}, x_{v}\right) \mathrm{d} u\right)
\end{aligned} q ( t + 1 ) ( x v ā ) ā Ļ ( x v ā ) u ā N ( v ) ā ā ā exp ( ā« u ā q ( t ) ( x u ā ) log Ļ ( x u ā , x v ā ) d u ) ā q ( t + 1 ) ( x v ) q^{(t+1)}\left(x_{v}\right) q ( t + 1 ) ( x v ā ) : marginal approximation of a set of variables x v x_v x v ā at step t + 1 t+1 t + 1
Steady State : q ā ( x v ) = Ļ ( x v ) ā u ā N ( v ) exp ā” ( ā« u q ā ( x u ) log ā” Ļ ( x u , x v ) d u ) \begin{aligned} q^{*}\left(x_{v}\right) = \phi\left(x_{v}\right) \prod_{u \in \mathcal{N}(v)} & \exp \left(\int_{u}q^{*}\left(x_{u}\right) \log \phi\left(x_{u}, x_{v}\right) \mathrm{d} u\right) \end{aligned} q ā ( x v ā ) = Ļ ( x v ā ) u ā N ( v ) ā ā ā exp ( ā« u ā q ā ( x u ā ) log Ļ ( x u ā , x v ā ) d u ) ā
4) Compute long range graph convolution features (node classification)
Goal: extract long range features from graph and use that features to figure to capture the relation between graph topology and external labels
How: One possible parametrization of graph convolution features h v h_v h v ā can be updated from zeros initialization as:
h v ( t + 1 ) ā Ļ ( W 1 x v + W 2 ā u ā N ( v ) h u ( t ) ) h_{v}^{(t+1)} \leftarrow \sigma\left(W_{1} x_{v}+W_{2} \sum_{u \in \mathcal{N}(v)} h_{u}^{(t)}\right) h v ( t + 1 ) ā ā Ļ ā W 1 ā x v ā + W 2 ā u ā N ( v ) ā ā h u ( t ) ā ā h v ( t + 1 ) : graphĀ convolutionĀ featuresĀ forĀ nodeĀ v Ā atĀ stepĀ t + 1 h_{v}^{(t+1)} : \text{graph convolution features for node } v \text{ at step } t+1 h v ( t + 1 ) ā : graphĀ convolutionĀ featuresĀ forĀ nodeĀ v Ā atĀ stepĀ t + 1
Ļ : aĀ nonlinearĀ element-wiseĀ operation \sigma : \text{a nonlinear element-wise operation} Ļ : aĀ nonlinearĀ element-wiseĀ operation
W 1 , W 2 : parametersĀ ofĀ theĀ operator W_1,W_2: \text{parameters of the operator} W 1 ā , W 2 ā : parametersĀ ofĀ theĀ operator
Steady State: h v ā ā Ļ ( W 1 x v + W 2 ā u ā N ( v ) h u ā ) h_{v}^{*} \leftarrow \sigma\left(W_{1} x_{v}+W_{2} \sum_{u \in \mathcal{N}(v)} h_{u}^{*}\right) h v ā ā ā Ļ ( W 1 ā x v ā + W 2 ā ā u ā N ( v ) ā h u ā ā )
After that: the label for each node will be determined by the steady state feature h v ā h_v^* h v ā ā by a labeling function: f ( h v ā ) f(h_v^*) f ( h v ā ā )
The Algorithm Learning Problem: framework of algorithm design
Assumption : we have collected the output of an iterative algorithm T \mathcal{T} T over a single large graph.
Training dataset (input of the proposed algorithm) consists of:
input graph G = ( V , E ) \mathcal{G}=(\mathcal{V},\mathcal{E}) G = ( V , E )
output of the iterative algorithm for a subset of nodes, V ( y ) ā V \mathcal{V}^{(y)} \subseteq\mathcal{V} V ( y ) ā V , (Note: labeled nodes ) from the graph
D = { f v ā : = f ( h v ā ) ⣠h v ā = T [ { h u ā } u ā N ( v ) ] , v ā V ( y ) } Ā Ā Ā Ā Ā ( 3 ) \mathcal{D}=\left\{f_v^{*}:=f(h_v^{*})|h_v^{*}=\mathcal{T[\{h_u^{*}\}_{u\in \mathcal{N}(v)}], v\in \mathcal{V}}^{(y)}\right\} \space\space\space\space\space (3) D = { f v ā ā := f ( h v ā ā ) ⣠h v ā ā = T [{ h u ā ā } u ā N ( v ) ā ] , v ā V ( y ) } Ā Ā Ā Ā Ā ( 3 )
h v ā h_v^{*} h v ā ā : is the quantity in the iterative algorithm satisfying the steady-state conditions
f ( ā
) f(\cdot) f ( ā
) : an additional labeling function taking input the steady-state quantity, produces the final label for each node
f v ā : f_v^{*}: f v ā ā : ground truth of node v v v
Given the above D \mathcal{D} D ,
Goal: to learn a parameterized algorithm A Ī \mathcal{A}_{\Theta} A Ī ā , such that the output of the algorithm A Ī \mathcal{A}_{\Theta} A Ī ā can mimic the output of the original algorithm T \mathcal{T} T .
Namely:
The output of A Ī \mathcal{A}_{\Theta} A Ī ā is: A Ī [ G ] = { f v ^ } v ā V ( y ) \mathcal{A}_{\Theta}[\mathcal{G}]=\{\hat{f_v}\}_{v\in \mathcal{V}^{(y)}} A Ī ā [ G ] = { f v ā ^ ā } v ā V ( y ) ā , which are close to f v ā f_v^{*} f v ā ā according to some loss function.
The algorithm learning problem for A Ī \mathcal{A}_{\Theta} A Ī ā can be formulated into the following optimization problem:
min ā” Ī ā v ā V ( y ) ā ( f v ā , f ^ v ) Ā Ā Ā Ā Ā Ā Ā ( 4 ) Ā s.t.Ā { f ^ v } v ā V ( y ) = A Ī [ G ] Ā Ā Ā Ā Ā Ā Ā ( 5 )
\min_{\Theta} \sum_{v \in \mathcal{V}^{(y)}} \ell\left(f_{v}^{*}, \widehat{f}_{v}\right) \space\space\space\space\space\space\space (4)\\
\text { s.t. }\left\{\widehat{f}_{v}\right\}_{v \in \mathcal{V}^{(y)}}=\mathcal{A}_{\Theta}[\mathcal{G}] \space\space\space\space\space\space\space (5)
Ī min ā v ā V ( y ) ā ā ā ( f v ā ā , f ā v ā ) Ā Ā Ā Ā Ā Ā Ā ( 4 ) Ā s.t.Ā { f ā v ā } v ā V ( y ) ā = A Ī ā [ G ] Ā Ā Ā Ā Ā Ā Ā ( 5 ) ā ( f v ā , f ^ v ) \ell\left(f_{v}^{*}, \widehat{f}_{v}\right) ā ( f v ā ā , f ā v ā ) : loss function
Design goal : respect steady-state conditions and learn fast .
Core :
a steady-state operator T Ī \mathcal{T}_{\Theta} T Ī ā between vector embedding representations of nodes
a link function g g g mapping the embedding to the algorithm output.
Note : namely A Ī : T Ī Ā andĀ g \mathcal{A}_{\Theta}: \mathcal{T}_{\Theta} \text{ and } g A Ī ā : T Ī ā Ā andĀ g
Stead-state operator and link function
Ā outputĀ : { f ^ v : = g ( h ^ v ) } v ā V Ā Ā Ā Ā Ā Ā Ā Ā Ā ( 6 ) Ā s.t.Ā h ^ v = T Ī [ { h ^ u } u ā N ( v ) ] Ā Ā Ā Ā Ā Ā Ā Ā Ā ( 7 ) \begin{aligned}
\text { output }: &\left\{\widehat{f}_{v}:=g(\hat{h}_{v})\right\}_{v \in \mathcal{V}} \space\space\space\space\space\space\space\space\space (6)\\
\text { s.t. } & \widehat{h}_{v}=\mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right] \space\space\space\space\space\space\space\space\space (7)
\end{aligned} Ā outputĀ : Ā s.t.Ā ā { f ā v ā := g ( h ^ v ā ) } v ā V ā Ā Ā Ā Ā Ā Ā Ā Ā Ā ( 6 ) h v ā = T Ī ā [ { h u ā } u ā N ( v ) ā ] Ā Ā Ā Ā Ā Ā Ā Ā Ā ( 7 ) ā initialization: h ^ v ā constantĀ forĀ allĀ v ā V \widehat{h}_v \leftarrow \text{constant for all } v\in \mathcal{V} h v ā ā constantĀ forĀ allĀ v ā V
update using equation (7)
Operator T Ī \mathcal{T}_{\Theta} T Ī ā : a two-layer NN
General nolinear function class
The operator enforces the steady-state condition of node embeddings based on 1-hop
local neighborhood information.
Due to the variety of graph structures, this function should be able to handle different
number of inputs (i.e., different number of neighbor nodes)
h ^ v = T Ī [ { h ^ u } u ā N ( v ) ] = W 1 Ļ ( W 2 [ x v , ā u ā N ( v ) [ h ^ u , x u ] ] ) Ā Ā Ā Ā Ā Ā Ā Ā ( 9 ) \widehat{h}_v = \mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right]=W_{1} \sigma\left(W_{2}\left[x_{v}, \sum_{u \in \mathcal{N}(v)}\left[\widehat{h}_{u}, x_{u}\right]\right]\right) \space\space\space\space\space\space\space\space (9) h v ā = T Ī ā [ { h u ā } u ā N ( v ) ā ] = W 1 ā Ļ ā W 2 ā ā x v ā , u ā N ( v ) ā ā [ h u ā , x u ā ] ā ā Ā Ā Ā Ā Ā Ā Ā Ā ( 9 ) Ļ ( ā
) \sigma(\cdot) Ļ ( ā
) : element-wise activation function: Sigmoid, ReLU
W 1 , W 2 : weightĀ matricesĀ ofĀ NN . W_1, W_2 : \text{weight matrices of NN} . W 1 ā , W 2 ā : weightĀ matricesĀ ofĀ NN . W 1 : firstĀ layer,Ā W 2 : 2ndĀ layer W_1: \text{first layer, } W_2: \text{2nd layer} W 1 ā : firstĀ layer,Ā W 2 ā : 2ndĀ layer
x v x_v x v ā : the optional feature representation of nodes
Link function (prediction function) g g g : a two-layer NN
General nolinear function class
predict: the corresponding algorithm outputs (like label of node)
g ( h ^ v ) = Ļ ( V 2 ⤠ReLU ā” ( V 1 ⤠h ^ v ) ) Ā Ā Ā Ā Ā Ā ( 10 ) g\left(\widehat{h}_{v}\right)=\sigma\left(V_{2}^{\top} \operatorname{ReLU}\left(V_{1}^{\top} \widehat{h}_{v}\right)\right) \space\space\space\space\space\space (10) g ( h v ā ) = Ļ ( V 2 ⤠ā ReLU ( V 1 ⤠ā h v ā ) ) Ā Ā Ā Ā Ā Ā ( 10 ) h ^ v \widehat{h}_{v} h v ā : node embeddings
V 1 , V 2 : parametersĀ ofĀ g . V_1, V_2 : \text{parameters of } g. V 1 ā , V 2 ā : parametersĀ ofĀ g . V 1 : firstĀ layer,Ā V 2 : 2ndĀ layer V_1: \text{first layer, } V_2: \text{2nd layer} V 1 ā : firstĀ layer,Ā V 2 ā : 2ndĀ layer
Ļ : task-specificĀ activationĀ function \sigma:\text{task-specific activation function} Ļ : task-specificĀ activationĀ function
linear regression: identity function Ļ ( x ) = x \sigma(x)=x Ļ ( x ) = x
multi-class classification: Ļ ( ā
) \sigma(\cdot) Ļ ( ā
) is softmax (output a probabilistic simplex)
The overall optimization problem
min ā” { W i , V i } i = 1 2 L ( { W i , V i } i = 1 2 ) : = 1 ⣠V y ⣠ā v ā V ( y ) ā ( f v ā , g ( h ^ v ) ) Ā s.t.Ā h ^ v = T Ī [ { h ^ u } u ā N ( v ) ] , ā v ā V Ā Ā Ā Ā Ā Ā ( 11 ) \begin{array}{c}
\min _{\left\{W_{i}, V_{i}\right\}_{i=1}^{2}} \mathcal{L}\left(\left\{W_{i}, V_{i}\right\}_{i=1}^{2}\right):=\frac{1}{\left|\mathcal{V}^{y}\right|} \sum_{v \in \mathcal{V}^{(y)}} \ell(f_{v}^{*}, g(\hat{h}_{v})) \\
\text { s.t. } \widehat{h}_{v}=\mathcal{T}_{\Theta}\left[\left\{\widehat{h}_{u}\right\}_{u \in \mathcal{N}(v)}\right], \forall v \in \mathcal{V}
\end{array} \space\space\space\space\space\space (11) min { W i ā , V i ā } i = 1 2 ā ā L ( { W i ā , V i ā } i = 1 2 ā ) := ⣠V y ⣠1 ā ā v ā V ( y ) ā ā ( f v ā ā , g ( h ^ v ā )) Ā s.t.Ā h v ā = T Ī ā [ { h u ā } u ā N ( v ) ā ] , ā v ā V ā Ā Ā Ā Ā Ā Ā ( 11 ) W 1 , W 2 : parametersĀ ofĀ T Ī W_1,W_2: \text{parameters of } \mathcal{T}_{\Theta} W 1 ā , W 2 ā : parametersĀ ofĀ T Ī ā
V 1 , V 2 : parametersĀ ofĀ g V_1,V_2: \text{parameters of } g V 1 ā , V 2 ā : parametersĀ ofĀ g
my understanding: semi-surpervised learning
https://www.wikiwand.com/en/Fixed-point_iteration How to solve (11): Stochastic Steady-state Embedding(SSE)
An alternating algorithm: alternate between:
using most current model to find the embeddings and make prediction
using the gradient of the loss with respect to { W 1 , W 2 , V 1 , V 2 } \{W_1, W_2, V_1, V_2\} { W 1 ā , W 2 ā , V 1 ā , V 2 ā } for update these parameters
Intuition:
RL (policy iteration): improve the policy minimizing the cost proportional to f ā f^{*} f ā by updating the parameters T Ī Ā andĀ g \mathcal{T}_{\Theta} \text{ and } g T Ī ā Ā andĀ g
steady-state h v ^ \hat{h_v} h v ā ^ ā for each node: "value function "
embedding operator T Ī \mathcal{T}_{\Theta} T Ī ā and classifier function g g g : "policy "
"Value" estimation: estimate steady-state
h v ^ \hat{h_v} h v ā ^ ā limitation : it is prohibitive to solve the steady-state equation exactly in large-scale graph with millions of vertices since it requires visiting all the nodes in the graph .
Solution : stochastic fixed point iteration, the extra randomness on the constraints for sampling the constraints to tackle the groups of equations approximately.
In k -th k\text{-th} k -th step, first sample a set of nodes V ~ = { v 1 , v 2 , ⦠, v N } ā V \tilde{\mathcal{V}}=\{v_1,v_2,\dots,v_N\}\in \mathcal{V} V ~ = { v 1 ā , v 2 ā , ⦠, v N ā } ā V from the entire node set rather of the labeled set. Update the new embedding by moving averageļ¼
h ^ v i ( k ) ā ( 1 ā α ) h ^ v i ( k ā 1 ) + α T Ī [ { h ^ u ( k ā 1 ) } u ā N ( v i ) ] , ā v i ā V ~ Ā Ā Ā Ā Ā Ā Ā Ā ( 12 ) \widehat{h}_{v_{i}}^{(k)} \leftarrow(1-\alpha) \hat{h}_{v_{i}}^{(k-1)}+\alpha \mathcal{T}_{\Theta}\left[\{\widehat{h}_{u}^{(k-1)}\}_{u \in \mathcal{N}\left(v_{i}\right)}\right], \forall v_{i} \in \tilde{\mathcal{V}} \space\space\space\space\space\space\space\space (12) h v i ā ( k ) ā ā ( 1 ā α ) h ^ v i ā ( k ā 1 ) ā + α T Ī ā [ { h u ( k ā 1 ) ā } u ā N ( v i ā ) ā ] , ā v i ā ā V ~ Ā Ā Ā Ā Ā Ā Ā Ā ( 12 ) α : 0 ⤠α ⤠1 \alpha: 0 \leq \alpha\leq1 α : 0 ⤠α ⤠1
"Policy" improvement: update parameters of
T Ī Ā andĀ g \mathcal{T}_{\Theta} \text{ and } g T Ī ā Ā andĀ g At the k -th k\text{-th} k -th step, once we have { h ^ v ( k ) } v ā V \{\widehat{h}_{v}^{(k)}\}_{v\in \mathcal{V}} { h v ( k ) ā } v ā V ā satisfying the steady-state equation, we use vanilla stochastic gradient descent to update parameters { W 1 , W 2 , V 1 , V 2 } \{W_1, W_2, V_1, V_2\} { W 1 ā , W 2 ā , V 1 ā , V 2 ā } :
ā L ā V i = E ^ [ ā ā ( f v ā , g ( h ^ v k ) ) ā g ( h ^ v k ) ā g ( h ^ v k ) ā V i ] ā L ā W i = E ^ [ ā ā ( f v ā , g ( h ^ v k ) ) ā h ^ v k . ā T Ī ā W i ] \begin{aligned}
\frac{\partial \mathcal{L}}{\partial \color{red}V_{i}} &=\widehat{\mathbb{E}}\left[\frac{\partial \ell\left(f_{v}^{*}, g\left(\hat{h}_{v}^{k}\right)\right)}{\partial g\left(\hat{h}_{v}^{k}\right)} \frac{\partial g\left(\hat{h}_{v}^{k}\right)}{\partial \color{red}V_{i}}\right] \\
\frac{\partial \mathcal{L}}{\partial \color{red}W_{i}} &=\widehat{\mathbb{E}}\left[\frac{\partial \ell\left(f_{v}^{*}, g\left(\hat{h}_{v}^{k}\right)\right)}{\partial \widehat{h}_{v}^{k}} \frac{.\partial \mathcal{T}_{\Theta}}{\partial \color{red}W_{i}}\right]
\end{aligned} ā V i ā ā L ā ā W i ā ā L ā ā = E ā ā g ( h ^ v k ā ) ā ā ( f v ā ā , g ( h ^ v k ā ) ) ā ā V i ā ā g ( h ^ v k ā ) ā ā = E ā ā h v k ā ā ā ( f v ā ā , g ( h ^ v k ā ) ) ā ā W i ā . ā T Ī ā ā ā ā E ^ [ ā
] \widehat{\mathbb{E}}[\cdot] E [ ā
] : the expectation is w.r.t. uniform distribution over labeled nodes V ( y ) \mathcal{V}^{(y)} V ( y ) .
n h n_h n h ā : the # of inner loops in "value" estimation
n f n_f n f ā : the # of inner loops in "policy" improvement
Policy Iteration for comparison
Complexity
Memory space
O ( ⣠V ⣠) : O(|\mathcal{V}|): O ( ⣠V ⣠) : The dominating part is the persistent node embedding matrix{ h ^ v } v ā V \{\widehat{h}_v\}_{v\in \mathcal{V}} { h v ā } v ā V ā
O ( T ⣠V ⣠) : O(T|\mathcal{V}|): O ( T ⣠V ⣠) : T -hops T\text{-hops} T -hops for GNN family
Time: the computational cost in each iteration is just proportional to the number of edges in each mini-batch.
"policy" improvement: O ( M ⣠E ⣠⣠V ⣠) O(M\frac{|\mathcal{E}|}{|\mathcal{V}|}) O ( M ⣠V ⣠⣠E ⣠ā )
"value" estimation: O ( N ⣠E ⣠⣠V ⣠) O(N\frac{|\mathcal{E}|}{|\mathcal{V}|}) O ( N ⣠V ⣠⣠E ⣠ā )
Reference