From f235dfb356c696751ac6aa4a856d86bca0a117b8 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 30 Sep 2025 16:23:03 -0700 Subject: [PATCH] feat: add tool call test for new agent loop (#5034) --- .../anthropic_streaming_interface.py | 18 ++- tests/integration_test_send_message_v2.py | 149 ++++++++++++++++++ 2 files changed, 164 insertions(+), 3 deletions(-) diff --git a/letta/interfaces/anthropic_streaming_interface.py b/letta/interfaces/anthropic_streaming_interface.py index 1ddf7a02..7352aea7 100644 --- a/letta/interfaces/anthropic_streaming_interface.py +++ b/letta/interfaces/anthropic_streaming_interface.py @@ -272,11 +272,14 @@ class AnthropicStreamingInterface: if not self.use_assistant_message: # Only buffer the initial tool call message if it doesn't require approval # For approval-required tools, we'll create the ApprovalRequestMessage later + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 if self.tool_call_name not in self.requires_approval_tools: tool_call_msg = ToolCallMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) self.tool_call_buffer.append(tool_call_msg) elif isinstance(content, BetaThinkingBlock): @@ -737,20 +740,23 @@ class SimpleAnthropicStreamingInterface: self.tool_call_id = content.id self.tool_call_name = content.name - if prev_message_type and prev_message_type != "tool_call_message": - message_index += 1 - if self.tool_call_name in self.requires_approval_tools: + if prev_message_type and prev_message_type != "approval_request_message": + message_index += 1 tool_call_msg = ApprovalRequestMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 tool_call_msg = ToolCallMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id), date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) prev_message_type = tool_call_msg.message_type yield tool_call_msg @@ -809,16 +815,22 @@ class SimpleAnthropicStreamingInterface: self.accumulated_tool_call_args += delta.partial_json if self.tool_call_name in self.requires_approval_tools: + if prev_message_type and prev_message_type != "approval_request_message": + message_index += 1 tool_call_msg = ApprovalRequestMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 tool_call_msg = ToolCallMessage( id=self.letta_message_id, tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json), date=datetime.now(timezone.utc).isoformat(), + otid=Message.generate_otid_from_id(self.letta_message_id, message_index), ) yield tool_call_msg diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 96590c62..dbabbd0f 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -84,6 +84,13 @@ USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ otid=USER_MESSAGE_OTID, ) ] +USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.", + otid=USER_MESSAGE_OTID, + ) +] def assert_greeting_response( @@ -145,6 +152,90 @@ def assert_greeting_response( assert messages[index].step_count > 0 +def assert_tool_call_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config) + expected_message_count = 6 if streaming else 5 if from_db else 4 + assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0) + + # User message if loaded from db + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Reasoning message if reasoning enabled + otid_suffix = 0 + if is_reasoner_model: + if LLMConfig.is_openai_reasoning_model(llm_config): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Assistant message + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Tool call message + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + + # Tool return message + otid_suffix = 0 + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + + # Reasoning message if reasoning enabled + otid_suffix = 0 + # if is_reasoner_model: + # if LLMConfig.is_openai_reasoning_model(llm_config): + # assert isinstance(messages[index], HiddenReasoningMessage) + # else: + # assert isinstance(messages[index], ReasoningMessage) + # assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + # index += 1 + # otid_suffix += 1 + + # Assistant message + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Stop reason and usage statistics if streaming + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]: """ Accumulates chunks into a list of messages. @@ -330,3 +421,61 @@ async def test_greeting( assert run_id is not None run = await client.runs.retrieve(run_id=run_id) assert run.status == JobStatus.completed + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) +@pytest.mark.asyncio(loop_scope="function") +async def test_tool_call( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, +) -> None: + last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1) + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if send_type == "step": + response = await client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + ) + messages = response.messages + run_id = messages[0].run_id + elif send_type == "async": + run = await client.agents.messages.create_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + ) + run = await wait_for_run_completion(client, run.id) + messages = await client.runs.messages.list(run_id=run.id) + messages = [m for m in messages if m.message_type != "user_message"] + run_id = run.id + else: + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + stream_tokens=(send_type == "stream_tokens"), + background=(send_type == "stream_tokens_background"), + ) + messages = await accumulate_chunks(response) + run_id = messages[0].run_id + + assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config) + + if "background" in send_type: + response = client.runs.stream(run_id=run_id, starting_after=0) + messages = await accumulate_chunks(response) + assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config) + + messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + # assert run_id is not None + # run = await client.runs.retrieve(run_id=run_id) + # assert run.status == JobStatus.completed