GraphNet utilities

tf_gnns.graphnet_utils.unsorted_segment_min_or_zero(values, indices, num_groups, name='unsorted_segment_min_or_zero')[source]

Aggregates information using elementwise min. Segments with no elements are given a “min” of zero instead of the most positive finite value possible (which is what tf.math.unsorted_segment_min would do). :param values: A Tensor of per-element features. :param indices: A 1-D Tensor whose length is equal to values’ first dimension. :param num_groups: A Tensor. :param name: (string, optional) A name for the operation.

Returns:

A Tensor of the same type as values.

tf_gnns.graphnet_utils.unsorted_segment_max_or_zero(values, indices, num_groups, name='unsorted_segment_max_or_zero')[source]

Aggregates information using elementwise max. Segments with no elements are given a “max” of zero instead of the most negative finite value possible (which is what tf.math.unsorted_segment_max would do). :param values: A Tensor of per-element features. :param indices: A 1-D Tensor whose length is equal to values’ first dimension. :param num_groups: A Tensor. :param name: (string, optional) A name for the operation.

Returns:

A Tensor of the same type as values.

class tf_gnns.graphnet_utils.EdgeInput(value)[source]

Bases: Enum

GLOBAL_STATE = 'global_state'
SENDER_NODE_STATE = 'sender_node_state'
RECEIVER_NODE_STATE = 'receiver_node_state'
EDGE_STATE = 'edge_state'
class tf_gnns.graphnet_utils.NodeInput(value)[source]

Bases: Enum

GLOBAL_STATE = 'global_state'
NODE_STATE = 'node_state'
EDGE_AGG_STATE = 'edge_state_agg'
class tf_gnns.graphnet_utils.GlobalInput(value)[source]

Bases: Enum

GLOBAL_STATE = 'global_state'
EDGE_AGG_STATE = 'edge_state_agg'
NODE_AGG_STATE = 'node_state_agg'
class tf_gnns.graphnet_utils.GraphNet(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]

Bases: object

One GraphNet computation block.

A GraphNet instance wraps edge, node, and optional global update functions and evaluates them over either object graphs or tensor-dict graph tuples.

__init__(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]

Initialize a GraphNet block.

Parameters:
  • edge_function – Keras model for edge updates.

  • node_function – Keras model for node updates.

  • global_function – Optional Keras model for global updates.

  • edge_aggregation_function – Aggregation used for edge-to-node and edge-to-global reductions in non-segment paths.

  • node_to_global_aggregation_function – Aggregation used for node-to-global reductions.

  • graph_independent – If True, disable message passing and apply independent edge/node/global transforms.

  • use_global_input – Whether global features are provided as model inputs.

  • name – Optional block name.

static make_from_path(path)[source]
scan_edge_to_node_aggregation_function(fn)[source]

Scans inputs & outputs of agg. function and keeps track of them for subsequent computation. Throws an error if the aggregation is not compatible with the rest of the defined GN functions.

scan_edge_function()[source]

Edge function signature (wether it has or has not an input) is inferred by the naming of the inputs. Scans the inputs of the edge function to keep track of which graph variables the edge functions uses - throws errors for cases that don’t make sense. Creates a dict that resolves this correspondence for the evaluation.

scan_global_function()[source]
scan_node_function()[source]

Basic sanity checks for the node function.

get_graphnet_input_shapes()[source]
get_graphnet_output_shapes()[source]
summary()[source]
graph_tuple_eval(tf_graph_tuple: GraphTuple)[source]
edge_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
node_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
global_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
eval_tensor_dict(d)[source]

For better performance this uses a dictionary of all the necessary tensors rather than a GraphTuple. A graph tuple is easilly transformed to a dictionary of tensors by _graphtuple_to_tensor_dict function.

d is an ordered dictionary containing the following:

‘edges’,’nodes’,’senders’,’receivers’,’n_edges’,’n_nodes’,’global_attr’,’global_reps_for_edges’,’global_reps_for_nodes’

graph_eval(graph: Graph, **kwargs)[source]

Evaluate a single Graph by routing through GraphTuple execution.

