GraphNet utilities
- tf_gnns.graphnet_utils.unsorted_segment_min_or_zero(values, indices, num_groups, name='unsorted_segment_min_or_zero')[source]
Aggregates information using elementwise min. Segments with no elements are given a “min” of zero instead of the most positive finite value possible (which is what tf.math.unsorted_segment_min would do). :param values: A Tensor of per-element features. :param indices: A 1-D Tensor whose length is equal to values’ first dimension. :param num_groups: A Tensor. :param name: (string, optional) A name for the operation.
- Returns:
A Tensor of the same type as values.
- tf_gnns.graphnet_utils.unsorted_segment_max_or_zero(values, indices, num_groups, name='unsorted_segment_max_or_zero')[source]
Aggregates information using elementwise max. Segments with no elements are given a “max” of zero instead of the most negative finite value possible (which is what tf.math.unsorted_segment_max would do). :param values: A Tensor of per-element features. :param indices: A 1-D Tensor whose length is equal to values’ first dimension. :param num_groups: A Tensor. :param name: (string, optional) A name for the operation.
- Returns:
A Tensor of the same type as values.
- class tf_gnns.graphnet_utils.EdgeInput(value)[source]
Bases:
Enum- GLOBAL_STATE = 'global_state'
- SENDER_NODE_STATE = 'sender_node_state'
- RECEIVER_NODE_STATE = 'receiver_node_state'
- EDGE_STATE = 'edge_state'
- class tf_gnns.graphnet_utils.NodeInput(value)[source]
Bases:
Enum- GLOBAL_STATE = 'global_state'
- NODE_STATE = 'node_state'
- EDGE_AGG_STATE = 'edge_state_agg'
- class tf_gnns.graphnet_utils.GlobalInput(value)[source]
Bases:
Enum- GLOBAL_STATE = 'global_state'
- EDGE_AGG_STATE = 'edge_state_agg'
- NODE_AGG_STATE = 'node_state_agg'
- class tf_gnns.graphnet_utils.GraphNet(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]
Bases:
objectOne GraphNet computation block.
A
GraphNetinstance wraps edge, node, and optional global update functions and evaluates them over either object graphs or tensor-dict graph tuples.- __init__(edge_function=None, node_function=None, global_function=None, edge_aggregation_function=None, node_to_global_aggregation_function=None, graph_independent=False, use_global_input=False, name=None)[source]
Initialize a GraphNet block.
- Parameters:
edge_function – Keras model for edge updates.
node_function – Keras model for node updates.
global_function – Optional Keras model for global updates.
edge_aggregation_function – Aggregation used for edge-to-node and edge-to-global reductions in non-segment paths.
node_to_global_aggregation_function – Aggregation used for node-to-global reductions.
graph_independent – If
True, disable message passing and apply independent edge/node/global transforms.use_global_input – Whether global features are provided as model inputs.
name – Optional block name.
- scan_edge_to_node_aggregation_function(fn)[source]
Scans inputs & outputs of agg. function and keeps track of them for subsequent computation. Throws an error if the aggregation is not compatible with the rest of the defined GN functions.
- scan_edge_function()[source]
Edge function signature (wether it has or has not an input) is inferred by the naming of the inputs. Scans the inputs of the edge function to keep track of which graph variables the edge functions uses - throws errors for cases that don’t make sense. Creates a dict that resolves this correspondence for the evaluation.
- graph_tuple_eval(tf_graph_tuple: GraphTuple)[source]
- edge_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- node_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- global_block(edges=None, nodes=None, senders=None, receivers=None, n_edges=None, n_nodes=None, global_attr=None, global_reps_for_edges=None, global_reps_for_nodes=None, n_graphs=None)[source]
- eval_tensor_dict(d)[source]
For better performance this uses a dictionary of all the necessary tensors rather than a GraphTuple. A graph tuple is easilly transformed to a dictionary of tensors by _graphtuple_to_tensor_dict function.
- d is an ordered dictionary containing the following:
‘edges’,’nodes’,’senders’,’receivers’,’n_edges’,’n_nodes’,’global_attr’,’global_reps_for_edges’,’global_reps_for_nodes’
- graph_eval(graph: Graph, **kwargs)[source]
Evaluate a single Graph by routing through GraphTuple execution.
Legacy safe / batched modes were removed in favor of the unified GraphTuple path, which is the supported runtime route.
- tf_gnns.graphnet_utils.make_mlp(units, input_tensor_list, output_shape, activation='relu', **kwargs)[source]
A default method for making a small MLP. Concatenates named inputs provided in a list. The inputs are named even though they are provided in a list. This naming is used from the evaluation method (self.graph_tuple_eval). This is less error prone to trying to keep a specific ordering for the inputs.
- tf_gnns.graphnet_utils.make_node_mlp(units, edge_message_input_shape=None, node_state_input_shape=None, global_state_input_shape=None, node_emb_size=None, use_global_input=False, use_edge_state_agg_input=True, graph_indep=False, use_node_state_input=True, activation='relu', **kwargs)[source]
- tf_gnns.graphnet_utils.make_global_mlp(units, global_in_size=None, global_emb_size=None, node_in_size=None, edge_in_size=None, use_node_agg_input=True, use_edge_agg_input=True, use_global_state_input=True, node_to_global_agg=None, graph_indep=False, activation='relu', **kwargs)[source]
Always uses the global state input. May use node/edge aggregator (full GraphNet case)
- tf_gnns.graphnet_utils.make_edge_mlp(units, edge_state_input_shape=None, sender_node_state_output_shape=None, global_to_edge_state_size=None, receiver_node_state_shape=None, edge_output_state_size=None, use_edge_state=True, use_sender_out_state=True, use_receiver_state=True, use_global_state=False, graph_indep=False, activation='relu', **kwargs)[source]
When this is a graph-independent edge function, the node states from the sender and receiver are not used. As the make_node_mlp, it uses a list of named keras.Input layers.
- class tf_gnns.graphnet_utils.SimpleAggLayerDense(*args, **kwargs)[source]
Bases:
LayerCreates a simple aggregator layer (dense)
- class tf_gnns.graphnet_utils.SimpleAggLayerSparse(*args, **kwargs)[source]
Bases:
LayerCreates a sparse aggregator layer (for use with GraphTuples)
- tf_gnns.graphnet_utils.make_keras_simple_agg(input_size, agg_type)[source]
For consistency I’m making this a keras model (easier saving/loading) This is for the “naive” graphNet evaluators. There is a fully batched aggregator with segment sums that should be preferred.
- Parameters:
input_size – the size of the expected input (mainly useful to enforce consistency)
agg_type – [‘mean’,’sum’,’min’,’max’]
- tf_gnns.graphnet_utils.make_mean_max_agg(input_size)[source]
A mean and a max aggregator appended together. This was is useful for some special use-cases.
- Inpsired by:
Corso, Gabriele, et al. “Principal neighbourhood aggregation for graph nets.” arXiv preprint arXiv:2004.05718 (2020).
- tf_gnns.graphnet_utils.make_mean_max_min_agg(input_size)[source]
A mean, a max and a min aggregator appended together. This was is useful for some special use-cases.
- Inpsired by:
Corso, Gabriele, et al. “Principal neighbourhood aggregation for graph nets.” arXiv preprint arXiv:2004.05718 (2020).
- tf_gnns.graphnet_utils.make_mlp_graphnet_functions(units, node_input_size, node_output_size, edge_input_size=None, edge_output_size=None, create_global_function=False, global_input_size=None, global_output_size=None, use_global_input=False, use_global_to_edge=False, use_global_to_node=False, node_mlp_use_edge_state_agg_input=True, graph_indep=False, message_size='auto', aggregation_function='mean', node_to_global_aggr_fn=None, edge_to_global_aggr_fn=None, activation='relu', activate_last_layer=False, **kwargs)[source]
Build callables required to construct a
GraphNetblock.- Parameters:
units – Width specification used when creating edge/node/global MLPs.
node_input_size – Input node feature size.
node_output_size – Output node feature size.
edge_input_size – Optional input edge feature size. Defaults to
node_input_size.edge_output_size – Optional output edge feature size. Defaults to
node_output_size.create_global_function – If
True, create a global update model.global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
use_global_input – If
True, consumeglobal_attrfrom inputs.use_global_to_edge – If
True, include input global state in the edge model inputs.use_global_to_node – If
True, include input global state in the node model inputs.node_mlp_use_edge_state_agg_input – If
True, node MLP consumes aggregated edge messages.graph_indep – If
True, create a graph-independent block (no message passing).message_size – Edge message size, or
"auto"to infer from selected aggregation.aggregation_function – Aggregation mode used for message passing.
node_to_global_aggr_fn – Optional override for node-to-global aggregation.
edge_to_global_aggr_fn – Optional override for edge-to-global aggregation.
activation – Activation used for created MLPs.
activate_last_layer – Whether to apply activation on final MLP layer.
**kwargs – Forwarded to low-level MLP factory helpers.
- Returns:
Dictionary of GraphNet constructor keyword arguments, including
edge_function,node_function, optionalglobal_function, and aggregation callables.- Raises:
ValueError – If global state is requested by edge/node models while
use_global_inputisFalse.
- tf_gnns.graphnet_utils.make_graph_to_graph_and_global_functions(units, node_or_core_input_size, global_output_size, node_or_core_output_size=None, edge_output_size=None, edge_input_size=None, aggregation_function='mean', **kwargs)[source]
Create a graph-to-graph-plus-global GraphNet function bundle.
This wrapper builds a block that does not consume input globals but still produces global outputs by aggregating node and edge states.
- Parameters:
units – Width specification used for constructed MLPs.
node_or_core_input_size – Input node feature size.
global_output_size – Output global feature size.
node_or_core_output_size – Optional output node feature size.
edge_output_size – Optional output edge feature size.
edge_input_size – Optional input edge feature size.
aggregation_function – Aggregation mode used in message passing.
**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNet.
- tf_gnns.graphnet_utils.make_graph_indep_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', create_global_function=True, use_global_input=True, **kwargs)[source]
Create GraphNet callables for a graph-independent block.
- Parameters:
units – Width specification used for created MLPs.
node_or_core_input_size – Input node feature size.
node_or_core_output_size – Optional output node feature size.
edge_input_size – Optional input edge feature size.
edge_output_size – Optional output edge feature size.
global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
aggregation_function – Aggregation mode (retained for interface consistency).
create_global_function – If
True, create global output function.use_global_input – If
True, consume global input state.**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNetconfigured in graph-independent mode.
- tf_gnns.graphnet_utils.make_mpnn_graphnet_noglobal_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, aggregation_function='mean', **kwargs)[source]
Create GraphNet callables for MPNN-style blocks without globals.
- Parameters:
units – Width specification used for created MLPs.
node_or_core_input_size – Input node feature size.
node_or_core_output_size – Optional output node feature size.
edge_input_size – Optional input edge feature size.
edge_output_size – Optional output edge feature size.
aggregation_function – Aggregation mode used for edge-to-node updates.
**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNetwith no global input/output functions.
- tf_gnns.graphnet_utils.make_full_graphnet_functions(units, node_or_core_input_size, node_or_core_output_size=None, edge_input_size=None, edge_output_size=None, global_input_size=None, global_output_size=None, aggregation_function='mean', **kwargs)[source]
Create GraphNet callables for a full message-passing block.
- Parameters:
units – Width specification used for created MLPs.
node_or_core_input_size – Input node feature size.
node_or_core_output_size – Optional output node feature size.
edge_input_size – Optional input edge feature size.
edge_output_size – Optional output edge feature size.
global_input_size – Optional input global feature size.
global_output_size – Optional output global feature size.
aggregation_function – Aggregation mode for message passing.
**kwargs – Forwarded to
make_mlp_graphnet_functions().
- Returns:
Dictionary of constructor arguments for
GraphNetconfigured with edge, node, and global update functions.