High-level models

Keras layer wrappers for common GraphNet model stacks.

class tf_gnns.models.graphnet.GNCellMLP(*args, **kwargs)[source]

Bases: Layer

Single GraphNet processing block implemented with MLPs.

This layer is close to MLPGraphNetwork from the graph_nets examples.

build(input_shape)[source]
call(g_)[source]

Run one GraphNet update step on a tensor dictionary.

class tf_gnns.models.graphnet.GraphNetMLP(*args, **kwargs)[source]

Bases: Layer

Encode-process-decode GraphNet model implemented with MLPs.

build(d_shapes)[source]
call(g_)[source]

Evaluate the full encode-process-decode GraphNet pipeline.

class tf_gnns.models.graphnet.GraphIndep(*args, **kwargs)[source]

Bases: Layer

Graph-independent block with optional global input handling.

build(input_shape)[source]
call(g_)[source]

Apply graph-independent transforms to nodes/edges/globals.

class tf_gnns.models.graphnet.GraphNetMPNN_MLP(*args, **kwargs)[source]

Bases: Layer

MLP-based MPNN without graph-level global state updates.

build(d_shapes)[source]

Build encoder, processor, and decoder GraphNet modules.

call(g_)[source]

Run the MPNN-style no-global encode-process-decode stack.

Backend-agnostic sparse GCN layers and models.

This module provides: - SparseGCNConv: a single graph convolution layer over GraphTuple-like

tensor dictionaries.

  • SparseGCN: a multi-layer stack for node-level prediction.

The implementation uses only keras.ops and backend facade ops from tf_gnns.backend_ops so it works across Keras backends.

class tf_gnns.models.gcn.SparseGCNConv(*args, **kwargs)[source]

Bases: Layer

Single sparse GCN convolution layer.

Inputs are expected to follow the graph tensor-dict format used by this repository (keys: nodes, senders, receivers, and counts).

This layer supports a paper-faithful update structure with a separate root transform, normalized neighborhood aggregation, and optional BatchNorm:

\[h_v^{l} = \sigma\left(\mathrm{Norm}\left( h_v^{l-1}W_r^l + \sum_{u \in \mathcal{N}(v) \cup \{v\}} \frac{1}{\sqrt{\hat d_u \hat d_v}} h_u^{l-1}W^l \right)\right)\]

where dropout is typically applied by the outer SparseGCN stack.

Reference:

“Classic GNNs are Strong Baselines” (2024), https://arxiv.org/pdf/2406.08993

build(input_shape)[source]
call(d, training=False)[source]
class tf_gnns.models.gcn.SparseGCN(*args, **kwargs)[source]

Bases: Layer

Multi-layer sparse GCN stack for node-level representations.

This layer composes multiple SparseGCNConv blocks and applies optional dropout between hidden layers. By default, each convolution uses BatchNorm, matching the “Norm” term used in common strong-baseline GCN formulations. Inputs follow the graph tensor-dict structure used across tf_gnns:

  • nodes: node feature matrix with shape [N, F]

  • senders: source-node indices for edges, shape [E]

  • receivers: target-node indices for edges, shape [E]

  • edge_weights (optional): scalar edge weights, shape [E]

  • graph bookkeeping keys such as n_nodes/n_edges are passed through

The stack performs repeated message passing in sparse form. Each SparseGCNConv can add self-loops and symmetric normalization D^{-1/2} A D^{-1/2} (configurable).

Typical usage (2-layer GCN for node classification):

gcn = SparseGCN(
    hidden_units=[128],
    output_units=num_classes,
    activation="relu",
    dropout_rate=0.5,
    add_self_loops=True,
    normalize=True,
)
out_td = gcn(input_td, training=True)
logits = out_td["nodes"]

A deeper variant with custom dtypes:

