tf_gnns with Keras 3 Torch Backend

This notebook demonstrates using the higher-level tf_gnns GraphNet constructs with the PyTorch backend via Keras 3.

It focuses on model construction and forward passes with graph tensor dictionaries.

1) Backend setup

Set KERAS_BACKEND=torch before importing keras or tf_gnns.

If needed, install PyTorch wheels:

pip install torch --index-url https://download.pytorch.org/whl/cpu # or GPU.
!pip install tf_gnns==0.2.0
import os
os.environ['KERAS_BACKEND'] = 'torch'

import keras
from tf_gnns.models.graphnet import GraphNetMLP, GraphNetMPNN_MLP
from tf_gnns.tfgnns_datastructures import GraphTuple

print('Keras version:', keras.__version__)
print('Active backend:', keras.backend.backend())

2) Build a sample GraphTuple tensor dictionary

# Small toy graph with one graph in the batch
nodes = keras.ops.convert_to_tensor([[1.0, 0.0], [2.0, -1.0], [3.0, 4.0]], dtype='float32')
edges = keras.ops.convert_to_tensor([[0.5, -1.0], [1.5, 2.0], [-0.3, 0.7]], dtype='float32')
senders = keras.ops.convert_to_tensor([0, 1, 2], dtype='int32')
receivers = keras.ops.convert_to_tensor([1, 2, 0], dtype='int32')
global_attr = keras.ops.convert_to_tensor([[0.2, -0.4]], dtype='float32')

gt = GraphTuple(
    nodes=nodes,
    edges=edges,
    senders=senders,
    receivers=receivers,
    n_nodes=[3],
    n_edges=[3],
    global_attr=global_attr,
)

# Build tensor-dict explicitly to stay backend-neutral even on older installed tf_gnns versions.
td = {
    'nodes': nodes,
    'edges': edges,
    'senders': senders,
    'receivers': receivers,
    'n_nodes': keras.ops.convert_to_tensor([3], dtype='int32'),
    'n_edges': keras.ops.convert_to_tensor([3], dtype='int32'),
    'n_graphs': keras.ops.convert_to_tensor(1, dtype='int32'),
    'global_attr': global_attr,
    'global_reps_for_edges': keras.ops.convert_to_tensor([0, 0, 0], dtype='int32'),
    'global_reps_for_nodes': keras.ops.convert_to_tensor([0, 0, 0], dtype='int32'),
}
td.keys()

3) Run GraphNetMLP (with globals)

assert keras.backend.backend() == 'torch', f"Active backend is {keras.backend.backend()} (expected 'torch'). Restart kernel and run from top."

# Normalize tensor-dict to active backend tensors (defensive for mixed notebook states).
td = {k: (None if v is None else keras.ops.convert_to_tensor(v)) for k, v in td.items()}

model_global = GraphNetMLP(
    units=16,
    core_steps=2,
    recurrent=False,
    residual=True,
    node_output_size=6,
    edge_output_size=5,
    global_output_size=4,
)

out_global = model_global(td)
print('nodes shape:', keras.ops.shape(out_global['nodes']))
print('edges shape:', keras.ops.shape(out_global['edges']))
print('global shape:', keras.ops.shape(out_global['global_attr']))

4) Run GraphNetMPNN_MLP (no globals)

td_no_global = {k: v for k, v in td.items()}
td_no_global['global_attr'] = None

model_mpnn = GraphNetMPNN_MLP(
    units=16,
    core_steps=2,
    recurrent=False,
    residual=True,
    node_output_size=6,
    edge_output_size=5,
)

out_mpnn = model_mpnn(td_no_global)
print('nodes shape:', keras.ops.shape(out_mpnn['nodes']))
print('edges shape:', keras.ops.shape(out_mpnn['edges']))

5) Verify structure tensors are preserved

Higher-level blocks should update feature tensors (nodes, edges, global_attr) while preserving graph structure bookkeeping.

for key in ['senders', 'receivers', 'n_nodes', 'n_edges', 'global_reps_for_edges', 'global_reps_for_nodes', 'n_graphs']:
    lhs = keras.ops.convert_to_numpy(out_global[key])
    rhs = keras.ops.convert_to_numpy(td[key])
    assert (lhs == rhs).all(), f'Mismatch in {key}'

print('Structure tensors preserved.')