Quickstart

This example uses the high-level GraphNetMLP layer on a tensor-dictionary graph tuple.

import numpy as np
from tf_gnns.models.graphnet import GraphNetMLP

graph = {
    "nodes": np.random.randn(4, 8).astype("float32"),
    "edges": np.random.randn(6, 8).astype("float32"),
    "senders": np.array([0, 0, 1, 2, 3, 3], dtype="int32"),
    "receivers": np.array([1, 2, 2, 3, 0, 1], dtype="int32"),
    "n_nodes": np.array([4], dtype="int32"),
    "n_edges": np.array([6], dtype="int32"),
    "global_attr": np.random.randn(1, 8).astype("float32"),
    "global_reps_for_nodes": np.zeros((4,), dtype="int32"),
    "global_reps_for_edges": np.zeros((6,), dtype="int32"),
    "n_graphs": 1,
}

model = GraphNetMLP(units=32, core_steps=2)
output = model(graph)

print(output["nodes"].shape, output["edges"].shape, output["global_attr"].shape)

For more complete workflows, see the rendered notebooks in Tutorials and examples.