Source code for assistant_stream_ce.serialization.data_stream

from assistant_stream_ce.assistant_stream_chunk import (
    AssistantStreamChunk,
)
import json
from typing import AsyncGenerator, Any
from assistant_stream_ce.serialization.assistant_stream_response import (
    AssistantStreamResponse,
)
from assistant_stream_ce.serialization.stream_encoder import StreamEncoder
from assistant_stream_ce.state_proxy import StateProxy


[docs] class StateProxyJSONEncoder(json.JSONEncoder): """Custom JSON encoder that can handle StateProxy objects."""
[docs] def default(self, obj: Any) -> Any: if isinstance(obj, StateProxy): return obj._get_value() return super().default(obj)
[docs] class DataStreamEncoder(StreamEncoder): def __init__(self): pass
[docs] def encode_chunk(self, chunk: AssistantStreamChunk) -> str: if chunk.type == "text-delta": if hasattr(chunk, 'parent_id') and chunk.parent_id: return f"aui-text-delta:{json.dumps({'textDelta': chunk.text_delta, 'parentId': chunk.parent_id}, cls=StateProxyJSONEncoder)}\n" else: return f"0:{json.dumps(chunk.text_delta, cls=StateProxyJSONEncoder)}\n" elif chunk.type == "reasoning-delta": if hasattr(chunk, 'parent_id') and chunk.parent_id: return f"aui-reasoning-delta:{json.dumps({'reasoningDelta': chunk.reasoning_delta, 'parentId': chunk.parent_id}, cls=StateProxyJSONEncoder)}\n" else: return f"g:{json.dumps(chunk.reasoning_delta, cls=StateProxyJSONEncoder)}\n" elif chunk.type == "tool-call-begin": data = {"toolCallId": chunk.tool_call_id, "toolName": chunk.tool_name} if hasattr(chunk, 'parent_id') and chunk.parent_id: data["parentId"] = chunk.parent_id return f'b:{json.dumps(data, cls=StateProxyJSONEncoder)}\n' elif chunk.type == "tool-call-delta": return f'c:{json.dumps({ "toolCallId": chunk.tool_call_id, "argsTextDelta": chunk.args_text_delta }, cls=StateProxyJSONEncoder)}\n' elif chunk.type == "tool-result": res = {"toolCallId": chunk.tool_call_id, "result": chunk.result} if chunk.artifact is not None: res["artifact"] = chunk.artifact if chunk.is_error: res["isError"] = chunk.is_error return f"a:{json.dumps(res, cls=StateProxyJSONEncoder)}\n" elif chunk.type == "data": return f"2:{json.dumps([chunk.data], cls=StateProxyJSONEncoder)}\n" elif chunk.type == "error": return f"3:{json.dumps(chunk.error, cls=StateProxyJSONEncoder)}\n" elif chunk.type == "source": source_data = { "sourceType": chunk.source_type, "id": chunk.id, "url": chunk.url } if chunk.title is not None: source_data["title"] = chunk.title if hasattr(chunk, 'parent_id') and chunk.parent_id: source_data["parentId"] = chunk.parent_id return f"h:{json.dumps(source_data, cls=StateProxyJSONEncoder)}\n" elif chunk.type == "update-state": return f"aui-state:{json.dumps(chunk.operations, cls=StateProxyJSONEncoder)}\n"
[docs] def get_media_type(self) -> str: return "text/plain"
[docs] async def encode_stream( self, stream: AsyncGenerator[AssistantStreamChunk, None] ) -> AsyncGenerator[str, None]: async for chunk in stream: encoded = self.encode_chunk(chunk) if encoded is None: continue yield encoded
[docs] class DataStreamResponse(AssistantStreamResponse): def __init__( self, stream: AsyncGenerator[AssistantStreamChunk, None], ): super().__init__(stream, DataStreamEncoder())