feat: handle flaky reasoning in v2 tests (#5133)

This commit is contained in:
cthomas
2025-10-03 15:52:00 -07:00
committed by Caren Thomas
parent 6016ac0f33
commit 89321ff29a
3 changed files with 89 additions and 100 deletions

View File

@@ -6,7 +6,7 @@ import time
import uuid
from contextlib import contextmanager
from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict, List
from typing import Any, Dict, List, Tuple
from unittest.mock import patch
import httpx
@@ -115,19 +115,10 @@ def assert_greeting_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
expected_message_count = get_expected_message_count(llm_config, streaming=streaming, from_db=from_db)
try:
assert len(messages) == expected_message_count
except:
# Reasoning summary in responses API when effort is high is still flaky, so don't throw if missing
if (
LLMConfig.is_openai_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
):
assert len(messages) == expected_message_count - 1
else:
raise
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
# User message if loaded from db
index = 0
@@ -139,26 +130,14 @@ def assert_greeting_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if (
(LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high")
or LLMConfig.is_anthropic_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
):
if is_reasoner_model(llm_config):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning summary in responses API when effort is high is still flaky, so don't throw if missing
if (
LLMConfig.is_openai_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
):
pass
else:
raise
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
assert isinstance(messages[index], AssistantMessage)
@@ -196,15 +175,10 @@ def assert_tool_call_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
expected_message_count = get_expected_message_count(llm_config, tool_call=True, streaming=streaming, from_db=from_db)
try:
assert len(messages) == expected_message_count
except:
# Reasoning summary in responses API when effort is high is still flaky, so don't throw if missing
if LLMConfig.is_openai_reasoning_model(llm_config):
assert len(messages) == expected_message_count - 1
else:
raise
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, tool_call=True, streaming=streaming, from_db=from_db
)
assert expected_message_count_min <= len(messages) <= expected_message_count_max
# User message if loaded from db
index = 0
@@ -216,19 +190,14 @@ def assert_tool_call_response(
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if (
LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high"
) or LLMConfig.is_anthropic_reasoning_model(llm_config):
if is_reasoner_model(llm_config):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning summary in responses API when effort is high is still flaky, so don't throw if missing
if LLMConfig.is_openai_reasoning_model(llm_config):
pass
else:
raise
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
if llm_config.model_endpoint_type == "anthropic":
@@ -248,6 +217,18 @@ def assert_tool_call_response(
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
if is_reasoner_model(llm_config):
assert isinstance(messages[index], ReasoningMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
otid_suffix += 1
except:
# Reasoning is non-deterministic, so don't throw if missing
pass
# Assistant message
assert isinstance(messages[index], AssistantMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
@@ -312,42 +293,41 @@ async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: floa
time.sleep(interval)
def get_expected_message_count(llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False) -> int:
def get_expected_message_count_range(
llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False
) -> Tuple[int, int]:
"""
Returns the expected number of messages for a given LLM configuration.
Returns the expected range of number of messages for a given LLM configuration. Uses range to account for possible variations in the number of reasoning messages.
Greeting:
------------------------------------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| AssistantMessage | AssistantMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | AssistantMessage |
| | | AssistantMessage | | AssistantMessage | |
| AssistantMessage | AssistantMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage |
| | | AssistantMessage | | AssistantMessage | AssistantMessage |
Tool Call:
------------------------------------------------------------------------------------------------------------------------------------------------------------------
| gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
| ToolCallMessage | ToolCallMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ToolCallMessage |
| ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | ToolReturnMessage |
| AssistantMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | AssistantMessage |
| | | AssistantMessage | AssistantMessage | ToolReturnMessage | |
| | | | | AssistantMessage | |
| ToolCallMessage | ToolCallMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage |
| ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | ToolCallMessage |
| AssistantMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolReturnMessage |
| | | ReasoningMessage | AssistantMessage | ToolReturnMessage | ReasoningMessage |
| | | AssistantMessage | | AssistantMessage | AssistantMessage |
"""
is_reasoner_model = (
(LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high")
or LLMConfig.is_anthropic_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
)
# assistant message
expected_message_count = 1
expected_range = 0
if is_reasoner_model:
if is_reasoner_model(llm_config):
# reasoning message
expected_message_count += 1
expected_range += 1
if tool_call and not LLMConfig.is_anthropic_reasoning_model(llm_config):
# reasoning message for additional turn, only for openai and google models
expected_range += 1
if tool_call:
# tool call and tool return messages
@@ -364,7 +344,16 @@ def get_expected_message_count(llm_config: LLMConfig, tool_call: bool = False, s
# stop reason and usage statistics
expected_message_count += 2
return expected_message_count
return expected_message_count, expected_message_count + expected_range
def is_reasoner_model(llm_config: LLMConfig) -> bool:
return (
(LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high")
or LLMConfig.is_anthropic_reasoning_model(llm_config)
or LLMConfig.is_google_vertex_reasoning_model(llm_config)
or LLMConfig.is_google_ai_reasoning_model(llm_config)
)
# ------------------------------