Package exports

Public package interface for tf_gnns.

class tf_gnns.Graph(nodes, edges, global_attr=None, NO_VALIDATION=True)[source]

Bases: object

Object graph made of node and edge instances.

Parameters:
  • nodes – List of Node instances.

  • edges – List of Edge instances.

  • global_attr – Optional graph-level attributes.

  • NO_VALIDATION – If False, run connectivity validation checks.

is_equal_by_value(g2)[source]

Checks if the graphs have the same values for node and edge attributes

compare_connectivity(g2)[source]

Checks if the connectivity of two graphs is the same.

static validate_graph(self)[source]
copy()[source]
get_subgraph_from_nodes(nodes, edge_trimming_mode='+from+to')[source]

Create a subgraph by filtering nodes and incident edges.

Parameters:
  • nodes – Node subset to keep.

  • edge_trimming_mode – Edge filter mode. Supported values are "+from+to" (keep edges where both endpoints are in nodes) and "-from+to" (keep edges where both endpoints are not in nodes).

Returns:

A new Graph with copied nodes and matching copied edges.

class tf_gnns.GraphTuple(nodes, edges, senders, receivers, n_nodes, n_edges, global_attr=None, global_reps_for_nodes=None, global_reps_for_edges=None, n_graphs=None)[source]

Bases: object

Batched graph representation used by GraphNet tensor-dict paths.

A GraphTuple stores all node and edge features in contiguous tensors and keeps graph boundaries via n_nodes and n_edges vectors.

The GraphTuple makes multiple smaller graphs appear as a single large graph, with contiguous indexing for nodes and edges. This allows fast batched computation and takes advantage of default performance optimizations in deep learning frameworks.

__init__(nodes, edges, senders, receivers, n_nodes, n_edges, global_attr=None, global_reps_for_nodes=None, global_reps_for_edges=None, n_graphs=None)[source]

Initialize a graph batch.

Parameters:
  • nodes – Tensor-like node feature array with shape [sum(n_nodes), d_n].

  • edges – Tensor-like edge feature array with shape [sum(n_edges), d_e].

  • senders – Sender node indices for each edge. Indices are unique across graphs in the flattened representation.

  • receivers – Receiver node indices for each edge. Indices are unique across graphs in the flattened representation.

  • n_nodes – Per-graph node counts.

  • n_edges – Per-graph edge counts.

  • global_attr – Optional graph-level features of shape [n_graphs, d_g].

  • global_reps_for_nodes – Optional precomputed mapping from node rows to graph ids.

  • global_reps_for_edges – Optional precomputed mapping from edge rows to graph ids.

  • n_graphs – Optional number of graphs.

update_reps_for_globals()[source]

Build helper vectors mapping nodes/edges to graph indices.

assign_global(global_attr, check_shape=False)[source]

Assign graph-level features.

Parameters:
  • global_attr – Tensor-like global features.

  • check_shape – If True, assert first dimension equals n_graphs.

is_equal_by_value(other_graph_tuple)[source]
copy()[source]
get_graph(graph_index)[source]

Extract a single Graph from this batch.

Parameters:

graph_index – Zero-based index of the graph to extract.

Returns:

A new Graph object containing copied node/edge features.

to_tensor_dict()[source]

Convert this graph batch to a GraphNet tensor dictionary.

class tf_gnns.Node(node_attr_tensor)[source]

Bases: object

Graph node with a feature tensor.

Parameters:

node_attr_tensor – Tensor-like node attributes with at least rank 2. The first dimension is usually batch-like in object graph mode.

get_state()[source]
set_tensor(tensor)[source]
copy()[source]
class tf_gnns.Edge(edge_attr_tensor, node_from, node_to)[source]

Bases: object

Directed edge connecting two Node objects.

Parameters:
  • edge_attr_tensor – Tensor-like edge attributes.

  • node_from – Source node.

  • node_to – Destination node.

set_tensor(edge_tensor)[source]
copy(nodes_correspondence)[source]
tf_gnns.make_graph_tuple_from_graph_list(list_of_graphs)[source]

Create a GraphTuple from a list of object graphs.

Parameters:

list_of_graphs – List of Graph objects with consistent feature dimensionality.

Returns:

A GraphTuple with flattened node/edge tensors and bookkeeping vectors (senders, receivers, n_nodes, n_edges).

Notes

