import asyncio import json from collections.abc import AsyncGenerator from datetime import datetime, timezone from enum import Enum from typing import Optional from anthropic import AsyncStream from anthropic.types.beta import ( BetaInputJSONDelta, BetaRawContentBlockDeltaEvent, BetaRawContentBlockStartEvent, BetaRawContentBlockStopEvent, BetaRawMessageDeltaEvent, BetaRawMessageStartEvent, BetaRawMessageStopEvent, BetaRawMessageStreamEvent, BetaRedactedThinkingBlock, BetaSignatureDelta, BetaTextBlock, BetaTextDelta, BetaThinkingBlock, BetaThinkingDelta, BetaToolUseBlock, ) from letta.log import get_logger from letta.schemas.letta_message import ( ApprovalRequestMessage, AssistantMessage, HiddenReasoningMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage, ) from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser from letta.server.rest_api.streaming_response import RunCancelledException from letta.server.rest_api.utils import decrement_message_uuid logger = get_logger(__name__) # TODO: These modes aren't used right now - but can be useful we do multiple sequential tool calling within one Claude message class EventMode(Enum): TEXT = "TEXT" TOOL_USE = "TOOL_USE" THINKING = "THINKING" REDACTED_THINKING = "REDACTED_THINKING" # TODO: There's a duplicate version of this in anthropic_streaming_interface class SimpleAnthropicStreamingInterface: """ A simpler version of AnthropicStreamingInterface focused on streaming assistant text and tool call deltas. Updated to support parallel tool calling by collecting completed ToolUse blocks (from content_block stop events) and exposing all finalized tool calls via get_tool_call_objects(). Notes: - We keep emitting the stream (text and tool-call deltas) as before for latency. - We no longer rely on accumulating partial JSON to build the final tool call; instead we read the finalized ToolUse input from the stop event and store it. - Multiple tool calls within a single message (parallel tool use) are collected and can be returned to the agent as a list. """ def __init__( self, requires_approval_tools: list = [], run_id: str | None = None, step_id: str | None = None, ): self.json_parser: JSONParser = PydanticJSONParser() self.run_id = run_id self.step_id = step_id # Premake IDs for database writes self.letta_message_id = Message.generate_id() self.anthropic_mode = None self.message_id = None self.accumulated_inner_thoughts = [] self.tool_call_id = None self.tool_call_name = None self.accumulated_tool_call_args = "" self.previous_parse = {} self.thinking_signature = None # usage trackers self.input_tokens = 0 self.output_tokens = 0 self.model = None # cache tracking (Anthropic-specific) self.cache_read_tokens = 0 self.cache_creation_tokens = 0 # Raw usage from provider (for transparent logging in provider trace) self.raw_usage: dict | None = None # reasoning object trackers self.reasoning_messages = [] # assistant object trackers self.assistant_messages: list[AssistantMessage] = [] # Buffer to hold tool call messages until inner thoughts are complete self.tool_call_buffer = [] self.inner_thoughts_complete = False # Buffer to handle partial XML tags across chunks self.partial_tag_buffer = "" self.requires_approval_tools = requires_approval_tools # Collected finalized tool calls (supports parallel tool use) self.collected_tool_calls: list[ToolCall] = [] # Track active tool_use blocks by stream index for parallel tool calling # { index: {"id": str, "name": str, "args_parts": list[str]} } self.active_tool_uses: dict[int, dict[str, object]] = {} # Maintain start order and indexed collection for stable ordering self._tool_use_start_order: list[int] = [] self._collected_indexed: list[tuple[int, ToolCall]] = [] def get_tool_call_objects(self) -> list[ToolCall]: """Return all finalized tool calls collected during this message (parallel supported).""" # Prefer indexed ordering if available if self._collected_indexed: return [ call for _, call in sorted( self._collected_indexed, key=lambda x: self._tool_use_start_order.index(x[0]) if x[0] in self._tool_use_start_order else x[0], ) ] return self.collected_tool_calls # This exists for legacy compatibility def get_tool_call_object(self) -> Optional[ToolCall]: tool_calls = self.get_tool_call_objects() if tool_calls: return tool_calls[0] return None def get_usage_statistics(self) -> "LettaUsageStatistics": """Extract usage statistics from accumulated streaming data. Returns: LettaUsageStatistics with token counts from the stream. """ from letta.schemas.usage import LettaUsageStatistics # Anthropic: input_tokens is NON-cached only, must add cache tokens for total actual_input_tokens = (self.input_tokens or 0) + (self.cache_read_tokens or 0) + (self.cache_creation_tokens or 0) return LettaUsageStatistics( prompt_tokens=actual_input_tokens, completion_tokens=self.output_tokens or 0, total_tokens=actual_input_tokens + (self.output_tokens or 0), cached_input_tokens=self.cache_read_tokens if self.cache_read_tokens else None, cache_write_tokens=self.cache_creation_tokens if self.cache_creation_tokens else None, reasoning_tokens=None, # Anthropic doesn't report reasoning tokens separately ) def get_reasoning_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]: def _process_group( group: list[ReasoningMessage | HiddenReasoningMessage | AssistantMessage], group_type: str, ) -> TextContent | ReasoningContent | RedactedReasoningContent: if group_type == "reasoning": reasoning_text = "".join(chunk.reasoning for chunk in group).strip() is_native = any(chunk.source == "reasoner_model" for chunk in group) signature = next((chunk.signature for chunk in group if chunk.signature is not None), None) if is_native: return ReasoningContent(is_native=is_native, reasoning=reasoning_text, signature=signature) else: return TextContent(text=reasoning_text) elif group_type == "redacted": redacted_text = "".join(chunk.hidden_reasoning for chunk in group if chunk.hidden_reasoning is not None) return RedactedReasoningContent(data=redacted_text) elif group_type == "text": parts: list[str] = [] for chunk in group: if isinstance(chunk.content, list): parts.append("".join([c.text for c in chunk.content])) else: parts.append(chunk.content) return TextContent(text="".join(parts)) else: raise ValueError("Unexpected group type") merged = [] current_group = [] current_group_type = None # "reasoning" or "redacted" for msg in self.reasoning_messages: # Determine the type of the current message if isinstance(msg, HiddenReasoningMessage): msg_type = "redacted" elif isinstance(msg, ReasoningMessage): msg_type = "reasoning" elif isinstance(msg, AssistantMessage): msg_type = "text" else: raise ValueError("Unexpected message type") # Initialize group type if not set if current_group_type is None: current_group_type = msg_type # If the type changes, process the current group if msg_type != current_group_type: merged.append(_process_group(current_group, current_group_type)) current_group = [] current_group_type = msg_type current_group.append(msg) # Process the final group, if any. if current_group: merged.append(_process_group(current_group, current_group_type)) return merged def get_content(self) -> list[TextContent | ReasoningContent | RedactedReasoningContent]: return self.get_reasoning_content() async def process( self, stream: AsyncStream[BetaRawMessageStreamEvent], ttft_span: Optional["Span"] = None, ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: prev_message_type = None message_index = 0 event = None try: async with stream: async for event in stream: try: async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: if prev_message_type != None: message_index += 1 prev_message_type = new_message_type # print(f"Yielding message: {message}") yield message except (asyncio.CancelledError, RunCancelledException) as e: import traceback logger.info("Cancelled stream attempt but overriding (%s) %s: %s", type(e).__name__, e, traceback.format_exc()) async for message in self._process_event(event, ttft_span, prev_message_type, message_index): new_message_type = message.message_type if new_message_type != prev_message_type: if prev_message_type != None: message_index += 1 prev_message_type = new_message_type yield message # Don't raise the exception here continue except Exception as e: import traceback logger.error("Error processing stream: %s\n%s", e, traceback.format_exc()) if ttft_span: ttft_span.add_event( name="stop_reason", attributes={"stop_reason": StopReasonType.error.value, "error": str(e), "stacktrace": traceback.format_exc()}, ) yield LettaStopReason(stop_reason=StopReasonType.error) # Transform Anthropic errors into our custom error types for consistent handling from letta.llm_api.anthropic_client import AnthropicClient client = AnthropicClient() transformed_error = client.handle_llm_error(e) raise transformed_error finally: logger.info("AnthropicStreamingInterface: Stream processing complete.") async def _process_event( self, event: BetaRawMessageStreamEvent, ttft_span: Optional["Span"] = None, prev_message_type: Optional[str] = None, message_index: int = 0, ) -> AsyncGenerator[LettaMessage | LettaStopReason, None]: """Process a single event from the Anthropic stream and yield any resulting messages. Args: event: The event to process Yields: Messages generated from processing this event """ if isinstance(event, BetaRawContentBlockStartEvent): content = event.content_block if isinstance(content, BetaTextBlock): self.anthropic_mode = EventMode.TEXT # TODO: Can capture citations, etc. elif isinstance(content, BetaToolUseBlock): # New tool_use block started at this index self.anthropic_mode = EventMode.TOOL_USE self.active_tool_uses[event.index] = {"id": content.id, "name": content.name, "args_parts": []} if event.index not in self._tool_use_start_order: self._tool_use_start_order.append(event.index) # Emit an initial tool call delta for this new block name = content.name call_id = content.id # Initialize arguments from the start event's input (often {}) to avoid undefined in UIs if name in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), # Do not emit placeholder arguments here to avoid UI duplicates tool_call=ToolCallDelta(name=name, tool_call_id=call_id), date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), run_id=self.run_id, step_id=self.step_id, ) else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 tool_call_msg = ToolCallMessage( id=self.letta_message_id, # Do not emit placeholder arguments here to avoid UI duplicates tool_call=ToolCallDelta(name=name, tool_call_id=call_id), tool_calls=ToolCallDelta(name=name, tool_call_id=call_id), date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg elif isinstance(content, BetaThinkingBlock): self.anthropic_mode = EventMode.THINKING # TODO: Can capture signature, etc. elif isinstance(content, BetaRedactedThinkingBlock): self.anthropic_mode = EventMode.REDACTED_THINKING if prev_message_type and prev_message_type != "hidden_reasoning_message": message_index += 1 hidden_reasoning_message = HiddenReasoningMessage( id=self.letta_message_id, state="redacted", hidden_reasoning=content.data, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, ) self.reasoning_messages.append(hidden_reasoning_message) prev_message_type = hidden_reasoning_message.message_type yield hidden_reasoning_message elif isinstance(event, BetaRawContentBlockDeltaEvent): delta = event.delta if isinstance(delta, BetaTextDelta): # Safety check if not self.anthropic_mode == EventMode.TEXT: raise RuntimeError(f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}") if prev_message_type and prev_message_type != "assistant_message": message_index += 1 assistant_msg = AssistantMessage( id=self.letta_message_id, content=delta.text, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, ) # self.assistant_messages.append(assistant_msg) self.reasoning_messages.append(assistant_msg) prev_message_type = assistant_msg.message_type yield assistant_msg elif isinstance(delta, BetaInputJSONDelta): # Append partial JSON for the specific tool_use block at this index if not self.anthropic_mode == EventMode.TOOL_USE: raise RuntimeError( f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" ) ctx = self.active_tool_uses.get(event.index) if ctx is None: # Defensive: initialize if missing self.active_tool_uses[event.index] = { "id": self.tool_call_id or "", "name": self.tool_call_name or "", "args_parts": [], } ctx = self.active_tool_uses[event.index] # Append only non-empty partials if delta.partial_json: # Append fragment to args_parts to avoid O(n^2) string growth args_parts = ctx.get("args_parts") if isinstance(ctx.get("args_parts"), list) else None if args_parts is None: args_parts = [] ctx["args_parts"] = args_parts args_parts.append(delta.partial_json) else: # Skip streaming a no-op delta to prevent duplicate placeholders in UI return name = ctx.get("name") call_id = ctx.get("id") if name in self.requires_approval_tools: tool_call_msg = ApprovalRequestMessage( id=decrement_message_uuid(self.letta_message_id), tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json), date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(decrement_message_uuid(self.letta_message_id), -1), run_id=self.run_id, step_id=self.step_id, ) else: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 tool_call_msg = ToolCallMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json), tool_calls=ToolCallDelta(name=name, tool_call_id=call_id, arguments=delta.partial_json), date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, ) prev_message_type = tool_call_msg.message_type yield tool_call_msg elif isinstance(delta, BetaThinkingDelta): # Safety check if not self.anthropic_mode == EventMode.THINKING: raise RuntimeError( f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" ) # Only emit reasoning message if we have actual content if delta.thinking and delta.thinking.strip(): if prev_message_type and prev_message_type != "reasoning_message": message_index += 1 reasoning_message = ReasoningMessage( id=self.letta_message_id, source="reasoner_model", reasoning=delta.thinking, signature=self.thinking_signature, date=datetime.now(timezone.utc).isoformat(), otid=Message.generate_otid_from_id(self.letta_message_id, message_index), run_id=self.run_id, step_id=self.step_id, ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type yield reasoning_message elif isinstance(delta, BetaSignatureDelta): # Safety check if not self.anthropic_mode == EventMode.THINKING: raise RuntimeError( f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" ) # Store signature but don't emit empty reasoning message # Signature will be attached when actual thinking content arrives self.thinking_signature = delta.signature # Update the last reasoning message with the signature so it gets persisted if self.reasoning_messages: last_msg = self.reasoning_messages[-1] if isinstance(last_msg, ReasoningMessage): last_msg.signature = delta.signature elif isinstance(event, BetaRawMessageStartEvent): self.message_id = event.message.id self.input_tokens += event.message.usage.input_tokens self.output_tokens += event.message.usage.output_tokens self.model = event.message.model # Capture cache data if available usage = event.message.usage if hasattr(usage, "cache_read_input_tokens") and usage.cache_read_input_tokens: self.cache_read_tokens += usage.cache_read_input_tokens if hasattr(usage, "cache_creation_input_tokens") and usage.cache_creation_input_tokens: self.cache_creation_tokens += usage.cache_creation_input_tokens # Store raw usage for transparent provider trace logging try: self.raw_usage = usage.model_dump(exclude_none=True) except Exception as e: logger.error(f"Failed to capture raw_usage from Anthropic: {e}") self.raw_usage = None elif isinstance(event, BetaRawMessageDeltaEvent): # Per Anthropic docs: "The token counts shown in the usage field of the # message_delta event are *cumulative*." So we assign, not accumulate. self.output_tokens = event.usage.output_tokens elif isinstance(event, BetaRawMessageStopEvent): # Update raw_usage with final accumulated values for accurate provider trace logging if self.raw_usage: self.raw_usage["input_tokens"] = self.input_tokens self.raw_usage["output_tokens"] = self.output_tokens if self.cache_read_tokens: self.raw_usage["cache_read_input_tokens"] = self.cache_read_tokens if self.cache_creation_tokens: self.raw_usage["cache_creation_input_tokens"] = self.cache_creation_tokens elif isinstance(event, BetaRawContentBlockStopEvent): # Finalize the tool_use block at this index using accumulated deltas ctx = self.active_tool_uses.pop(event.index, None) if ctx is not None and ctx.get("id") and ctx.get("name") is not None: parts = ctx.get("args_parts") if isinstance(ctx.get("args_parts"), list) else None raw_args = "".join(parts) if parts else "" try: # Prefer strict JSON load, fallback to permissive parser tool_input = json.loads(raw_args) if raw_args else {} except json.JSONDecodeError: try: tool_input = self.json_parser.parse(raw_args) if raw_args else {} except Exception: tool_input = {} arguments = json.dumps(tool_input) finalized = ToolCall(id=ctx["id"], function=FunctionCall(arguments=arguments, name=ctx["name"])) # Keep both raw list and indexed list for compatibility self.collected_tool_calls.append(finalized) self._collected_indexed.append((event.index, finalized)) # Reset mode when a content block ends self.anthropic_mode = None