{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "vln3LX7iYdts" }, "source": [ "# `tf_gnns`: A Hackable `graph_nets` library\n", "largely inspired by the `graph_nets` project.\n", "\n", "First install the library on the colab notebook and import it" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8OZrOK7oJ9XW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf_gnns==0.1.2b in /home/charilaos/.local/lib/python3.8/site-packages (0.1.2b0)\r\n" ] } ], "source": [ "# Install the tf_gnns library with pip:\n", "!pip install tf_gnns==0.1.7\n", "import tf_gnns" ] }, { "cell_type": "markdown", "metadata": { "id": "38YcdOq2KIb6" }, "source": [ "# **GraphNet basics**\n", "The data-structure over which GraphNets (GNs) operate is an attributed multi-graph. In what follows some examples of definition of graphs are given. \n", "\n", "\n", "As detailed later in the notebook, the `GraphTuple` or a dictionary that corresponds to a `GraphTuple` are the recommended data structures for achieving good performance.\n", "\n", "## Creating a single `Graph'\n", "In order to specify a `Graph' you need to specify a set of edges and nodes. *(note that this data-structure is not very efficient. It is, however, more intuitive to work with.)*\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "wQfUezXNKOUi" }, "outputs": [], "source": [ "# Creation of edges and nodes:\n", "import numpy as np\n", "from tf_gnns import Graph, Edge, Node\n", "\n", "nw_state_size = 10 # the state of the node and edge attributes. They can be different if necessary.\n", "\n", "# Defining the connectivity of a graph:\n", "adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,7)]\n", "\n", "nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]\n", "edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]\n", "graph_A = Graph(nodes=nodes_A, edges= edges_A)" ] }, { "cell_type": "markdown", "metadata": { "id": "3Zt-TYiiaG1f" }, "source": [ "For more expressivity some operations are over-loaded. Inspect the docstrings and the source code directly when not sure what happens. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kt_Kx5XMaYCO", "outputId": "1ea575ba-10a6-45c2-c218-6f923f1ffe40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on method __add__ in module tf_gnns.datastructures:\n", "\n", "__add__(graph) method of tf_gnns.datastructures.Graph instance\n", " This should only work with graphs that have compatible node and edge features\n", " Assumed also that the two graphs have the same connectivity (otherwise this will fail ugly)\n", "\n" ] } ], "source": [ "help(graph_A.__add__)" ] }, { "cell_type": "markdown", "metadata": { "id": "fiisOCW1KsJa" }, "source": [ "Simple graphs are not convenient for bached computation and in general should be avoided. Sometimes it may be \"safer\" to use graphs as a datastructure for experimentation on new algorithms. \n", "\n", "For most use-cases a `GraphTuple' should be used. `GraphTuples' are the only datastructures supported from the `graph_nets' library.\n", "\n", "## Creating a `GraphTuple`" ] }, { "cell_type": "markdown", "metadata": { "id": "U08TIII0aiOF" }, "source": [ "# The `GraphTuple` data structure\n", "The `GraphTuple` is a set of graphs packed into an object that allows for easier parallelization of the `GraphNet` computation block. It is the default (and only) possible data-structure in DeepMind's GraphNets library.\n", "You can either create a `GraphTuple` directly from tensors or numpy arrays or by a list of `Graph` objects (as above). Utilities are provided to perform such manipulations. Note that in general it's better to avoid using `Graph`s alltogether, but they are useful for testing.\n", "\n", "## Note on performance\n", "For achieving good performance it is important to wrap the training iteration code (dataset sampling, forward pass, gradient computation) in a `tf.function` decorated block. This has some limmitations with respect to the datastructures that are allowed in the arguments of the wrapped functions. Unfortunately, only `tf.Tensors`, built-in python datatypes, or collections of `tf.Tensors` are allowed for the `tf.function` arguments (`GraphTuples` are not allowed). A work-around for using `tf_gnns` with `tf.function` is to cast the `GraphTuple` to a dictionary.\n", "\n", "## Creation of a `GraphTuple`\n", "\n", "Here is how to create a `GraphTuple` from a list of `Graph` objects:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "tNZYYcuZa8Fh" }, "outputs": [], "source": [ "from tf_gnns import make_graph_tuple_from_graph_list\n", "\n", "# Some graphs to compute with:\n", "adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(4,2)]\n", "adj_B = [(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,9)]\n", "\n", "nw_state_size = 2\n", "\n", "nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(15)]\n", "edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]\n", "graph_A = Graph(nodes=nodes_A, edges= edges_A)\n", " \n", "nodes_B = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]\n", "edges_B = [Edge(np.random.randn(1,nw_state_size), node_from=nodes_B[e_ij[0]], node_to = nodes_B[e_ij[1]]) for e_ij in adj_B]\n", "graph_B = Graph(nodes=nodes_B, edges= edges_B)\n", "gt = make_graph_tuple_from_graph_list([graph_A, graph_B])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0eU4-zIDbLoH", "outputId": "84623566-7d43-4289-aa4d-60574652e8f0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of nodes of graphtuple:\n", "25\n", "GraphTuple tensor shapes:\n", "(25, 2) (23, 2)\n", "Shapes of a single node from graph_A: [1,2]\n", "Number of nodes of graphs: A: 15, B: 10\n", "\n", "size of node attributes: [2]\n", "size of edge attributes: [2]\n" ] } ], "source": [ "print(\"number of nodes of graphtuple:\")\n", "print(gt.nodes.shape[0])\n", "\n", "print('GraphTuple tensor shapes:')\n", "print(gt.nodes.shape, gt.edges.shape)\n", "\n", "print(\"Shapes of a single node from graph_A: [%i,%i]\"%(graph_A.nodes[0].shape))\n", "\n", "\n", "print(\"Number of nodes of graphs: A: %i, B: %i\"%(len(graph_A.nodes), len(graph_B.nodes)))\n", "print(\"\")\n", "print(\"size of node attributes: [%i]\"%( graph_A.nodes[0].node_attr_tensor.shape[-1]))\n", "print(\"size of edge attributes: [%i]\"%( graph_A.edges[0].edge_tensor.shape[-1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "eaBvkXVkbz0D" }, "source": [ "# **Creating a Custom `GraphNet`**\n", "\n", "In order to create a GraphNet (without global variables) one needs to define the following:\n", "\n", "* node function\n", "* an edge function\n", "* an edge aggregation function (except if a GraphIndependent network is implemented)\n", "\n", "In addition to that, one needs to pay attention that the input sizes of the node and edge function are consistent with the input graph, the edge aggregation function (if it exists) has to have outputs consistent with expected inputs of the node function (if the graph is not graph independet). Moreover, each of these functions have potentially different inputs related to the input graph.\n", "\n", "For instance:\n", "\n", "* Edge functions may have as inputs (\n", " 1) the edge state, \n", " 2) the sender node state, \n", " 3) the receiver node state\n", "\n", "* Node functions may have as inputs \n", " 1) the node state \n", " 2) an aggregated message incoming from the edges that point to that node. \n", " \n", "These cases are identified internally by the naming of the inputs of the provided functions. See the examples to understand how to implement your own in case you want something special to happen in these functions. The `tf_gnns` library is built with `tf.keras` in mind and defines constraints to help development in a way that it exploits keras constructs, such as naming of inputs and output variables and the functional API of keras (`tf.keras.Model`) to facilitate building `GraphNets`.\n", "\n", "\n", "In contrast, DeepMind's GraphNet library uses `Sonnet` for that reason, and puts constraints on how computation and model construction happens for the same purpose. For some use-cases (such as implementation of bootstrapping or potential extensions on sparse graphs) It was found this was not very convenient - this is why this library was built.\n", "\n", "One can think of `tf_gnns` as a `graph_nets` library without the `tf1.x` and `Sonnet` \"baggage\".\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SNodjfRfcHJO", "outputId": "855ca824-812a-4be9-c8f4-9f81245e3a2b" }, "outputs": [], "source": [ "from tf_gnns import make_mlp_graphnet_functions, GraphNet, Node, Edge, Graph, GraphTuple, make_graph_tuple_from_graph_list\n", "## Naming:\n", "# ... _enc : parameter of encoder block GNN (graph indep.)\n", "# ... _core : parameter of core block GNN \n", "# ... _dec : parameter of decoder block GNN\n", "gnn_size = 50;\n", "node_input_size_enc, node_output_size_enc = [2, 10]\n", "node_input_size_core, node_output_size_core = [10, 10]\n", "node_input_size_dec, node_output_size_dec = [10,2]\n", "\n", "edge_input_size_enc = 2\n", "edge_output_size_dec = 2\n", "\n", "graph_fcn_enc = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_enc,\n", " node_output_size = node_output_size_enc,\n", " edge_input_size = edge_input_size_enc,\n", " graph_indep=True)\n", "\n", "graph_fcn_core = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_core, \n", " node_output_size = node_output_size_core, \n", " graph_indep=False)\n", "\n", "graph_fcn_dec = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_dec, \n", " node_output_size = node_output_size_dec, \n", " edge_output_size= edge_output_size_dec,\n", " graph_indep=True)\n", "\n", "\n", "graph_fcn_enc['graph_independent'] = True\n", "graph_fcn_dec['graph_independent'] = True\n", "\n", "gnn_dicts = [graph_fcn_enc, graph_fcn_core, graph_fcn_enc]\n", "\n", "gnns = [GraphNet(**fcns) for fcns in [graph_fcn_enc,graph_fcn_core,graph_fcn_dec]] # A full encode-core-decode set of GNNs. One may eval. the core block multiple times.\n", " " ] }, { "cell_type": "markdown", "metadata": { "id": "CDMVNfemgOsD" }, "source": [ "In the example above `make_mlp_graphnet_functions` is a factory method. It conveniently returns a set of functions (keras models) that are in turn used to create GNNs.\n", "\n", "## Easy inspection of GN functions\n", "The `tf_gnns.GraphNet` can be pretty-printed in jupyter for easier inspection. Click on the buttons to see the contents of each GN function:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AeUjRGf9j0c5", "outputId": "2b9c557f-95d2-4acb-a2f0-e6d85a58e878" }, "outputs": [ { "data": { "text/html": [ "
Model: \"node_fn/model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"node_state (InputLayer) [(None, 2)] 0 \n",
"_________________________________________________________________\n",
"node_fn/dense_0 (Dense) (None, 50) 100 \n",
"_________________________________________________________________\n",
"node_fn/dense_1 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"node_fn/dense_2 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"node_fn/dense_3 (Dense) (None, 10) 510 \n",
"=================================================================\n",
"Total params: 5,710\n",
"Trainable params: 5,710\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Model: \"edge_fn/model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"edge_state (InputLayer) [(None, 2)] 0 \n",
"_________________________________________________________________\n",
"edge_fn/dense_0 (Dense) (None, 50) 100 \n",
"_________________________________________________________________\n",
"edge_fn/dense_1 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"edge_fn/dense_2 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"edge_fn/dense_3 (Dense) (None, 10) 510 \n",
"=================================================================\n",
"Total params: 5,710\n",
"Trainable params: 5,710\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Model: \"node_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"node_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"edge_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_1 (Concatenate) (None, 20) 0 node_state[0][0] \n",
" edge_state_agg[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_0 (Dense) (None, 50) 1000 concatenate_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 6,610\n",
"Trainable params: 6,610\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"edge_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"edge_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"sender_node_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"receiver_node_state (InputLayer [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate (Concatenate) (None, 30) 0 edge_state[0][0] \n",
" sender_node_state[0][0] \n",
" receiver_node_state[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_0 (Dense) (None, 50) 1500 concatenate[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 7,110\n",
"Trainable params: 7,110\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"node_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"node_state (InputLayer) [(None, 2)] 0 \n",
"__________________________________________________________________________________________________\n",
"edge_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_3 (Concatenate) (None, 12) 0 node_state[0][0] \n",
" edge_state_agg[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_0 (Dense) (None, 50) 600 concatenate_3[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 6,210\n",
"Trainable params: 6,210\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"edge_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"edge_state (InputLayer) [(None, 2)] 0 \n",
"__________________________________________________________________________________________________\n",
"sender_node_state (InputLayer) [(None, 2)] 0 \n",
"__________________________________________________________________________________________________\n",
"receiver_node_state (InputLayer [(None, 2)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_2 (Concatenate) (None, 6) 0 edge_state[0][0] \n",
" sender_node_state[0][0] \n",
" receiver_node_state[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_0 (Dense) (None, 50) 300 concatenate_2[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 5,910\n",
"Trainable params: 5,910\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"global_mlp/glob_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"node_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"edge_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_4 (Concatenate) (None, 20) 0 node_state_agg[0][0] \n",
" edge_state_agg[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_0 (Den (None, 50) 1000 concatenate_4[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_1 (Den (None, 50) 2550 global_mlp/glob_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_2 (Den (None, 50) 2550 global_mlp/glob_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_3 (Den (None, 10) 510 global_mlp/glob_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 6,610\n",
"Trainable params: 6,610\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"node_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"node_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"edge_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"global_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_6 (Concatenate) (None, 30) 0 node_state[0][0] \n",
" edge_state_agg[0][0] \n",
" global_state[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_0 (Dense) (None, 50) 1500 concatenate_6[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_1 (Dense) (None, 50) 2550 node_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_2 (Dense) (None, 50) 2550 node_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"node_fn/dense_3 (Dense) (None, 10) 510 node_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 7,110\n",
"Trainable params: 7,110\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"edge_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"edge_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"sender_node_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"receiver_node_state (InputLayer [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"global_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_5 (Concatenate) (None, 40) 0 edge_state[0][0] \n",
" sender_node_state[0][0] \n",
" receiver_node_state[0][0] \n",
" global_state[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_0 (Dense) (None, 50) 2000 concatenate_5[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_1 (Dense) (None, 50) 2550 edge_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_2 (Dense) (None, 50) 2550 edge_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"edge_fn/dense_3 (Dense) (None, 10) 510 edge_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 7,610\n",
"Trainable params: 7,610\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"global_mlp/glob_fn/model\"\n",
"__________________________________________________________________________________________________\n",
"Layer (type) Output Shape Param # Connected to \n",
"==================================================================================================\n",
"global_state (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"node_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"edge_state_agg (InputLayer) [(None, 10)] 0 \n",
"__________________________________________________________________________________________________\n",
"concatenate_7 (Concatenate) (None, 30) 0 global_state[0][0] \n",
" node_state_agg[0][0] \n",
" edge_state_agg[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_0 (Den (None, 50) 1500 concatenate_7[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_1 (Den (None, 50) 2550 global_mlp/glob_fn/dense_0[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_2 (Den (None, 50) 2550 global_mlp/glob_fn/dense_1[0][0] \n",
"__________________________________________________________________________________________________\n",
"global_mlp/glob_fn/dense_3 (Den (None, 10) 510 global_mlp/glob_fn/dense_2[0][0] \n",
"==================================================================================================\n",
"Total params: 7,110\n",
"Trainable params: 7,110\n",
"Non-trainable params: 0\n",
"__________________________________________________________________________________________________\n",
"Model: \"node_fn/model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"node_state (InputLayer) [(None, 10)] 0 \n",
"_________________________________________________________________\n",
"node_fn/dense_0 (Dense) (None, 50) 500 \n",
"_________________________________________________________________\n",
"node_fn/dense_1 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"node_fn/dense_2 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"node_fn/dense_3 (Dense) (None, 2) 102 \n",
"=================================================================\n",
"Total params: 5,702\n",
"Trainable params: 5,702\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"Model: \"edge_fn/model\"\n",
"_________________________________________________________________\n",
"Layer (type) Output Shape Param # \n",
"=================================================================\n",
"edge_state (InputLayer) [(None, 10)] 0 \n",
"_________________________________________________________________\n",
"edge_fn/dense_0 (Dense) (None, 50) 500 \n",
"_________________________________________________________________\n",
"edge_fn/dense_1 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"edge_fn/dense_2 (Dense) (None, 50) 2550 \n",
"_________________________________________________________________\n",
"edge_fn/dense_3 (Dense) (None, 2) 102 \n",
"=================================================================\n",
"Total params: 5,702\n",
"Trainable params: 5,702\n",
"Non-trainable params: 0\n",
"_________________________________________________________________\n",
"