Source code for assistant_stream_ce.modules.tool_call

import asyncio
from typing import Any, AsyncGenerator
from assistant_stream_ce.assistant_stream_chunk import (
    AssistantStreamChunk,
    ToolCallBeginChunk,
    ToolCallDeltaChunk,
    ToolResultChunk,
)
import string
import random


[docs] def generate_openai_style_tool_call_id(): prefix = "call_" characters = string.ascii_letters + string.digits random_id = "".join(random.choices(characters, k=24)) return prefix + random_id
[docs] class ToolCallController: def __init__(self, queue, tool_name: str, tool_call_id: str, parent_id: str = None): self.tool_name = tool_name self.tool_call_id = tool_call_id self.queue = queue self.loop = asyncio.get_running_loop() begin_chunk = ToolCallBeginChunk( tool_call_id=self.tool_call_id, tool_name=self.tool_name, parent_id=parent_id, ) self.queue.put_nowait(begin_chunk)
[docs] def append_args_text(self, args_text_delta: str) -> None: """Append an args text delta to the stream.""" chunk = ToolCallDeltaChunk( tool_call_id=self.tool_call_id, args_text_delta=args_text_delta, ) self.loop.call_soon_threadsafe(self.queue.put_nowait, chunk)
[docs] def set_result(self, result: Any) -> None: """ Set the result of the tool call. Deprecated: Use set_response() instead. """ import warnings warnings.warn( "set_result() is deprecated. Use set_response() instead.", DeprecationWarning, stacklevel=2, ) return self.set_response(result)
[docs] def set_response( self, result: Any, *, artifact: Any | None = None, is_error: bool = False ) -> None: """Set the result of the tool call.""" chunk = ToolResultChunk( tool_call_id=self.tool_call_id, result=result, artifact=artifact, is_error=is_error, ) self.loop.call_soon_threadsafe(self.queue.put_nowait, chunk) self.close()
[docs] def close(self) -> None: """Close the stream.""" self.loop.call_soon_threadsafe(self.queue.put_nowait, None)
[docs] async def create_tool_call( tool_name: str, tool_call_id: str, parent_id: str = None, ) -> tuple[AsyncGenerator[AssistantStreamChunk, None], ToolCallController]: queue = asyncio.Queue() controller = ToolCallController(queue, tool_name, tool_call_id, parent_id) async def stream(): while True: chunk = await controller.queue.get() if chunk is None: break yield chunk controller.queue.task_done() return stream(), controller