{ "cells": [ { "cell_type": "markdown", "metadata": { "id": "vln3LX7iYdts" }, "source": [ "# `tf_gnns`: A Hackable `graph_nets` library\n", "largely inspired by the `graph_nets` project.\n", "\n", "First install the library on the colab notebook and import it" ] }, { "cell_type": "code", "execution_count": null, "metadata": { "id": "8OZrOK7oJ9XW" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Requirement already satisfied: tf_gnns==0.1.2b in /home/charilaos/.local/lib/python3.8/site-packages (0.1.2b0)\r\n" ] } ], "source": [ "# Install the tf_gnns library with pip:\n", "!pip install tf_gnns==0.1.7\n", "import tf_gnns" ] }, { "cell_type": "markdown", "metadata": { "id": "38YcdOq2KIb6" }, "source": [ "# **GraphNet basics**\n", "The data-structure over which GraphNets (GNs) operate is an attributed multi-graph. In what follows some examples of definition of graphs are given. \n", "\n", "\n", "As detailed later in the notebook, the `GraphTuple` or a dictionary that corresponds to a `GraphTuple` are the recommended data structures for achieving good performance.\n", "\n", "## Creating a single `Graph'\n", "In order to specify a `Graph' you need to specify a set of edges and nodes. *(note that this data-structure is not very efficient. It is, however, more intuitive to work with.)*\n", "\n" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "id": "wQfUezXNKOUi" }, "outputs": [], "source": [ "# Creation of edges and nodes:\n", "import numpy as np\n", "from tf_gnns import Graph, Edge, Node\n", "\n", "nw_state_size = 10 # the state of the node and edge attributes. They can be different if necessary.\n", "\n", "# Defining the connectivity of a graph:\n", "adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,7)]\n", "\n", "nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]\n", "edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]\n", "graph_A = Graph(nodes=nodes_A, edges= edges_A)" ] }, { "cell_type": "markdown", "metadata": { "id": "3Zt-TYiiaG1f" }, "source": [ "For more expressivity some operations are over-loaded. Inspect the docstrings and the source code directly when not sure what happens. " ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "Kt_Kx5XMaYCO", "outputId": "1ea575ba-10a6-45c2-c218-6f923f1ffe40" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Help on method __add__ in module tf_gnns.datastructures:\n", "\n", "__add__(graph) method of tf_gnns.datastructures.Graph instance\n", " This should only work with graphs that have compatible node and edge features\n", " Assumed also that the two graphs have the same connectivity (otherwise this will fail ugly)\n", "\n" ] } ], "source": [ "help(graph_A.__add__)" ] }, { "cell_type": "markdown", "metadata": { "id": "fiisOCW1KsJa" }, "source": [ "Simple graphs are not convenient for bached computation and in general should be avoided. Sometimes it may be \"safer\" to use graphs as a datastructure for experimentation on new algorithms. \n", "\n", "For most use-cases a `GraphTuple' should be used. `GraphTuples' are the only datastructures supported from the `graph_nets' library.\n", "\n", "## Creating a `GraphTuple`" ] }, { "cell_type": "markdown", "metadata": { "id": "U08TIII0aiOF" }, "source": [ "# The `GraphTuple` data structure\n", "The `GraphTuple` is a set of graphs packed into an object that allows for easier parallelization of the `GraphNet` computation block. It is the default (and only) possible data-structure in DeepMind's GraphNets library.\n", "You can either create a `GraphTuple` directly from tensors or numpy arrays or by a list of `Graph` objects (as above). Utilities are provided to perform such manipulations. Note that in general it's better to avoid using `Graph`s alltogether, but they are useful for testing.\n", "\n", "## Note on performance\n", "For achieving good performance it is important to wrap the training iteration code (dataset sampling, forward pass, gradient computation) in a `tf.function` decorated block. This has some limmitations with respect to the datastructures that are allowed in the arguments of the wrapped functions. Unfortunately, only `tf.Tensors`, built-in python datatypes, or collections of `tf.Tensors` are allowed for the `tf.function` arguments (`GraphTuples` are not allowed). A work-around for using `tf_gnns` with `tf.function` is to cast the `GraphTuple` to a dictionary.\n", "\n", "## Creation of a `GraphTuple`\n", "\n", "Here is how to create a `GraphTuple` from a list of `Graph` objects:" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "id": "tNZYYcuZa8Fh" }, "outputs": [], "source": [ "from tf_gnns import make_graph_tuple_from_graph_list\n", "\n", "# Some graphs to compute with:\n", "adj_A = [(2,4),(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(4,2)]\n", "adj_B = [(3,4),(2,2) , (2,5),(5,1),(1,2),(2,3),(3,4),(4,5),(6,5),(7,6),(8,9)]\n", "\n", "nw_state_size = 2\n", "\n", "nodes_A = [Node(np.random.randn(1,nw_state_size)) for n in range(15)]\n", "edges_A = [Edge(np.random.randn(1,nw_state_size), node_from= nodes_A[e_ij[0]], node_to= nodes_A[e_ij[1]]) for e_ij in adj_A]\n", "graph_A = Graph(nodes=nodes_A, edges= edges_A)\n", " \n", "nodes_B = [Node(np.random.randn(1,nw_state_size)) for n in range(10)]\n", "edges_B = [Edge(np.random.randn(1,nw_state_size), node_from=nodes_B[e_ij[0]], node_to = nodes_B[e_ij[1]]) for e_ij in adj_B]\n", "graph_B = Graph(nodes=nodes_B, edges= edges_B)\n", "gt = make_graph_tuple_from_graph_list([graph_A, graph_B])" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "0eU4-zIDbLoH", "outputId": "84623566-7d43-4289-aa4d-60574652e8f0" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "number of nodes of graphtuple:\n", "25\n", "GraphTuple tensor shapes:\n", "(25, 2) (23, 2)\n", "Shapes of a single node from graph_A: [1,2]\n", "Number of nodes of graphs: A: 15, B: 10\n", "\n", "size of node attributes: [2]\n", "size of edge attributes: [2]\n" ] } ], "source": [ "print(\"number of nodes of graphtuple:\")\n", "print(gt.nodes.shape[0])\n", "\n", "print('GraphTuple tensor shapes:')\n", "print(gt.nodes.shape, gt.edges.shape)\n", "\n", "print(\"Shapes of a single node from graph_A: [%i,%i]\"%(graph_A.nodes[0].shape))\n", "\n", "\n", "print(\"Number of nodes of graphs: A: %i, B: %i\"%(len(graph_A.nodes), len(graph_B.nodes)))\n", "print(\"\")\n", "print(\"size of node attributes: [%i]\"%( graph_A.nodes[0].node_attr_tensor.shape[-1]))\n", "print(\"size of edge attributes: [%i]\"%( graph_A.edges[0].edge_tensor.shape[-1]))" ] }, { "cell_type": "markdown", "metadata": { "id": "eaBvkXVkbz0D" }, "source": [ "# **Creating a Custom `GraphNet`**\n", "\n", "In order to create a GraphNet (without global variables) one needs to define the following:\n", "\n", "* node function\n", "* an edge function\n", "* an edge aggregation function (except if a GraphIndependent network is implemented)\n", "\n", "In addition to that, one needs to pay attention that the input sizes of the node and edge function are consistent with the input graph, the edge aggregation function (if it exists) has to have outputs consistent with expected inputs of the node function (if the graph is not graph independet). Moreover, each of these functions have potentially different inputs related to the input graph.\n", "\n", "For instance:\n", "\n", "* Edge functions may have as inputs (\n", " 1) the edge state, \n", " 2) the sender node state, \n", " 3) the receiver node state\n", "\n", "* Node functions may have as inputs \n", " 1) the node state \n", " 2) an aggregated message incoming from the edges that point to that node. \n", " \n", "These cases are identified internally by the naming of the inputs of the provided functions. See the examples to understand how to implement your own in case you want something special to happen in these functions. The `tf_gnns` library is built with `tf.keras` in mind and defines constraints to help development in a way that it exploits keras constructs, such as naming of inputs and output variables and the functional API of keras (`tf.keras.Model`) to facilitate building `GraphNets`.\n", "\n", "\n", "In contrast, DeepMind's GraphNet library uses `Sonnet` for that reason, and puts constraints on how computation and model construction happens for the same purpose. For some use-cases (such as implementation of bootstrapping or potential extensions on sparse graphs) It was found this was not very convenient - this is why this library was built.\n", "\n", "One can think of `tf_gnns` as a `graph_nets` library without the `tf1.x` and `Sonnet` \"baggage\".\n", "\n" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "SNodjfRfcHJO", "outputId": "855ca824-812a-4be9-c8f4-9f81245e3a2b" }, "outputs": [], "source": [ "from tf_gnns import make_mlp_graphnet_functions, GraphNet, Node, Edge, Graph, GraphTuple, make_graph_tuple_from_graph_list\n", "## Naming:\n", "# ... _enc : parameter of encoder block GNN (graph indep.)\n", "# ... _core : parameter of core block GNN \n", "# ... _dec : parameter of decoder block GNN\n", "gnn_size = 50;\n", "node_input_size_enc, node_output_size_enc = [2, 10]\n", "node_input_size_core, node_output_size_core = [10, 10]\n", "node_input_size_dec, node_output_size_dec = [10,2]\n", "\n", "edge_input_size_enc = 2\n", "edge_output_size_dec = 2\n", "\n", "graph_fcn_enc = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_enc,\n", " node_output_size = node_output_size_enc,\n", " edge_input_size = edge_input_size_enc,\n", " graph_indep=True)\n", "\n", "graph_fcn_core = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_core, \n", " node_output_size = node_output_size_core, \n", " graph_indep=False)\n", "\n", "graph_fcn_dec = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_dec, \n", " node_output_size = node_output_size_dec, \n", " edge_output_size= edge_output_size_dec,\n", " graph_indep=True)\n", "\n", "\n", "graph_fcn_enc['graph_independent'] = True\n", "graph_fcn_dec['graph_independent'] = True\n", "\n", "gnn_dicts = [graph_fcn_enc, graph_fcn_core, graph_fcn_enc]\n", "\n", "gnns = [GraphNet(**fcns) for fcns in [graph_fcn_enc,graph_fcn_core,graph_fcn_dec]] # A full encode-core-decode set of GNNs. One may eval. the core block multiple times.\n", " " ] }, { "cell_type": "markdown", "metadata": { "id": "CDMVNfemgOsD" }, "source": [ "In the example above `make_mlp_graphnet_functions` is a factory method. It conveniently returns a set of functions (keras models) that are in turn used to create GNNs.\n", "\n", "## Easy inspection of GN functions\n", "The `tf_gnns.GraphNet` can be pretty-printed in jupyter for easier inspection. Click on the buttons to see the contents of each GN function:" ] }, { "cell_type": "code", "execution_count": 14, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "AeUjRGf9j0c5", "outputId": "2b9c557f-95d2-4acb-a2f0-e6d85a58e878" }, "outputs": [ { "data": { "text/html": [ "

Graph Indep. GNN function (@140643679157648)

\n", "
\n", "
Model: \"node_fn/model\"\n",
       "_________________________________________________________________\n",
       "Layer (type)                 Output Shape              Param #   \n",
       "=================================================================\n",
       "node_state (InputLayer)      [(None, 2)]               0         \n",
       "_________________________________________________________________\n",
       "node_fn/dense_0 (Dense)      (None, 50)                100       \n",
       "_________________________________________________________________\n",
       "node_fn/dense_1 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "node_fn/dense_2 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "node_fn/dense_3 (Dense)      (None, 10)                510       \n",
       "=================================================================\n",
       "Total params: 5,710\n",
       "Trainable params: 5,710\n",
       "Non-trainable params: 0\n",
       "_________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"edge_fn/model\"\n",
       "_________________________________________________________________\n",
       "Layer (type)                 Output Shape              Param #   \n",
       "=================================================================\n",
       "edge_state (InputLayer)      [(None, 2)]               0         \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_0 (Dense)      (None, 50)                100       \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_1 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_2 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_3 (Dense)      (None, 10)                510       \n",
       "=================================================================\n",
       "Total params: 5,710\n",
       "Trainable params: 5,710\n",
       "Non-trainable params: 0\n",
       "_________________________________________________________________
\n", "
\n", "" ], "text/plain": [ "" ] }, "execution_count": 14, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gnns[0]" ] }, { "cell_type": "code", "execution_count": 15, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

GNN function (@140643679157792)

\n", "
\n", "
Model: \"node_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "node_state (InputLayer)         [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "edge_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_1 (Concatenate)     (None, 20)           0           node_state[0][0]                 \n",
       "                                                                 edge_state_agg[0][0]             \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_0 (Dense)         (None, 50)           1000        concatenate_1[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_1 (Dense)         (None, 50)           2550        node_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_2 (Dense)         (None, 50)           2550        node_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_3 (Dense)         (None, 10)           510         node_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 6,610\n",
       "Trainable params: 6,610\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"edge_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "edge_state (InputLayer)         [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "sender_node_state (InputLayer)  [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "receiver_node_state (InputLayer [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate (Concatenate)       (None, 30)           0           edge_state[0][0]                 \n",
       "                                                                 sender_node_state[0][0]          \n",
       "                                                                 receiver_node_state[0][0]        \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_0 (Dense)         (None, 50)           1500        concatenate[0][0]                \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_1 (Dense)         (None, 50)           2550        edge_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_2 (Dense)         (None, 50)           2550        edge_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_3 (Dense)         (None, 10)           510         edge_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 7,110\n",
       "Trainable params: 7,110\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "" ], "text/plain": [ "" ] }, "execution_count": 15, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gnns[1]" ] }, { "cell_type": "markdown", "metadata": { "id": "SXWq1OZFj2Ul" }, "source": [ "Each of the GraphNets can be easilly saved using the `.save` method. You can invoke the summary method of any of these GNNs to inspect the inputs and outputs of the GN functions:" ] }, { "cell_type": "markdown", "metadata": { "id": "h9nNnRAzkcRh" }, "source": [ "The overloaded `__add__()` operator shown above, comes in handy when computing residual connections:" ] }, { "cell_type": "code", "execution_count": 16, "metadata": { "id": "nzBrWl9ZmMQ_" }, "outputs": [], "source": [ "CORE_STEPS = 2\n", "def eval_full_shared_resid_core(G : GraphTuple, core_steps =CORE_STEPS):\n", " ## The actual computation\n", " G = gnns[0].graph_tuple_eval(G) # The \"encode\" GraphNet block \n", "\n", " for ncore in range(0,core_steps):\n", " # gnns[1] is the \"core\" GraphNet block\n", " G += gnns[1].graph_tuple_eval(G) # Overloaded sum operator - assign-add implements the residual connection (no projection needed - same size)\n", " \n", " G = gnns[-1].graph_tuple_eval(G)\n", "\n", " return G" ] }, { "cell_type": "code", "execution_count": 17, "metadata": { "id": "Hj3JArLQmRHa" }, "outputs": [], "source": [ "gt_eval = eval_full_shared_resid_core(gt.copy())" ] }, { "cell_type": "code", "execution_count": 18, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "t5tZrJUbnyJV", "outputId": "50004ca9-2dc4-4171-e96a-8a030e6f908f" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "Output GraphTuple:\n", " nodes shape: [25,2]\n", " edges shape: [23,2]\n" ] } ], "source": [ "print(\"Output GraphTuple:\")\n", "edges_shape, nodes_shape = gt_eval.edges.shape.as_list(),gt_eval.nodes.shape.as_list()\n", "print(\" nodes shape: [%i,%i]\"%(nodes_shape[0], nodes_shape[1]))\n", "print(\" edges shape: [%i,%i]\"%(edges_shape[0], edges_shape[1]))\n" ] }, { "cell_type": "markdown", "metadata": { "id": "KnIiL_-bo8_Y" }, "source": [ "## `Global` blocks\n", "\n", "In what follows the manual creation of a Global block is shown. \n", "\n" ] }, { "cell_type": "code", "execution_count": 19, "metadata": { "id": "E2XSCY_5o_MR" }, "outputs": [], "source": [ "from tf_gnns.graphnet_utils import make_keras_simple_agg\n", "from tensorflow.keras.layers import Input, Dense\n", "from tensorflow.keras import Model\n", "\n", "GN_STATE = 2\n", "GLOB_STATE_OUT = 1\n", "def make_graph_tuple_to_global(insize = GN_STATE,output_size = GLOB_STATE_OUT, agg_type = 'mean'):\n", " \n", " agg_fcn = make_keras_simple_agg(insize,agg_type) # from ibk_gnns import make_keras_simple_agg\n", " agg_fcn = agg_fcn[1] # The graph tuple version of the aggregator\n", " \n", " # Constructing the node+edge -> global function. \n", " xx = Input(shape = (insize,))\n", " out = Dense(insize, 'tanh')(xx)\n", " out = Dense(insize, 'tanh')(out)\n", " out = Dense(output_size, activation = None, use_bias = False)(out)\n", " global_fcn = Model(inputs = xx, outputs = out)\n", " \n", " def fcn_node_and_edge(gt_):\n", " # Aggregate the node and edge info:\n", "\n", " # First, retrieve the segment IDs (as done on the full GN step with no global aggregation):\n", " graph_indices_nodes = []\n", " for k_,k in enumerate(gt_.n_nodes):\n", " graph_indices_nodes.extend(np.ones(k).astype(\"int\")*k_)\n", "\n", " graph_indices_edges = []\n", " for k_,k in enumerate(gt_.n_edges):\n", " graph_indices_edges.extend(np.ones(k).astype(\"int\")*k_)\n", " \n", " o1 = agg_fcn(gt_.nodes,graph_indices_nodes, gt_.n_graphs) # node_to_global aggregation\n", " o2 = agg_fcn(gt_.edges,graph_indices_edges, gt_.n_graphs) # edge_to_global aggregation\n", " return global_fcn(o1+o2) # either concat or add the aggregated information.\n", "\n", " return fcn_node_and_edge, global_fcn #(the global fcn is also returned to later retrieve its weights)" ] }, { "cell_type": "code", "execution_count": 20, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "dquuy1kfs7s_", "outputId": "fbe3a890-c5df-496b-ae3b-3872a9a36242" }, "outputs": [], "source": [ "# Evaluation example:\n", "# gt_eval = eval_full_shared_resid_core(gt.copy())\n", "graph_tuple_to_global, global_fcn = make_graph_tuple_to_global(insize = GN_STATE,agg_type =\"mean\")" ] }, { "cell_type": "code", "execution_count": 21, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "ihf0xZ8ktaOT", "outputId": "5929fd19-c1f6-49c5-8baa-863ff826bfe2" }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "2 graphs, 1 dimensional output:\n" ] }, { "data": { "text/plain": [ "" ] }, "execution_count": 21, "metadata": {}, "output_type": "execute_result" } ], "source": [ "res = graph_tuple_to_global(gt_eval)\n", "print(\"%i graphs, %i dimensional output:\"%(gt.n_graphs, GLOB_STATE_OUT))\n", "res" ] }, { "cell_type": "markdown", "metadata": { "id": "7V3rMYVfuGYr" }, "source": [ "## All weights of a GN:\n", "Gathering the weights (trainable parameters) of `GraphNet` blocks in the `tf_gnns` is easy. Here is how to get all the parameters of the above computation:" ] }, { "cell_type": "code", "execution_count": 22, "metadata": { "id": "QwIBwzbZuzku" }, "outputs": [], "source": [ "all_weights=[]\n", "for g in gnns: # the encode/core/decode layers:\n", " ww = g.weights\n", " all_weights.extend(ww)\n", "\n", "all_weights.extend(global_fcn.weights) # The final \"graph to global\" layer.\n" ] }, { "cell_type": "markdown", "metadata": { "id": "S9xwRC0Wu4mS" }, "source": [ "# Taking gradients of the whole operation:\n", "The whole operation above is differentiable. Here is how to simply take gradients of the above with `tf.GradientTape()`:" ] }, { "cell_type": "code", "execution_count": 23, "metadata": { "id": "P5rm4kfa1VIb" }, "outputs": [], "source": [ "import tensorflow as tf\n", "opt = tf.keras.optimizers.Adam(learning_rate = 0.001)" ] }, { "cell_type": "code", "execution_count": 24, "metadata": { "id": "g-lKQHJwybRb" }, "outputs": [], "source": [ "\n", "# Just some random numbers to work with:\n", "Y = np.random.randn(gt.n_graphs,GLOB_STATE_OUT)\n", "def full_eval_with_global(G):\n", " G = eval_full_shared_resid_core(G)\n", " return graph_tuple_to_global(G)\n", "\n", "def loss(Ypred, Yactual):\n", " return tf.reduce_mean(tf.square(Ypred - Yactual))\n", "\n", "# The following also gives us direct access to the gradients which may be \n", "# useful in some contexts:\n", "with tf.GradientTape() as tape:\n", " Yhat = full_eval_with_global(gt.copy())\n", " loss_current = loss(Yhat, Y)\n", " grads = tape.gradient(loss_current,all_weights)\n", " opt.apply_gradients(zip(grads,all_weights))\n" ] }, { "cell_type": "code", "execution_count": 25, "metadata": { "colab": { "base_uri": "https://localhost:8080/" }, "id": "I_1K5jH2D3m3", "outputId": "9f67dea7-fb7f-4c4c-8e55-09e344b75676" }, "outputs": [ { "data": { "text/plain": [ "array([ 0, 14, 28, 42, 47])" ] }, "execution_count": 25, "metadata": {}, "output_type": "execute_result" } ], "source": [ "nweights_per_block = [len(g.weights) for g in gnns]\n", "nweights_per_block.append(len(global_fcn.weights))\n", "cumsum_nweights_per_block = np.cumsum([0,*nweights_per_block])\n", "cumsum_nweights_per_block" ] }, { "cell_type": "markdown", "metadata": { "id": "j80LRib_0mVY" }, "source": [ "A plot of the gradient histograms from different blocks:" ] }, { "cell_type": "code", "execution_count": 27, "metadata": { "colab": { "base_uri": "https://localhost:8080/", "height": 267 }, "id": "_EhF0SfF4kSX", "outputId": "cc86efeb-7101-49ef-ac08-60f02d73a497" }, "outputs": [ { "name": "stderr", "output_type": "stream", "text": [ ":8: MatplotlibDeprecationWarning: Adding an axes using the same arguments as a previous axes currently reuses the earlier instance. In a future version, a new instance will always be created and returned. Meanwhile, this warning can be suppressed, and the future behavior ensured, by passing a unique label to each axes instance.\n", " plt.subplot(1,4,1+k)\n" ] }, { "data": { "image/png": "iVBORw0KGgoAAAANSUhEUgAAAlwAAADGCAYAAAAHfl68AAAAOXRFWHRTb2Z0d2FyZQBNYXRwbG90bGliIHZlcnNpb24zLjMuMiwgaHR0cHM6Ly9tYXRwbG90bGliLm9yZy8vihELAAAACXBIWXMAAAsTAAALEwEAmpwYAAArOUlEQVR4nO3de7xVdZ3/8ddbUDEvKSAnBJMs8qeRY8pPazI6pZU5DVreYCwlaNDSKe0yYoXXbKyhssZMdCTtImhZI83PNI1OZuU9uoCZhJoigoAiqIHo5/fH+h5YZ7P3ue591t77vJ+Px4G9rvuz1v6ctT9nre9aX0UEZmZmZlY72xQdgJmZmVmzc8FlZmZmVmMuuMzMzMxqzAWXmZmZWY254DIzMzOrMRdcZmZmZjXmgqtGJD0i6fCi46gVSVMk3VF0HFafJLVKerzoOOqNpKslfaGf39PHogGiJ793fclFSedJ+l5vlq0XReSNC64m4i8564ykMZJC0uCiY7Hm5mOR9Uaz540LLusxf2FXTzPuy2bcJqtPzjXrjaLypikLLkl7SLpB0lOSHpb08dy08yRdL+k7ktZJWiRpfG76npJ+lJZdLenSNH4bSZ+X9KiklWn5V+aW+1CatlrS50ri2UbSDEl/TdOvlzS0Quytkh6X9Kn0PsslfTg3fXtJsyT9TdIKSZdL2kHSjsBPgT0krU8/e0h6QdLwtOznJG2StEsavlDSJen1K9M2PZW24/OStknTpkj6taSvSVoNnFcm7v+UdEd+nzS73uRK7izTNEl/Axak8VMlPSDpaUm3SNqrwnu2L39yyoFV+XzrItduT/8/k/LjLSnGg9KyJ6Z1vyENT5P0P+n19pIukfRE+rlE0vZpWnvOniXpSeDbZeL+uKTFkkb3ecc3EElvknR/OtZcBwwpmf4+SQslPSPpN5L2z03zscjHoq1IOlDS71JO/UDSdapwaVDSvpLaUn4tkjSxZJbhkm5N6/pl/rgj6euSHpP0rKT7JL2tm/E5bypouoIr7eCfAL8HRgGHAWdIek9utonAPGBXYD7QfiAbBPwv8CgwJi0/Ly0zJf28A9gb2Cm33H7At4APAXsAw4D8F8u/AUcDb0/Tnwa+2clmvAp4ZXr/acA3Je2Wpl0MvB44AHhdmueciHgOeC/wRETslH6eAO5J70v6/1HgrbnhX6bX/5Xec+80/iRg8y8JcAiwFGgBLmofmQ7gVwL7A++OiLWdbFfT6G2u5Lwd2Bd4j6SjgM8CHwB2B34FzO0ihEOBfcjy+xxJ+6bxneXahPT/rik/fkv2+bfmYlqamy+fH58D3kyWd/8AHAx8PhfPq4ChwF7A9Hygks5J++PtEdG0lwtKSdoO+B/gu2T75gfAMbnpbwLmAKeQHTNmA/PTF5KPRT4WbSXl1I+Bq8lyai7w/grzbkv2XfgzYATZZ/99SfvkZjsRuBAYDiwEvp+bdg/ZZzsUuBb4gaQOfzB0wnlTTkQ01U/aqX8rGXc28O30+jzgtty0/YAX0uu3AE8Bg8us9+fAx3LD+wAvAoOBc4B5uWk7AhuBw9PwA8Bhuekj25ct8z6twAv5acBKsi87Ac8Br81NewvwcG7Zx0vWdyHwjRTnk8AnyBJ+SHqfYcCgFO9+ueVOAdrS6yll9ukU4C7gOuAGYLuiP/t+zrPe5soYIIC9c9N/CkzLDW8DPA/sVWbd7cuPzo27G5jUVa7lls3n1jRgfm7Zj7TnMtmB7cD0+q/Akbnl3gM8ksu7jcCQkjxeBnwVuAN4ZdGfWQE5MgF4AlBu3G+AL6TX3wIuLFnmQbIvCx+LtiznY1HHnFpWklN35HJq834H3pb28za5eecC56XXV5fkyk7AS8CeFd77aeAf0uvzgO9VmM95U+GnGa9/70V2SvKZ3LhBZGcN2j2Ze/08METZNd09gUcjYlOZ9e5B9gXU7lGyBGhJ0x5rnxARz6XTlvmYfizp5dy4l9Kyy8q81+qSGJ4n+2XYHXgFcJ+k9mlK21fJL8m+9A4E/gjcClxFlvxLImK1pBZg2zLbNyo3/Bhbex3pbEdEbOwkhmbU21xpl9+fewFfl/SV3DiR7f/8evJKc3in3Loq5Vo5vwRmSRpJlkfXA+dKGkP21+LCTrZpj9zwUxHx95J170p2tuuEaMKzDd2wB7As0tE9ye/DvYCTJf1bbtx2abmX8LEov30+FmXK5VS5/dE+72MRkf+sK+7LiFgvaU37cpI+TfYH2R5kf6jtQnYmrDucN2U03SVFsp36cETsmvvZOSKO7Oayr1b5BnVPkB2s2r0a2ASsAJaTfQEDIOkVZFV3fr3vLYlpSESUO8B1ZhVZRf+G3HpeGRHtX7ZRZpnfkP0F/H7glxGxOMV+JFtOxa4i+yu3dPvy8ZVb9wNkp2x/WnKaeiDoba60Kz1gnlKSHztExG96GVelXNvqM4yIJWQHw38Dbo+IZ8mKuenAHbmDdblteqLC9rR7Gngf8G1Jby0zvdktB0Yp981Ctt/aPQZcVPJZvSIi5uJjkY9F5ZXLqT0rzPsEsGd7O6akdF/mc2UnssuHT6T2Wv8OHA/sFhG7AmvJiqO+GNB504wF193AOmUNeHeQNEjSOEn/t5vLLgculrSjpCG5L4q5wJmSXpMS84vAdamK/yHwPkmHpmvsF9Bx314OXNTeIFHS7qndTo+kL78rga9JGpHWNSrXPm0FMCzf6C8ingfuA05jS3L+Bji1fTgiXiI7s3GRpJ1TnJ8EunzOSvpy+Cxwm6TX9nSbGlhvc6Wcy4GztaWx+islHdfLuDrLtaeAl8naOOT9EjidLfnRVjLcvk2fT+sbTnbpqjv50UbWTuRHkg7uzQY1sN+SFUIfl7StpA+QtX1rdyVwqqRDlNlR0j9J2hkfi3wsKu+3ZGckT5c0OH12lX6v7iL7Y+rfU/61Av/MlraAAEfmcuVC4M6IeAzYmSx3nwIGp3aYu/Q1+IGeN01XcKUd/z6yBnkPk1W+/012eaQ7y/4z2WnGvwGPAyekyXPIGr/entb7d7KzAkTEIrJkuJbsIPl0Wrbd18ka5/9M0jrgTrK2Zr1xFrAEuFPSs8BtZNU/EfFnsoPxUmV3pbRf8vkl2enWu3PDO7PlrjXStjxH1qjwjrQtc7oTUERcQ3ZgX5AuRTW93uZKhXX9GPgSMC99pn8iazzaGxVzLR24LgJ+nfLjzWmZ0nwolx9fAO4F/kB2Wv/+NK5LEXErMBX4iaQDe7ldDSddovgAWVuRNWT58aPc9HuBfyVr8P402e/1lDTNxyIfi7aSy6lpwDPAB8lurthQYd5/JjuWrAIuA05Kn027a4FzyfLzoLQ+gFuAm4G/kF2a+zuVL1321IDNG3W8FGxmZmaNQtJdwOUR8e2iY7HONd0ZLjMzs2Yl6e2SXpUuKZ5M9jiDm4uOy7rWjHcpmpmZNat9yNor7Uh2+ezYiFhebEjWHb6kaGZmZlZjvqRoZmZmVmMuuHpJ0iOSDu/rvJLeJunB6kZnjcJ5ZH3lHLJqcB7VXlMWXJImSbpL0nPKOs+8S9LHJPX1oW1VFxG/ioguH7imrNPtLp8pYtXjPLK+cg5ZNTiPmkPTFVySPkX2rJn/JOtAs4XsAWlvJes2o9wynXUr0PRU/mnWA5rzqOecRx05h3rOObQ151HP1W0elXau2Mg/ZA83fQ44pov5ribrOPamNP/hwD8BvwOeJXvA23kly3yI7AFwq4HPAY+QOoTtRlyPAJ8me2jkWrLOModEmc46yR4KtwxYR9aR7WHAEWQdc74IrAd+n+bdg+whhmvIHiT3r7n17ABcQ/bgwwfIuml4vCSms1JMG8juWJ1B1knxOmAx8P7c/FOAXwNfI3vg3lLgH9P4x8g6Jz256BxwHjmP6uHHOeQcch45j7bab0UnVJWT8wiy7ggGdzHf1SlJ3kp2lm9ISpI3puH9yboYODrNv19KignA9mQdaW7qYXLenZJpaEqWU0uTk+x238eAPdLwGFKv6pTpnZ3sKbuXpfgPIOuG4Z1p2sVkT+PdDRidkrA0OReS9aW1Qxp3XIpxG7KnWj8HjMwl5yayfqcGkT1l/G/AN9M+eXdK6p2KzgPnkfOo6B/nkHPIeeQ82mq/FZ1QVU7ODwJPloz7DVn1+gIwIZec3+liXZcAX0uvzwHm5abtSFad9yQ5P5gb/jLZk4FLk/N1ZFX14cC2JevokJwpqV4Cds6N+w/g6vR6KfCe3LSPlEnOqV3EvRA4KpecD+WmvZGsM9CW3LjVwAFF54HzyHlU9I9zyDnkPHIelf40Wxuu1cDw/PXbiPjHyHo6X03HNmsd+oVS1oHsLyQ9JWkt2TXy4WnyHvn5I+K5tL6eeDL3+nlgp9IZImIJcAZZIq6UNC/Xl1SpPYA1EbEuN+5RYFS5mCnfD1bpPjhJ0sLUh9UzwDi27API/kJq90KKuXTcVtvVgJxHzqO+cg45h6rBedREedRsBddvya7dHtWNeaNk+Fqya8d7RsQrgcuB9jtAlpNV3wBIegUwrM/Rlgsq4tqIOBTYK8X4pQrxPgEMlbRzbtyrya6Vt8c8OjdtT7a2eZ3Kele/EjgdGJZ+of/Eln0wkDiPnEd95RxyDlWD86iJ8qipCq6IeAY4H7hM0rGSdpa0jaQDyE6ZdmZnsur675IOBv4lN+2HwPskHSppO7Jexau+7yTtI+mdkrYn6539BeDlNHkFMEbSNgAR8RjZqeX/kDRE0v5kPci332Z7PXC2pN0kjSJLus7sSJasT6VYPkz218CA4zxyHvWVc8g5VA3Oo+bKo6YquAAi4svAJ8nuYFiRfmaT3b3wm04W/RhwgaR1ZNe3r8+tcxFwGtlfDMvJ7pJ4vH26pBMlLapC+NuTNQxcRXa6dgRwdpr2g/T/akn3p9eTyRohPgH8GDg3Im5L0y5IMT4M3Eb2C7ah0htHxGLgK2R/Ua0gu5796ypsU0NyHjmP+so55ByqBudR8+SR+1IcICR9FJgUEW8vOhZrXM4j6yvnkFVDI+ZR053hsoykkZLemk4/7wN8iuwvBrNucx5ZXzmHrBqaIY/q82msVg3bkZ12fg3ZLcTzyJ5vYtYTziPrK+eQVUPD55EvKZqZmZnVmC8pmpmZmdWYCy4zMzOzGqvrNlzDhw+PMWPG1GTdzz33HDvu2NVjTBpbo2zjfffdtyoidq/V+quRR42yL8tp1NhL4964cSMPP/wwmzZtAmD48OG0tLSwadMmfv/7379I9lTqR4DjI+JpSQK+DhxJ9iTsKRFxP4Ckk4HPp1V/ISKu6SqeWh6PStXDZzYQY6iXY9FA3Pf1HEdPYug0h6rVR1Atfg466KColV/84hc1W3e9aJRtBO6NOs+jRtmX5TRq7KVxP/HEE3HfffdFRMSzzz4bY8eOjUWLFsVnPvOZYEvfbTOAL6XXRwI/JXuy9JuBu9L4oWT9sg0l6wh3KbBbFHg8KlUPn9lAjKFejkUDcd9XUg9x9CSGznLIlxTNrCGMHDmSAw88EICdd96Zfffdl2XLlnHjjTfCln7grgGOTq+PIuvQNyLiTmBXSSOB9wC3RsSaiHgauBU4oh83xcwGIBdcZtZwHnnkEX73u99xyCGHsGLFCoAX06QngZb0ehQdO7N9PI2rNN4MAEl7Kuv4ebGkRZI+UWYeSfqGpCWS/iDpwCJitcZR1224zMxKrV+/nmOOOYZLLrmEXXbZpcO0iAhJVXvWjaTpwHSAlpYW2traqrXqTq1fv77f3ssxlLUJ+FRE3J86U75P0q2RdRfT7r3A2PRzCPCt9L9ZWS64zKxhvPjiixxzzDGceOKJfOADHwCyQmjt2rXbQvY0amBlmn0ZsGdu8dFp3DKgtWR8W7n3i4grgCsAxo8fH62treVmq7q2tjb6670cw9YiYjlZH4NExDpJD5CdBc0XXJsvWQN3StpV0si0rNlWXHA1qYfeNgEuvKDoMKxos/aB8bOLjqIqIoJp06ax77778slPfnLz+IkTJzJr1qxhafBk4Mb0ej5wuqR5ZGce1kbEckm3AF+UtFua791s6VDXrANJY4A3AXeVTKp0abpDwdWbs6RPPbOqw3x/efovvH631/c49r6ohzOc9RJHtWJwwWVmDeHXv/413/3ud3njG9/IAQccAMAXv/hFZsyYwaxZs3aR9BDZoyGOT4vcRHan4hKyx0J8GCAi1ki6ELgnzXdBRKzpx02xBiFpJ+AG4IyIeLY36+jNWdLLbpjNca3Hbh6+4PoLWNC6oDdv32v1cIazXuKoVgwuuMysIRx66KFE5a7I/hIR4/Mj0qWe08rNHBFzgDnVjdCaiaRtyYqt70fEj8rMUumStVlZvkvRrIkdvObcokMwazjpoblXAQ9ExFcrzDYfOCndrfhm0iXrfgvSGo4LLusXU6dOZcSIEYwbN27zuPPOO49Ro0YB7CdpoaQj26dJOjvdbv2gpPfkxh+Rxi2RNKNfN6IBbXq+6AjMGtJbgQ8B70zHpoWSjpR0qqRT0zw3kT00dwlwJfCxgmK1BuGCy/rFlClTuPnmm7caf+aZZwIsjogDIuImAEn7AZOAN5A9kPIySYMkDQK+SXY79n7A5DSvmVnVRMQdEaGI2D8dmw6IiJsi4vKIuDzNExFxWkS8NiLeGBH3Fh231TcXXNYvJkyYwNChQ7s7+1HAvIjYEBEPk/0FeXD6WRIRSyNiIzAvzWtmZlbX3GjeCnXppZdCdklxDtmDBp8mu7X6ztxs+SeBl96GXfZBg9V+YGU93JrcG/960JCGjb1R424WGx58EOrgLjWzZuGCywrz0Y9+lJkzZzJ48ODFZM+u+QowtRrrrvYDK+vh1uTe+OQZ1/LVo4c2ZOyNus/NzMrp8pKipDmSVkr6U27ceZKW5RsT5qa5sbN1S0tLC4MGDWofvJLskiF0/oRw34ZtZmYNpzttuK4ma7hc6mv5xoTgxs7WM8uXd7iD+v1Ae1E/H5gkaXtJryHrq+xusgdVjpX0GknbkeXa/H4M2czMrFe6vKQYEbenrg26Y3NjZ+BhSe2NnSE1dgZIXW0cRcd+qayJTZ48mba2NlatWsXo0aM5//zzaWtrY+HChZAV4e8ATgGIiEWSrifLj03AaRHxEoCk04FbgEHAnIhYVMDmmJmZ9Uhf2nCdLukk4F6q2NjZmtPcuXO3Gjdt2jQAJC2OiIn5aRFxEXBR6TLpbOpNtYnSzMysNnpbcH0LuBCI9H/VGjtX++6ySpr9DqgNHzyRF5t8G83MzBpFrwquiFjR/lrSlcD/psHOGjV3q7Fzte8uq6TZ74B6aOY5LLvwgqbeRjMzs0bRqwefShqZG3RjZzMzM7NOdHmGS9JcoBUYLulx4FygVdIBZJcUH8GNnc3MzMwq6s5dipPLjL6qk/nd2NnMzMwsx30pmpmZmdWYC64m9ucn13HwRbcVHYaZmdmA54LLzMzMrMZccJmZmZnVmAsuMzMzsxpzwWVmZmZWYy64zMzMzGrMBZeZmZlZjbngMjMzM6sxF1xm1hCmTp3KiBEjGDdu3OZx5513HqNGjQLYT9JCSUe2T5N0tqQlkh6U9J7c+CPSuCWSZvTrRpjZgOWCy8wawpQpU7j55pu3Gn/mmWcCLI6IA1IXYkjaD5gEvAE4ArhM0iBJg4BvAu8F9gMmp3nNzGrKBZeZNYQJEyYwdOjQ7s5+FDAvIjZExMPAEuDg9LMkIpZGxEZgXprXbDNJcyStlPSnCtNbJa1NZ1UXSjqnv2O0xuOCy8wa2qWXXgrZJcU5knZLo0cBj+VmezyNqzTeLO9qsjOjnflVOqt6QERc0A8xWYMbXHQAZma99dGPfpSZM2cyePDgxcBy4CvA1GqtX9J0YDpAS0sLbW1t1Vp1p9avX99v71XJhmHDCo+hqP0QEbdLGtPvb2xNzQWXmTWslpaW/OCVwP+m18uAPXPTRqdxdDJ+KxFxBXAFwPjx46O1tbVvAXdTW1sb/fVeldwyezatxx5baAz1sB868RZJvweeAD4dEYuKDsjqmwsuM2tYy5cvZ+TIke2D7wfa29zMB66V9FVgD2AscDcgYKyk15AVWpOAf+nXoK0Z3A/sFRHr052x/0OWY1vpzVnS3Qd1PLs4adtJ/X6mrx7OstZLHNWKwQWXWRN7eeeiI6ieyZMn09bWxqpVqxg9ejTnn38+bW1tLFy4ELI7Dt8BnAIQEYskXQ8sBjYBp0XESwCSTgduAQYBc3xmwnoqIp7Nvb5J0mWShkfEqjLz9vgs6WU3zOa41i1nFy+4/gIWtC6oRujdVi9nF+shjmrF4ILLzBrC3Llztxo3bdo0ACQtjoiJ+WkRcRFwUeky6dERN9UmShsIJL0KWBERIelgshvQVhccltU5F1xmZmY5kuYCrcBwSY8D5wLbAkTE5cCxwEclbQJeACZFRBQUrjUIF1xmZmY5ETG5i+mXApf2UzjWJPwcLjMzM7Mac8FlZmZmVmMuuMzMzMxqzAWXmZmZWY254DIzMzOrMRdcZmZmZjXmgsv6xdSpUxkxYgTjxo3bPG7NmjW8613vAhgn6VZJuwEo8w1JSyT9QdKB7ctIOlnSQ+nn5H7fEDMzs15wwWX9YsqUKdx8880dxl188cUcdthhkPV/93NgRpr0XrJ+ycaS9UH2LQBJQ8keQHgIcDBwbnuRZmZmVs9ccFm/mDBhAkOHDu0w7sYbb+TkkzefpLoGODq9Pgr4TmTuBHaVNBJ4D3BrRKyJiKeBW4Ej+iN+MzOzvnDBZYVZsWIFI0eObB98EmhJr0cBj+VmfTyNqzTezMysrnXZtY+kOcD7gJURMS6NGwpcB4wBHgGOj4inJQn4OnAk8DwwJSLuT8ucDHw+rfYLEXFNdTfFGlnqBLZqfZFJmk52OZKWlhba2tr6tL7169f3eR1FmL7vkIaNvVHjNjMrpzt9KV5N1mfUd3LjZgA/j4iLJc1Iw2fRse3NIWRtbw7Jtb0ZDwRwn6T56bKQDVAtLS0sX74cgHTJcGWatAzYMzfr6DRuGVmHsvnxbeXWHRFXAFcAjB8/PlpbW8vN1m1tbW30dR1FOGPmtVxy2NCGjL1R97mZWTldXlKMiNuBNSWjjyJrcwNue2O9NHHiRK65ZvOJzpOBG9Pr+cBJ6W7FNwNrI2I5cAvwbkm7pcby707jzMzM6lp3znCV05K+AMFtb6wbJk+eTFtbG6tWrWL06NGcf/75zJgxg+OPPx5gHPAMcHya/Sayy9JLyC5NfxggItZIuhC4J813QUSU/jFgZmZWd3pbcG1W721vKmn29iEbPngiw7Z/mamvfaEutvOUU07hlFNO6TDuj3/8IzNnzmTBggV/iojD28dHRACnlVtPRMwB5tQ0WDMzsyrrbcG1QtLIiFhe721vKmn29iEPzTyHn5/2Geb8dQfuntRadDhmZmYDWm8fCzGfrM0NuO2NmZmZWae681iIuWRnp4ZLepzsbsOLgeslTQMexW1vzMzMzCrqsuCKiMkVJh1WZl63vTEzMzMr4SfNm5mZmdWYCy4zMzOzGnPBZWZmZlZjLrjMrGFMnTqVESNGMG7cuM3j1qxZAzBW0kOSbk13QpPulv6GpCWS/iDpwPZlJJ2c5n8o9fNqZlZTLrjMrGFMmTKFm2++ucO4iy++GGBdRIwFfk7Wtyt07Nt1OlnfruT6dj0EOBg4t71IMwOQNEfSSkl/qjC9YjFvVokLLjNrGBMmTGDo0KEdxt14440Aq9Og+3a1ariaznOibDFv1hkXXGbW0FasWAHwYhp0367WZxFxO9DZsyIrFfNmFfW5L0Uzs3rRqH27lqqHvl43DBtWeAz1sB8qqFS0Ly8mHGsELria2Msvbyo6BLOaa2lpYe3atdsCNGrfrqXqoa/XW2bPpvXYYwuNoR72Q191p2hfvWw9w0bttHl490Fbit3Vy9Yz6RWTNg9vePDBbr3v9vvs06e4a1bsrlgELW/o9vgiiu7Sfbxh2DBumT17q/l6uo9dcJlZQ5s4cSKzZs0algZL+3Y9XdI8sgbyayNiuaRbgC/mGsq/Gzi7X4O2RlepmN9Kd4r2b591B8eceOjm4ctumM1xrcdunjbvoK+yoHUBAA/NPKdbAY791e3dmq+SmhW7s06BE8oUjRXGF1F0l+7jpR88kb2/9/2t5uvpPnYbria17FM+s23NZ/LkybzlLW/hwQcfZPTo0Vx11VXMmDEDYBdJDwGHk/X1ClnfrkvJ+na9EvgYZH27Au19u96D+3a1npsPnJTuVnwzqZgvOiirbz7DZWYNY+7cuZUm/SUixudHuG9X6y1Jc8kuOw+X9DjZY0S2BYiIy8mK+SPJivnngQ8XE6k1EhdcZmZmORExuYvpFYt5s0p8SdHMzMysxlxwmZmZmdWYCy4zMzOzGnPBZWZmZlZjLrjMzMzMaswFV1OrWg8nZmZm1gcuuMzMzMxqzAWXmZmZWY254DIzMzOrMRdcZmZmZjXmgsus2a1YVHQEZmYDnguuJnXGfecXHYLViT9vGll0CGZmA54LLjMzM7Mac8FlZmZmVmMuuMzMzMxqzAWX1YM3SvqjpIWS7gWQNFTSrZIeSv/vlsZL0jckLZH0B0kHFhu6mZlZ11xwWb14R0QcEBHj0/AM4OcRMRb4eRoGeC8wNv1MB77V75GamZn1UJ8KLkmP+MyE1chRwDXp9TXA0bnx34nMncCuknwbnpmZ1bVqnOHymQmrhp9Juk/S9DTcEhHL0+sngZb0ehTwWG65x9M4MzOzujW4Bus8CmhNr68B2oCzyJ2ZAO6UtKukkbkvVRu4/hwRB0oaAdwq6c/5iRERkqInK0yF23SAlpYW2tra+hTg+vXr+7yOIkzfdwjDhmzTkLE36j43MyunrwVXkJ2ZCGB2RFxBz89MuOCyFwEiYqWkHwMHAyvaC/J0yXBlmncZsGdu2dFpXAcpF68AGD9+fLS2tvYpwLa2Nvq6jiKcMfNaTv0/QziuAWNv1H1uZlZOXwuuQyNiWT2fmaik2f96nr7vEHYfsg1TX/tCXW/nCy+8AOnStqQdgXcDFwDzgZOBi9P/N6ZF5gOnS5oHHAKs9VlSMzOrd30quCJiWfq/bs9MVNLsfz2fMfNapu87hKsf3om7J7UWHU5FS5cuBfg/kn5Plo/XRsTNku4Brpc0DXgUOD4tchNwJLAEeB74cP9H3Vhe8s3IZmaF63XBlc5GbBMR63xmwnpr7733Blicu+kCgIhYDRxWOn9qA3ha/0RnjUTSI8A64CVgU0SMlzQUuA4YAzwCHB8RT0sS8HWy4v15YEpE3F9E3GY2MPTlT98W4I50ZuJu4P9FxM1khda7JD0EHJ6GITszsZTszMSVwMf68N5mZuX4rmmrCklHSHowPcpoRpnpUyQ9lR6LtFDSR4qI0xpHr89wRcRS4B/KjPeZCbM6EVLRIRTNd01bj0kaBHwTeBfZDV73SJofEYtLZr0uIk7v9wCtIblxh5k1i/a7pv08N+urg4ElEbE0IjYC88iKdLNeq8VzuMzMitCwd02Xqoe7qDcMG1Z4DAXuh3IF+SFl5jtG0gTgL8CZEfFY6QzdyaFd9u+4nbsP2rLvd9l/PZO2nbR5eMMHT+zWBiyr12cP7vlxKLfeCuOLyIHSfbxh2DCWltnvPd3HLrjMrCk08l3TperhLupbZs+m9dhjC42hHvZDJ34CzI2IDZJOIbtk/c7SmbqTQ98+6w6OOfHQzcOX3TCb41qP3Txt3kFfZUHrAgAemnlOt4Ib+6vbe7Y1JWq272edAic82O3xReRA6T5e+sET2ft7399qvp7uY19SNLNmsI2knaHD89z+xJa7pmHru6ZPSn28vhnfNW0ddVmQR8TqiNiQBv8bOKifYrMG5TNcZtYMBpPdNd3+2s9zs764Bxgr6TVkhdYk4F/yM5TcZDEReKB/Q7RG44LLzJrBxtJnuYHvmrbeiYhNkk4HbgEGAXMiYpGkC4B7I2I+8HFJE4FNwBpgSmEBW0NwwWVmZlYiIm4iOxOaH3dO7vXZwNn9HZc1LrfhMjMzM6sxF1xNbuPGVUWHYGZmNuC54DIzMzOrMRdcTejQ7+3PS2P+q+gwzKxBPfS2CUWHYNZ0XHCZmZmZ1ZgLLjMzM7Mac8Fl1sRe3usbRYdgDeip9Rvg5U1Fh2HWVFxwmZnZVl4c2aN+vs2sCy64zMzMzGrMBZeZmZlZjbngMjMzM6sxF1xmZmZmNeaCy8zMzKzGXHCZNamDL7qt6BCsgf3t+T2KDsGsqbjgGgD8xTuwSS8XHYKZ2YDngqvJHHzRbazdsGPRYZhZAxsWa4sOwazpuOAyMzMzqzEXXGbNbruVRUdgDeaJc18sOgSzpuOCy2wAcDs+64kz77+w6BDMmo4LLjMzK+uht00oOgSzpuGCq8ls3Liq44iXX2bj331JaSBatX5D0SGYmVnigsvMzMysxgYXHYBVxzuvfycAL736hS0jt1vJS2OuAeDgi0YAcPfnDu/32KwYO772C0WHYGZmSb+f4ZJ0hKQHJS2RNKO/398an3Oom7Tl5bp167j81JOKi6UOOY/Kc7utTFf5IWl7Sdel6XdJGlNAmNZA+rXgkjQI+CbwXmA/YLKk/fozhmb11PoNPDUA2uw4h7qn9K7EwWO/XFAk9cl5VNlPR+zIS2P+Kxt4buWALMC6mR/TgKcj4nXA14Av9W+U1mj6+5LiwcCSiFgKIGkecBSwuJ/jaArtlxFZv5JhEZ3O+8KrZqZlvsiC4xfUOrRacg51wwsjz4WSlLh0l4mcWkw49ch5VGrWPmxavYoTDtyFH/EKBm+3nF2O2chPln6KO866A4APf+nQgoPsN93Jj6OA89LrHwKXSlJEFwdjG7D6u+AaBTyWG34cOKSfY2hYs2bN4vqh1/PctttlIzZmfeTtXvrNmgzmJYbxbIdxL697ltY5b9g8fN3y55G2XHsa8dmlVY666gZUDn07fdG12+oLb9Y+rPr76s2Dx41sIYBhuVk258E+57Lios90WDz/2QOM2G7bngX46Qd7Nn/9GFB51MGsfTa/XLnxRXg5iPY02GkXTnjVKzZPP+FVO3AdZ2wefvKLMPjldQAc//r9ATjyjk8AMGyHfNY1fHHWnfzYPE9EbJK0luxXr+RWcbNM3TWalzQdmJ4G10uq1RF9OE3+i/FLTu9yG19VOuJzKjdbre1V7RXWII/qIl+mdnllcPVWY/J5sNXn3VefqWm+9GSfVz2HoF+PR6UKz7f2vKmcM79I//6i7NSuc7Vb+ns/FHYsKtlfw0/j1A7brRN6+LumPv9u1m7fVzpulB9f+O8Cd/yqfAzl93HFHOrvgmsZsGdueHQat1lEXAFcUetAJN0bEeNr/T5FatJt7DKHoPp51Mj7slFjr3HcheRRd9XDZzbAY+hOfrTP87ikwcArKfMXT29yaIDv+7qLo1ox9PddivcAYyW9RtJ2wCRgfj/HYI3NOWTV4DyyznQnP+YDJ6fXxwIL3H7LOtOvZ7jSde7TgVuAQcCciFjUnzFYY3MOWTU4j6wzlfJD0gXAvRExH7gK+K6kJcAasqLMrKJ+b8MVETcBN/X3+5bR75cJCtCU21hQDjXyvmzU2Gsadx0di8qph89sQMdQLj8i4pzc678Dx9Xo7Qf0vi9RD3FUJQb5DKiZmZlZbbkvRTMzM7MaGzAFl6Shkm6V9FD6f7cK870kaWH6aYhGtO6Con9I+k9Jf5b0B0k/lrRr0TF1plG7rpG0p6RfSFosaZGkTxQdUxGKzLcic2egf/5F/95KmiNppaQ/9fd752KoixyQNETS3ZJ+n+I4v0/rGyiXFCV9GVgTERenJN4tIs4qM9/6iNip/yPsndQFxV+Ad5E9nO8eYHJELM7N8zFg/4g4VdIk4P0RcUIhATcwSe8muxNpk6QvAZTLoXrQnbyoV5JGAiMj4n5JOwP3AUc3QuzVVFS+FZ07A/nzL3rfpxgmAOuB70TEuP5635IY6iIHlD0ZeseIWC9pW+AO4BMRcWdv1jdgznCRdcNwTXp9DXB0caFU1eYuKCJiI9DeBUVeftt/CByWEsl6ICJ+FhGb0uCdZM/mqVfdyYu6FBHLI+L+9Hod8ADZU70HlALzrdDcGeCff+G/txFxO9ldl4WplxyIzPo0uG366fVZqoFUcLVExPL0+kmgpcJ8QyTdK+lOSUf3T2h9Uq4LitLE7NAFBdDeBYX13lTgp0UH0Ynu5EXdS5e/3wTcVXAoRevPfKub3BmAn3/d7Pt6UXQOSBokaSGwErg1InodR9117dMXkm6jfO8ln8sPRERIqlSl7hURyyTtDSyQ9MeI+Gu1Y7X61FkORcSNaZ7PAZuA7/dnbAONpJ2AG4AzIuLZruZvRM63ygbC52+dq4cciIiXgANSG8ofSxoXEb1q39ZUBVdEHF5pmqQVkkZGxPJ0fXhlhXUsS/8vldRGVlnXc8FVtS4orPMcApA0BXgfcFidP1W6W13X1KvUXuIG4PsR8aOi46mVOs23wnNnoHz+ZRS+7+tFveVARDwj6RfAEUCvCq6BdEkx3w3DycCNpTNI2k3S9un1cOCtQL031HQXFP1E0hHAvwMTI+L5ouPpQsN2XZPaF14FPBARXy06nqIUmG+F5s4A//wb9ve2muolByTt3n53sKQdyG5m+HOv1zdQvnclDQOuB14NPAocHxFrJI0HTo2Ij0j6R2A28DJZMXpJRFxVWNDdJOlI4BK2dEFxkXJdUEgaAnyX7GzdGmBSRCwtLOAGpawLj+3Zcnbwzog4tcCQOlUuL4qNqHskHQr8Cvgj2e8iwGfTk78HjCLzrcjcGeiff9G/t5LmAq3AcGAFcG5/fw/WSw5I2p/shrNBZDXB9RFxQa/XN1AKLjMzM7OiDKRLimZmZmaFcMFlZmZmVmMuuMzMzMxqzAWXmZmZWY254DIzMzOrMRdcZmZmZjXmgsvMzMysxlxwmZmZmdXY/wcE4RgHdjG1/gAAAABJRU5ErkJggg==", "text/plain": [ "
" ] }, "metadata": { "needs_background": "light" }, "output_type": "display_data" } ], "source": [ "import matplotlib.pyplot as plt\n", "\n", "sc = 0.5\n", "plt.figure(figsize = (20*sc,5*sc))\n", "\n", "plt.subplot(1,4,1)\n", "for k,name in enumerate(['encode','core','decode','global']):\n", " plt.subplot(1,4,1+k)\n", " # Encode network:\n", " start_idx, end_idx = cumsum_nweights_per_block[k:k+2]\n", " for g_ in grads[start_idx:end_idx]:\n", " plt.hist(g_.numpy().flatten(), alpha = 0.9)\n", "\n", " plt.grid()\n", " plt.title(\"%s network\\nGrad. histogram\"%name)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "# Support for global variables\n", "global-to-edge and global-to-node functions are also supported.\n", "\n", "The GN block does nothing special to the global variable (it simply passes it to the input of the node and edge block). In order to use the global variable, define `use_global_to_edge = True` and/or `use_global_to_node = True` in construction, and the size/shape (shapes with rank>1 for the global variable are currently un-tested and probably will lead to bugs) of the global variable during construction. \n", "\n", "For more custom computations and more advanced features, check the source of the factory method `make_mlp_graphnet_factory(...)` to understand naming conventions etc. The global variable is appended to the inputs of the node/edge functions when used.\n", "\n", "For even more control on the functions, you may use the `make_node_mlp(...)` etc factories and pass them directly on the `GraphNet` constructor.\n", "\n", "In the code bellow, some typical utility GN function factories are used.\n", "\n" ] }, { "cell_type": "code", "execution_count": 28, "metadata": {}, "outputs": [], "source": [ "from tf_gnns import make_mlp_graphnet_functions, make_graph_to_graph_and_global_functions, make_full_graphnet_functions\n", "from tf_gnns import GraphNet, Node, Edge, Graph, GraphTuple, make_graph_tuple_from_graph_list\n", "\n", "gnn_size = 50;\n", "node_input_size_enc, node_output_size_enc = [2, 10]\n", "node_input_size_core, node_output_size_core = [10, 10]\n", "edge_input_size_core, edge_output_size_core = [10,10]\n", "node_input_size_dec, node_output_size_dec = [10,2]\n", "\n", "global_core_state = 10\n", "\n", "edge_input_size_enc = 2\n", "edge_output_size_dec = 2\n", "\n", "graph_fcn_enc = make_graph_to_graph_and_global_functions(gnn_size,\n", " node_or_core_input_size = node_input_size_enc,\n", " node_or_core_output_size = node_output_size_enc,\n", " global_output_size = global_core_state)\n", "\n", "graph_fcn_core = make_full_graphnet_functions(gnn_size,\n", " node_or_core_input_size = node_output_size_enc, \n", " node_or_core_output_size = node_output_size_enc,\n", " global_output_size = global_core_state)\n", "\n", "graph_fcn_dec = make_mlp_graphnet_functions(gnn_size,\n", " node_input_size = node_input_size_dec, \n", " node_output_size = node_output_size_dec, \n", " edge_output_size= edge_output_size_dec,\n", " graph_indep=True)\n", "\n", "\n", "graph_fcn_enc['graph_independent'] = True\n", "graph_fcn_dec['graph_independent'] = True\n", "\n", "graph_tuple_to_global, global_fcn = make_graph_tuple_to_global(insize = node_output_size_core,\n", " output_size=global_core_state,\n", " agg_type =\"mean\")" ] }, { "cell_type": "code", "execution_count": 29, "metadata": {}, "outputs": [], "source": [ "gn_enc = GraphNet(**graph_fcn_enc)\n", "gn_core = GraphNet(**graph_fcn_core)\n", "gn_dec = GraphNet(**graph_fcn_dec)\n" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Inspection of the functions:" ] }, { "cell_type": "code", "execution_count": 30, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Graph Indep. GNN function (@140641199316128)

\n", "
\n", "
Model: \"node_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "node_state (InputLayer)         [(None, 2)]          0                                            \n",
       "__________________________________________________________________________________________________\n",
       "edge_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_3 (Concatenate)     (None, 12)           0           node_state[0][0]                 \n",
       "                                                                 edge_state_agg[0][0]             \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_0 (Dense)         (None, 50)           600         concatenate_3[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_1 (Dense)         (None, 50)           2550        node_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_2 (Dense)         (None, 50)           2550        node_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_3 (Dense)         (None, 10)           510         node_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 6,210\n",
       "Trainable params: 6,210\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"edge_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "edge_state (InputLayer)         [(None, 2)]          0                                            \n",
       "__________________________________________________________________________________________________\n",
       "sender_node_state (InputLayer)  [(None, 2)]          0                                            \n",
       "__________________________________________________________________________________________________\n",
       "receiver_node_state (InputLayer [(None, 2)]          0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_2 (Concatenate)     (None, 6)            0           edge_state[0][0]                 \n",
       "                                                                 sender_node_state[0][0]          \n",
       "                                                                 receiver_node_state[0][0]        \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_0 (Dense)         (None, 50)           300         concatenate_2[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_1 (Dense)         (None, 50)           2550        edge_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_2 (Dense)         (None, 50)           2550        edge_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_3 (Dense)         (None, 10)           510         edge_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 5,910\n",
       "Trainable params: 5,910\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"global_mlp/glob_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "node_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "edge_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_4 (Concatenate)     (None, 20)           0           node_state_agg[0][0]             \n",
       "                                                                 edge_state_agg[0][0]             \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_0 (Den (None, 50)           1000        concatenate_4[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_1 (Den (None, 50)           2550        global_mlp/glob_fn/dense_0[0][0] \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_2 (Den (None, 50)           2550        global_mlp/glob_fn/dense_1[0][0] \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_3 (Den (None, 10)           510         global_mlp/glob_fn/dense_2[0][0] \n",
       "==================================================================================================\n",
       "Total params: 6,610\n",
       "Trainable params: 6,610\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "" ], "text/plain": [ "" ] }, "execution_count": 30, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gn_enc" ] }, { "cell_type": "code", "execution_count": 31, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

GNN function (@140641198037840)

\n", "
\n", "
Model: \"node_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "node_state (InputLayer)         [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "edge_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "global_state (InputLayer)       [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_6 (Concatenate)     (None, 30)           0           node_state[0][0]                 \n",
       "                                                                 edge_state_agg[0][0]             \n",
       "                                                                 global_state[0][0]               \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_0 (Dense)         (None, 50)           1500        concatenate_6[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_1 (Dense)         (None, 50)           2550        node_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_2 (Dense)         (None, 50)           2550        node_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "node_fn/dense_3 (Dense)         (None, 10)           510         node_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 7,110\n",
       "Trainable params: 7,110\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"edge_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "edge_state (InputLayer)         [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "sender_node_state (InputLayer)  [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "receiver_node_state (InputLayer [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "global_state (InputLayer)       [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_5 (Concatenate)     (None, 40)           0           edge_state[0][0]                 \n",
       "                                                                 sender_node_state[0][0]          \n",
       "                                                                 receiver_node_state[0][0]        \n",
       "                                                                 global_state[0][0]               \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_0 (Dense)         (None, 50)           2000        concatenate_5[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_1 (Dense)         (None, 50)           2550        edge_fn/dense_0[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_2 (Dense)         (None, 50)           2550        edge_fn/dense_1[0][0]            \n",
       "__________________________________________________________________________________________________\n",
       "edge_fn/dense_3 (Dense)         (None, 10)           510         edge_fn/dense_2[0][0]            \n",
       "==================================================================================================\n",
       "Total params: 7,610\n",
       "Trainable params: 7,610\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"global_mlp/glob_fn/model\"\n",
       "__________________________________________________________________________________________________\n",
       "Layer (type)                    Output Shape         Param #     Connected to                     \n",
       "==================================================================================================\n",
       "global_state (InputLayer)       [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "node_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "edge_state_agg (InputLayer)     [(None, 10)]         0                                            \n",
       "__________________________________________________________________________________________________\n",
       "concatenate_7 (Concatenate)     (None, 30)           0           global_state[0][0]               \n",
       "                                                                 node_state_agg[0][0]             \n",
       "                                                                 edge_state_agg[0][0]             \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_0 (Den (None, 50)           1500        concatenate_7[0][0]              \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_1 (Den (None, 50)           2550        global_mlp/glob_fn/dense_0[0][0] \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_2 (Den (None, 50)           2550        global_mlp/glob_fn/dense_1[0][0] \n",
       "__________________________________________________________________________________________________\n",
       "global_mlp/glob_fn/dense_3 (Den (None, 10)           510         global_mlp/glob_fn/dense_2[0][0] \n",
       "==================================================================================================\n",
       "Total params: 7,110\n",
       "Trainable params: 7,110\n",
       "Non-trainable params: 0\n",
       "__________________________________________________________________________________________________
\n", "
\n", "" ], "text/plain": [ "" ] }, "execution_count": 31, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gn_core" ] }, { "cell_type": "code", "execution_count": 32, "metadata": {}, "outputs": [ { "data": { "text/html": [ "

Graph Indep. GNN function (@140641192216272)

\n", "
\n", "
Model: \"node_fn/model\"\n",
       "_________________________________________________________________\n",
       "Layer (type)                 Output Shape              Param #   \n",
       "=================================================================\n",
       "node_state (InputLayer)      [(None, 10)]              0         \n",
       "_________________________________________________________________\n",
       "node_fn/dense_0 (Dense)      (None, 50)                500       \n",
       "_________________________________________________________________\n",
       "node_fn/dense_1 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "node_fn/dense_2 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "node_fn/dense_3 (Dense)      (None, 2)                 102       \n",
       "=================================================================\n",
       "Total params: 5,702\n",
       "Trainable params: 5,702\n",
       "Non-trainable params: 0\n",
       "_________________________________________________________________
\n", "
\n", "\n", "
\n", "
Model: \"edge_fn/model\"\n",
       "_________________________________________________________________\n",
       "Layer (type)                 Output Shape              Param #   \n",
       "=================================================================\n",
       "edge_state (InputLayer)      [(None, 10)]              0         \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_0 (Dense)      (None, 50)                500       \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_1 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_2 (Dense)      (None, 50)                2550      \n",
       "_________________________________________________________________\n",
       "edge_fn/dense_3 (Dense)      (None, 2)                 102       \n",
       "=================================================================\n",
       "Total params: 5,702\n",
       "Trainable params: 5,702\n",
       "Non-trainable params: 0\n",
       "_________________________________________________________________
\n", "
\n", "" ], "text/plain": [ "" ] }, "execution_count": 32, "metadata": {}, "output_type": "execute_result" } ], "source": [ "gn_dec" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Computation using graph tuples:" ] }, { "cell_type": "code", "execution_count": 33, "metadata": {}, "outputs": [], "source": [ "g_ = gn_enc.graph_tuple_eval(gt.copy())\n", "\n", "# The following lines implement the full-GN block:\n", "g_ = gn_core.graph_tuple_eval(g_) # <-- Note that the global variables are used here. \n", "glob_accum = graph_tuple_to_global(g_)\n", "\n", "g_ = gn_core.graph_tuple_eval(g_)" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## Faster computation using `tf.function`\n", "Tensorflow supports complilation of eager code in graph-mode code. The `GraphNet` can be included in `tf.function` compilable code through the `.eval_tensor_dict(...)` method, that operates in tensor dictionaries which correspond to GraphTuples. This should match the performance of other libraries." ] }, { "cell_type": "code", "execution_count": 35, "metadata": {}, "outputs": [], "source": [ "def eval_graph_tuple(gt):\n", " g_ = gn_enc.graph_tuple_eval(gt)\n", " g_ = gn_core.graph_tuple_eval(g_)\n", " return g_\n", "\n", "@tf.function\n", "def eval_tensordict(td_):\n", " g_ = gn_enc.eval_tensor_dict(td_)\n", " g_ = gn_core.eval_tensor_dict(g_)\n", " return g_\n", " " ] }, { "cell_type": "code", "execution_count": 39, "metadata": {}, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "GraphTuple: 0.009[s/rep]\n", "tf.function: 0.001[s/rep]\n", "Perf. improvement: 86.149%\n" ] } ], "source": [ "from time import time\n", "nreps = 500\n", "t0 = time()\n", "for i in range(nreps):\n", " eval_graph_tuple(gt.copy())\n", "dt_graph_tuple =( time() - t0)/nreps\n", "\n", "t0 =time()\n", "g_ = gt.to_tensor_dict()\n", "for i in range(nreps):\n", " eval_tensordict(g_)\n", "dt_tf_function = (time()-t0)/nreps\n", "\n", "print(\"GraphTuple: %2.3f[s/rep]\\ntf.function: %2.3f[s/rep]\"%(dt_graph_tuple, dt_tf_function))\n", "print(\"Perf. improvement: %2.3f\"%((dt_graph_tuple-dt_tf_function)/dt_graph_tuple * 100)+\"%\")" ] } ], "metadata": { "colab": { "collapsed_sections": [], "name": "01-tf-gnn-basics.ipynb", "provenance": [] }, "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.8.10" } }, "nbformat": 4, "nbformat_minor": 1 }