Legacy safe / batched modes were removed in favor of the unified GraphTuple path, which is the supported runtime route.

save(path)[source]

Save the model. Iterates of the keras models required and saves them in a folder.

Parameters:

path – the path to save to. Creates it if it does not exist.

static load_graph_functions(path)[source]

Returns a list of loaded graph functions.

load(path)[source]

Load a model from disk. If the model is already initialized the current graphnet functions are simply overwritten. If the model is un-initialized, this is called from a static method (factory method) to make a new object with consistent properties.

tf_gnns.graphnet_utils.make_mlp(units, input_tensor_list, output_shape, activation='relu', **kwargs)[source]

A default method for making a small MLP. Concatenates named inputs provided in a list. The inputs are named even though they are provided in a list. This naming is used from the evaluation method (self.graph_tuple_eval). This is less error prone to trying to keep a specific ordering for the inputs.

tf_gnns.graphnet_utils.make_node_mlp(units, edge_message_input_shape=None, node_state_input_shape=None, global_state_input_shape=None, node_emb_size=None, use_global_input=False, use_edge_state_agg_input=True, graph_indep=False, use_node_state_input=True, activation='relu', **kwargs)[source]
tf_gnns.graphnet_utils.make_global_mlp(units, global_in_size=None, global_emb_size=None, node_in_size=None, edge_in_size=None, use_node_agg_input=True, use_edge_agg_input=True, use_global_state_input=True, node_to_global_agg=None, graph_indep=False, activation='relu', **kwargs)[source]

Always uses the global state input. May use node/edge aggregator (full GraphNet case)

tf_gnns.graphnet_utils.make_edge_mlp(units, edge_state_input_shape=None, sender_node_state_output_shape=None, global_to_edge_state_size=None, receiver_node_state_shape=None, edge_output_state_size=None, use_edge_state=True, use_sender_out_state=True, use_receiver_state=True, use_global_state=False, graph_indep=False, activation='relu', **kwargs)[source]

When this is a graph-independent edge function, the node states from the sender and receiver are not used. As the make_node_mlp, it uses a list of named keras.Input layers.

class tf_gnns.graphnet_utils.SimpleAggLayerDense(*args, **kwargs)[source]

Bases: Layer

Creates a simple aggregator layer (dense)

call(x)[source]
class tf_gnns.graphnet_utils.SimpleAggLayerSparse(*args, **kwargs)[source]

Bases: Layer

Creates a sparse aggregator layer (for use with GraphTuples)

call(x)[source]
tf_gnns.graphnet_utils.make_keras_simple_agg(input_size, agg_type)[source]

For consistency I’m making this a keras model (easier saving/loading) This is for the “naive” graphNet evaluators. There is a fully batched aggregator with segment sums that should be preferred.

Parameters:
  • input_size – the size of the expected input (mainly useful to enforce consistency)

  • agg_type – [‘mean’,’sum’,’min’,’max’]

tf_gnns.graphnet_utils.make_mean_max_agg(input_size)[source]

A mean and a max aggregator appended together. This was is useful for some special use-cases.

Inpsired by:

Corso, Gabriele, et al. “Principal neighbourhood aggregation for graph nets.” arXiv preprint arXiv:2004.05718 (2020).

tf_gnns.graphnet_utils.make_mean_max_min_agg(input_size)[source]

A mean, a max and a min aggregator appended together. This was is useful for some special use-cases.

Inpsired by:

Corso, Gabriele, et al. “Principal neighbourhood aggregation for graph nets.” arXiv preprint arXiv:2004.05718 (2020).

tf_gnns.graphnet_utils.make_mean_max_min_sum_agg(input_size)[source]
tf_gnns.graphnet_utils.make_mlp_graphnet_functions(units, node_input_size, node_output_size, edge_input_size=None, edge_output_size=None, create_global_function=False, global_input_size=None, global_output_size=None, use_global_input=False, use_global_to_edge=False, use_global_to_node=False, node_mlp_use_edge_state_agg_input=True, graph_indep=False, message_size='auto', aggregation_function='mean', node_to_global_aggr_fn=None, edge_to_global_aggr_fn=None, activation='relu', activate_last_layer=False, **kwargs)[source]

