From 904ccd65b6123f69db0bfb38edabe9dc0fc39c84 Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 4 Jun 2025 17:35:55 -0700 Subject: [PATCH] fix: remove separate tool call id in streaming path (#2641) --- letta/agents/letta_agent.py | 5 +-- letta/agents/letta_agent_batch.py | 1 - .../anthropic_streaming_interface.py | 35 +++++++++---------- letta/server/rest_api/utils.py | 3 -- tests/integration_test_pinecone_tool.py | 3 ++ 5 files changed, 21 insertions(+), 26 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 56d8121a..348ed669 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -515,8 +515,7 @@ class LettaAgent(BaseAgent): total_tokens=interface.input_tokens + interface.output_tokens, ), reasoning_content=reasoning_content, - pre_computed_assistant_message_id=interface.letta_assistant_message_id, - pre_computed_tool_message_id=interface.letta_tool_message_id, + pre_computed_assistant_message_id=interface.letta_message_id, step_id=step_id, agent_step_span=agent_step_span, ) @@ -811,7 +810,6 @@ class LettaAgent(BaseAgent): usage: UsageStatistics, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, - pre_computed_tool_message_id: Optional[str] = None, step_id: str | None = None, new_in_context_messages: Optional[List[Message]] = None, agent_step_span: Optional["Span"] = None, @@ -927,7 +925,6 @@ class LettaAgent(BaseAgent): add_heartbeat_request_system_message=continue_stepping, reasoning_content=reasoning_content, pre_computed_assistant_message_id=pre_computed_assistant_message_id, - pre_computed_tool_message_id=pre_computed_tool_message_id, step_id=logged_step.id if logged_step else None, # TODO (cliandy): eventually move over other agent loops ) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 6df672df..dd1c71d0 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -551,7 +551,6 @@ class LettaAgentBatch(BaseAgent): add_heartbeat_request_system_message=False, reasoning_content=reasoning_content, pre_computed_assistant_message_id=None, - pre_computed_tool_message_id=None, llm_batch_item_id=llm_batch_item_id, ) diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 48c57f48..8854d9a5 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -62,8 +62,7 @@ class AnthropicStreamingInterface: self.use_assistant_message = use_assistant_message # Premake IDs for database writes - self.letta_assistant_message_id = Message.generate_id() - self.letta_tool_message_id = Message.generate_id() + self.letta_message_id = Message.generate_id() self.anthropic_mode = None self.message_id = None @@ -152,7 +151,7 @@ class AnthropicStreamingInterface: if not self.use_assistant_message: # Buffer the initial tool call message instead of yielding immediately tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, + id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), date=datetime.now(timezone.utc).isoformat(), ) @@ -165,11 +164,11 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "hidden_reasoning_message": message_index += 1 hidden_reasoning_message = HiddenReasoningMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, state="redacted", hidden_reasoning=content.data, date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.reasoning_messages.append(hidden_reasoning_message) prev_message_type = hidden_reasoning_message.message_type @@ -206,10 +205,10 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "reasoning_message": message_index += 1 reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, reasoning=self.accumulated_inner_thoughts[-1], date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -233,10 +232,10 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "reasoning_message": message_index += 1 reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, reasoning=inner_thoughts_diff, date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -250,7 +249,7 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 for buffered_msg in self.tool_call_buffer: - buffered_msg.otid = Message.generate_otid_from_id(self.letta_tool_message_id, message_index) + buffered_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index) prev_message_type = buffered_msg.message_type yield buffered_msg self.tool_call_buffer = [] @@ -266,17 +265,17 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "assistant_message": message_index += 1 assistant_msg = AssistantMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, content=[TextContent(text=send_message_diff)], date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) prev_message_type = assistant_msg.message_type yield assistant_msg else: # Otherwise, it is a normal tool call - buffer or yield based on inner thoughts status tool_call_msg = ToolCallMessage( - id=self.letta_tool_message_id, + id=self.letta_message_id, tool_call=ToolCallDelta( name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json ), @@ -285,7 +284,7 @@ class AnthropicStreamingInterface: if self.inner_thoughts_complete: if prev_message_type and prev_message_type != "tool_call_message": message_index += 1 - tool_call_msg.otid = Message.generate_otid_from_id(self.letta_tool_message_id, message_index) + tool_call_msg.otid = Message.generate_otid_from_id(self.letta_message_id, message_index) prev_message_type = tool_call_msg.message_type yield tool_call_msg else: @@ -303,11 +302,11 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "reasoning_message": message_index += 1 reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, source="reasoner_model", reasoning=delta.thinking, date=datetime.now(timezone.utc).isoformat(), - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type @@ -322,12 +321,12 @@ class AnthropicStreamingInterface: if prev_message_type and prev_message_type != "reasoning_message": message_index += 1 reasoning_message = ReasoningMessage( - id=self.letta_assistant_message_id, + id=self.letta_message_id, source="reasoner_model", reasoning="", date=datetime.now(timezone.utc).isoformat(), signature=delta.signature, - otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.reasoning_messages.append(reasoning_message) prev_message_type = reasoning_message.message_type diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index d12b100f..cd00af35 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -190,7 +190,6 @@ def create_letta_messages_from_llm_response( add_heartbeat_request_system_message: bool = False, reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, pre_computed_assistant_message_id: Optional[str] = None, - pre_computed_tool_message_id: Optional[str] = None, llm_batch_item_id: Optional[str] = None, step_id: str | None = None, ) -> List[Message]: @@ -245,8 +244,6 @@ def create_letta_messages_from_llm_response( ) ], ) - if pre_computed_tool_message_id: - tool_message.id = pre_computed_tool_message_id messages.append(tool_message) if add_heartbeat_request_system_message: diff --git a/tests/integration_test_pinecone_tool.py b/tests/integration_test_pinecone_tool.py index 9ce7589d..e8f84d4a 100644 --- a/tests/integration_test_pinecone_tool.py +++ b/tests/integration_test_pinecone_tool.py @@ -182,6 +182,9 @@ async def test_pinecone_tool(client: AsyncLetta) -> None: stream_message = response_messages_from_stream[idx] db_message = response_messages_from_db[idx] assert stream_message.message_type == db_message.message_type + print("message type:", stream_message.message_type) + print("stream message:", stream_message.model_dump_json(indent=4)) + print("db message:", db_message.model_dump_json(indent=4)) assert stream_message.id == db_message.id assert stream_message.otid == db_message.otid