feat: add new agent loop tests to ci (#5049)

This commit is contained in:
cthomas
2025-09-30 20:18:07 -07:00
committed by Caren Thomas
parent 7565fc4a00
commit ad42c886b7

View File

@@ -46,8 +46,11 @@ logger = get_logger(__name__)
all_configs = [
"openai-gpt-4o-mini.json",
"openai-o3.json",
"claude-3-5-sonnet.json",
"claude-3-7-sonnet-extended.json",
"gemini-2.5-flash.json",
]
@@ -109,9 +112,8 @@ def assert_greeting_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
expected_message_count = 3 if streaming else 2 if from_db else 1
assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0)
expected_message_count = get_expected_message_count(llm_config, streaming=streaming, from_db=from_db)
assert len(messages) == expected_message_count
# User message if loaded from db
index = 0
@@ -122,7 +124,7 @@ def assert_greeting_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
if is_reasoner_model:
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
if LLMConfig.is_openai_reasoning_model(llm_config):
assert isinstance(messages[index], HiddenReasoningMessage)
else:
@@ -168,9 +170,8 @@ def assert_tool_call_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
expected_message_count = 6 if streaming else 5 if from_db else 4
assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0)
expected_message_count = get_expected_message_count(llm_config, tool_call=True, streaming=streaming, from_db=from_db)
assert len(messages) == expected_message_count
# User message if loaded from db
index = 0
@@ -181,7 +182,7 @@ def assert_tool_call_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
if is_reasoner_model:
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
if LLMConfig.is_openai_reasoning_model(llm_config):
assert isinstance(messages[index], HiddenReasoningMessage)
else:
@@ -191,10 +192,11 @@ def assert_tool_call_response(
otid_suffix += 1
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
if llm_config.model_endpoint_type == "anthropic":
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Tool call message
assert isinstance(messages[index], ToolCallMessage)
@@ -207,16 +209,13 @@ def assert_tool_call_response(
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Reasoning message if reasoning enabled
# Reasoning message if reasoning enabled for openai models
otid_suffix = 0
# if is_reasoner_model:
# if LLMConfig.is_openai_reasoning_model(llm_config):
# assert isinstance(messages[index], HiddenReasoningMessage)
# else:
# assert isinstance(messages[index], ReasoningMessage)
# assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
# index += 1
# otid_suffix += 1
if LLMConfig.is_openai_reasoning_model(llm_config):
assert isinstance(messages[index], HiddenReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
# Assistant message
assert isinstance(messages[index], AssistantMessage)
@@ -283,6 +282,59 @@ async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: floa
time.sleep(interval)
def get_expected_message_count(llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False) -> int:
"""
Returns the expected number of messages for a given LLM configuration.
Greeting:
---------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| AssistantMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | AssistantMessage |
| | AssistantMessage | | AssistantMessage | |
Tool Call:
---------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| ToolCallMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | ToolCallMessage |
| ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | ToolReturnMessage |
| AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | AssistantMessage |
| | HiddenReasoningMessage | AssistantMessage | ToolReturnMessage | |
| | AssistantMessage | | AssistantMessage | |
"""
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
# assistant message
expected_message_count = 1
if is_reasoner_model:
# reasoning message
expected_message_count += 1
if tool_call:
# tool call and tool return messages
expected_message_count += 2
if llm_config.model_endpoint_type == "anthropic":
# anthropic models return an assistant message first before the tool call message
expected_message_count += 1
if LLMConfig.is_openai_reasoning_model(llm_config):
# openai reasoning models return an additional reasoning message before final assistant message
expected_message_count += 1
if from_db:
# user message
expected_message_count += 1
if streaming:
# stop reason and usage statistics
expected_message_count += 2
return expected_message_count
# ------------------------------
# Fixtures
# ------------------------------
@@ -476,6 +528,6 @@ async def test_tool_call(
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
# assert run_id is not None
# run = await client.runs.retrieve(run_id=run_id)
# assert run.status == JobStatus.completed
assert run_id is not None
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.completed