Source code for assistant_stream_ce.state_manager
import asyncio
from typing import Any, Callable, Dict, List
from assistant_stream_ce.assistant_stream_chunk import (
ObjectStreamOperation,
UpdateStateChunk,
)
from assistant_stream_ce.state_proxy import StateProxy
[docs]
class StateManager:
"""Manages state operations with efficient batching and local updates."""
def __init__(
self,
put_chunk_callback: Callable[[UpdateStateChunk], None],
state_data: Any | None = None,
):
"""Initialize with callback for sending state updates."""
self._state_data = state_data
self._pending_operations = []
self._update_scheduled = False
self._put_chunk_callback = put_chunk_callback
self._loop = asyncio.get_running_loop()
self._state_proxy = StateProxy(self, [])
@property
def state(self) -> Any:
"""Access the state proxy object for making state updates.
If state is None, returns None directly instead of a proxy.
Otherwise returns a proxy object for the state.
"""
if self._state_data is None:
return None
return self._state_proxy
@property
def state_data(self) -> Dict[str, Any]:
"""Current state data."""
return self._state_data
[docs]
def add_operations(self, operations: List[ObjectStreamOperation]) -> None:
"""Add operations to pending batch and apply locally."""
# Apply to local state immediately
for operation in operations:
self._apply_operation_to_local_state(operation)
# Add to pending operations
self._pending_operations.extend(operations)
# Schedule batch update if needed
if not self._update_scheduled:
self._update_scheduled = True
self._loop.call_soon_threadsafe(self._flush_updates)
def _flush_updates(self) -> None:
"""Send pending operations as a batch."""
if self._pending_operations:
operations_to_send = self._pending_operations.copy()
self._pending_operations.clear()
self._put_chunk_callback(UpdateStateChunk(operations=operations_to_send))
self._update_scheduled = False
[docs]
def flush(self) -> None:
"""Explicitly flush any pending operations.
This should be called before the run completes to ensure all state updates are sent.
"""
if self._pending_operations:
self._flush_updates()
def _apply_operation_to_local_state(self, operation: ObjectStreamOperation) -> None:
"""Apply operation to local state."""
op_type = operation["type"]
if op_type == "set":
self._update_path(operation["path"], lambda _: operation["value"])
elif op_type == "append-text":
def append_text(current):
if not isinstance(current, str):
path_str = ", ".join(operation["path"])
raise TypeError(f"Expected string at path [{path_str}]")
return current + operation["value"]
self._update_path(operation["path"], append_text)
else:
raise TypeError(f"Invalid operation type: {op_type}")
[docs]
def get_value_at_path(self, path: List[str]) -> Any:
"""Get value at path, raising KeyError for invalid paths."""
if not path:
return self._state_data
# If state is None, we can't navigate further
if self._state_data is None:
raise KeyError(path[0] if path else "")
current = self._state_data
for key in path:
try:
if isinstance(current, list):
idx = int(key)
if idx < 0 or idx >= len(current):
raise KeyError(key)
current = current[idx]
elif isinstance(current, dict):
current = current[key]
else:
raise KeyError(key)
except (ValueError, KeyError, IndexError):
raise KeyError(key)
return current
def _update_path(self, path: List[str], updater: Callable[[Any], Any]) -> None:
"""Update value at path without creating parent objects."""
# Handle empty path (update root state)
if not path:
self._state_data = updater(self._state_data)
return
# Initialize state as empty object if it's null
if self._state_data is None:
self._state_data = {}
if not isinstance(self._state_data, (dict, list)):
raise KeyError(f"Invalid path: [{', '.join(path)}]")
key, *rest = path
# Handle list access
if isinstance(self._state_data, list):
try:
idx = int(key)
if idx < 0 or idx > len(self._state_data):
raise KeyError(key)
if not rest:
# For direct update
if idx == len(self._state_data): # Append case
value = updater(None)
if value is not None:
self._state_data.append(value)
else: # Update existing element
self._state_data[idx] = updater(self._state_data[idx])
else:
# For nested update
if idx == len(self._state_data):
raise KeyError(key)
# Create a copy for the nested update
next_state = self._state_data.copy()
# Create a temporary manager for the nested path
temp_manager = type(self)(lambda _: None)
temp_manager._state_data = next_state[idx]
# Update nested path
temp_manager._update_path(rest, updater)
next_state[idx] = temp_manager._state_data
self._state_data = next_state
except ValueError:
raise KeyError(key)
else: # Handle dict access
if not rest:
# For direct update
if key not in self._state_data and updater(None) is None:
return
self._state_data[key] = updater(self._state_data.get(key))
else:
# For nested update
if key not in self._state_data:
raise KeyError(key)
# Create a copy for the nested update
next_state = dict(self._state_data)
# Create a temporary manager for the nested path
temp_manager = type(self)(lambda _: None)
temp_manager._state_data = next_state[key]
# Update nested path
temp_manager._update_path(rest, updater)
next_state[key] = temp_manager._state_data
self._state_data = next_state