Build callables required to construct a GraphNet block.

Parameters:
  • units – Width specification used when creating edge/node/global MLPs.

  • node_input_size – Input node feature size.

  • node_output_size – Output node feature size.

  • edge_input_size – Optional input edge feature size. Defaults to node_input_size.

  • edge_output_size – Optional output edge feature size. Defaults to node_output_size.

  • create_global_function – If True, create a global update model.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • use_global_input – If True, consume global_attr from inputs.

  • use_global_to_edge – If True, include input global state in the edge model inputs.

  • use_global_to_node – If True, include input global state in the node model inputs.

  • node_mlp_use_edge_state_agg_input – If True, node MLP consumes aggregated edge messages.

  • graph_indep – If True, create a graph-independent block (no message passing).

  • message_size – Edge message size, or "auto" to infer from selected aggregation.

  • aggregation_function – Aggregation mode used for message passing.

  • node_to_global_aggr_fn – Optional override for node-to-global aggregation.

  • edge_to_global_aggr_fn – Optional override for edge-to-global aggregation.

  • activation – Activation used for created MLPs.

  • activate_last_layer – Whether to apply activation on final MLP layer.

  • **kwargs – Forwarded to low-level MLP factory helpers.

Returns:

Dictionary of GraphNet constructor keyword arguments, including edge_function, node_function, optional global_function, and aggregation callables.

Raises:

ValueError – If global state is requested by edge/node models while use_global_input is False.

tf_gnns.graphnet_utils.make_graph_to_graph_and_global_functions(units, node_or_core_input_size, global_output_size, node_or_core_output_size=None, edge_output_size=None, edge_input_size=None, aggregation_function='mean', **kwargs)[source]

Create a graph-to-graph-plus-global GraphNet function bundle.

This wrapper builds a block that does not consume input globals but still produces global outputs by aggregating node and edge states.

Parameters:
  • units – Width specification used for constructed MLPs.

  • node_or_core_input_size – Input node feature size.

  • global_output_size – Output global feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_output_size – Optional output edge feature size.

  • edge_input_size – Optional input edge feature size.

  • aggregation_function – Aggregation mode used in message passing.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet.

tf_gnns.graphnet_utils.make_graph_indep_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', create_global_function=True, use_global_input=True, **kwargs)[source]

Create GraphNet callables for a graph-independent block.

Parameters:
  • units – Width specification used for created MLPs.

  • node_or_core_input_size – Input node feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_input_size – Optional input edge feature size.

  • edge_output_size – Optional output edge feature size.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • aggregation_function – Aggregation mode (retained for interface consistency).

  • create_global_function – If True, create global output function.

  • use_global_input – If True, consume global input state.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet configured in graph-independent mode.

tf_gnns.graphnet_utils.make_mpnn_graphnet_noglobal_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, aggregation_function='mean', **kwargs)[source]

Create GraphNet callables for MPNN-style blocks without globals.

Parameters:
  • units – Width specification used for created MLPs.

  • node_or_core_input_size – Input node feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_input_size – Optional input edge feature size.

  • edge_output_size – Optional output edge feature size.

  • aggregation_function – Aggregation mode used for edge-to-node updates.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet with no global input/output functions.

tf_gnns.graphnet_utils.make_full_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', **kwargs)[source]

Create GraphNet callables for a full message-passing block.

Parameters:
  • units – Width specification used for created MLPs.

  • node_or_core_input_size – Input node feature size.

  • node_or_core_output_size – Optional output node feature size.

  • edge_input_size – Optional input edge feature size.

  • edge_output_size – Optional output edge feature size.

  • global_input_size – Optional input global feature size.

  • global_output_size – Optional output global feature size.

  • aggregation_function – Aggregation mode for message passing.

  • **kwargs – Forwarded to make_mlp_graphnet_functions().

Returns:

Dictionary of constructor arguments for GraphNet configured with edge, node, and global update functions.