feat: add new agent loop tests to ci (#5049)
This commit is contained in:
@@ -46,8 +46,11 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
all_configs = [
|
||||
"openai-gpt-4o-mini.json",
|
||||
"openai-o3.json",
|
||||
"claude-3-5-sonnet.json",
|
||||
"claude-3-7-sonnet-extended.json",
|
||||
"gemini-2.5-flash.json",
|
||||
]
|
||||
|
||||
|
||||
@@ -109,9 +112,8 @@ def assert_greeting_response(
|
||||
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
||||
]
|
||||
|
||||
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
|
||||
expected_message_count = 3 if streaming else 2 if from_db else 1
|
||||
assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0)
|
||||
expected_message_count = get_expected_message_count(llm_config, streaming=streaming, from_db=from_db)
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
# User message if loaded from db
|
||||
index = 0
|
||||
@@ -122,7 +124,7 @@ def assert_greeting_response(
|
||||
|
||||
# Reasoning message if reasoning enabled
|
||||
otid_suffix = 0
|
||||
if is_reasoner_model:
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config):
|
||||
assert isinstance(messages[index], HiddenReasoningMessage)
|
||||
else:
|
||||
@@ -168,9 +170,8 @@ def assert_tool_call_response(
|
||||
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
||||
]
|
||||
|
||||
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
|
||||
expected_message_count = 6 if streaming else 5 if from_db else 4
|
||||
assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0)
|
||||
expected_message_count = get_expected_message_count(llm_config, tool_call=True, streaming=streaming, from_db=from_db)
|
||||
assert len(messages) == expected_message_count
|
||||
|
||||
# User message if loaded from db
|
||||
index = 0
|
||||
@@ -181,7 +182,7 @@ def assert_tool_call_response(
|
||||
|
||||
# Reasoning message if reasoning enabled
|
||||
otid_suffix = 0
|
||||
if is_reasoner_model:
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config):
|
||||
assert isinstance(messages[index], HiddenReasoningMessage)
|
||||
else:
|
||||
@@ -191,10 +192,11 @@ def assert_tool_call_response(
|
||||
otid_suffix += 1
|
||||
|
||||
# Assistant message
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
|
||||
# Tool call message
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
@@ -207,16 +209,13 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
|
||||
# Reasoning message if reasoning enabled
|
||||
# Reasoning message if reasoning enabled for openai models
|
||||
otid_suffix = 0
|
||||
# if is_reasoner_model:
|
||||
# if LLMConfig.is_openai_reasoning_model(llm_config):
|
||||
# assert isinstance(messages[index], HiddenReasoningMessage)
|
||||
# else:
|
||||
# assert isinstance(messages[index], ReasoningMessage)
|
||||
# assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
# index += 1
|
||||
# otid_suffix += 1
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config):
|
||||
assert isinstance(messages[index], HiddenReasoningMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
otid_suffix += 1
|
||||
|
||||
# Assistant message
|
||||
assert isinstance(messages[index], AssistantMessage)
|
||||
@@ -283,6 +282,59 @@ async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: floa
|
||||
time.sleep(interval)
|
||||
|
||||
|
||||
def get_expected_message_count(llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False) -> int:
|
||||
"""
|
||||
Returns the expected number of messages for a given LLM configuration.
|
||||
|
||||
Greeting:
|
||||
---------------------------------------------------------------------------------------------------------------------------------------
|
||||
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
|
||||
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
|
||||
| AssistantMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | AssistantMessage |
|
||||
| | AssistantMessage | | AssistantMessage | |
|
||||
|
||||
|
||||
Tool Call:
|
||||
---------------------------------------------------------------------------------------------------------------------------------------
|
||||
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
|
||||
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
|
||||
| ToolCallMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | ToolCallMessage |
|
||||
| ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | ToolReturnMessage |
|
||||
| AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | AssistantMessage |
|
||||
| | HiddenReasoningMessage | AssistantMessage | ToolReturnMessage | |
|
||||
| | AssistantMessage | | AssistantMessage | |
|
||||
|
||||
"""
|
||||
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
|
||||
|
||||
# assistant message
|
||||
expected_message_count = 1
|
||||
|
||||
if is_reasoner_model:
|
||||
# reasoning message
|
||||
expected_message_count += 1
|
||||
|
||||
if tool_call:
|
||||
# tool call and tool return messages
|
||||
expected_message_count += 2
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
# anthropic models return an assistant message first before the tool call message
|
||||
expected_message_count += 1
|
||||
if LLMConfig.is_openai_reasoning_model(llm_config):
|
||||
# openai reasoning models return an additional reasoning message before final assistant message
|
||||
expected_message_count += 1
|
||||
|
||||
if from_db:
|
||||
# user message
|
||||
expected_message_count += 1
|
||||
|
||||
if streaming:
|
||||
# stop reason and usage statistics
|
||||
expected_message_count += 2
|
||||
|
||||
return expected_message_count
|
||||
|
||||
|
||||
# ------------------------------
|
||||
# Fixtures
|
||||
# ------------------------------
|
||||
@@ -476,6 +528,6 @@ async def test_tool_call(
|
||||
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
# assert run_id is not None
|
||||
# run = await client.runs.retrieve(run_id=run_id)
|
||||
# assert run.status == JobStatus.completed
|
||||
assert run_id is not None
|
||||
run = await client.runs.retrieve(run_id=run_id)
|
||||
assert run.status == JobStatus.completed
|
||||
|
||||
Reference in New Issue
Block a user