feat: add tool call test for new agent loop (#5034)

This commit is contained in:
cthomas
2025-09-30 16:23:03 -07:00
committed by Caren Thomas
parent d3c5d0c330
commit f235dfb356
2 changed files with 164 additions and 3 deletions

View File

@@ -272,11 +272,14 @@ class AnthropicStreamingInterface:
if not self.use_assistant_message:
# Only buffer the initial tool call message if it doesn't require approval
# For approval-required tools, we'll create the ApprovalRequestMessage later
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
if self.tool_call_name not in self.requires_approval_tools:
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
self.tool_call_buffer.append(tool_call_msg)
elif isinstance(content, BetaThinkingBlock):
@@ -737,20 +740,23 @@ class SimpleAnthropicStreamingInterface:
self.tool_call_id = content.id
self.tool_call_name = content.name
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
if self.tool_call_name in self.requires_approval_tools:
if prev_message_type and prev_message_type != "approval_request_message":
message_index += 1
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
prev_message_type = tool_call_msg.message_type
yield tool_call_msg
@@ -809,16 +815,22 @@ class SimpleAnthropicStreamingInterface:
self.accumulated_tool_call_args += delta.partial_json
if self.tool_call_name in self.requires_approval_tools:
if prev_message_type and prev_message_type != "approval_request_message":
message_index += 1
tool_call_msg = ApprovalRequestMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
else:
if prev_message_type and prev_message_type != "tool_call_message":
message_index += 1
tool_call_msg = ToolCallMessage(
id=self.letta_message_id,
tool_call=ToolCallDelta(name=self.tool_call_name, tool_call_id=self.tool_call_id, arguments=delta.partial_json),
date=datetime.now(timezone.utc).isoformat(),
otid=Message.generate_otid_from_id(self.letta_message_id, message_index),
)
yield tool_call_msg

View File

@@ -84,6 +84,13 @@ USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
otid=USER_MESSAGE_OTID,
)
]
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
MessageCreate(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
otid=USER_MESSAGE_OTID,
)
]
def assert_greeting_response(
@@ -145,6 +152,90 @@ def assert_greeting_response(
assert messages[index].step_count > 0
def assert_tool_call_response(
messages: List[Any],
llm_config: LLMConfig,
streaming: bool = False,
from_db: bool = False,
) -> None:
"""
Asserts that the messages list follows the expected sequence:
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
ReasoningMessage -> AssistantMessage.
"""
# Filter out LettaPing messages which are keep-alive messages for SSE streams
messages = [
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
expected_message_count = 6 if streaming else 5 if from_db else 4
assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0)
# User message if loaded from db
index = 0
if from_db:
assert isinstance(messages[index], UserMessage)
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# Reasoning message if reasoning enabled
otid_suffix = 0
if is_reasoner_model:
if LLMConfig.is_openai_reasoning_model(llm_config):
assert isinstance(messages[index], HiddenReasoningMessage)
else:
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Tool call message
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Tool return message
otid_suffix = 0
assert isinstance(messages[index], ToolReturnMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Reasoning message if reasoning enabled
otid_suffix = 0
# if is_reasoner_model:
# if LLMConfig.is_openai_reasoning_model(llm_config):
# assert isinstance(messages[index], HiddenReasoningMessage)
# else:
# assert isinstance(messages[index], ReasoningMessage)
# assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
# index += 1
# otid_suffix += 1
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Stop reason and usage statistics if streaming
if streaming:
assert isinstance(messages[index], LettaStopReason)
assert messages[index].stop_reason == "end_turn"
index += 1
assert isinstance(messages[index], LettaUsageStatistics)
assert messages[index].prompt_tokens > 0
assert messages[index].completion_tokens > 0
assert messages[index].total_tokens > 0
assert messages[index].step_count > 0
async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]:
"""
Accumulates chunks into a list of messages.
@@ -330,3 +421,61 @@ async def test_greeting(
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.completed
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
@pytest.mark.asyncio(loop_scope="function")
async def test_tool_call(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
send_type: str,
) -> None:
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
if send_type == "step":
response = await client.agents.messages.create(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
messages = response.messages
run_id = messages[0].run_id
elif send_type == "async":
run = await client.agents.messages.create_async(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
)
run = await wait_for_run_completion(client, run.id)
messages = await client.runs.messages.list(run_id=run.id)
messages = [m for m in messages if m.message_type != "user_message"]
run_id = run.id
else:
response = client.agents.messages.create_stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=(send_type == "stream_tokens"),
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = messages[0].run_id
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
if "background" in send_type:
response = client.runs.stream(run_id=run_id, starting_after=0)
messages = await accumulate_chunks(response)
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
# assert run_id is not None
# run = await client.runs.retrieve(run_id=run_id)
# assert run.status == JobStatus.completed