Package exports
Public package interface for tf_gnns.
- class tf_gnns.Graph(nodes, edges, global_attr=None, NO_VALIDATION=True)[source]
Bases:
objectObject graph made of node and edge instances.
- Parameters:
- is_equal_by_value(g2)[source]
Checks if the graphs have the same values for node and edge attributes
- get_subgraph_from_nodes(nodes, edge_trimming_mode='+from+to')[source]
Create a subgraph by filtering nodes and incident edges.
- Parameters:
nodes – Node subset to keep.
edge_trimming_mode – Edge filter mode. Supported values are
"+from+to"(keep edges where both endpoints are innodes) and"-from+to"(keep edges where both endpoints are not innodes).
- Returns:
A new
Graphwith copied nodes and matching copied edges.
- class tf_gnns.GraphTuple(nodes, edges, senders, receivers, n_nodes, n_edges, global_attr=None, global_reps_for_nodes=None, global_reps_for_edges=None, n_graphs=None)[source]
Bases:
objectBatched graph representation used by GraphNet tensor-dict paths.
A
GraphTuplestores all node and edge features in contiguous tensors and keeps graph boundaries via n_nodes and n_edges vectors.The
GraphTuplemakes multiple smaller graphs appear as a single large graph, with contiguous indexing for nodes and edges. This allows fast batched computation and takes advantage of default performance optimizations in deep learning frameworks.- __init__(nodes, edges, senders, receivers, n_nodes, n_edges, global_attr=None, global_reps_for_nodes=None, global_reps_for_edges=None, n_graphs=None)[source]
Initialize a graph batch.
- Parameters:
nodes – Tensor-like node feature array with shape
[sum(n_nodes), d_n].edges – Tensor-like edge feature array with shape
[sum(n_edges), d_e].senders – Sender node indices for each edge. Indices are unique across graphs in the flattened representation.
receivers – Receiver node indices for each edge. Indices are unique across graphs in the flattened representation.
n_nodes – Per-graph node counts.
n_edges – Per-graph edge counts.
global_attr – Optional graph-level features of shape
[n_graphs, d_g].global_reps_for_nodes – Optional precomputed mapping from node rows to graph ids.
global_reps_for_edges – Optional precomputed mapping from edge rows to graph ids.
n_graphs – Optional number of graphs.
- assign_global(global_attr, check_shape=False)[source]
Assign graph-level features.
- Parameters:
global_attr – Tensor-like global features.
check_shape – If
True, assert first dimension equalsn_graphs.
- class tf_gnns.Node(node_attr_tensor)[source]
Bases:
objectGraph node with a feature tensor.
- Parameters:
node_attr_tensor – Tensor-like node attributes with at least rank 2. The first dimension is usually batch-like in object graph mode.
- class tf_gnns.Edge(edge_attr_tensor, node_from, node_to)[source]
Bases:
objectDirected edge connecting two
Nodeobjects.- Parameters:
edge_attr_tensor – Tensor-like edge attributes.
node_from – Source node.
node_to – Destination node.
- tf_gnns.make_graph_tuple_from_graph_list(list_of_graphs)[source]
Create a
GraphTuplefrom a list of object graphs.- Parameters:
list_of_graphs – List of
Graphobjects with consistent feature dimensionality.- Returns:
A
GraphTuplewith flattened node/edge tensors and bookkeeping vectors (senders, receivers, n_nodes, n_edges).
Notes
This helper currently expects node and edge attributes in each input graph to have first dimension equal to
1.
- class tf_gnns.GraphNet(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]
Bases:
objectOne GraphNet computation block.
A
GraphNetinstance wraps edge, node, and optional global update functions and evaluates them over either object graphs or tensor-dict graph tuples.- __init__(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]
Initialize a GraphNet block.
- Parameters:
edge_function – Keras model for edge updates.
node_function – Keras model for node updates.
global_function – Optional Keras model for global updates.
edge_aggregation_function – Aggregation used for edge-to-node and edge-to-global reductions in non-segment paths.
node_to_global_aggregation_function – Aggregation used for node-to-global reductions.
graph_independent – If
True, disable message passing and apply independent edge/node/global transforms.use_global_input – Whether global features are provided as model inputs.
name – Optional block name.
- scan_edge_to_node_aggregation_function(fn)[source]
Scans inputs & outputs of agg. function and keeps track of them for subsequent computation. Throws an error if the aggregation is not compatible with the rest of the defined GN functions.
- scan_edge_function()[source]
Edge function signature (wether it has or has not an input) is inferred by the naming of the inputs. Scans the inputs of the edge function to keep track of which graph variables the edge functions uses - throws errors for cases that don’t make sense. Creates a dict that resolves this correspondence for the evaluation.
- graph_tuple_eval(tf_graph_tuple: GraphTuple)[source]
- edge_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- node_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- global_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- eval_tensor_dict(d)[source]
For better performance this uses a dictionary of all the necessary tensors rather than a GraphTuple. A graph tuple is easilly transformed to a dictionary of tensors by _graphtuple_to_tensor_dict function.
- d is an ordered dictionary containing the following:
‘edges’,’nodes’,’senders’,’receivers’,’n_edges’,’n_nodes’,’global_attr’,’global_reps_for_edges’,’global_reps_for_nodes’
- graph_eval(graph: Graph, **kwargs)[source]
Evaluate a single Graph by routing through GraphTuple execution.
Legacy safe / batched modes were removed in favor of the unified GraphTuple path, which is the supported runtime route.
- tf_gnns.make_node_mlp(units, edge_message_input_shape=None, node_state_input_shape=None, global_state_input_shape=None, node_emb_size=None, use_global_input=False, use_edge_state_agg_input=True, graph_indep=False, use_node_state_input=True, activation='relu', **kwargs)[source]
- tf_gnns.make_edge_mlp(units, edge_state_input_shape=None, sender_node_state_output_shape=None, global_to_edge_state_size=None, receiver_node_state_shape=None, edge_output_state_size=None, use_edge_state=True, use_sender_out_state=True, use_receiver_state=True, use_global_state=False, graph_indep=False, activation='relu', **kwargs)[source]
When this is a graph-independent edge function, the node states from the sender and receiver are not used. As the make_node_mlp, it uses a list of named keras.Input layers.
- tf_gnns.make_keras_simple_agg(input_size, agg_type)[source]
For consistency I’m making this a keras model (easier saving/loading) This is for the “naive” graphNet evaluators. There is a fully batched aggregator with segment sums that should be preferred.
- Parameters:
input_size – the size of the expected input (mainly useful to enforce consistency)
agg_type – [‘mean’,’sum’,’min’,’max’]
- tf_gnns.make_mlp_graphnet_functions(units, node_input_size, node_output_size, edge_input_size=None, edge_output_size=None, create_global_function=False, global_input_size=None, global_output_size=None, use_global_input=False, use_global_to_edge=False, use_global_to_node=False, node_mlp_use_edge_state_agg_input=True, graph_indep=False, message_size='auto', aggregation_function='mean', node_to_global_aggr_fn=None, edge_to_global_aggr_fn=None, activation='relu', activate_last_layer=False, **kwargs)[source]
Build callables required to construct a
GraphNetblock.- Parameters:
units – Width specification used when creating edge/node/global MLPs.
node_input_size – Input node feature size.
node_output_size – Output node feature size.
edge_input_size – Optional input edge feature size. Defaults to
node_input_size.edge_output_size – Optional output edge feature size. Defaults to
node_output_size.create_global_function – If
True, create a global update model.global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
use_global_input – If
True, consumeglobal_attrfrom inputs.use_global_to_edge – If
True, include input global state in the edge model inputs.use_global_to_node – If
True, include input global state in the node model inputs.node_mlp_use_edge_state_agg_input – If
True, node MLP consumes aggregated edge messages.graph_indep – If
True, create a graph-independent block (no message passing).message_size – Edge message size, or
"auto"to infer from selected aggregation.aggregation_function – Aggregation mode used for message passing.
node_to_global_aggr_fn – Optional override for node-to-global aggregation.
edge_to_global_aggr_fn – Optional override for edge-to-global aggregation.
activation – Activation used for created MLPs.
activate_last_layer – Whether to apply activation on final MLP layer.
**kwargs – Forwarded to low-level MLP factory helpers.
- Returns:
Dictionary of GraphNet constructor keyword arguments, including
edge_function,node_function, optionalglobal_function, and aggregation callables.- Raises:
ValueError – If global state is requested by edge/node models while
use_global_inputisFalse.
- tf_gnns.make_global_mlp(units, global_in_size=None, global_emb_size=None, node_in_size=None, edge_in_size=None, use_node_agg_input=True, use_edge_agg_input=True, use_global_state_input=True, node_to_global_agg=None, graph_indep=False, activation='relu', **kwargs)[source]
Always uses the global state input. May use node/edge aggregator (full GraphNet case)
- tf_gnns.make_full_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', **kwargs)[source]
Create GraphNet callables for a full message-passing block.
- Parameters:
units – Width specification used for created MLPs.
node_or_core_input_size – Input node feature size.
node_or_core_output_size – Optional output node feature size.
edge_input_size – Optional input edge feature size.
edge_output_size – Optional output edge feature size.
global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
aggregation_function – Aggregation mode for message passing.
**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNetconfigured with edge, node, and global update functions.
- tf_gnns.make_graph_indep_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', create_global_function=True, use_global_input=True, **kwargs)[source]
Create GraphNet callables for a graph-independent block.
- Parameters:
units – Width specification used for created MLPs.
node_or_core_input_size – Input node feature size.
node_or_core_output_size – Optional output node feature size.
edge_input_size – Optional input edge feature size.
edge_output_size – Optional output edge feature size.
global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
aggregation_function – Aggregation mode (retained for interface consistency).
create_global_function – If
True, create global output function.use_global_input – If
True, consume global input state.**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNetconfigured in graph-independent mode.
- tf_gnns.make_graph_to_graph_and_global_functions(units, node_or_core_input_size, global_output_size, node_or_core_output_size=None, edge_output_size=None, edge_input_size=None, aggregation_function='mean', **kwargs)[source]
Create a graph-to-graph-plus-global GraphNet function bundle.
This wrapper builds a block that does not consume input globals but still produces global outputs by aggregating node and edge states.
- Parameters:
units – Width specification used for constructed MLPs.
node_or_core_input_size – Input node feature size.
global_output_size – Output global feature size.
node_or_core_output_size – Optional output node feature size.
edge_output_size – Optional output edge feature size.
edge_input_size – Optional input edge feature size.
aggregation_function – Aggregation mode used in message passing.
**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNet.
- class tf_gnns.GraphNetMLP(*args, **kwargs)[source]
Bases:
LayerEncode-process-decode GraphNet model implemented with MLPs.
- class tf_gnns.GraphIndep(*args, **kwargs)[source]
Bases:
LayerGraph-independent block with optional global input handling.
- class tf_gnns.GNCellMLP(*args, **kwargs)[source]
Bases:
LayerSingle GraphNet processing block implemented with MLPs.
This layer is close to MLPGraphNetwork from the graph_nets examples.
- class tf_gnns.GraphNetMPNN_MLP(*args, **kwargs)[source]
Bases:
LayerMLP-based MPNN without graph-level global state updates.
- class tf_gnns.SparseGCNConv(*args, **kwargs)[source]
Bases:
LayerSingle sparse GCN convolution layer.
Inputs are expected to follow the graph tensor-dict format used by this repository (keys:
nodes,senders,receivers, and counts).This layer supports a paper-faithful update structure with a separate root transform, normalized neighborhood aggregation, and optional BatchNorm:
\[h_v^{l} = \sigma\left(\mathrm{Norm}\left( h_v^{l-1}W_r^l + \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\hat d_u \hat d_v}} h_u^{l-1}W^l \right)\right)\]where dropout is typically applied by the outer
SparseGCNstack.- Reference:
“Classic GNNs are Strong Baselines” (2024), https://arxiv.org/pdf/2406.08993
- class tf_gnns.SparseGCN(*args, **kwargs)[source]
Bases:
LayerMulti-layer sparse GCN stack for node-level representations.
This layer composes multiple
SparseGCNConvblocks and applies optional dropout between hidden layers. By default, each convolution uses BatchNorm, matching the “Norm” term used in common strong-baseline GCN formulations. Inputs follow the graph tensor-dict structure used acrosstf_gnns:nodes: node feature matrix with shape[N, F]senders: source-node indices for edges, shape[E]receivers: target-node indices for edges, shape[E]edge_weights(optional): scalar edge weights, shape[E]graph bookkeeping keys such as
n_nodes/n_edgesare passed through
The stack performs repeated message passing in sparse form. Each
SparseGCNConvcan add self-loops and symmetric normalizationD^{-1/2} A D^{-1/2}(configurable).Typical usage (2-layer GCN for node classification):
gcn = SparseGCN( hidden_units=[128], output_units=num_classes, activation="relu", dropout_rate=0.5, add_self_loops=True, normalize=True, ) out_td = gcn(input_td, training=True) logits = out_td["nodes"]
A deeper variant with custom dtypes:
gcn = SparseGCN( hidden_units=[256, 256, 128], output_units=64, activation="relu", dropout_rate=0.2, feature_dtype="float32", index_dtype="int32", )
- Parameters:
hidden_units – Width(s) of hidden GCN layers. Accepts an
intfor a single hidden layer or a non-empty list ofintvalues.output_units – Optional final output width appended after hidden layers. For classification this is typically the number of classes.
activation – Activation for hidden layers. Final layer uses
Noneby default so downstream losses can consume logits directly.dropout_rate – Dropout probability applied to node features between convolution layers (not applied after the final layer).
add_self_loops – Whether each
SparseGCNConvinjects identity edges.normalize – Whether each
SparseGCNConvapplies degree-based normalization to edge weights.batchnorm – Whether each
SparseGCNConvapplies BatchNormalization to node features after message/root aggregation and before activation. Enabled by default.layernorm – Whether each
SparseGCNConvapplies LayerNormalization to node features after message/root aggregation and before activation. Must not be enabled at the same time asbatchnorm.jit_compile – Flag forwarded to
SparseGCNConvfor backend-specific compiled execution behavior.residual – If
True, add residual node skip-connections between hidden layers when feature dimensions match. Residual links are not applied on the final layer.residual_projection – If
Trueandresidualis enabled, add a learnable linear projection on residual paths when hidden feature dimensions do not match. Projection is only used for hidden layers (not the final output layer).feature_dtype – Optional dtype override for feature tensors and weights.
index_dtype – Optional dtype override for graph index tensors.
**kwargs – Forwarded to
keras.layers.Layer.
- Returns:
A tensor-dict with updated
nodesand preserved graph structure keys.- Raises:
ValueError – If
hidden_unitsis empty.