diff --git a/tests/data/Camponotus_flavomarginatus_ant.jpg b/tests/data/Camponotus_flavomarginatus_ant.jpg new file mode 100644 index 00000000..37f5c02c Binary files /dev/null and b/tests/data/Camponotus_flavomarginatus_ant.jpg differ diff --git a/tests/sdk_v1/integration/integration_test_send_message.py b/tests/sdk_v1/integration/integration_test_send_message.py index b9a35d6f..2d023b4d 100644 --- a/tests/sdk_v1/integration/integration_test_send_message.py +++ b/tests/sdk_v1/integration/integration_test_send_message.py @@ -10,7 +10,6 @@ from http.server import BaseHTTPRequestHandler, HTTPServer from typing import Any, Dict, List from unittest.mock import patch -import httpx import pytest import requests from dotenv import load_dotenv @@ -151,9 +150,18 @@ USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreateParam] = [ otid=USER_MESSAGE_OTID, ) ] -BASE64_IMAGE = base64.standard_b64encode( - httpx.get("https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg").content -).decode("utf-8") + + +# Load test image from local file rather than fetching from external URL. +# Using a local file avoids network dependencies and makes tests faster and more reliable. +def _load_test_image() -> str: + """Loads the test image from the data folder and returns it as base64.""" + image_path = os.path.join(os.path.dirname(__file__), "../../data/Camponotus_flavomarginatus_ant.jpg") + with open(image_path, "rb") as f: + return base64.standard_b64encode(f.read()).decode("utf-8") + + +BASE64_IMAGE = _load_test_image() USER_MESSAGE_BASE64_IMAGE: List[MessageCreateParam] = [ MessageCreateParam( role="user", @@ -213,20 +221,10 @@ TESTED_LLM_CONFIGS = [ for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type in ["google_vertex", "google_ai"] and cfg.model.startswith("gemini-1.5")) ] -# Filter out flaky OpenAI gpt-4o-mini models to avoid intermittent failures in streaming tool-call tests -TESTED_LLM_CONFIGS = [ - cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "openai" and cfg.model.startswith("gpt-4o-mini")) -] # Filter out deprecated Claude 3.5 Sonnet model that is no longer available TESTED_LLM_CONFIGS = [ cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "anthropic" and cfg.model == "claude-3-5-sonnet-20241022") ] -# Filter out Bedrock models that require aioboto3 dependency (not available in CI) -TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "bedrock")] -# Filter out Gemini models that have Google Cloud permission issues -TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if cfg.model_endpoint_type not in ["google_vertex", "google_ai"]] -# Filter out qwen2.5:7b model that has server issues -TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model == "qwen2.5:7b")] def assert_first_message_is_user_message(messages: List[Any]) -> None: @@ -1016,8 +1014,13 @@ def test_tool_call( # pytest.skip("Skipping test for flash model due to malformed function call from llm") raise e assert_tool_call_response(response.messages, llm_config=llm_config) + + # Get the run_id from the response to filter messages by this specific run + # This handles cases where retries create multiple runs (e.g., Google Vertex 504 DEADLINE_EXCEEDED) + run_id = response.messages[0].run_id if response.messages else None + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) - messages_from_db = messages_from_db_page.items + messages_from_db = [msg for msg in messages_from_db_page.items if msg.run_id == run_id] if run_id else messages_from_db_page.items assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) @@ -1200,6 +1203,8 @@ def test_step_streaming_tool_call( # Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner: messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"): + messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH else: messages_to_send = USER_MESSAGE_ROLL_DICE response = client.agents.messages.stream( @@ -1727,6 +1732,8 @@ def test_async_tool_call( # Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner: messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"): + messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH else: messages_to_send = USER_MESSAGE_ROLL_DICE run = client.agents.messages.send_async( diff --git a/tests/sdk_v1/integration/integration_test_send_message_v2.py b/tests/sdk_v1/integration/integration_test_send_message_v2.py index 2fc7040d..a863e94e 100644 --- a/tests/sdk_v1/integration/integration_test_send_message_v2.py +++ b/tests/sdk_v1/integration/integration_test_send_message_v2.py @@ -174,6 +174,10 @@ def assert_tool_call_response( msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) ] + # If cancellation happened and no messages were persisted (early cancellation), return early + if with_cancellation and len(messages) == 0: + return + if not with_cancellation: expected_message_count_min, expected_message_count_max = get_expected_message_count_range( llm_config, tool_call=True, streaming=streaming, from_db=from_db @@ -187,6 +191,10 @@ def assert_tool_call_response( assert messages[index].otid == USER_MESSAGE_OTID index += 1 + # If cancellation happened after user message but before any response, return early + if with_cancellation and index >= len(messages): + return + # Reasoning message if reasoning enabled otid_suffix = 0 try: @@ -210,11 +218,16 @@ def assert_tool_call_response( otid_suffix += 1 # Tool call message (may be skipped if cancelled early) - if with_cancellation and isinstance(messages[index], AssistantMessage): + if with_cancellation and index < len(messages) and isinstance(messages[index], AssistantMessage): # If cancelled early, model might respond with text instead of making tool call assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower() return # Skip tool call assertions for early cancellation + # If cancellation happens before tool call, we might get LettaStopReason directly + if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason): + assert messages[index].stop_reason == "cancelled" + return # Skip remaining assertions for very early cancellation + assert isinstance(messages[index], ToolCallMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 @@ -540,7 +553,7 @@ async def test_greeting( messages=USER_MESSAGE_FORCE_REPLY, ) messages = response.messages - run_id = messages[0].run_id + run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) elif send_type == "async": run = await client.agents.messages.send_async( agent_id=agent_state.id, @@ -558,7 +571,12 @@ async def test_greeting( background=(send_type == "stream_tokens_background"), ) messages = await accumulate_chunks(response) - run_id = messages[0].run_id + run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) + + # If run_id is not in messages (e.g., due to early cancellation), get the most recent run + if run_id is None: + runs = await client.runs.list(agent_ids=[agent_state.id]) + run_id = runs[0].id if runs else None assert_greeting_response( messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config @@ -716,7 +734,7 @@ async def test_tool_call( messages=USER_MESSAGE_ROLL_DICE, ) messages = response.messages - run_id = messages[0].run_id + run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) elif send_type == "async": run = await client.agents.messages.send_async( agent_id=agent_state.id, @@ -734,7 +752,12 @@ async def test_tool_call( background=(send_type == "stream_tokens_background"), ) messages = await accumulate_chunks(response) - run_id = messages[0].run_id + run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) + + # If run_id is not in messages (e.g., due to early cancellation), get the most recent run + if run_id is None: + runs = await client.runs.list(agent_ids=[agent_state.id]) + run_id = runs[0].id if runs else None assert_tool_call_response( messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation")