High-level models
Keras layer wrappers for common GraphNet model stacks.
- class tf_gnns.models.graphnet.GNCellMLP(*args, **kwargs)[source]
Bases:
LayerSingle GraphNet processing block implemented with MLPs.
This layer is close to MLPGraphNetwork from the graph_nets examples.
- class tf_gnns.models.graphnet.GraphNetMLP(*args, **kwargs)[source]
Bases:
LayerEncode-process-decode GraphNet model implemented with MLPs.
- class tf_gnns.models.graphnet.GraphIndep(*args, **kwargs)[source]
Bases:
LayerGraph-independent block with optional global input handling.
- class tf_gnns.models.graphnet.GraphNetMPNN_MLP(*args, **kwargs)[source]
Bases:
LayerMLP-based MPNN without graph-level global state updates.
Backend-agnostic sparse GCN layers and models.
This module provides:
- SparseGCNConv: a single graph convolution layer over GraphTuple-like
tensor dictionaries.
SparseGCN: a multi-layer stack for node-level prediction.
The implementation uses only keras.ops and backend facade ops from
tf_gnns.backend_ops so it works across Keras backends.
- class tf_gnns.models.gcn.SparseGCNConv(*args, **kwargs)[source]
Bases:
LayerSingle sparse GCN convolution layer.
Inputs are expected to follow the graph tensor-dict format used by this repository (keys:
nodes,senders,receivers, and counts).This layer supports a paper-faithful update structure with a separate root transform, normalized neighborhood aggregation, and optional BatchNorm:
\[h_v^{l} = \sigma\left(\mathrm{Norm}\left( h_v^{l-1}W_r^l + \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\hat d_u \hat d_v}} h_u^{l-1}W^l \right)\right)\]where dropout is typically applied by the outer
SparseGCNstack.- Reference:
“Classic GNNs are Strong Baselines” (2024), https://arxiv.org/pdf/2406.08993
- class tf_gnns.models.gcn.SparseGCN(*args, **kwargs)[source]
Bases:
LayerMulti-layer sparse GCN stack for node-level representations.
This layer composes multiple
SparseGCNConvblocks and applies optional dropout between hidden layers. By default, each convolution uses BatchNorm, matching the “Norm” term used in common strong-baseline GCN formulations. Inputs follow the graph tensor-dict structure used acrosstf_gnns:nodes: node feature matrix with shape[N, F]senders: source-node indices for edges, shape[E]receivers: target-node indices for edges, shape[E]edge_weights(optional): scalar edge weights, shape[E]graph bookkeeping keys such as
n_nodes/n_edgesare passed through
The stack performs repeated message passing in sparse form. Each
SparseGCNConvcan add self-loops and symmetric normalizationD^{-1/2} A D^{-1/2}(configurable).Typical usage (2-layer GCN for node classification):
gcn = SparseGCN( hidden_units=[128], output_units=num_classes, activation="relu", dropout_rate=0.5, add_self_loops=True, normalize=True, ) out_td = gcn(input_td, training=True) logits = out_td["nodes"]
A deeper variant with custom dtypes:
gcn = SparseGCN( hidden_units=[256, 256, 128], output_units=64, activation="relu", dropout_rate=0.2, feature_dtype="float32", index_dtype="int32", )
- Parameters:
hidden_units – Width(s) of hidden GCN layers. Accepts an
intfor a single hidden layer or a non-empty list ofintvalues.output_units – Optional final output width appended after hidden layers. For classification this is typically the number of classes.
activation – Activation for hidden layers. Final layer uses
Noneby default so downstream losses can consume logits directly.dropout_rate – Dropout probability applied to node features between convolution layers (not applied after the final layer).
add_self_loops – Whether each
SparseGCNConvinjects identity edges.normalize – Whether each
SparseGCNConvapplies degree-based normalization to edge weights.batchnorm – Whether each
SparseGCNConvapplies BatchNormalization to node features after message/root aggregation and before activation. Enabled by default.layernorm – Whether each
SparseGCNConvapplies LayerNormalization to node features after message/root aggregation and before activation. Must not be enabled at the same time asbatchnorm.jit_compile – Flag forwarded to
SparseGCNConvfor backend-specific compiled execution behavior.residual – If
True, add residual node skip-connections between hidden layers when feature dimensions match. Residual links are not applied on the final layer.residual_projection – If
Trueandresidualis enabled, add a learnable linear projection on residual paths when hidden feature dimensions do not match. Projection is only used for hidden layers (not the final output layer).feature_dtype – Optional dtype override for feature tensors and weights.
index_dtype – Optional dtype override for graph index tensors.
**kwargs – Forwarded to
keras.layers.Layer.
- Returns:
A tensor-dict with updated
nodesand preserved graph structure keys.- Raises:
ValueError – If
hidden_unitsis empty.
- class tf_gnns.models.gcn.GCNv2(*args, **kwargs)[source]
Bases:
LayerHigh-level GCN stack matching tunedGNN’s OGB-Arxiv recipe.
This layer mirrors the common
main-arxiv.pysetup used in tunedGNN:repeated hidden GCN layers (all with same hidden width),
each hidden layer uses the conv’s internal root transform,
optional BatchNorm/LayerNorm, ReLU, and dropout per hidden layer,
optional input dropout before the first hidden layer,
final prediction via a dense linear head (not a graph conv).
- __init__(hidden_units, output_units, num_layers=3, add_self_loops=True, normalize=True, residual=True, residual_projection=True, batchnorm=True, layernorm=False, input_dropout_rate=0.0, dropout_rate=0.0, jit_compile=False, feature_dtype=None, index_dtype=None, use_shortcut=True, use_bias=True, **kwargs)[source]
Initializes a tunedGNN-style multi-layer GCN stack.
This constructor configures a fixed-width hidden stack of
SparseGCNConvlayers followed by a dense prediction head. The design is intended to mirror the effective forward path of tunedGNN’sMPNNsmodel used on OGBN-Arxiv, where each hidden layer applies:graph convolution,
residual/additive skip,
normalization,
ReLU,
dropout.
- Parameters:
hidden_units – Integer hidden width used for all graph-convolution layers.
output_units – Integer output width for the final prediction head (typically number of classes for node classification).
num_layers – Number of hidden graph-convolution layers. Must be greater than or equal to 1.
add_self_loops – If
True, each convolution adds self-loops to edge indices internally. Set this toFalsewhen self-loops are already added during graph preprocessing to avoid double-counting.normalize – If
True, applies symmetric degree normalization in each convolution.residual – If
True, enables per-layer additive skip connections.residual_projection – If
True(andresidualis enabled), uses a learned linear projection on each residual branch. This most closely matches tunedGNN’sGCNConv(...) + Linear(...)pattern.batchnorm – Enables BatchNormalization inside each hidden convolution layer.
layernorm – Enables LayerNormalization inside each hidden convolution layer. Mutually exclusive with
batchnorm.input_dropout_rate – Dropout probability applied once on input node features before the first hidden layer.
dropout_rate – Dropout probability applied after each hidden layer.
jit_compile – Backend hint for compiled execution in lower-level ops where supported.
feature_dtype – Optional floating-point dtype used for feature computations and trainable weights.
index_dtype – Optional integer dtype used for graph indices.
use_shortcut – If
True, accumulates intermediate hidden outputs into an extra shortcut path that is added before the final head.use_bias – If
True, enables bias terms in hidden convolutions.**kwargs – Additional keyword arguments forwarded to
keras.layers.Layer.
- Raises:
ValueError – If both
batchnormandlayernormare enabled.ValueError – If
num_layers < 1.
Notes
Batch normalization settings are highly sensitive for optimization on OGBN-Arxiv. In internal benchmarking, the BatchNorm hyper- parameters (especially
momentumandepsilon) significantly affect both convergence speed and best validation/test accuracy. Small mismatches from reference implementations can produce large performance gaps even when architecture and optimizer settings are otherwise aligned.