From 3d94adbac381d4d00c171d69dc57e15d0e57b246 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Wed, 30 Apr 2025 18:07:42 -0700 Subject: [PATCH] fix: user messages on new agent loop are not processed in ADE (includes new json parser) (#1934) --- letta/agent.py | 5 +- letta/agents/letta_agent.py | 72 ++-- letta/functions/helpers.py | 10 +- letta/helpers/message_helper.py | 25 +- .../anthropic_streaming_interface.py | 333 +++++++++--------- ...ai_chat_completions_streaming_interface.py | 2 +- .../rest_api/chat_completions_interface.py | 2 +- letta/server/rest_api/interface.py | 18 +- ...timistic_json_parser.py => json_parser.py} | 88 +++-- letta/server/rest_api/routers/v1/agents.py | 2 +- letta/server/rest_api/utils.py | 29 +- tests/test_optimistic_json_parser.py | 2 +- 12 files changed, 319 insertions(+), 269 deletions(-) rename letta/server/rest_api/{optimistic_json_parser.py => json_parser.py} (70%) diff --git a/letta/agent.py b/letta/agent.py index 84e8d0b4..01587bed 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -28,7 +28,7 @@ from letta.helpers import ToolRulesSolver from letta.helpers.composio_helpers import get_composio_api_key from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.json_helpers import json_dumps, json_loads -from letta.helpers.message_helper import prepare_input_message_create +from letta.helpers.message_helper import convert_message_creates_to_messages from letta.interface import AgentInterface from letta.llm_api.helpers import calculate_summarizer_cutoff, get_token_counts_for_messages, is_context_overflow_error from letta.llm_api.llm_api_tools import create @@ -726,8 +726,7 @@ class Agent(BaseAgent): self.tool_rules_solver.clear_tool_history() # Convert MessageCreate objects to Message objects - message_objects = [prepare_input_message_create(m, self.agent_state.id, True, True) for m in input_messages] - next_input_messages = message_objects + next_input_messages = convert_message_creates_to_messages(input_messages, self.agent_state.id) counter = 0 total_usage = UsageStatistics() step_count = 0 diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 2ac83ec4..79a6f7d0 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -109,7 +109,7 @@ class LettaAgent(BaseAgent): ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) llm_client = LLMClient.create( - llm_config=agent_state.llm_config, + provider=agent_state.llm_config.model_endpoint_type, put_inner_thoughts_first=True, ) @@ -125,7 +125,7 @@ class LettaAgent(BaseAgent): # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=llm_client.llm_config.put_inner_thoughts_in_kwargs + use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs ) async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" @@ -275,45 +275,49 @@ class LettaAgent(BaseAgent): return persisted_messages, continue_stepping def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: - self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) + try: + self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) - # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this - curr_system_message = in_context_messages[0] - curr_memory_str = agent_state.memory.compile() - curr_system_message_text = curr_system_message.content[0].text - if curr_memory_str in curr_system_message_text: - # NOTE: could this cause issues if a block is removed? (substring match would still work) - logger.debug( - f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" - ) - return in_context_messages + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + curr_system_message_text = curr_system_message.content[0].text + if curr_memory_str in curr_system_message_text: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages - memory_edit_timestamp = get_utc_time() + memory_edit_timestamp = get_utc_time() - num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) - new_system_message_str = compile_system_message( - system_prompt=agent_state.system, - in_context_memory=agent_state.memory, - in_context_memory_last_edit=memory_edit_timestamp, - previous_message_count=num_messages, - archival_memory_size=num_archival_memories, - ) - - diff = united_diff(curr_system_message_text, new_system_message_str) - if len(diff) > 0: - logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - - new_system_message = self.message_manager.update_message_by_id( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, ) - # Skip pulling down the agent's memory again to save on a db call - return [new_system_message] + in_context_messages[1:] + diff = united_diff(curr_system_message_text, new_system_message_str) + if len(diff) > 0: + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") - else: - return in_context_messages + new_system_message = self.message_manager.update_message_by_id( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + + # Skip pulling down the agent's memory again to save on a db call + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages + except: + logger.exception(f"Failed to rebuild memory for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name})") + raise @trace_method async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]: diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 705046c4..9797796d 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -39,10 +39,10 @@ def generate_langchain_tool_wrapper( ) -> tuple[str, str]: tool_name = tool.__class__.__name__ import_statement = f"from langchain_community.tools import {tool_name}" - extra_module_imports = generate_import_code(additional_imports_module_attr_map) + extra_module_imports = _generate_import_code(additional_imports_module_attr_map) # Safety check that user has passed in all required imports: - assert_all_classes_are_imported(tool, additional_imports_module_attr_map) + _assert_all_classes_are_imported(tool, additional_imports_module_attr_map) tool_instantiation = f"tool = {generate_imported_tool_instantiation_call_str(tool)}" run_call = f"return tool._run(**kwargs)" @@ -71,7 +71,7 @@ def _assert_code_gen_compilable(code_str): print(f"Syntax error in code: {e}") -def assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: +def _assert_all_classes_are_imported(tool: Union["LangChainBaseTool"], additional_imports_module_attr_map: dict[str, str]) -> None: # Safety check that user has passed in all required imports: tool_name = tool.__class__.__name__ current_class_imports = {tool_name} @@ -193,7 +193,7 @@ def _is_base_model(obj: Any): return isinstance(obj, BaseModel) -def generate_import_code(module_attr_map: Optional[dict]): +def _generate_import_code(module_attr_map: Optional[dict]): if not module_attr_map: return "" @@ -295,7 +295,7 @@ async def _send_message_to_agent_no_stream( return LettaResponse(messages=final_messages, usage=usage_stats) -async def async_send_message_with_retries( +async def _async_send_message_with_retries( server: "SyncServer", sender_agent: "Agent", target_agent_id: str, diff --git a/letta/helpers/message_helper.py b/letta/helpers/message_helper.py index 41d2b8f6..be05b85a 100644 --- a/letta/helpers/message_helper.py +++ b/letta/helpers/message_helper.py @@ -4,7 +4,24 @@ from letta.schemas.letta_message_content import TextContent from letta.schemas.message import Message, MessageCreate -def prepare_input_message_create( +def convert_message_creates_to_messages( + messages: list[MessageCreate], + agent_id: str, + wrap_user_message: bool = True, + wrap_system_message: bool = True, +) -> list[Message]: + return [ + _convert_message_create_to_message( + message=message, + agent_id=agent_id, + wrap_user_message=wrap_user_message, + wrap_system_message=wrap_system_message, + ) + for message in messages + ] + + +def _convert_message_create_to_message( message: MessageCreate, agent_id: str, wrap_user_message: bool = True, @@ -23,12 +40,12 @@ def prepare_input_message_create( raise ValueError("Message content is empty or invalid") # Apply wrapping if needed - if message.role == MessageRole.user and wrap_user_message: + if message.role not in {MessageRole.user, MessageRole.system}: + raise ValueError(f"Invalid message role: {message.role}") + elif message.role == MessageRole.user and wrap_user_message: message_content = system.package_user_message(user_message=message_content) elif message.role == MessageRole.system and wrap_system_message: message_content = system.package_system_message(system_message=message_content) - elif message.role not in {MessageRole.user, MessageRole.system}: - raise ValueError(f"Invalid message role: {message.role}") return Message( agent_id=agent_id, diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 84178932..974673f8 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -35,7 +35,7 @@ from letta.schemas.letta_message import ( from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import JSONParser, PydanticJSONParser logger = get_logger(__name__) @@ -56,7 +56,7 @@ class AnthropicStreamingInterface: """ def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): - self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.json_parser: JSONParser = PydanticJSONParser() self.use_assistant_message = use_assistant_message # Premake IDs for database writes @@ -68,7 +68,7 @@ class AnthropicStreamingInterface: self.accumulated_inner_thoughts = [] self.tool_call_id = None self.tool_call_name = None - self.accumulated_tool_call_args = [] + self.accumulated_tool_call_args = "" self.previous_parse = {} # usage trackers @@ -85,193 +85,200 @@ class AnthropicStreamingInterface: def get_tool_call_object(self) -> ToolCall: """Useful for agent loop""" - return ToolCall( - id=self.tool_call_id, function=FunctionCall(arguments="".join(self.accumulated_tool_call_args), name=self.tool_call_name) - ) + return ToolCall(id=self.tool_call_id, function=FunctionCall(arguments=self.accumulated_tool_call_args, name=self.tool_call_name)) def _check_inner_thoughts_complete(self, combined_args: str) -> bool: """ Check if inner thoughts are complete in the current tool call arguments by looking for a closing quote after the inner_thoughts field """ - if not self.put_inner_thoughts_in_kwarg: - # None of the things should have inner thoughts in kwargs - return True - else: - parsed = self.optimistic_json_parser.parse(combined_args) - # TODO: This will break on tools with 0 input - return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() + try: + if not self.put_inner_thoughts_in_kwarg: + # None of the things should have inner thoughts in kwargs + return True + else: + parsed = self.json_parser.parse(combined_args) + # TODO: This will break on tools with 0 input + return len(parsed.keys()) > 1 and INNER_THOUGHTS_KWARG in parsed.keys() + except Exception as e: + logger.error("Error checking inner thoughts: %s", e) + raise async def process(self, stream: AsyncStream[BetaRawMessageStreamEvent]) -> AsyncGenerator[LettaMessage, None]: - async with stream: - async for event in stream: - # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock - if isinstance(event, BetaRawContentBlockStartEvent): - content = event.content_block + try: + async with stream: + async for event in stream: + # TODO: Support BetaThinkingBlock, BetaRedactedThinkingBlock + if isinstance(event, BetaRawContentBlockStartEvent): + content = event.content_block - if isinstance(content, BetaTextBlock): - self.anthropic_mode = EventMode.TEXT - # TODO: Can capture citations, etc. - elif isinstance(content, BetaToolUseBlock): - self.anthropic_mode = EventMode.TOOL_USE - self.tool_call_id = content.id - self.tool_call_name = content.name - self.inner_thoughts_complete = False + if isinstance(content, BetaTextBlock): + self.anthropic_mode = EventMode.TEXT + # TODO: Can capture citations, etc. + elif isinstance(content, BetaToolUseBlock): + self.anthropic_mode = EventMode.TOOL_USE + self.tool_call_id = content.id + self.tool_call_name = content.name + self.inner_thoughts_complete = False - if not self.use_assistant_message: - # Buffer the initial tool call message instead of yielding immediately - tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, - tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + if not self.use_assistant_message: + # Buffer the initial tool call message instead of yielding immediately + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), + date=datetime.now(timezone.utc).isoformat(), + ) + self.tool_call_buffer.append(tool_call_msg) + elif isinstance(content, BetaThinkingBlock): + self.anthropic_mode = EventMode.THINKING + # TODO: Can capture signature, etc. + elif isinstance(content, BetaRedactedThinkingBlock): + self.anthropic_mode = EventMode.REDACTED_THINKING + + hidden_reasoning_message = HiddenReasoningMessage( + id=self.letta_assistant_message_id, + state="redacted", + hidden_reasoning=content.data, date=datetime.now(timezone.utc).isoformat(), ) - self.tool_call_buffer.append(tool_call_msg) - elif isinstance(content, BetaThinkingBlock): - self.anthropic_mode = EventMode.THINKING - # TODO: Can capture signature, etc. - elif isinstance(content, BetaRedactedThinkingBlock): - self.anthropic_mode = EventMode.REDACTED_THINKING + self.reasoning_messages.append(hidden_reasoning_message) + yield hidden_reasoning_message - hidden_reasoning_message = HiddenReasoningMessage( - id=self.letta_assistant_message_id, - state="redacted", - hidden_reasoning=content.data, - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(hidden_reasoning_message) - yield hidden_reasoning_message + elif isinstance(event, BetaRawContentBlockDeltaEvent): + delta = event.delta - elif isinstance(event, BetaRawContentBlockDeltaEvent): - delta = event.delta + if isinstance(delta, BetaTextDelta): + # Safety check + if not self.anthropic_mode == EventMode.TEXT: + raise RuntimeError( + f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}" + ) - if isinstance(delta, BetaTextDelta): - # Safety check - if not self.anthropic_mode == EventMode.TEXT: - raise RuntimeError( - f"Streaming integrity failed - received BetaTextDelta object while not in TEXT EventMode: {delta}" - ) + # TODO: Strip out more robustly, this is pretty hacky lol + delta.text = delta.text.replace("", "") + self.accumulated_inner_thoughts.append(delta.text) - # TODO: Strip out more robustly, this is pretty hacky lol - delta.text = delta.text.replace("", "") - self.accumulated_inner_thoughts.append(delta.text) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - reasoning=self.accumulated_inner_thoughts[-1], - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - - elif isinstance(delta, BetaInputJSONDelta): - if not self.anthropic_mode == EventMode.TOOL_USE: - raise RuntimeError( - f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" - ) - - self.accumulated_tool_call_args.append(delta.partial_json) - combined_args = "".join(self.accumulated_tool_call_args) - current_parsed = self.optimistic_json_parser.parse(combined_args) - - # Start detecting a difference in inner thoughts - previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "") - current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "") - inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :] - - if inner_thoughts_diff: reasoning_message = ReasoningMessage( id=self.letta_assistant_message_id, - reasoning=inner_thoughts_diff, + reasoning=self.accumulated_inner_thoughts[-1], date=datetime.now(timezone.utc).isoformat(), ) self.reasoning_messages.append(reasoning_message) yield reasoning_message - # Check if inner thoughts are complete - if so, flush the buffer - if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(combined_args): - self.inner_thoughts_complete = True - # Flush all buffered tool call messages + elif isinstance(delta, BetaInputJSONDelta): + if not self.anthropic_mode == EventMode.TOOL_USE: + raise RuntimeError( + f"Streaming integrity failed - received BetaInputJSONDelta object while not in TOOL_USE EventMode: {delta}" + ) + + self.accumulated_tool_call_args += delta.partial_json + current_parsed = self.json_parser.parse(self.accumulated_tool_call_args) + + # Start detecting a difference in inner thoughts + previous_inner_thoughts = self.previous_parse.get(INNER_THOUGHTS_KWARG, "") + current_inner_thoughts = current_parsed.get(INNER_THOUGHTS_KWARG, "") + inner_thoughts_diff = current_inner_thoughts[len(previous_inner_thoughts) :] + + if inner_thoughts_diff: + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + reasoning=inner_thoughts_diff, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + + # Check if inner thoughts are complete - if so, flush the buffer + if not self.inner_thoughts_complete and self._check_inner_thoughts_complete(self.accumulated_tool_call_args): + self.inner_thoughts_complete = True + # Flush all buffered tool call messages + for buffered_msg in self.tool_call_buffer: + yield buffered_msg + self.tool_call_buffer = [] + + # Start detecting special case of "send_message" + if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: + previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") + send_message_diff = current_send_message[len(previous_send_message) :] + + # Only stream out if it's not an empty string + if send_message_diff: + yield AssistantMessage( + id=self.letta_assistant_message_id, + content=[TextContent(text=send_message_diff)], + date=datetime.now(timezone.utc).isoformat(), + ) + else: + # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + tool_call=ToolCallDelta(arguments=delta.partial_json), + date=datetime.now(timezone.utc).isoformat(), + ) + + if self.inner_thoughts_complete: + yield tool_call_msg + else: + self.tool_call_buffer.append(tool_call_msg) + + # Set previous parse + self.previous_parse = current_parsed + elif isinstance(delta, BetaThinkingDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning=delta.thinking, + date=datetime.now(timezone.utc).isoformat(), + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(delta, BetaSignatureDelta): + # Safety check + if not self.anthropic_mode == EventMode.THINKING: + raise RuntimeError( + f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" + ) + + reasoning_message = ReasoningMessage( + id=self.letta_assistant_message_id, + source="reasoner_model", + reasoning="", + date=datetime.now(timezone.utc).isoformat(), + signature=delta.signature, + ) + self.reasoning_messages.append(reasoning_message) + yield reasoning_message + elif isinstance(event, BetaRawMessageStartEvent): + self.message_id = event.message.id + self.input_tokens += event.message.usage.input_tokens + self.output_tokens += event.message.usage.output_tokens + elif isinstance(event, BetaRawMessageDeltaEvent): + self.output_tokens += event.usage.output_tokens + elif isinstance(event, BetaRawMessageStopEvent): + # Don't do anything here! We don't want to stop the stream. + pass + elif isinstance(event, BetaRawContentBlockStopEvent): + # If we're exiting a tool use block and there are still buffered messages, + # we should flush them now + if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer: for buffered_msg in self.tool_call_buffer: yield buffered_msg self.tool_call_buffer = [] - # Start detecting special case of "send_message" - if self.tool_call_name == DEFAULT_MESSAGE_TOOL and self.use_assistant_message: - previous_send_message = self.previous_parse.get(DEFAULT_MESSAGE_TOOL_KWARG, "") - current_send_message = current_parsed.get(DEFAULT_MESSAGE_TOOL_KWARG, "") - send_message_diff = current_send_message[len(previous_send_message) :] - - # Only stream out if it's not an empty string - if send_message_diff: - yield AssistantMessage( - id=self.letta_assistant_message_id, - content=[TextContent(text=send_message_diff)], - date=datetime.now(timezone.utc).isoformat(), - ) - else: - # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status - tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, - tool_call=ToolCallDelta(arguments=delta.partial_json), - date=datetime.now(timezone.utc).isoformat(), - ) - - if self.inner_thoughts_complete: - yield tool_call_msg - else: - self.tool_call_buffer.append(tool_call_msg) - - # Set previous parse - self.previous_parse = current_parsed - elif isinstance(delta, BetaThinkingDelta): - # Safety check - if not self.anthropic_mode == EventMode.THINKING: - raise RuntimeError( - f"Streaming integrity failed - received BetaThinkingBlock object while not in THINKING EventMode: {delta}" - ) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - source="reasoner_model", - reasoning=delta.thinking, - date=datetime.now(timezone.utc).isoformat(), - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - elif isinstance(delta, BetaSignatureDelta): - # Safety check - if not self.anthropic_mode == EventMode.THINKING: - raise RuntimeError( - f"Streaming integrity failed - received BetaSignatureDelta object while not in THINKING EventMode: {delta}" - ) - - reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, - source="reasoner_model", - reasoning="", - date=datetime.now(timezone.utc).isoformat(), - signature=delta.signature, - ) - self.reasoning_messages.append(reasoning_message) - yield reasoning_message - elif isinstance(event, BetaRawMessageStartEvent): - self.message_id = event.message.id - self.input_tokens += event.message.usage.input_tokens - self.output_tokens += event.message.usage.output_tokens - elif isinstance(event, BetaRawMessageDeltaEvent): - self.output_tokens += event.usage.output_tokens - elif isinstance(event, BetaRawMessageStopEvent): - # Don't do anything here! We don't want to stop the stream. - pass - elif isinstance(event, BetaRawContentBlockStopEvent): - # If we're exiting a tool use block and there are still buffered messages, - # we should flush them now - if self.anthropic_mode == EventMode.TOOL_USE and self.tool_call_buffer: - for buffered_msg in self.tool_call_buffer: - yield buffered_msg - self.tool_call_buffer = [] - - self.anthropic_mode = None + self.anthropic_mode = None + except Exception as e: + logger.error("Error processing stream: %s", e) + raise + finally: + logger.info("AnthropicStreamingInterface: Stream processing complete.") def get_reasoning_content(self) -> List[Union[TextContent, ReasoningContent, RedactedReasoningContent]]: def _process_group( diff --git a/letta/interfaces/openai_chat_completions_streaming_interface.py b/letta/interfaces/openai_chat_completions_streaming_interface.py index 0f3bd841..6ff38cab 100644 --- a/letta/interfaces/openai_chat_completions_streaming_interface.py +++ b/letta/interfaces/openai_chat_completions_streaming_interface.py @@ -5,7 +5,7 @@ from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, from letta.constants import PRE_EXECUTION_MESSAGE_ARG from letta.interfaces.utils import _format_sse_chunk -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser class OpenAIChatCompletionsStreamingInterface: diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 0f684ed7..9b05ca84 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -12,7 +12,7 @@ from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import LettaMessage from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.streaming_interface import AgentChunkStreamingInterface logger = get_logger(__name__) diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index edf8a233..9a89f907 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -28,7 +28,7 @@ from letta.schemas.letta_message import ( from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.streaming_interface import AgentChunkStreamingInterface from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor from letta.utils import parse_json @@ -291,7 +291,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg) # @matt's changes here, adopting new optimistic json parser - self.current_function_arguments = [] + self.current_function_arguments = "" self.optimistic_json_parser = OptimisticJSONParser() self.current_json_parse_result = {} @@ -387,7 +387,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def stream_start(self): """Initialize streaming by activating the generator and clearing any old chunks.""" self.streaming_chat_completion_mode_function_name = None - self.current_function_arguments = [] + self.current_function_arguments = "" self.current_json_parse_result = {} if not self._active: @@ -398,7 +398,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def stream_end(self): """Clean up the stream by deactivating and clearing chunks.""" self.streaming_chat_completion_mode_function_name = None - self.current_function_arguments = [] + self.current_function_arguments = "" self.current_json_parse_result = {} # if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode: @@ -609,14 +609,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # early exit to turn into content mode return None if tool_call.function.arguments: - self.current_function_arguments.append(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name: # Strip out any extras tokens # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk - combined_args = "".join(self.current_function_arguments) - parsed_args = self.optimistic_json_parser.parse(combined_args) + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( self.assistant_message_tool_kwarg @@ -686,7 +685,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # updates_inner_thoughts = "" # else: # OpenAI # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) - self.current_function_arguments.append(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) # If we have inner thoughts, we should output them as a chunk @@ -805,8 +804,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # TODO: THIS IS HORRIBLE # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE - combined_args = "".join(self.current_function_arguments) - parsed_args = self.optimistic_json_parser.parse(combined_args) + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( self.assistant_message_tool_kwarg diff --git a/letta/server/rest_api/optimistic_json_parser.py b/letta/server/rest_api/json_parser.py similarity index 70% rename from letta/server/rest_api/optimistic_json_parser.py rename to letta/server/rest_api/json_parser.py index c3a2f069..27b4f0cf 100644 --- a/letta/server/rest_api/optimistic_json_parser.py +++ b/letta/server/rest_api/json_parser.py @@ -1,7 +1,43 @@ import json +from abc import ABC, abstractmethod +from typing import Any + +from pydantic_core import from_json + +from letta.log import get_logger + +logger = get_logger(__name__) -class OptimisticJSONParser: +class JSONParser(ABC): + @abstractmethod + def parse(self, input_str: str) -> Any: + raise NotImplementedError() + + +class PydanticJSONParser(JSONParser): + """ + https://docs.pydantic.dev/latest/concepts/json/#json-parsing + If `strict` is True, we will not allow for partial parsing of JSON. + + Compared with `OptimisticJSONParser`, this parser is more strict. + Note: This will not partially parse strings which may be decrease parsing speed for message strings + """ + + def __init__(self, strict=False): + self.strict = strict + + def parse(self, input_str: str) -> Any: + if not input_str: + return {} + try: + return from_json(input_str, allow_partial="trailing-strings" if not self.strict else False) + except ValueError as e: + logger.error(f"Failed to parse JSON: {e}") + raise + + +class OptimisticJSONParser(JSONParser): """ A JSON parser that attempts to parse a given string using `json.loads`, and if that fails, it parses as much valid JSON as possible while @@ -13,25 +49,25 @@ class OptimisticJSONParser: def __init__(self, strict=False): self.strict = strict self.parsers = { - " ": self.parse_space, - "\r": self.parse_space, - "\n": self.parse_space, - "\t": self.parse_space, - "[": self.parse_array, - "{": self.parse_object, - '"': self.parse_string, - "t": self.parse_true, - "f": self.parse_false, - "n": self.parse_null, + " ": self._parse_space, + "\r": self._parse_space, + "\n": self._parse_space, + "\t": self._parse_space, + "[": self._parse_array, + "{": self._parse_object, + '"': self._parse_string, + "t": self._parse_true, + "f": self._parse_false, + "n": self._parse_null, } # Register number parser for digits and signs for char in "0123456789.-": self.parsers[char] = self.parse_number self.last_parse_reminding = None - self.on_extra_token = self.default_on_extra_token + self.on_extra_token = self._default_on_extra_token - def default_on_extra_token(self, text, data, reminding): + def _default_on_extra_token(self, text, data, reminding): print(f"Parsed JSON with extra tokens: {data}, remaining: {reminding}") def parse(self, input_str): @@ -45,7 +81,7 @@ class OptimisticJSONParser: try: return json.loads(input_str) except json.JSONDecodeError as decode_error: - data, reminding = self.parse_any(input_str, decode_error) + data, reminding = self._parse_any(input_str, decode_error) self.last_parse_reminding = reminding if self.on_extra_token and reminding: self.on_extra_token(input_str, data, reminding) @@ -53,7 +89,7 @@ class OptimisticJSONParser: else: return json.loads("{}") - def parse_any(self, input_str, decode_error): + def _parse_any(self, input_str, decode_error): """Determine which parser to use based on the first character.""" if not input_str: raise decode_error @@ -62,11 +98,11 @@ class OptimisticJSONParser: raise decode_error return parser(input_str, decode_error) - def parse_space(self, input_str, decode_error): + def _parse_space(self, input_str, decode_error): """Strip leading whitespace and parse again.""" - return self.parse_any(input_str.strip(), decode_error) + return self._parse_any(input_str.strip(), decode_error) - def parse_array(self, input_str, decode_error): + def _parse_array(self, input_str, decode_error): """Parse a JSON array, returning the list and remaining string.""" # Skip the '[' input_str = input_str[1:] @@ -77,7 +113,7 @@ class OptimisticJSONParser: # Skip the ']' input_str = input_str[1:] break - value, input_str = self.parse_any(input_str, decode_error) + value, input_str = self._parse_any(input_str, decode_error) array_values.append(value) input_str = input_str.strip() if input_str.startswith(","): @@ -85,7 +121,7 @@ class OptimisticJSONParser: input_str = input_str[1:].strip() return array_values, input_str - def parse_object(self, input_str, decode_error): + def _parse_object(self, input_str, decode_error): """Parse a JSON object, returning the dict and remaining string.""" # Skip the '{' input_str = input_str[1:] @@ -96,7 +132,7 @@ class OptimisticJSONParser: # Skip the '}' input_str = input_str[1:] break - key, input_str = self.parse_any(input_str, decode_error) + key, input_str = self._parse_any(input_str, decode_error) input_str = input_str.strip() if not input_str or input_str[0] == "}": @@ -113,7 +149,7 @@ class OptimisticJSONParser: input_str = input_str[1:] break - value, input_str = self.parse_any(input_str, decode_error) + value, input_str = self._parse_any(input_str, decode_error) obj[key] = value input_str = input_str.strip() if input_str.startswith(","): @@ -121,7 +157,7 @@ class OptimisticJSONParser: input_str = input_str[1:].strip() return obj, input_str - def parse_string(self, input_str, decode_error): + def _parse_string(self, input_str, decode_error): """Parse a JSON string, respecting escaped quotes if present.""" end = input_str.find('"', 1) while end != -1 and input_str[end - 1] == "\\": @@ -166,19 +202,19 @@ class OptimisticJSONParser: return num, remainder - def parse_true(self, input_str, decode_error): + def _parse_true(self, input_str, decode_error): """Parse a 'true' value.""" if input_str.startswith(("t", "T")): return True, input_str[4:] raise decode_error - def parse_false(self, input_str, decode_error): + def _parse_false(self, input_str, decode_error): """Parse a 'false' value.""" if input_str.startswith(("f", "F")): return False, input_str[5:] raise decode_error - def parse_null(self, input_str, decode_error): + def _parse_null(self, input_str, decode_error): """Parse a 'null' value.""" if input_str.startswith("n"): return None, input_str[4:] diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index bd03348e..4dc7819a 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -680,7 +680,7 @@ async def send_message_streaming( server: SyncServer = Depends(get_letta_server), request: LettaStreamingRequest = Body(...), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present -): +) -> StreamingResponse | LettaResponse: """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 40471eab..2e9b3e9a 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -16,6 +16,7 @@ from pydantic import BaseModel from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.message_helper import convert_message_creates_to_messages from letta.log import get_logger from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent @@ -143,27 +144,15 @@ def log_error_to_sentry(e): def create_input_messages(input_messages: List[MessageCreate], agent_id: str, actor: User) -> List[Message]: """ Converts a user input message into the internal structured format. - """ - new_messages = [] - for input_message in input_messages: - # Construct the Message object - new_message = Message( - id=f"message-{uuid.uuid4()}", - role=input_message.role, - content=input_message.content, - name=input_message.name, - otid=input_message.otid, - sender_id=input_message.sender_id, - organization_id=actor.organization_id, - agent_id=agent_id, - model=None, - tool_calls=None, - tool_call_id=None, - created_at=get_utc_time(), - ) - new_messages.append(new_message) - return new_messages + TODO (cliandy): this effectively duplicates the functionality of `convert_message_creates_to_messages`, + we should unify this when it's clear what message attributes we need. + """ + + messages = convert_message_creates_to_messages(input_messages, agent_id, wrap_user_message=False, wrap_system_message=False) + for message in messages: + message.organization_id = actor.organization_id + return messages def create_letta_messages_from_llm_response( diff --git a/tests/test_optimistic_json_parser.py b/tests/test_optimistic_json_parser.py index f7741f7c..08bb11c1 100644 --- a/tests/test_optimistic_json_parser.py +++ b/tests/test_optimistic_json_parser.py @@ -3,7 +3,7 @@ from unittest.mock import patch import pytest -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.json_parser import OptimisticJSONParser @pytest.fixture