Source code for assistant_stream_ce.modules.langgraph

from typing import Any, Dict, List, Union, Tuple, Callable, Optional

from assistant_stream_ce.create_run import RunController
from langchain_core.messages.ai import AIMessageChunk, add_ai_message_chunks
from langchain_core.messages.tool import ToolMessage


[docs] def append_langgraph_event( state: Dict[str, Any], _namespace: str, type: str, payload: Any ) -> None: """ Append a LangGraph event to the state object. Args: state: The state dictionary to update _namespace: Event namespace (currently unused) type: Event type ('messages' or 'updates') payload: Event payload containing the data to append """ if type == "messages": if "messages" not in state: state["messages"] = [] message = payload[0] message_dict = message.model_dump() # Check if this is an AIMessageChunk is_ai_message_chunk = message_dict.get("type") == "AIMessageChunk" if is_ai_message_chunk: message_dict["type"] = "ai" existing_message_index = None if "id" in message_dict: for i, existing_message in enumerate(state["messages"]): if existing_message.get("id") == message_dict["id"] or \ "tool_call_id" in existing_message and \ "tool_call_id" in message_dict and \ existing_message["tool_call_id"] == message_dict["tool_call_id"]: existing_message_index = i break if existing_message_index is not None: if is_ai_message_chunk: existing_message = state["messages"][ existing_message_index ]._get_value() new_message_dict = add_ai_message_chunks( AIMessageChunk(**{**existing_message, "type": "AIMessageChunk"}), AIMessageChunk(**{**message_dict, "type": "AIMessageChunk"}), ).model_dump() new_message_dict["type"] = "ai" state["messages"][existing_message_index] = new_message_dict else: state["messages"][existing_message_index] = message_dict else: state["messages"].append(message_dict) elif type == "updates": for _node_name, channels in payload.items(): if not isinstance(channels, dict): continue for channel_name, channel_value in channels.items(): if channel_name == "messages": continue # if "messages" in state: # continue # state["messages"] = [c.model_dump() for c in channel_value] state[channel_name] = channel_value
[docs] def get_tool_call_subgraph_state( controller: RunController, namespace: Tuple[str, ...], subgraph_node: Union[str, List[str], Callable[[List[str]], bool]], default_state: Dict[str, Any], *, artifact_field_name: Optional[str] = None, tool_name: Union[str, List[str]] | None = None, ) -> Dict[str, Any]: """ Get the state for a tool call subgraph by traversing the namespace and checking for subgraph nodes. Ensures there's a ToolMessage as the last message and returns its artifact field value. Args: controller: The run controller managing the state subgraph_node: Node name(s) to check against, or a function that checks node names namespace: Tuple of strings in format 'node_name:task_id' artifact_field_name: Optional field name to extract from artifact default_state: Default state to use if artifact field is None Returns: The artifact field value from the ToolMessage. If the last message is already a ToolMessage, returns its artifact field. If it's an AI message with tool calls, creates a ToolMessage and returns the appropriate artifact field value. """ # Helper function to check if a node is a subgraph node def is_subgraph_node(node_name: str) -> bool: if isinstance(subgraph_node, str): return node_name == subgraph_node elif isinstance(subgraph_node, list): return node_name in subgraph_node elif callable(subgraph_node): return subgraph_node([node_name]) return False def is_subgraph_tool(tool: str) -> bool: if isinstance(tool_name, str): return tool == tool_name elif isinstance(tool_name, list): return tool in tool_name return True # Start with the controller's state if controller.state is None: controller.state = default_state current_state = controller.state # Traverse each level of the namespace for namespace_part in namespace: # Split the namespace part to get node_name node_name = namespace_part.split(':')[0] # Check if this node is a subgraph node if is_subgraph_node(node_name): # Check for messages in the current state if "messages" not in current_state: return current_state messages = current_state["messages"] if not messages or len(messages) == 0: return current_state # Get the last message last_message = messages[-1] # Check if it's an AI message if last_message["type"] == "ai": # Check if the AI message has tool calls tool_calls = last_message.get("tool_calls", []) if not tool_calls: # No tool calls, return current state return current_state # Get the last tool call last_tool_call = tool_calls[-1] if not is_subgraph_tool(last_tool_call["name"]): return current_state # Create a new tool message for this tool call tool_message = ToolMessage( tool_call_id=last_tool_call["id"], name=last_tool_call["name"], artifact={} if artifact_field_name else default_state, content="", additional_kwargs={ "streaming": True } ).model_dump() messages.append(tool_message) last_message = tool_message # Check if last message is already a ToolMessage if last_message["type"] == "tool": # Last message is already a ToolMessage, extract and return artifact field if "artifact" not in last_message: last_message["artifact"] = {} if artifact_field_name else default_state artifact = last_message["artifact"] if artifact_field_name: if artifact_field_name not in artifact: artifact[artifact_field_name] = default_state return artifact[artifact_field_name] else: return artifact return current_state