This helper currently expects node and edge attributes in each input graph to have first dimension equal to 1.

class tf_gnns.GraphNet(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]

Bases: object

One GraphNet computation block.

A GraphNet instance wraps edge, node, and optional global update functions and evaluates them over either object graphs or tensor-dict graph tuples.

__init__(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]

Initialize a GraphNet block.

Parameters:
  • edge_function – Keras model for edge updates.

  • node_function – Keras model for node updates.

  • global_function – Optional Keras model for global updates.

  • edge_aggregation_function – Aggregation used for edge-to-node and edge-to-global reductions in non-segment paths.

  • node_to_global_aggregation_function – Aggregation used for node-to-global reductions.

  • graph_independent – If True, disable message passing and apply independent edge/node/global transforms.

  • use_global_input – Whether global features are provided as model inputs.

  • name – Optional block name.

static make_from_path(path)[source]
scan_edge_to_node_aggregation_function(fn)[source]

Scans inputs & outputs of agg. function and keeps track of them for subsequent computation. Throws an error if the aggregation is not compatible with the rest of the defined GN functions.

scan_edge_function()[source]

Edge function signature (wether it has or has not an input) is inferred by the naming of the inputs. Scans the inputs of the edge function to keep track of which graph variables the edge functions uses - throws errors for cases that don’t make sense. Creates a dict that resolves this correspondence for the evaluation.

scan_global_function()[source]
scan_node_function()[source]

Basic sanity checks for the node function.

get_graphnet_input_shapes()[source]
get_graphnet_output_shapes()[source]
summary()[source]
graph_tuple_eval(tf_graph_tuple: GraphTuple)[source]
edge_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
node_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
global_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
eval_tensor_dict(d)[source]

For better performance this uses a dictionary of all the necessary tensors rather than a GraphTuple. A graph tuple is easilly transformed to a dictionary of tensors by _graphtuple_to_tensor_dict function.

d is an ordered dictionary containing the following:

‘edges’,’nodes’,’senders’,’receivers’,’n_edges’,’n_nodes’,’global_attr’,’global_reps_for_edges’,’global_reps_for_nodes’

graph_eval(graph: Graph, **kwargs)[source]

Evaluate a single Graph by routing through GraphTuple execution.

Legacy safe / batched modes were removed in favor of the unified GraphTuple path, which is the supported runtime route.

save(path)[source]

Save the model. Iterates of the keras models required and saves them in a folder.

Parameters:

path – the path to save to. Creates it if it does not exist.

static load_graph_functions(path)[source]

Returns a list of loaded graph functions.

load(path)[source]

Load a model from disk. If the model is already initialized the current graphnet functions are simply overwritten. If the model is un-initialized, this is called from a static method (factory method) to make a new object with consistent properties.

tf_gnns.make_node_mlp(units, edge_message_input_shape=None, node_state_input_shape=None, global_state_input_shape=None, node_emb_size=None, use_global_input=False, use_edge_state_agg_input=True, graph_indep=False, use_node_state_input=True, activation='relu', **kwargs)[source]
tf_gnns.make_edge_mlp(units, edge_state_input_shape=None, sender_node_state_output_shape=None, global_to_edge_state_size=None, receiver_node_state_shape=None, edge_output_state_size=None, use_edge_state=True, use_sender_out_state=True, use_receiver_state=True, use_global_state=False, graph_indep=False, activation='relu', **kwargs)[source]

When this is a graph-independent edge function, the node states from the sender and receiver are not used. As the make_node_mlp, it uses a list of named keras.Input layers.

tf_gnns.make_keras_simple_agg(input_size, agg_type)[source]

For consistency I’m making this a keras model (easier saving/loading) This is for the “naive” graphNet evaluators. There is a fully batched aggregator with segment sums that should be preferred.

Parameters:
  • input_size – the size of the expected input (mainly useful to enforce consistency)

  • agg_type – [‘mean’,’sum’,’min’,’max’]

tf_gnns.make_mlp_graphnet_functions(units, node_input_size, node_output_size, edge_input_size=None, edge_output_size=None, create_global_function=False, global_input_size=None, global_output_size=None, use_global_input=False, use_global_to_edge=False, use_global_to_node=False, node_mlp_use_edge_state_agg_input=True, graph_indep=False, message_size='auto', aggregation_function='mean', node_to_global_aggr_fn=None, edge_to_global_aggr_fn=None, activation='relu', activate_last_layer=False, **kwargs)[source]