gcn = SparseGCN(
    hidden_units=[256, 256, 128],
    output_units=64,
    activation="relu",
    dropout_rate=0.2,
    feature_dtype="float32",
    index_dtype="int32",
)
Parameters:
  • hidden_units – Width(s) of hidden GCN layers. Accepts an int for a single hidden layer or a non-empty list of int values.

  • output_units – Optional final output width appended after hidden layers. For classification this is typically the number of classes.

  • activation – Activation for hidden layers. Final layer uses None by default so downstream losses can consume logits directly.

  • dropout_rate – Dropout probability applied to node features between convolution layers (not applied after the final layer).

  • add_self_loops – Whether each SparseGCNConv injects identity edges.

  • normalize – Whether each SparseGCNConv applies degree-based normalization to edge weights.

  • batchnorm – Whether each SparseGCNConv applies BatchNormalization to node features after message/root aggregation and before activation. Enabled by default.

  • layernorm – Whether each SparseGCNConv applies LayerNormalization to node features after message/root aggregation and before activation. Must not be enabled at the same time as batchnorm.

  • jit_compile – Flag forwarded to SparseGCNConv for backend-specific compiled execution behavior.

  • residual – If True, add residual node skip-connections between hidden layers when feature dimensions match. Residual links are not applied on the final layer.

  • residual_projection – If True and residual is enabled, add a learnable linear projection on residual paths when hidden feature dimensions do not match. Projection is only used for hidden layers (not the final output layer).

  • feature_dtype – Optional dtype override for feature tensors and weights.

  • index_dtype – Optional dtype override for graph index tensors.

  • **kwargs – Forwarded to keras.layers.Layer.

Returns:

A tensor-dict with updated nodes and preserved graph structure keys.

Raises:

ValueError – If hidden_units is empty.

call(d, training=False)[source]

Apply stacked sparse graph convolutions.

Parameters:
  • d – Graph tensor-dict containing at least nodes, senders, and receivers.

  • training – Whether to run in training mode (controls dropout between hidden layers).

Returns:

Graph tensor-dict with transformed nodes.

class tf_gnns.models.gcn.GCNv2(*args, **kwargs)[source]

Bases: Layer

High-level GCN stack matching tunedGNN’s OGB-Arxiv recipe.

This layer mirrors the common main-arxiv.py setup used in tunedGNN:

  • repeated hidden GCN layers (all with same hidden width),

  • each hidden layer uses the conv’s internal root transform,

  • optional BatchNorm/LayerNorm, ReLU, and dropout per hidden layer,

  • optional input dropout before the first hidden layer,

  • final prediction via a dense linear head (not a graph conv).

__init__(hidden_units, output_units, num_layers=3, add_self_loops=True, normalize=True, residual=True, residual_projection=True, batchnorm=True, layernorm=False, input_dropout_rate=0.0, dropout_rate=0.0, jit_compile=False, feature_dtype=None, index_dtype=None, use_shortcut=True, use_bias=True, **kwargs)[source]

Initializes a tunedGNN-style multi-layer GCN stack.

This constructor configures a fixed-width hidden stack of SparseGCNConv layers followed by a dense prediction head. The design is intended to mirror the effective forward path of tunedGNN’s MPNNs model used on OGBN-Arxiv, where each hidden layer applies:

  1. graph convolution,

  2. residual/additive skip,

  3. normalization,

  4. ReLU,

  5. dropout.

Parameters:
  • hidden_units – Integer hidden width used for all graph-convolution layers.

  • output_units – Integer output width for the final prediction head (typically number of classes for node classification).

  • num_layers – Number of hidden graph-convolution layers. Must be greater than or equal to 1.

  • add_self_loops – If True, each convolution adds self-loops to edge indices internally. Set this to False when self-loops are already added during graph preprocessing to avoid double-counting.

  • normalize – If True, applies symmetric degree normalization in each convolution.

  • residual – If True, enables per-layer additive skip connections.

  • residual_projection – If True (and residual is enabled), uses a learned linear projection on each residual branch. This most closely matches tunedGNN’s GCNConv(...) + Linear(...) pattern.

  • batchnorm – Enables BatchNormalization inside each hidden convolution layer.

  • layernorm – Enables LayerNormalization inside each hidden convolution layer. Mutually exclusive with batchnorm.

  • input_dropout_rate – Dropout probability applied once on input node features before the first hidden layer.

  • dropout_rate – Dropout probability applied after each hidden layer.

  • jit_compile – Backend hint for compiled execution in lower-level ops where supported.

  • feature_dtype – Optional floating-point dtype used for feature computations and trainable weights.

  • index_dtype – Optional integer dtype used for graph indices.

  • use_shortcut – If True, accumulates intermediate hidden outputs into an extra shortcut path that is added before the final head.

  • use_bias – If True, enables bias terms in hidden convolutions.

  • **kwargs – Additional keyword arguments forwarded to keras.layers.Layer.

Raises:
  • ValueError – If both batchnorm and layernorm are enabled.

  • ValueError – If num_layers < 1.

Notes

Batch normalization settings are highly sensitive for optimization on OGBN-Arxiv. In internal benchmarking, the BatchNorm hyper- parameters (especially momentum and epsilon) significantly affect both convergence speed and best validation/test accuracy. Small mismatches from reference implementations can produce large performance gaps even when architecture and optimizer settings are otherwise aligned.

call(d, training=False)[source]