tf_gnns: A Hackable graph_nets library
largely inspired by the graph_nets project.
First install the library on the colab notebook and import it
# Install the tf_gnns library with pip:
!pip install tf_gnns==0.1.7
import tf_gnns
Requirement already satisfied: tf_gnns==0.1.2b in /home/charilaos/.local/lib/python3.8/site-packages (0.1.2b0)
GraphNet basics
The data-structure over which GraphNets (GNs) operate is an attributed multi-graph. In what follows some examples of definition of graphs are given.
As detailed later in the notebook, the GraphTuple or a dictionary that corresponds to a GraphTuple are the recommended data structures for achieving good performance.
Creating a single `Graph’
In order to specify a `Graph’ you need to specify a set of edges and nodes. (note that this data-structure is not very efficient. It is, however, more intuitive to work with.)
# Creation of edges and nodes:
import numpy as np
from tf_gnns import Graph, Edge, Node
nw_state_size = 10 # the state of the node and edge attributes. They can be different if necessary.
# Defining the connectivity of a graph:
adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,7)]
nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]
edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]
graph_A = Graph(nodes=nodes_A, edges= edges_A)
For more expressivity some operations are over-loaded. Inspect the docstrings and the source code directly when not sure what happens.
help(graph_A.__add__)
Help on method __add__ in module tf_gnns.datastructures:
__add__(graph) method of tf_gnns.datastructures.Graph instance
This should only work with graphs that have compatible node and edge features
Assumed also that the two graphs have the same connectivity (otherwise this will fail ugly)
Simple graphs are not convenient for bached computation and in general should be avoided. Sometimes it may be “safer” to use graphs as a datastructure for experimentation on new algorithms.
For most use-cases a GraphTuple' should be used. GraphTuples’ are the only datastructures supported from the `graph_nets’ library.
Creating a GraphTuple
The GraphTuple data structure
The GraphTuple is a set of graphs packed into an object that allows for easier parallelization of the GraphNet computation block. It is the default (and only) possible data-structure in DeepMind’s GraphNets library.
You can either create a GraphTuple directly from tensors or numpy arrays or by a list of Graph objects (as above). Utilities are provided to perform such manipulations. Note that in general it’s better to avoid using Graphs alltogether, but they are useful for testing.
Note on performance
For achieving good performance it is important to wrap the training iteration code (dataset sampling, forward pass, gradient computation) in a tf.function decorated block. This has some limmitations with respect to the datastructures that are allowed in the arguments of the wrapped functions. Unfortunately, only tf.Tensors, built-in python datatypes, or collections of tf.Tensors are allowed for the tf.function arguments (GraphTuples are not allowed). A work-around for using tf_gnns with tf.function is to cast the GraphTuple to a dictionary.
Creation of a GraphTuple
Here is how to create a GraphTuple from a list of Graph objects:
from tf_gnns import make_graph_tuple_from_graph_list
# Some graphs to compute with:
adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(4,2)]
adj_B = [(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,9)]
nw_state_size = 2
nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(15)]
edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]
graph_A = Graph(nodes=nodes_A, edges= edges_A)
nodes_B = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]
edges_B = [Edge(np.random.randn(1,nw_state_size), node_from=nodes_B[e_ij[0]], node_to = nodes_B[e_ij[1]]) for e_ij in adj_B]
graph_B = Graph(nodes=nodes_B, edges= edges_B)
gt = make_graph_tuple_from_graph_list([graph_A, graph_B])
print("number of nodes of graphtuple:")
print(gt.nodes.shape[0])
print('GraphTuple tensor shapes:')
print(gt.nodes.shape, gt.edges.shape)
print("Shapes of a single node from graph_A: [%i,%i]"%(graph_A.nodes[0].shape))
print("Number of nodes of graphs: A: %i, B: %i"%(len(graph_A.nodes), len(graph_B.nodes)))
print("")
print("size of node attributes: [%i]"%( graph_A.nodes[0].node_attr_tensor.shape[-1]))
print("size of edge attributes: [%i]"%( graph_A.edges[0].edge_tensor.shape[-1]))
number of nodes of graphtuple:
25
GraphTuple tensor shapes:
(25, 2) (23, 2)
Shapes of a single node from graph_A: [1,2]
Number of nodes of graphs: A: 15, B: 10
size of node attributes: [2]
size of edge attributes: [2]
Creating a Custom GraphNet
In order to create a GraphNet (without global variables) one needs to define the following:
node function
an edge function
an edge aggregation function (except if a GraphIndependent network is implemented)
In addition to that, one needs to pay attention that the input sizes of the node and edge function are consistent with the input graph, the edge aggregation function (if it exists) has to have outputs consistent with expected inputs of the node function (if the graph is not graph independet). Moreover, each of these functions have potentially different inputs related to the input graph.
For instance:
Edge functions may have as inputs (
the edge state,
the sender node state,
the receiver node state
Node functions may have as inputs
the node state
an aggregated message incoming from the edges that point to that node.
These cases are identified internally by the naming of the inputs of the provided functions. See the examples to understand how to implement your own in case you want something special to happen in these functions. The tf_gnns library is built with tf.keras in mind and defines constraints to help development in a way that it exploits keras constructs, such as naming of inputs and output variables and the functional API of keras (tf.keras.Model) to facilitate building GraphNets.
In contrast, DeepMind’s GraphNet library uses Sonnet for that reason, and puts constraints on how computation and model construction happens for the same purpose. For some use-cases (such as implementation of bootstrapping or potential extensions on sparse graphs) It was found this was not very convenient - this is why this library was built.
One can think of tf_gnns as a graph_nets library without the tf1.x and Sonnet “baggage”.
from tf_gnns import make_mlp_graphnet_functions, GraphNet, Node, Edge, Graph, GraphTuple, make_graph_tuple_from_graph_list
## Naming:
# ... _enc : parameter of encoder block GNN (graph indep.)
# ... _core : parameter of core block GNN
# ... _dec : parameter of decoder block GNN
gnn_size = 50;
node_input_size_enc, node_output_size_enc = [2, 10]
node_input_size_core, node_output_size_core = [10, 10]
node_input_size_dec, node_output_size_dec = [10,2]
edge_input_size_enc = 2
edge_output_size_dec = 2
graph_fcn_enc = make_mlp_graphnet_functions(gnn_size,
node_input_size = node_input_size_enc,
node_output_size = node_output_size_enc,
edge_input_size = edge_input_size_enc,
graph_indep=True)
graph_fcn_core = make_mlp_graphnet_functions(gnn_size,
node_input_size = node_input_size_core,
node_output_size = node_output_size_core,
graph_indep=False)
graph_fcn_dec = make_mlp_graphnet_functions(gnn_size,
node_input_size = node_input_size_dec,
node_output_size = node_output_size_dec,
edge_output_size= edge_output_size_dec,
graph_indep=True)
graph_fcn_enc['graph_independent'] = True
graph_fcn_dec['graph_independent'] = True
gnn_dicts = [graph_fcn_enc, graph_fcn_core, graph_fcn_enc]
gnns = [GraphNet(**fcns) for fcns in [graph_fcn_enc,graph_fcn_core,graph_fcn_dec]] # A full encode-core-decode set of GNNs. One may eval. the core block multiple times.
In the example above make_mlp_graphnet_functions is a factory method. It conveniently returns a set of functions (keras models) that are in turn used to create GNNs.
Easy inspection of GN functions
The tf_gnns.GraphNet can be pretty-printed in jupyter for easier inspection. Click on the buttons to see the contents of each GN function:
gnns[0]
Graph Indep. GNN function (@140643679157648)
Model: "node_fn/model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= node_state (InputLayer) [(None, 2)] 0 _________________________________________________________________ node_fn/dense_0 (Dense) (None, 50) 100 _________________________________________________________________ node_fn/dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ node_fn/dense_2 (Dense) (None, 50) 2550 _________________________________________________________________ node_fn/dense_3 (Dense) (None, 10) 510 ================================================================= Total params: 5,710 Trainable params: 5,710 Non-trainable params: 0 _________________________________________________________________
Model: "edge_fn/model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= edge_state (InputLayer) [(None, 2)] 0 _________________________________________________________________ edge_fn/dense_0 (Dense) (None, 50) 100 _________________________________________________________________ edge_fn/dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ edge_fn/dense_2 (Dense) (None, 50) 2550 _________________________________________________________________ edge_fn/dense_3 (Dense) (None, 10) 510 ================================================================= Total params: 5,710 Trainable params: 5,710 Non-trainable params: 0 _________________________________________________________________
gnns[1]
GNN function (@140643679157792)
Model: "node_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
node_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
edge_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_1 (Concatenate) (None, 20) 0 node_state[0][0]
edge_state_agg[0][0]
__________________________________________________________________________________________________
node_fn/dense_0 (Dense) (None, 50) 1000 concatenate_1[0][0]
__________________________________________________________________________________________________
node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0]
__________________________________________________________________________________________________
node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0]
__________________________________________________________________________________________________
node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0]
==================================================================================================
Total params: 6,610
Trainable params: 6,610
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "edge_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
edge_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
sender_node_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
receiver_node_state (InputLayer [(None, 10)] 0
__________________________________________________________________________________________________
concatenate (Concatenate) (None, 30) 0 edge_state[0][0]
sender_node_state[0][0]
receiver_node_state[0][0]
__________________________________________________________________________________________________
edge_fn/dense_0 (Dense) (None, 50) 1500 concatenate[0][0]
__________________________________________________________________________________________________
edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0]
__________________________________________________________________________________________________
edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0]
__________________________________________________________________________________________________
edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0]
==================================================================================================
Total params: 7,110
Trainable params: 7,110
Non-trainable params: 0
__________________________________________________________________________________________________
Each of the GraphNets can be easilly saved using the .save method. You can invoke the summary method of any of these GNNs to inspect the inputs and outputs of the GN functions:
The overloaded __add__() operator shown above, comes in handy when computing residual connections:
CORE_STEPS = 2
def eval_full_shared_resid_core(G : GraphTuple, core_steps =CORE_STEPS):
## The actual computation
G = gnns[0].graph_tuple_eval(G) # The "encode" GraphNet block
for ncore in range(0,core_steps):
# gnns[1] is the "core" GraphNet block
G += gnns[1].graph_tuple_eval(G) # Overloaded sum operator - assign-add implements the residual connection (no projection needed - same size)
G = gnns[-1].graph_tuple_eval(G)
return G
gt_eval = eval_full_shared_resid_core(gt.copy())
print("Output GraphTuple:")
edges_shape, nodes_shape = gt_eval.edges.shape.as_list(),gt_eval.nodes.shape.as_list()
print(" nodes shape: [%i,%i]"%(nodes_shape[0], nodes_shape[1]))
print(" edges shape: [%i,%i]"%(edges_shape[0], edges_shape[1]))
Output GraphTuple:
nodes shape: [25,2]
edges shape: [23,2]
Global blocks
In what follows the manual creation of a Global block is shown.
from tf_gnns.graphnet_utils import make_keras_simple_agg
from tensorflow.keras.layers import Input, Dense
from tensorflow.keras import Model
GN_STATE = 2
GLOB_STATE_OUT = 1
def make_graph_tuple_to_global(insize = GN_STATE,output_size = GLOB_STATE_OUT, agg_type = 'mean'):
agg_fcn = make_keras_simple_agg(insize,agg_type) # from ibk_gnns import make_keras_simple_agg
agg_fcn = agg_fcn[1] # The graph tuple version of the aggregator
# Constructing the node+edge -> global function.
xx = Input(shape = (insize,))
out = Dense(insize, 'tanh')(xx)
out = Dense(insize, 'tanh')(out)
out = Dense(output_size, activation = None, use_bias = False)(out)
global_fcn = Model(inputs = xx, outputs = out)
def fcn_node_and_edge(gt_):
# Aggregate the node and edge info:
# First, retrieve the segment IDs (as done on the full GN step with no global aggregation):
graph_indices_nodes = []
for k_,k in enumerate(gt_.n_nodes):
graph_indices_nodes.extend(np.ones(k).astype("int")*k_)
graph_indices_edges = []
for k_,k in enumerate(gt_.n_edges):
graph_indices_edges.extend(np.ones(k).astype("int")*k_)
o1 = agg_fcn(gt_.nodes,graph_indices_nodes, gt_.n_graphs) # node_to_global aggregation
o2 = agg_fcn(gt_.edges,graph_indices_edges, gt_.n_graphs) # edge_to_global aggregation
return global_fcn(o1+o2) # either concat or add the aggregated information.
return fcn_node_and_edge, global_fcn #(the global fcn is also returned to later retrieve its weights)
# Evaluation example:
# gt_eval = eval_full_shared_resid_core(gt.copy())
graph_tuple_to_global, global_fcn = make_graph_tuple_to_global(insize = GN_STATE,agg_type ="mean")
res = graph_tuple_to_global(gt_eval)
print("%i graphs, %i dimensional output:"%(gt.n_graphs, GLOB_STATE_OUT))
res
2 graphs, 1 dimensional output:
<tf.Tensor: shape=(2, 1), dtype=float32, numpy=
array([[0.07865859],
[0.05267247]], dtype=float32)>
All weights of a GN:
Gathering the weights (trainable parameters) of GraphNet blocks in the tf_gnns is easy. Here is how to get all the parameters of the above computation:
all_weights=[]
for g in gnns: # the encode/core/decode layers:
ww = g.weights
all_weights.extend(ww)
all_weights.extend(global_fcn.weights) # The final "graph to global" layer.
Taking gradients of the whole operation:
The whole operation above is differentiable. Here is how to simply take gradients of the above with tf.GradientTape():
import tensorflow as tf
opt = tf.keras.optimizers.Adam(learning_rate = 0.001)
# Just some random numbers to work with:
Y = np.random.randn(gt.n_graphs,GLOB_STATE_OUT)
def full_eval_with_global(G):
G = eval_full_shared_resid_core(G)
return graph_tuple_to_global(G)
def loss(Ypred, Yactual):
return tf.reduce_mean(tf.square(Ypred - Yactual))
# The following also gives us direct access to the gradients which may be
# useful in some contexts:
with tf.GradientTape() as tape:
Yhat = full_eval_with_global(gt.copy())
loss_current = loss(Yhat, Y)
grads = tape.gradient(loss_current,all_weights)
opt.apply_gradients(zip(grads,all_weights))
nweights_per_block = [len(g.weights) for g in gnns]
nweights_per_block.append(len(global_fcn.weights))
cumsum_nweights_per_block = np.cumsum([0,*nweights_per_block])
cumsum_nweights_per_block
array([ 0, 14, 28, 42, 47])
A plot of the gradient histograms from different blocks:
import matplotlib.pyplot as plt
sc = 0.5
plt.figure(figsize = (20*sc,5*sc))
plt.subplot(1,4,1)
for k,name in enumerate(['encode','core','decode','global']):
plt.subplot(1,4,1+k)
# Encode network:
start_idx, end_idx = cumsum_nweights_per_block[k:k+2]
for g_ in grads[start_idx:end_idx]:
plt.hist(g_.numpy().flatten(), alpha = 0.9)
plt.grid()
plt.title("%s network\nGrad. histogram"%name)
<ipython-input-27-c0fecd24b1ac>:8: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.
plt.subplot(1,4,1+k)
Support for global variables
global-to-edge and global-to-node functions are also supported.
The GN block does nothing special to the global variable (it simply passes it to the input of the node and edge block). In order to use the global variable, define use_global_to_edge = True and/or use_global_to_node = True in construction, and the size/shape (shapes with rank>1 for the global variable are currently un-tested and probably will lead to bugs) of the global variable during construction.
For more custom computations and more advanced features, check the source of the factory method make_mlp_graphnet_factory(...) to understand naming conventions etc. The global variable is appended to the inputs of the node/edge functions when used.
For even more control on the functions, you may use the make_node_mlp(...) etc factories and pass them directly on the GraphNet constructor.
In the code bellow, some typical utility GN function factories are used.
from tf_gnns import make_mlp_graphnet_functions, make_graph_to_graph_and_global_functions, make_full_graphnet_functions
from tf_gnns import GraphNet, Node, Edge, Graph, GraphTuple, make_graph_tuple_from_graph_list
gnn_size = 50;
node_input_size_enc, node_output_size_enc = [2, 10]
node_input_size_core, node_output_size_core = [10, 10]
edge_input_size_core, edge_output_size_core = [10,10]
node_input_size_dec, node_output_size_dec = [10,2]
global_core_state = 10
edge_input_size_enc = 2
edge_output_size_dec = 2
graph_fcn_enc = make_graph_to_graph_and_global_functions(gnn_size,
node_or_core_input_size = node_input_size_enc,
node_or_core_output_size = node_output_size_enc,
global_output_size = global_core_state)
graph_fcn_core = make_full_graphnet_functions(gnn_size,
node_or_core_input_size = node_output_size_enc,
node_or_core_output_size = node_output_size_enc,
global_output_size = global_core_state)
graph_fcn_dec = make_mlp_graphnet_functions(gnn_size,
node_input_size = node_input_size_dec,
node_output_size = node_output_size_dec,
edge_output_size= edge_output_size_dec,
graph_indep=True)
graph_fcn_enc['graph_independent'] = True
graph_fcn_dec['graph_independent'] = True
graph_tuple_to_global, global_fcn = make_graph_tuple_to_global(insize = node_output_size_core,
output_size=global_core_state,
agg_type ="mean")
gn_enc = GraphNet(**graph_fcn_enc)
gn_core = GraphNet(**graph_fcn_core)
gn_dec = GraphNet(**graph_fcn_dec)
Inspection of the functions:
gn_enc
Graph Indep. GNN function (@140641199316128)
Model: "node_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
node_state (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
edge_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_3 (Concatenate) (None, 12) 0 node_state[0][0]
edge_state_agg[0][0]
__________________________________________________________________________________________________
node_fn/dense_0 (Dense) (None, 50) 600 concatenate_3[0][0]
__________________________________________________________________________________________________
node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0]
__________________________________________________________________________________________________
node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0]
__________________________________________________________________________________________________
node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0]
==================================================================================================
Total params: 6,210
Trainable params: 6,210
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "edge_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
edge_state (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
sender_node_state (InputLayer) [(None, 2)] 0
__________________________________________________________________________________________________
receiver_node_state (InputLayer [(None, 2)] 0
__________________________________________________________________________________________________
concatenate_2 (Concatenate) (None, 6) 0 edge_state[0][0]
sender_node_state[0][0]
receiver_node_state[0][0]
__________________________________________________________________________________________________
edge_fn/dense_0 (Dense) (None, 50) 300 concatenate_2[0][0]
__________________________________________________________________________________________________
edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0]
__________________________________________________________________________________________________
edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0]
__________________________________________________________________________________________________
edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0]
==================================================================================================
Total params: 5,910
Trainable params: 5,910
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "global_mlp/glob_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
node_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
edge_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_4 (Concatenate) (None, 20) 0 node_state_agg[0][0]
edge_state_agg[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_0 (Den (None, 50) 1000 concatenate_4[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_1 (Den (None, 50) 2550 global_mlp/glob_fn/dense_0[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_2 (Den (None, 50) 2550 global_mlp/glob_fn/dense_1[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_3 (Den (None, 10) 510 global_mlp/glob_fn/dense_2[0][0]
==================================================================================================
Total params: 6,610
Trainable params: 6,610
Non-trainable params: 0
__________________________________________________________________________________________________
gn_core
GNN function (@140641198037840)
Model: "node_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
node_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
edge_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
global_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_6 (Concatenate) (None, 30) 0 node_state[0][0]
edge_state_agg[0][0]
global_state[0][0]
__________________________________________________________________________________________________
node_fn/dense_0 (Dense) (None, 50) 1500 concatenate_6[0][0]
__________________________________________________________________________________________________
node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0]
__________________________________________________________________________________________________
node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0]
__________________________________________________________________________________________________
node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0]
==================================================================================================
Total params: 7,110
Trainable params: 7,110
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "edge_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
edge_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
sender_node_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
receiver_node_state (InputLayer [(None, 10)] 0
__________________________________________________________________________________________________
global_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_5 (Concatenate) (None, 40) 0 edge_state[0][0]
sender_node_state[0][0]
receiver_node_state[0][0]
global_state[0][0]
__________________________________________________________________________________________________
edge_fn/dense_0 (Dense) (None, 50) 2000 concatenate_5[0][0]
__________________________________________________________________________________________________
edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0]
__________________________________________________________________________________________________
edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0]
__________________________________________________________________________________________________
edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0]
==================================================================================================
Total params: 7,610
Trainable params: 7,610
Non-trainable params: 0
__________________________________________________________________________________________________
Model: "global_mlp/glob_fn/model"
__________________________________________________________________________________________________
Layer (type) Output Shape Param # Connected to
==================================================================================================
global_state (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
node_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
edge_state_agg (InputLayer) [(None, 10)] 0
__________________________________________________________________________________________________
concatenate_7 (Concatenate) (None, 30) 0 global_state[0][0]
node_state_agg[0][0]
edge_state_agg[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_0 (Den (None, 50) 1500 concatenate_7[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_1 (Den (None, 50) 2550 global_mlp/glob_fn/dense_0[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_2 (Den (None, 50) 2550 global_mlp/glob_fn/dense_1[0][0]
__________________________________________________________________________________________________
global_mlp/glob_fn/dense_3 (Den (None, 10) 510 global_mlp/glob_fn/dense_2[0][0]
==================================================================================================
Total params: 7,110
Trainable params: 7,110
Non-trainable params: 0
__________________________________________________________________________________________________
gn_dec
Graph Indep. GNN function (@140641192216272)
Model: "node_fn/model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= node_state (InputLayer) [(None, 10)] 0 _________________________________________________________________ node_fn/dense_0 (Dense) (None, 50) 500 _________________________________________________________________ node_fn/dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ node_fn/dense_2 (Dense) (None, 50) 2550 _________________________________________________________________ node_fn/dense_3 (Dense) (None, 2) 102 ================================================================= Total params: 5,702 Trainable params: 5,702 Non-trainable params: 0 _________________________________________________________________
Model: "edge_fn/model" _________________________________________________________________ Layer (type) Output Shape Param # ================================================================= edge_state (InputLayer) [(None, 10)] 0 _________________________________________________________________ edge_fn/dense_0 (Dense) (None, 50) 500 _________________________________________________________________ edge_fn/dense_1 (Dense) (None, 50) 2550 _________________________________________________________________ edge_fn/dense_2 (Dense) (None, 50) 2550 _________________________________________________________________ edge_fn/dense_3 (Dense) (None, 2) 102 ================================================================= Total params: 5,702 Trainable params: 5,702 Non-trainable params: 0 _________________________________________________________________
Computation using graph tuples:
g_ = gn_enc.graph_tuple_eval(gt.copy())
# The following lines implement the full-GN block:
g_ = gn_core.graph_tuple_eval(g_) # <-- Note that the global variables are used here.
glob_accum = graph_tuple_to_global(g_)
g_ = gn_core.graph_tuple_eval(g_)
Faster computation using tf.function
Tensorflow supports complilation of eager code in graph-mode code. The GraphNet can be included in tf.function compilable code through the .eval_tensor_dict(...) method, that operates in tensor dictionaries which correspond to GraphTuples. This should match the performance of other libraries.
def eval_graph_tuple(gt):
g_ = gn_enc.graph_tuple_eval(gt)
g_ = gn_core.graph_tuple_eval(g_)
return g_
@tf.function
def eval_tensordict(td_):
g_ = gn_enc.eval_tensor_dict(td_)
g_ = gn_core.eval_tensor_dict(g_)
return g_
from time import time
nreps = 500
t0 = time()
for i in range(nreps):
eval_graph_tuple(gt.copy())
dt_graph_tuple =( time() - t0)/nreps
t0 =time()
g_ = gt.to_tensor_dict()
for i in range(nreps):
eval_tensordict(g_)
dt_tf_function = (time()-t0)/nreps
print("GraphTuple: %2.3f[s/rep]\ntf.function: %2.3f[s/rep]"%(dt_graph_tuple, dt_tf_function))
print("Perf. improvement: %2.3f"%((dt_graph_tuple-dt_tf_function)/dt_graph_tuple * 100)+"%")
GraphTuple: 0.009[s/rep]
tf.function: 0.001[s/rep]
Perf. improvement: 86.149%