Build callables required to construct a GraphNet block.

Parameters:
  • units – Width specification used when creating edge/node/global MLPs.

  • node_input_size – Input node feature size.

  • node_output_size – Output node feature size.

  • edge_input_size – Optional input edge feature size. Defaults to node_input_size.

  • edge_output_size – Optional output edge feature size. Defaults to node_output_size.

  • create_global_function – If True, create a global update model.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • use_global_input – If True, consume global_attr from inputs.

  • use_global_to_edge – If True, include input global state in the edge model inputs.

  • use_global_to_node – If True, include input global state in the node model inputs.

  • node_mlp_use_edge_state_agg_input – If True, node MLP consumes aggregated edge messages.

  • graph_indep – If True, create a graph-independent block (no message passing).

  • message_size – Edge message size, or "auto" to infer from selected aggregation.

  • aggregation_function – Aggregation mode used for message passing.

  • node_to_global_aggr_fn – Optional override for node-to-global aggregation.

  • edge_to_global_aggr_fn – Optional override for edge-to-global aggregation.

  • activation – Activation used for created MLPs.

  • activate_last_layer – Whether to apply activation on final MLP layer.

  • **kwargs – Forwarded to low-level MLP factory helpers.

Returns:

Dictionary of GraphNet constructor keyword arguments, including edge_function, node_function, optional global_function, and aggregation callables.

Raises:

ValueError – If global state is requested by edge/node models while use_global_input is False.

tf_gnns.make_global_mlp(units, global_in_size=None, global_emb_size=None, node_in_size=None, edge_in_size=None, use_node_agg_input=True, use_edge_agg_input=True, use_global_state_input=True, node_to_global_agg=None, graph_indep=False, activation='relu', **kwargs)[source]

Always uses the global state input. May use node/edge aggregator (full GraphNet case)

tf_gnns.make_full_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', **kwargs)[source]

Create GraphNet callables for a full message-passing block.

Parameters:
  • units – Width specification used for created MLPs.

  • node_or_core_input_size – Input node feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_input_size – Optional input edge feature size.

  • edge_output_size – Optional output edge feature size.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • aggregation_function – Aggregation mode for message passing.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet configured with edge, node, and global update functions.

tf_gnns.make_graph_indep_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', create_global_function=True, use_global_input=True, **kwargs)[source]

Create GraphNet callables for a graph-independent block.

Parameters:
  • units – Width specification used for created MLPs.

  • node_or_core_input_size – Input node feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_input_size – Optional input edge feature size.

  • edge_output_size – Optional output edge feature size.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • aggregation_function – Aggregation mode (retained for interface consistency).

  • create_global_function – If True, create global output function.

  • use_global_input – If True, consume global input state.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet configured in graph-independent mode.

tf_gnns.make_graph_to_graph_and_global_functions(units, node_or_core_input_size, global_output_size, node_or_core_output_size=None, edge_output_size=None, edge_input_size=None, aggregation_function='mean', **kwargs)[source]

Create a graph-to-graph-plus-global GraphNet function bundle.

This wrapper builds a block that does not consume input globals but still produces global outputs by aggregating node and edge states.

Parameters:
  • units – Width specification used for constructed MLPs.

  • node_or_core_input_size – Input node feature size.

  • global_output_size – Output global feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_output_size – Optional output edge feature size.

  • edge_input_size – Optional input edge feature size.

  • aggregation_function – Aggregation mode used in message passing.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet.

class tf_gnns.GraphNetMLP(*args, **kwargs)[source]

Bases: Layer

Encode-process-decode GraphNet model implemented with MLPs.

build(d_shapes)[source]
call(g_)[source]

Evaluate the full encode-process-decode GraphNet pipeline.

class tf_gnns.GraphIndep(*args, **kwargs)[source]

Bases: Layer

Graph-independent block with optional global input handling.

build(input_shape)[source]
call(g_)[source]

Apply graph-independent transforms to nodes/edges/globals.

class tf_gnns.GNCellMLP(*args, **kwargs)[source]

Bases: Layer

Single GraphNet processing block implemented with MLPs.

This layer is close to MLPGraphNetwork from the graph_nets examples.

build(input_shape)[source]
call(g_)[source]

Run one GraphNet update step on a tensor dictionary.

