{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# tf_gnns with Keras 3 Torch Backend\n", "\n", "This notebook demonstrates using the higher-level `tf_gnns` GraphNet constructs with the **PyTorch backend** via Keras 3.\n", "\n", "It focuses on model construction and forward passes with graph tensor dictionaries." ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 1) Backend setup\n", "\n", "Set `KERAS_BACKEND=torch` **before** importing `keras` or `tf_gnns`.\n", "\n", "If needed, install PyTorch wheels:\n", "\n", "```bash\n", "pip install torch --index-url https://download.pytorch.org/whl/cpu # or GPU.\n", "```" ] }, { "cell_type": "code", "execution_count": null, "id": "114be5b2", "metadata": {}, "outputs": [], "source": [ "!pip install tf_gnns==0.2.0" ] }, { "cell_type": "code", "execution_count": null, "id": "4440f648", "metadata": {}, "outputs": [], "source": [ "import os\n", "os.environ['KERAS_BACKEND'] = 'torch'\n", "\n", "import keras\n", "from tf_gnns.models.graphnet import GraphNetMLP, GraphNetMPNN_MLP\n", "from tf_gnns.tfgnns_datastructures import GraphTuple\n", "\n", "print('Keras version:', keras.__version__)\n", "print('Active backend:', keras.backend.backend())" ] }, { "cell_type": "markdown", "id": "8bc5262c", "metadata": {}, "source": [ "## 2) Build a sample `GraphTuple` tensor dictionary" ] }, { "cell_type": "code", "execution_count": null, "id": "e06592fa", "metadata": {}, "outputs": [], "source": [ "# Small toy graph with one graph in the batch\n", "nodes = keras.ops.convert_to_tensor([[1.0, 0.0], [2.0, -1.0], [3.0, 4.0]], dtype='float32')\n", "edges = keras.ops.convert_to_tensor([[0.5, -1.0], [1.5, 2.0], [-0.3, 0.7]], dtype='float32')\n", "senders = keras.ops.convert_to_tensor([0, 1, 2], dtype='int32')\n", "receivers = keras.ops.convert_to_tensor([1, 2, 0], dtype='int32')\n", "global_attr = keras.ops.convert_to_tensor([[0.2, -0.4]], dtype='float32')\n", "\n", "gt = GraphTuple(\n", " nodes=nodes,\n", " edges=edges,\n", " senders=senders,\n", " receivers=receivers,\n", " n_nodes=[3],\n", " n_edges=[3],\n", " global_attr=global_attr,\n", ")\n", "\n", "# Build tensor-dict explicitly to stay backend-neutral even on older installed tf_gnns versions.\n", "td = {\n", " 'nodes': nodes,\n", " 'edges': edges,\n", " 'senders': senders,\n", " 'receivers': receivers,\n", " 'n_nodes': keras.ops.convert_to_tensor([3], dtype='int32'),\n", " 'n_edges': keras.ops.convert_to_tensor([3], dtype='int32'),\n", " 'n_graphs': keras.ops.convert_to_tensor(1, dtype='int32'),\n", " 'global_attr': global_attr,\n", " 'global_reps_for_edges': keras.ops.convert_to_tensor([0, 0, 0], dtype='int32'),\n", " 'global_reps_for_nodes': keras.ops.convert_to_tensor([0, 0, 0], dtype='int32'),\n", "}\n", "td.keys()" ] }, { "cell_type": "markdown", "id": "a98d8def", "metadata": {}, "source": [ "## 3) Run `GraphNetMLP` (with globals)" ] }, { "cell_type": "code", "execution_count": null, "id": "ba0c1f48", "metadata": {}, "outputs": [], "source": [ "assert keras.backend.backend() == 'torch', f\"Active backend is {keras.backend.backend()} (expected 'torch'). Restart kernel and run from top.\"\n", "\n", "# Normalize tensor-dict to active backend tensors (defensive for mixed notebook states).\n", "td = {k: (None if v is None else keras.ops.convert_to_tensor(v)) for k, v in td.items()}\n", "\n", "model_global = GraphNetMLP(\n", " units=16,\n", " core_steps=2,\n", " recurrent=False,\n", " residual=True,\n", " node_output_size=6,\n", " edge_output_size=5,\n", " global_output_size=4,\n", ")\n", "\n", "out_global = model_global(td)\n", "print('nodes shape:', keras.ops.shape(out_global['nodes']))\n", "print('edges shape:', keras.ops.shape(out_global['edges']))\n", "print('global shape:', keras.ops.shape(out_global['global_attr']))" ] }, { "cell_type": "markdown", "id": "51c40bb8", "metadata": {}, "source": [ "## 4) Run `GraphNetMPNN_MLP` (no globals)" ] }, { "cell_type": "code", "execution_count": null, "id": "c6b37efa", "metadata": {}, "outputs": [], "source": [ "td_no_global = {k: v for k, v in td.items()}\n", "td_no_global['global_attr'] = None\n", "\n", "model_mpnn = GraphNetMPNN_MLP(\n", " units=16,\n", " core_steps=2,\n", " recurrent=False,\n", " residual=True,\n", " node_output_size=6,\n", " edge_output_size=5,\n", ")\n", "\n", "out_mpnn = model_mpnn(td_no_global)\n", "print('nodes shape:', keras.ops.shape(out_mpnn['nodes']))\n", "print('edges shape:', keras.ops.shape(out_mpnn['edges']))" ] }, { "cell_type": "markdown", "id": "f4633c53", "metadata": {}, "source": [ "## 5) Verify structure tensors are preserved\n", "\n", "Higher-level blocks should update feature tensors (`nodes`, `edges`, `global_attr`) while preserving graph structure bookkeeping." ] }, { "cell_type": "code", "execution_count": null, "id": "19d6a9f3", "metadata": {}, "outputs": [], "source": [ "for key in ['senders', 'receivers', 'n_nodes', 'n_edges', 'global_reps_for_edges', 'global_reps_for_nodes', 'n_graphs']:\n", " lhs = keras.ops.convert_to_numpy(out_global[key])\n", " rhs = keras.ops.convert_to_numpy(td[key])\n", " assert (lhs == rhs).all(), f'Mismatch in {key}'\n", "\n", "print('Structure tensors preserved.')" ] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "name": "python", "version": "3.11" } }, "nbformat": 4, "nbformat_minor": 5 }