fix: integration tests send message v1 sdk (#5920)

* fix: integration tests send message v1 sdk

* early cancellationg

* fix image

* remove

* update

* update comment

* specific prompt
This commit is contained in:
Christina Tong
2025-11-04 14:05:51 -08:00
committed by Caren Thomas
parent 53d2bd0443
commit 7c731feab3
3 changed files with 50 additions and 20 deletions

Binary file not shown.

After

Width:  |  Height:  |  Size: 741 KiB

View File

@@ -10,7 +10,6 @@ from http.server import BaseHTTPRequestHandler, HTTPServer
from typing import Any, Dict, List
from unittest.mock import patch
import httpx
import pytest
import requests
from dotenv import load_dotenv
@@ -151,9 +150,18 @@ USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreateParam] = [
otid=USER_MESSAGE_OTID,
)
]
BASE64_IMAGE = base64.standard_b64encode(
httpx.get("https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg").content
).decode("utf-8")
# Load test image from local file rather than fetching from external URL.
# Using a local file avoids network dependencies and makes tests faster and more reliable.
def _load_test_image() -> str:
"""Loads the test image from the data folder and returns it as base64."""
image_path = os.path.join(os.path.dirname(__file__), "../../data/Camponotus_flavomarginatus_ant.jpg")
with open(image_path, "rb") as f:
return base64.standard_b64encode(f.read()).decode("utf-8")
BASE64_IMAGE = _load_test_image()
USER_MESSAGE_BASE64_IMAGE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
@@ -213,20 +221,10 @@ TESTED_LLM_CONFIGS = [
for cfg in TESTED_LLM_CONFIGS
if not (cfg.model_endpoint_type in ["google_vertex", "google_ai"] and cfg.model.startswith("gemini-1.5"))
]
# Filter out flaky OpenAI gpt-4o-mini models to avoid intermittent failures in streaming tool-call tests
TESTED_LLM_CONFIGS = [
cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "openai" and cfg.model.startswith("gpt-4o-mini"))
]
# Filter out deprecated Claude 3.5 Sonnet model that is no longer available
TESTED_LLM_CONFIGS = [
cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "anthropic" and cfg.model == "claude-3-5-sonnet-20241022")
]
# Filter out Bedrock models that require aioboto3 dependency (not available in CI)
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "bedrock")]
# Filter out Gemini models that have Google Cloud permission issues
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if cfg.model_endpoint_type not in ["google_vertex", "google_ai"]]
# Filter out qwen2.5:7b model that has server issues
TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model == "qwen2.5:7b")]
def assert_first_message_is_user_message(messages: List[Any]) -> None:
@@ -1016,8 +1014,13 @@ def test_tool_call(
# pytest.skip("Skipping test for flash model due to malformed function call from llm")
raise e
assert_tool_call_response(response.messages, llm_config=llm_config)
# Get the run_id from the response to filter messages by this specific run
# This handles cases where retries create multiple runs (e.g., Google Vertex 504 DEADLINE_EXCEEDED)
run_id = response.messages[0].run_id if response.messages else None
messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None)
messages_from_db = messages_from_db_page.items
messages_from_db = [msg for msg in messages_from_db_page.items if msg.run_id == run_id] if run_id else messages_from_db_page.items
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
@@ -1200,6 +1203,8 @@ def test_step_streaming_tool_call(
# Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step
if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner:
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING
elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"):
messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH
else:
messages_to_send = USER_MESSAGE_ROLL_DICE
response = client.agents.messages.stream(
@@ -1727,6 +1732,8 @@ def test_async_tool_call(
# Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step
if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner:
messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING
elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"):
messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH
else:
messages_to_send = USER_MESSAGE_ROLL_DICE
run = client.agents.messages.send_async(

View File

@@ -174,6 +174,10 @@ def assert_tool_call_response(
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
]
# If cancellation happened and no messages were persisted (early cancellation), return early
if with_cancellation and len(messages) == 0:
return
if not with_cancellation:
expected_message_count_min, expected_message_count_max = get_expected_message_count_range(
llm_config, tool_call=True, streaming=streaming, from_db=from_db
@@ -187,6 +191,10 @@ def assert_tool_call_response(
assert messages[index].otid == USER_MESSAGE_OTID
index += 1
# If cancellation happened after user message but before any response, return early
if with_cancellation and index >= len(messages):
return
# Reasoning message if reasoning enabled
otid_suffix = 0
try:
@@ -210,11 +218,16 @@ def assert_tool_call_response(
otid_suffix += 1
# Tool call message (may be skipped if cancelled early)
if with_cancellation and isinstance(messages[index], AssistantMessage):
if with_cancellation and index < len(messages) and isinstance(messages[index], AssistantMessage):
# If cancelled early, model might respond with text instead of making tool call
assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower()
return # Skip tool call assertions for early cancellation
# If cancellation happens before tool call, we might get LettaStopReason directly
if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason):
assert messages[index].stop_reason == "cancelled"
return # Skip remaining assertions for very early cancellation
assert isinstance(messages[index], ToolCallMessage)
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
index += 1
@@ -540,7 +553,7 @@ async def test_greeting(
messages=USER_MESSAGE_FORCE_REPLY,
)
messages = response.messages
run_id = messages[0].run_id
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.send_async(
agent_id=agent_state.id,
@@ -558,7 +571,12 @@ async def test_greeting(
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = messages[0].run_id
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs[0].id if runs else None
assert_greeting_response(
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
@@ -716,7 +734,7 @@ async def test_tool_call(
messages=USER_MESSAGE_ROLL_DICE,
)
messages = response.messages
run_id = messages[0].run_id
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
elif send_type == "async":
run = await client.agents.messages.send_async(
agent_id=agent_state.id,
@@ -734,7 +752,12 @@ async def test_tool_call(
background=(send_type == "stream_tokens_background"),
)
messages = await accumulate_chunks(response)
run_id = messages[0].run_id
run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None)
# If run_id is not in messages (e.g., due to early cancellation), get the most recent run
if run_id is None:
runs = await client.runs.list(agent_ids=[agent_state.id])
run_id = runs[0].id if runs else None
assert_tool_call_response(
messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")