class tf_gnns.GraphNetMPNN_MLP(*args, **kwargs)[source]

Bases: Layer

MLP-based MPNN without graph-level global state updates.

build(d_shapes)[source]

Build encoder, processor, and decoder GraphNet modules.

call(g_)[source]

Run the MPNN-style no-global encode-process-decode stack.

class tf_gnns.SparseGCNConv(*args, **kwargs)[source]

Bases: Layer

Single sparse GCN convolution layer.

Inputs are expected to follow the graph tensor-dict format used by this repository (keys: nodes, senders, receivers, and counts).

This layer supports a paper-faithful update structure with a separate root transform, normalized neighborhood aggregation, and optional BatchNorm:

\[h_v^{l} = \sigma\left(\mathrm{Norm}\left( h_v^{l-1}W_r^l + \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\hat d_u \hat d_v}} h_u^{l-1}W^l \right)\right)\]

where dropout is typically applied by the outer SparseGCN stack.

Reference:

“Classic GNNs are Strong Baselines” (2024), https://arxiv.org/pdf/2406.08993

build(input_shape)[source]
call(d, training=False)[source]
class tf_gnns.SparseGCN(*args, **kwargs)[source]

Bases: Layer

Multi-layer sparse GCN stack for node-level representations.

This layer composes multiple SparseGCNConv blocks and applies optional dropout between hidden layers. By default, each convolution uses BatchNorm, matching the “Norm” term used in common strong-baseline GCN formulations. Inputs follow the graph tensor-dict structure used across tf_gnns:

  • nodes: node feature matrix with shape [N, F]

  • senders: source-node indices for edges, shape [E]

  • receivers: target-node indices for edges, shape [E]

  • edge_weights (optional): scalar edge weights, shape [E]

  • graph bookkeeping keys such as n_nodes/n_edges are passed through

The stack performs repeated message passing in sparse form. Each SparseGCNConv can add self-loops and symmetric normalization D^{-1/2} A D^{-1/2} (configurable).

Typical usage (2-layer GCN for node classification):

gcn = SparseGCN(
    hidden_units=[128],
    output_units=num_classes,
    activation="relu",
    dropout_rate=0.5,
    add_self_loops=True,
    normalize=True,
)
out_td = gcn(input_td, training=True)
logits = out_td["nodes"]

A deeper variant with custom dtypes:

gcn = SparseGCN(
    hidden_units=[256, 256, 128],
    output_units=64,
    activation="relu",
    dropout_rate=0.2,
    feature_dtype="float32",
    index_dtype="int32",
)
Parameters:
  • hidden_units – Width(s) of hidden GCN layers. Accepts an int for a single hidden layer or a non-empty list of int values.

  • output_units – Optional final output width appended after hidden layers. For classification this is typically the number of classes.

  • activation – Activation for hidden layers. Final layer uses None by default so downstream losses can consume logits directly.

  • dropout_rate – Dropout probability applied to node features between convolution layers (not applied after the final layer).

  • add_self_loops – Whether each SparseGCNConv injects identity edges.

  • normalize – Whether each SparseGCNConv applies degree-based normalization to edge weights.

  • batchnorm – Whether each SparseGCNConv applies BatchNormalization to node features after message/root aggregation and before activation. Enabled by default.

  • layernorm – Whether each SparseGCNConv applies LayerNormalization to node features after message/root aggregation and before activation. Must not be enabled at the same time as batchnorm.

  • jit_compile – Flag forwarded to SparseGCNConv for backend-specific compiled execution behavior.

  • residual – If True, add residual node skip-connections between hidden layers when feature dimensions match. Residual links are not applied on the final layer.

  • residual_projection – If True and residual is enabled, add a learnable linear projection on residual paths when hidden feature dimensions do not match. Projection is only used for hidden layers (not the final output layer).

  • feature_dtype – Optional dtype override for feature tensors and weights.

  • index_dtype – Optional dtype override for graph index tensors.

  • **kwargs – Forwarded to keras.layers.Layer.

Returns:

A tensor-dict with updated nodes and preserved graph structure keys.

Raises:

ValueError – If hidden_units is empty.

call(d, training=False)[source]

Apply stacked sparse graph convolutions.

Parameters:
  • d – Graph tensor-dict containing at least nodes, senders, and receivers.

  • training – Whether to run in training mode (controls dropout between hidden layers).

Returns:

Graph tensor-dict with transformed nodes.