fix: integration tests send message v1 sdk (#5920)
* fix: integration tests send message v1 sdk * early cancellationg * fix image * remove * update * update comment * specific prompt
This commit is contained in:
committed by
Caren Thomas
parent
53d2bd0443
commit
7c731feab3
BIN
tests/data/Camponotus_flavomarginatus_ant.jpg
Normal file
BIN
tests/data/Camponotus_flavomarginatus_ant.jpg
Normal file
Binary file not shown.
|
After Width: | Height: | Size: 741 KiB |
@@ -10,7 +10,6 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
|
||||
from typing import Any, Dict, List
|
||||
from unittest.mock import patch
|
||||
|
||||
import httpx
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
@@ -151,9 +150,18 @@ USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreateParam] = [
|
||||
otid=USER_MESSAGE_OTID,
|
||||
)
|
||||
]
|
||||
BASE64_IMAGE = base64.standard_b64encode(
|
||||
httpx.get("https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg").content
|
||||
).decode("utf-8")
|
||||
|
||||
|
||||
# Load test image from local file rather than fetching from external URL.
|
||||
# Using a local file avoids network dependencies and makes tests faster and more reliable.
|
||||
def _load_test_image() -> str:
|
||||
"""Loads the test image from the data folder and returns it as base64."""
|
||||
image_path = os.path.join(os.path.dirname(__file__), "../../data/Camponotus_flavomarginatus_ant.jpg")
|
||||
with open(image_path, "rb") as f:
|
||||
return base64.standard_b64encode(f.read()).decode("utf-8")
|
||||
|
||||
|
||||
BASE64_IMAGE = _load_test_image()
|
||||
USER_MESSAGE_BASE64_IMAGE: List[MessageCreateParam] = [
|
||||
MessageCreateParam(
|
||||
role="user",
|
||||
@@ -213,20 +221,10 @@ TESTED_LLM_CONFIGS = [
|
||||
for cfg in TESTED_LLM_CONFIGS
|
||||
if not (cfg.model_endpoint_type in ["google_vertex", "google_ai"] and cfg.model.startswith("gemini-1.5"))
|
||||
]
|
||||
# Filter out flaky OpenAI gpt-4o-mini models to avoid intermittent failures in streaming tool-call tests
|
||||
TESTED_LLM_CONFIGS = [
|
||||
cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "openai" and cfg.model.startswith("gpt-4o-mini"))
|
||||
]
|
||||
# Filter out deprecated Claude 3.5 Sonnet model that is no longer available
|
||||
TESTED_LLM_CONFIGS = [
|
||||
cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "anthropic" and cfg.model == "claude-3-5-sonnet-20241022")
|
||||
]
|
||||
# Filter out Bedrock models that require aioboto3 dependency (not available in CI)
|
||||
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "bedrock")]
|
||||
# Filter out Gemini models that have Google Cloud permission issues
|
||||
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if cfg.model_endpoint_type not in ["google_vertex", "google_ai"]]
|
||||
# Filter out qwen2.5:7b model that has server issues
|
||||
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model == "qwen2.5:7b")]
|
||||
|
||||
|
||||
def assert_first_message_is_user_message(messages: List[Any]) -> None:
|
||||
@@ -1016,8 +1014,13 @@ def test_tool_call(
|
||||
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
|
||||
raise e
|
||||
assert_tool_call_response(response.messages, llm_config=llm_config)
|
||||
|
||||
# Get the run_id from the response to filter messages by this specific run
|
||||
# This handles cases where retries create multiple runs (e.g., Google Vertex 504 DEADLINE_EXCEEDED)
|
||||
run_id = response.messages[0].run_id if response.messages else None
|
||||
|
||||
messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
|
||||
messages_from_db = messages_from_db_page.items
|
||||
messages_from_db = [msg for msg in messages_from_db_page.items if msg.run_id == run_id] if run_id else messages_from_db_page.items
|
||||
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
||||
|
||||
|
||||
@@ -1200,6 +1203,8 @@ def test_step_streaming_tool_call(
|
||||
# Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step
|
||||
if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner:
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING
|
||||
elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"):
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH
|
||||
else:
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE
|
||||
response = client.agents.messages.stream(
|
||||
@@ -1727,6 +1732,8 @@ def test_async_tool_call(
|
||||
# Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step
|
||||
if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner:
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING
|
||||
elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"):
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH
|
||||
else:
|
||||
messages_to_send = USER_MESSAGE_ROLL_DICE
|
||||
run = client.agents.messages.send_async(
|
||||
|
||||
@@ -174,6 +174,10 @@ def assert_tool_call_response(
|
||||
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
||||
]
|
||||
|
||||
# If cancellation happened and no messages were persisted (early cancellation), return early
|
||||
if with_cancellation and len(messages) == 0:
|
||||
return
|
||||
|
||||
if not with_cancellation:
|
||||
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
|
||||
llm_config, tool_call=True, streaming=streaming, from_db=from_db
|
||||
@@ -187,6 +191,10 @@ def assert_tool_call_response(
|
||||
assert messages[index].otid == USER_MESSAGE_OTID
|
||||
index += 1
|
||||
|
||||
# If cancellation happened after user message but before any response, return early
|
||||
if with_cancellation and index >= len(messages):
|
||||
return
|
||||
|
||||
# Reasoning message if reasoning enabled
|
||||
otid_suffix = 0
|
||||
try:
|
||||
@@ -210,11 +218,16 @@ def assert_tool_call_response(
|
||||
otid_suffix += 1
|
||||
|
||||
# Tool call message (may be skipped if cancelled early)
|
||||
if with_cancellation and isinstance(messages[index], AssistantMessage):
|
||||
if with_cancellation and index < len(messages) and isinstance(messages[index], AssistantMessage):
|
||||
# If cancelled early, model might respond with text instead of making tool call
|
||||
assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower()
|
||||
return # Skip tool call assertions for early cancellation
|
||||
|
||||
# If cancellation happens before tool call, we might get LettaStopReason directly
|
||||
if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason):
|
||||
assert messages[index].stop_reason == "cancelled"
|
||||
return # Skip remaining assertions for very early cancellation
|
||||
|
||||
assert isinstance(messages[index], ToolCallMessage)
|
||||
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
||||
index += 1
|
||||
@@ -540,7 +553,7 @@ async def test_greeting(
|
||||
messages=USER_MESSAGE_FORCE_REPLY,
|
||||
)
|
||||
messages = response.messages
|
||||
run_id = messages[0].run_id
|
||||
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
|
||||
elif send_type == "async":
|
||||
run = await client.agents.messages.send_async(
|
||||
agent_id=agent_state.id,
|
||||
@@ -558,7 +571,12 @@ async def test_greeting(
|
||||
background=(send_type == "stream_tokens_background"),
|
||||
)
|
||||
messages = await accumulate_chunks(response)
|
||||
run_id = messages[0].run_id
|
||||
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
|
||||
|
||||
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
|
||||
if run_id is None:
|
||||
runs = await client.runs.list(agent_ids=[agent_state.id])
|
||||
run_id = runs[0].id if runs else None
|
||||
|
||||
assert_greeting_response(
|
||||
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
|
||||
@@ -716,7 +734,7 @@ async def test_tool_call(
|
||||
messages=USER_MESSAGE_ROLL_DICE,
|
||||
)
|
||||
messages = response.messages
|
||||
run_id = messages[0].run_id
|
||||
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
|
||||
elif send_type == "async":
|
||||
run = await client.agents.messages.send_async(
|
||||
agent_id=agent_state.id,
|
||||
@@ -734,7 +752,12 @@ async def test_tool_call(
|
||||
background=(send_type == "stream_tokens_background"),
|
||||
)
|
||||
messages = await accumulate_chunks(response)
|
||||
run_id = messages[0].run_id
|
||||
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
|
||||
|
||||
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
|
||||
if run_id is None:
|
||||
runs = await client.runs.list(agent_ids=[agent_state.id])
|
||||
run_id = runs[0].id if runs else None
|
||||
|
||||
assert_tool_call_response(
|
||||
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")
|
||||
|
||||
Reference in New Issue
Block a user