From 599adb4c26c107c3d752eeca21724f04552a31be Mon Sep 17 00:00:00 2001 From: Christina Tong Date: Thu, 30 Oct 2025 16:32:26 -0700 Subject: [PATCH] chore: migrate integration test send message to v1 sdk [LET-5940] (#5794) * chore: migrate integration test send message to v1 sdk * new folder * set up new workflows for integration test * remove file * update stainless workflow * fix import err * add letta-client version logging * fix: SDK cache miss should fall back to PyPI instead of failing When the Stainless SDK cache is not available, the workflow should fall back to installing the published SDK from PyPI rather than failing the CI build. The workflow already has this fallback logic in the "Install Stainless SDK" step, but the "Check SDK cache" step was failing before it could reach that point. This change converts the hard failure (exit 1) to a warning message, allowing the workflow to continue and use the PyPI fallback. Co-Authored-By: Claude * force upgrade * remove frozen * install before running * add no sync * use upgrade instead of upgrade package * update * fix llm config * fix * update * update path * update workflow * see installed version * add fallback * update * fix mini * lettaping * fix: handle o1 token streaming and LettaPing step_id validation - Skip LettaPing messages in step_id validation (they don't have step_id) - Move o1/o3/o4 token streaming check before general assertion in assert_tool_call_response - o1 reasoning models omit AssistantMessage in token streaming mode (6 messages instead of 7) --------- Co-authored-by: Kian Jones Co-authored-by: Claude --- .../integration_test_send_message.py | 2340 +++++++++++++++++ .../integration_test_send_message_v2.py | 761 ++++++ 2 files changed, 3101 insertions(+) create mode 100644 tests/sdk_v1/integration/integration_test_send_message.py create mode 100644 tests/sdk_v1/integration/integration_test_send_message_v2.py diff --git a/tests/sdk_v1/integration/integration_test_send_message.py b/tests/sdk_v1/integration/integration_test_send_message.py new file mode 100644 index 00000000..60d30818 --- /dev/null +++ b/tests/sdk_v1/integration/integration_test_send_message.py @@ -0,0 +1,2340 @@ +import base64 +import json +import os +import threading +import time +import uuid +from contextlib import contextmanager +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, List +from unittest.mock import patch + +import httpx +import pytest +import requests +from dotenv import load_dotenv +from letta_client import APIError, AsyncLetta, Letta +from letta_client.types import ToolReturnMessage +from letta_client.types.agents import ( + AssistantMessage, + HiddenReasoningMessage, + LettaMessageUnion, + ReasoningMessage, + Run, + ToolCallMessage, + UserMessage, +) +from letta_client.types.agents.image_content_param import ImageContentParam, SourceBase64Image, SourceURLImage +from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics +from letta_client.types.agents.text_content_param import TextContentParam + +from letta.errors import LLMError +from letta.helpers.reasoning_helper import is_reasoning_completely_disabled +from letta.llm_api.openai_client import is_openai_reasoning_model +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.letta_message_content import Base64Image, ImageContent, TextContent, UrlImage +from letta.schemas.letta_request import LettaRequest +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate + +logger = get_logger(__name__) + +# ------------------------------ +# Helper Functions and Constants +# ------------------------------ + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + llm_config = LLMConfig(**config_data) + return llm_config + + +def roll_dice(num_sides: int) -> int: + """ + Returns a random number between 1 and num_sides. + Args: + num_sides (int): The number of sides on the die. + Returns: + int: A random integer between 1 and num_sides, representing the die roll. + """ + import random + + return random.randint(1, num_sides) + + +USER_MESSAGE_OTID = str(uuid.uuid4()) +USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work" +USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ + MessageCreate( + role="user", + content=f"This is an automated test message. Call the send_message tool with the message '{USER_MESSAGE_RESPONSE}'.", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_LONG_RESPONSE: str = ( + "Teamwork makes the dream work. When people collaborate and combine their unique skills, perspectives, and experiences, they can achieve far more than any individual working alone. " + "This synergy creates an environment where innovation flourishes, problems are solved more creatively, and goals are reached more efficiently. " + "In a team setting, diverse viewpoints lead to better decision-making as different team members bring their unique backgrounds and expertise to the table. " + "Communication becomes the backbone of success, allowing ideas to flow freely and ensuring everyone is aligned toward common objectives. " + "Trust builds gradually as team members learn to rely on each other's strengths while supporting one another through challenges. " + "The collective intelligence of a group often surpasses that of even the brightest individual, as collaboration sparks creativity and innovation. " + "Successful teams celebrate victories together and learn from failures as a unit, creating a culture of continuous improvement. " + "Together, we can overcome challenges that would be insurmountable alone, achieving extraordinary results through the power of collaboration." +) +USER_MESSAGE_FORCE_LONG_REPLY: List[MessageCreate] = [ + MessageCreate( + role="user", + content=f"This is an automated test message. Call the send_message tool with exactly this message: '{USER_MESSAGE_LONG_RESPONSE}'", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_GREETING: List[MessageCreate] = [ + MessageCreate( + role="user", + content="Hi!", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the roll_dice tool with 16 sides and send me a message with the outcome.", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE_LONG: List[MessageCreate] = [ + MessageCreate( + role="user", + content=( + "This is an automated test message. Call the roll_dice tool with 16 sides and send me a very detailed, comprehensive message about the outcome. " + "Your response must be at least 800 characters long. Start by explaining what dice rolling represents in games and probability theory. " + "Discuss the mathematical probability of getting each number on a 16-sided die (1/16 or 6.25% for each face). " + "Explain how 16-sided dice are commonly used in tabletop role-playing games like Dungeons & Dragons. " + "Describe the specific number you rolled and what it might mean in different gaming contexts. " + "Discuss how this particular roll compares to the expected value (8.5) of a 16-sided die. " + "Explain the concept of randomness and how true random number generation works. " + "End with some interesting facts about polyhedral dice and their history in gaming. " + "Remember, make your response detailed and at least 800 characters long." + ), + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE_GEMINI_FLASH: List[MessageCreate] = [ + MessageCreate( + role="user", + content=( + 'This is an automated test message. First, call the roll_dice tool with exactly this JSON: {"num_sides": 16, "request_heartbeat": true}. ' + "After you receive the tool result, as your final step, call the send_message tool with your user-facing reply in the 'message' argument. " + "Important: Do not output plain text for the final step; respond using a functionCall to send_message only. Use valid JSON for all function arguments." + ), + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE_LONG_THINKING: List[MessageCreate] = [ + MessageCreate( + role="user", + content=( + "This is an automated test message. First, think long and hard about about why you're here, and your creator. " + "Then, call the roll_dice tool with 16 sides. " + "Once you've rolled the die, think deeply about the meaning of the roll to you (but don't tell me, just think these thoughts privately). " + "Then, once you're done thinking, send me a very detailed, comprehensive message about the outcome, using send_message. " + "Your response must be at least 800 characters long. Start by explaining what dice rolling represents in games and probability theory. " + "Discuss the mathematical probability of getting each number on a 16-sided die (1/16 or 6.25% for each face). " + "Explain how 16-sided dice are commonly used in tabletop role-playing games like Dungeons & Dragons. " + "Describe the specific number you rolled and what it might mean in different gaming contexts. " + "Discuss how this particular roll compares to the expected value (8.5) of a 16-sided die. " + "Explain the concept of randomness and how true random number generation works. " + "End with some interesting facts about polyhedral dice and their history in gaming. " + "Remember, make your response detailed and at least 800 characters long." + "Absolutely do NOT violate this order of operations: (1) Think / reason, (2) Roll die, (3) Think / reason, (4) Call send_message tool." + ), + otid=USER_MESSAGE_OTID, + ) +] +BASE64_IMAGE = base64.standard_b64encode( + httpx.get("https://upload.wikimedia.org/wikipedia/commons/a/a7/Camponotus_flavomarginatus_ant.jpg").content +).decode("utf-8") +USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ + MessageCreate( + role="user", + content=[ + ImageContentParam(type="image", source=SourceBase64Image(type="base64", data=BASE64_IMAGE, media_type="image/jpeg")), + TextContentParam(type="text", text="What is in this image?"), + ], + otid=USER_MESSAGE_OTID, + ) +] + +# configs for models that are to dumb to do much other than messaging +limited_configs = [ + "ollama.json", + "together-qwen-2.5-72b-instruct.json", + "vllm.json", + "lmstudio.json", + "groq.json", + # treat deprecated models as limited to skip where generic checks are used + "gemini-1.5-pro.json", +] + +all_configs = [ + "openai-gpt-4.1.json", + "openai-o1.json", + "openai-o3.json", + "openai-o4-mini.json", + "azure-gpt-4o-mini.json", + "claude-4-sonnet-extended.json", + "claude-4-sonnet.json", + "claude-3-5-sonnet.json", + "claude-3-7-sonnet-extended.json", + "claude-3-7-sonnet.json", + "bedrock-claude-4-sonnet.json", + # NOTE: gemini-1.5-pro is deprecated / unsupported on v1beta generateContent, skip in CI + # "gemini-1.5-pro.json", + "gemini-2.5-flash-vertex.json", + "gemini-2.5-pro-vertex.json", + "ollama.json", + "together-qwen-2.5-72b-instruct.json", + "groq.json", +] + +reasoning_configs = [ + "openai-o1.json", + "openai-o3.json", + "openai-o4-mini.json", +] + + +requested = os.getenv("LLM_CONFIG_FILE") +filenames = [requested] if requested else all_configs +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] +# Filter out deprecated Gemini 1.5 models regardless of filename source +TESTED_LLM_CONFIGS = [ + cfg + for cfg in TESTED_LLM_CONFIGS + if not (cfg.model_endpoint_type in ["google_vertex", "google_ai"] and cfg.model.startswith("gemini-1.5")) +] +# Filter out flaky OpenAI gpt-4o-mini models to avoid intermittent failures in streaming tool-call tests +TESTED_LLM_CONFIGS = [ + cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "openai" and cfg.model.startswith("gpt-4o-mini")) +] +# Filter out deprecated Claude 3.5 Sonnet model that is no longer available +TESTED_LLM_CONFIGS = [ + cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "anthropic" and cfg.model == "claude-3-5-sonnet-20241022") +] +# Filter out Bedrock models that require aioboto3 dependency (not available in CI) +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "bedrock")] +# Filter out Gemini models that have Google Cloud permission issues +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if cfg.model_endpoint_type not in ["google_vertex", "google_ai"]] +# Filter out qwen2.5:7b model that has server issues +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model == "qwen2.5:7b")] + + +def assert_first_message_is_user_message(messages: List[Any]) -> None: + """ + Asserts that the first message is a user message. + """ + assert isinstance(messages[0], UserMessage) + + +def assert_greeting_with_assistant_message_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + # For o1 models in token streaming, AssistantMessage is not included in the stream + o1_token_streaming = is_openai_reasoning_model(llm_config.model) and streaming and token_streaming + expected_message_count = 3 if o1_token_streaming else (4 if streaming else 3 if from_db else 2) + assert len(messages) == expected_message_count + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + if is_openai_reasoning_model(llm_config.model): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + # Agent Step 2: AssistantMessage (skip for o1 token streaming) + if not o1_token_streaming: + assert isinstance(messages[index], AssistantMessage) + if not token_streaming: + # Check for either short or long response + assert "teamwork" in messages[index].content.lower() or USER_MESSAGE_LONG_RESPONSE in messages[index].content + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def assert_contains_run_id(messages: List[Any]) -> None: + """ + Asserts that the messages list contains a run_id. + """ + for message in messages: + if hasattr(message, "run_id"): + assert message.run_id is not None + + +def assert_contains_step_id(messages: List[Any]) -> None: + """ + Asserts that the messages list contains a step_id. + """ + for message in messages: + # Skip LettaPing messages which are keep-alive and don't have step_id + if isinstance(message, LettaPing): + continue + if hasattr(message, "step_id"): + assert message.step_id is not None + + +def assert_greeting_no_reasoning_response( + messages: List[Any], + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence without reasoning: + AssistantMessage (no ReasoningMessage when put_inner_thoughts_in_kwargs is False). + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + expected_message_count = 3 if streaming else 2 if from_db else 1 + assert len(messages) == expected_message_count + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 - should be AssistantMessage directly, no reasoning + assert isinstance(messages[index], AssistantMessage) + if not token_streaming: + assert "teamwork" in messages[index].content.lower() + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def assert_greeting_without_assistant_message_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> ToolCallMessage -> ToolReturnMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + expected_message_count = 5 if streaming else 4 if from_db else 3 + assert len(messages) == expected_message_count + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + if is_openai_reasoning_model(llm_config.model): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].tool_call.name == "send_message" + if not token_streaming: + assert "teamwork" in messages[index].tool_call.arguments.lower() + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + # Agent Step 2 + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def assert_tool_call_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + expected_message_count = 7 if streaming or from_db else 5 + + # Special-case relaxation for Gemini 2.5 Flash on Google endpoints during streaming + # Flash can legitimately end after the tool return without issuing a final send_message call. + # Accept the shorter sequence: Reasoning -> ToolCall -> ToolReturn -> StopReason(no_tool_call) + is_gemini_flash = llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash") + if streaming and is_gemini_flash: + if ( + len(messages) >= 4 + and getattr(messages[-1], "message_type", None) == "stop_reason" + and getattr(messages[-1], "stop_reason", None) == "no_tool_call" + and getattr(messages[0], "message_type", None) == "reasoning_message" + and getattr(messages[1], "message_type", None) == "tool_call_message" + and getattr(messages[2], "message_type", None) == "tool_return_message" + ): + return + + # OpenAI o1/o3/o4 reasoning models omit the final AssistantMessage in token streaming, + # yielding the shorter sequence: + # HiddenReasoning -> ToolCall -> ToolReturn -> HiddenReasoning -> StopReason -> Usage + o1_token_streaming = ( + streaming + and is_openai_reasoning_model(llm_config.model) + and len(messages) == 6 + and getattr(messages[0], "message_type", None) == "hidden_reasoning_message" + and getattr(messages[1], "message_type", None) == "tool_call_message" + and getattr(messages[2], "message_type", None) == "tool_return_message" + and getattr(messages[3], "message_type", None) == "hidden_reasoning_message" + and getattr(messages[4], "message_type", None) == "stop_reason" + and getattr(messages[5], "message_type", None) == "usage_statistics" + ) + if o1_token_streaming: + return + + try: + assert len(messages) == expected_message_count, messages + except: + if "claude-3-7-sonnet" not in llm_config.model: + raise + assert len(messages) == expected_message_count - 1, messages + + # OpenAI gpt-4o-mini can sometimes omit the final AssistantMessage in streaming, + # yielding the shorter sequence: + # Reasoning -> ToolCall -> ToolReturn -> Reasoning -> StopReason -> Usage + # Accept this variant to reduce flakiness. + if ( + streaming + and llm_config.model_endpoint_type == "openai" + and "gpt-4o-mini" in llm_config.model + and len(messages) == 6 + and getattr(messages[0], "message_type", None) == "reasoning_message" + and getattr(messages[1], "message_type", None) == "tool_call_message" + and getattr(messages[2], "message_type", None) == "tool_return_message" + and getattr(messages[3], "message_type", None) == "reasoning_message" + and getattr(messages[4], "message_type", None) == "stop_reason" + and getattr(messages[5], "message_type", None) == "usage_statistics" + ): + return + + # OpenAI o3 can sometimes stop after tool return without generating final reasoning/assistant messages + # Accept the shorter sequence: HiddenReasoning -> ToolCall -> ToolReturn + if ( + llm_config.model_endpoint_type == "openai" + and "o3" in llm_config.model + and len(messages) == 3 + and getattr(messages[0], "message_type", None) == "hidden_reasoning_message" + and getattr(messages[1], "message_type", None) == "tool_call_message" + and getattr(messages[2], "message_type", None) == "tool_return_message" + ): + return + + # Groq models can sometimes stop after tool return without generating final reasoning/assistant messages + # Accept the shorter sequence: Reasoning -> ToolCall -> ToolReturn + if ( + llm_config.model_endpoint_type == "groq" + and len(messages) == 3 + and getattr(messages[0], "message_type", None) == "reasoning_message" + and getattr(messages[1], "message_type", None) == "tool_call_message" + and getattr(messages[2], "message_type", None) == "tool_return_message" + ): + return + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + if is_openai_reasoning_model(llm_config.model): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + # Agent Step 2 + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + # Hidden User Message + if from_db: + assert isinstance(messages[index], UserMessage) + assert "request_heartbeat=true" in messages[index].content + index += 1 + + # Agent Step 3 + try: + if is_openai_reasoning_model(llm_config.model): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + except: + if "claude-3-7-sonnet" not in llm_config.model: + raise + pass + + assert isinstance(messages[index], AssistantMessage) + try: + assert messages[index].otid and messages[index].otid[-1] == "1" + except: + if "claude-3-7-sonnet" not in llm_config.model: + raise + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def validate_openai_format_scrubbing(messages: List[Dict[str, Any]]) -> None: + """ + Validate that OpenAI format assistant messages with tool calls have no inner thoughts content. + Args: + messages: List of message dictionaries in OpenAI format + """ + assistant_messages_with_tools = [] + + for msg in messages: + if msg.get("role") == "assistant" and msg.get("tool_calls"): + assistant_messages_with_tools.append(msg) + + # There should be at least one assistant message with tool calls + assert len(assistant_messages_with_tools) > 0, "Expected at least one OpenAI assistant message with tool calls" + + # Check that assistant messages with tool calls have no text content (inner thoughts scrubbed) + for msg in assistant_messages_with_tools: + if "content" in msg: + content = msg["content"] + assert content is None + + +def validate_anthropic_format_scrubbing(messages: List[Dict[str, Any]], reasoning_enabled: bool) -> None: + """ + Validate that Anthropic/Claude format assistant messages with tool_use have no tags. + Args: + messages: List of message dictionaries in Anthropic format + """ + claude_assistant_messages_with_tools = [] + + for msg in messages: + if ( + msg.get("role") == "assistant" + and isinstance(msg.get("content"), list) + and any(item.get("type") == "tool_use" for item in msg.get("content", [])) + ): + claude_assistant_messages_with_tools.append(msg) + + # There should be at least one Claude assistant message with tool_use + assert len(claude_assistant_messages_with_tools) > 0, "Expected at least one Claude assistant message with tool_use" + + # Check Claude format messages specifically + for msg in claude_assistant_messages_with_tools: + content_list = msg["content"] + + # Strict validation: assistant messages with tool_use should have NO text content items at all + text_items = [item for item in content_list if item.get("type") == "text"] + assert len(text_items) == 0, ( + f"Found {len(text_items)} text content item(s) in Claude assistant message with tool_use. " + f"When reasoning is disabled, there should be NO text items. " + f"Text items found: {[item.get('text', '') for item in text_items]}" + ) + + # Verify that the message only contains tool_use items + tool_use_items = [item for item in content_list if item.get("type") == "tool_use"] + assert len(tool_use_items) > 0, "Assistant message should have at least one tool_use item" + + if not reasoning_enabled: + assert len(content_list) == len(tool_use_items), ( + f"Assistant message should ONLY contain tool_use items when reasoning is disabled. " + f"Found {len(content_list)} total items but only {len(tool_use_items)} are tool_use items." + ) + + +def validate_google_format_scrubbing(contents: List[Dict[str, Any]]) -> None: + """ + Validate that Google/Gemini format model messages with functionCall have no thinking field. + Args: + contents: List of content dictionaries in Google format (uses 'contents' instead of 'messages') + """ + model_messages_with_function_calls = [] + + for content in contents: + if content.get("role") == "model" and isinstance(content.get("parts"), list): + for part in content["parts"]: + if "functionCall" in part: + model_messages_with_function_calls.append(part) + + # There should be at least one model message with functionCall + assert len(model_messages_with_function_calls) > 0, "Expected at least one Google model message with functionCall" + + # Check Google format messages specifically + for part in model_messages_with_function_calls: + function_call = part["functionCall"] + args = function_call.get("args", {}) + + # Assert that there is no 'thinking' field in the function call arguments + assert "thinking" not in args, ( + f"Found 'thinking' field in Google model functionCall args (inner thoughts not scrubbed): {args.get('thinking')}" + ) + + +def assert_image_input_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + # For o1 models in token streaming, AssistantMessage is not included in the stream + o1_token_streaming = is_openai_reasoning_model(llm_config.model) and streaming and token_streaming + expected_message_count = 3 if o1_token_streaming else (4 if streaming else 3 if from_db else 2) + assert len(messages) == expected_message_count + + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Agent Step 1 + if is_openai_reasoning_model(llm_config.model): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == "0" + index += 1 + + # Agent Step 2: AssistantMessage (skip for o1 token streaming) + if not o1_token_streaming: + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == "1" + index += 1 + + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]: + """ + Accumulates chunks into a list of messages. + Handles both message objects and raw SSE strings. + """ + messages = [] + current_message = None + prev_message_type = None + chunk_count = 0 + + # Check if chunks are raw SSE strings (from background streaming) + if chunks and isinstance(chunks[0], str): + import json + + # Join all string chunks and parse as SSE + sse_data = "".join(chunks) + for line in sse_data.strip().split("\n"): + if line.startswith("data: ") and line != "data: [DONE]": + try: + data = json.loads(line[6:]) # Remove 'data: ' prefix + if "message_type" in data: + message_type = data.get("message_type") + if message_type == "assistant_message": + chunk = AssistantMessage(**data) + elif message_type == "reasoning_message": + chunk = ReasoningMessage(**data) + elif message_type == "hidden_reasoning_message": + chunk = HiddenReasoningMessage(**data) + elif message_type == "tool_call_message": + chunk = ToolCallMessage(**data) + elif message_type == "tool_return_message": + chunk = ToolReturnMessage(**data) + elif message_type == "user_message": + chunk = UserMessage(**data) + elif message_type == "stop_reason": + chunk = LettaStopReason(**data) + elif message_type == "usage_statistics": + chunk = LettaUsageStatistics(**data) + else: + continue # Skip unknown types + + current_message_type = chunk.message_type + if prev_message_type != current_message_type: + if current_message is not None: + messages.append(current_message) + current_message = chunk + chunk_count = 1 + else: + # Accumulate content for same message type + if hasattr(current_message, "content") and hasattr(chunk, "content"): + current_message.content += chunk.content + chunk_count += 1 + prev_message_type = current_message_type + except json.JSONDecodeError: + continue + + if current_message is not None: + messages.append(current_message) + else: + # Handle message objects + for chunk in chunks: + current_message_type = chunk.message_type + if prev_message_type != current_message_type: + messages.append(current_message) + if ( + prev_message_type + and verify_token_streaming + and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + ): + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}. Messages: {messages}" + current_message = None + chunk_count = 0 + if current_message is None: + current_message = chunk + else: + pass # TODO: actually accumulate the chunks. For now we only care about the count + prev_message_type = current_message_type + chunk_count += 1 + messages.append(current_message) + if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]: + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}" + + return [m for m in messages if m is not None] + + +def cast_message_dict_to_messages(messages: List[Dict[str, Any]]) -> List[LettaMessageUnion]: + def cast_message(message: Dict[str, Any]) -> LettaMessageUnion: + if message["message_type"] == "reasoning_message": + return ReasoningMessage(**message) + elif message["message_type"] == "assistant_message": + return AssistantMessage(**message) + elif message["message_type"] == "tool_call_message": + return ToolCallMessage(**message) + elif message["message_type"] == "tool_return_message": + return ToolReturnMessage(**message) + elif message["message_type"] == "user_message": + return UserMessage(**message) + elif message["message_type"] == "hidden_reasoning_message": + return HiddenReasoningMessage(**message) + else: + raise ValueError(f"Unknown message type: {message['message_type']}") + + return [cast_message(message) for message in messages] + + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="module") +def client(server_url: str) -> Letta: + """ + Creates and returns a synchronous Letta REST client for testing. + """ + client_instance = Letta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +def async_client(server_url: str) -> AsyncLetta: + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + async_client_instance = AsyncLetta(base_url=server_url) + yield async_client_instance + + +@pytest.fixture(scope="function") +def agent_state(client: Letta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. + """ + client.tools.upsert_base_tools() + dice_tool = client.tools.upsert_from_function(func=roll_dice) + + send_message_tool = client.tools.list(name="send_message")[0] + agent_state_instance = client.agents.create( + name="supervisor", + include_base_tools=False, + tool_ids=[send_message_tool.id, dice_tool.id], + model="openai/gpt-4o", + embedding="letta/letta-free", + tags=["supervisor"], + ) + yield agent_state_instance + + # try: + # client.agents.delete(agent_state_instance.id) + # except Exception as e: + # logger.error(f"Failed to delete agent {agent_state_instance.name}: {str(e)}") + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + # Skip deprecated Gemini 1.5 models which are no longer supported on generateContent + if llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-1.5"): + pytest.skip(f"Skipping deprecated model {llm_config.model}") + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + assert_contains_run_id(response.messages) + assert_greeting_with_assistant_message_response(response.messages, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_first_message_is_user_message(messages_from_db) + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_greeting_without_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + # Skip deprecated Gemini 1.5 models which are no longer supported on generateContent + if llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-1.5"): + pytest.skip(f"Skipping deprecated model {llm_config.model}") + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + use_assistant_message=False, + ) + assert_greeting_without_assistant_message_response(response.messages, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list( + agent_id=agent_state.id, after=last_message.id if last_message else None, use_assistant_message=False + ) + messages_from_db = messages_from_db_page.items + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + # Skip deprecated Gemini 1.5 models which are no longer supported on generateContent + if llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-1.5"): + pytest.skip(f"Skipping deprecated model {llm_config.model}") + # Skip qwen and o4-mini models due to OTID chain issue and incomplete response (stops after tool return) + if "qwen" in llm_config.model.lower() or llm_config.model == "o4-mini": + pytest.skip(f"Skipping {llm_config.model} due to OTID chain issue and incomplete agent response") + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step + if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner: + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"): + messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH + else: + messages_to_send = USER_MESSAGE_ROLL_DICE + try: + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=messages_to_send, + ) + except Exception as e: + # if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): + # pytest.skip("Skipping test for flash model due to malformed function call from llm") + raise e + assert_tool_call_response(response.messages, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + [ + ( + pytest.param(config, marks=pytest.mark.xfail(reason="Qwen image processing unstable - needs investigation")) + if config.model == "Qwen/Qwen2.5-72B-Instruct-Turbo" + else config + ) + for config in TESTED_LLM_CONFIGS + ], + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_base64_image_input( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that the response messages follow the expected order. + """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_BASE64_IMAGE, + ) + assert_image_input_response(response.messages, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_image_input_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_agent_loop_error( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that no new messages are persisted on error. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + with patch("letta.agents.letta_agent_v2.LettaAgentV2.step") as mock_step: + mock_step.side_effect = LLMError("No tool calls found in response, model must make a tool call") + + with pytest.raises(APIError): + client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + + time.sleep(0.5) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert len(messages_from_db) == 0 + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_streaming_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + chunks = list(response) + assert_contains_step_id(chunks) + assert_contains_run_id(chunks) + messages = accumulate_chunks(chunks) + assert_greeting_with_assistant_message_response(messages, streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_contains_run_id(messages_from_db) + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_streaming_greeting_without_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + use_assistant_message=False, + ) + messages = accumulate_chunks(list(response)) + assert_greeting_without_assistant_message_response(messages, streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list( + agent_id=agent_state.id, after=last_message.id if last_message else None, use_assistant_message=False + ) + messages_from_db = messages_from_db_page.items + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_streaming_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step + if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner: + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + else: + messages_to_send = USER_MESSAGE_ROLL_DICE + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + timeout=300, + ) + messages = accumulate_chunks(list(response)) + + # Gemini 2.5 Flash can occasionally stop after tool return without making the final send_message call. + # Accept this shorter pattern for robustness when using Google endpoints with Flash. + # TODO un-relax this test once on the new v1 architecture / v3 loop + is_gemini_flash = llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash") + if ( + is_gemini_flash + and hasattr(messages[-1], "message_type") + and messages[-1].message_type == "stop_reason" + and getattr(messages[-1], "stop_reason", None) == "no_tool_call" + ): + # Relaxation: allow early stop on Flash without final send_message call + return + + # Default strict assertions for all other models / cases + assert_tool_call_response(messages, streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_step_stream_agent_loop_error( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message with a synchronous client. + Verifies that no new messages are persisted on error. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + with patch("letta.agents.letta_agent_v2.LettaAgentV2.stream") as mock_step: + mock_step.side_effect = ValueError("No tool calls found in response, model must make a tool call") + + with pytest.raises(APIError): + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + list(response) # This should trigger the error + + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert len(messages_from_db) == 0 + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_token_streaming_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to test if they stream in chunks + if llm_config.model_endpoint_type == "anthropic": + messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY + else: + messages_to_send = USER_MESSAGE_FORCE_REPLY + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + stream_tokens=True, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_token_streaming_greeting_without_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to force chunking + if llm_config.model_endpoint_type == "anthropic": + messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY + else: + messages_to_send = USER_MESSAGE_FORCE_REPLY + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + use_assistant_message=False, + stream_tokens=True, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list( + agent_id=agent_state.id, after=last_message.id if last_message else None, use_assistant_message=False + ) + messages_from_db = messages_from_db_page.items + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_token_streaming_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to force chunking + if llm_config.model_endpoint_type == "anthropic": + if llm_config.enable_reasoner: + # Without asking the model to think, Anthropic might decide to not think for the second step post-roll + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + else: + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG + elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"): + messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH + else: + messages_to_send = USER_MESSAGE_ROLL_DICE + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + stream_tokens=True, + timeout=300, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + # Relaxation for Gemini 2.5 Flash: allow early stop with no final send_message call + is_gemini_flash = llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash") + if ( + is_gemini_flash + and hasattr(messages[-1], "message_type") + and messages[-1].message_type == "stop_reason" + and getattr(messages[-1], "stop_reason", None) == "no_tool_call" + ): + # Accept the shorter pattern for token streaming on Flash + pass + else: + assert_tool_call_response(messages, streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_token_streaming_agent_loop_error( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Verifies that no new messages are persisted on error. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + with patch("letta.agents.letta_agent_v2.LettaAgentV2.stream") as mock_step: + mock_step.side_effect = ValueError("No tool calls found in response, model must make a tool call") + + with pytest.raises(APIError): + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + stream_tokens=True, + ) + list(response) # This should trigger the error + + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert len(messages_from_db) == 0 + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_background_token_streaming_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to test if they stream in chunks + if llm_config.model_endpoint_type == "anthropic": + messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY + else: + messages_to_send = USER_MESSAGE_FORCE_REPLY + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + stream_tokens=True, + background=True, + timeout=300, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + run_id = messages[0].run_id + assert run_id is not None + + runs = client.runs.list(agent_ids=[agent_state.id], background=True) + assert len(runs) > 0 + assert runs[0].id == run_id + + response = client.runs.messages.stream(run_id=run_id, starting_after=0) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_greeting_with_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) + + last_message_cursor = messages[-3].seq_id - 1 + response = client.runs.messages.stream(run_id=run_id, starting_after=last_message_cursor) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert len(messages) == 3 + assert messages[0].message_type == "assistant_message" and messages[0].seq_id == last_message_cursor + 1 + assert messages[1].message_type == "stop_reason" + assert messages[2].message_type == "usage_statistics" + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_background_token_streaming_greeting_without_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to force chunking + if llm_config.model_endpoint_type == "anthropic": + messages_to_send = USER_MESSAGE_FORCE_LONG_REPLY + else: + messages_to_send = USER_MESSAGE_FORCE_REPLY + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + use_assistant_message=False, + stream_tokens=True, + background=True, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_greeting_without_assistant_message_response(messages, streaming=True, token_streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list( + agent_id=agent_state.id, after=last_message.id if last_message else None, use_assistant_message=False + ) + messages_from_db = messages_from_db_page.items + assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_background_token_streaming_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a streaming message with a synchronous client. + Checks that each chunk in the stream has the correct message types. + """ + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # Use longer message for Anthropic models to force chunking + if llm_config.model_endpoint_type == "anthropic": + if llm_config.enable_reasoner: + # Without asking the model to think, Anthropic might decide to not think for the second step post-roll + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + else: + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG + elif llm_config.model_endpoint_type in ["google_vertex", "google_ai"] and llm_config.model.startswith("gemini-2.5-flash"): + messages_to_send = USER_MESSAGE_ROLL_DICE_GEMINI_FLASH + else: + messages_to_send = USER_MESSAGE_ROLL_DICE + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=messages_to_send, + stream_tokens=True, + background=True, + timeout=300, + ) + verify_token_streaming = ( + llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] and "claude-3-5-sonnet" not in llm_config.model + ) + messages = accumulate_chunks(list(response), verify_token_streaming=verify_token_streaming) + assert_tool_call_response(messages, streaming=True, llm_config=llm_config) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + +def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: + start = time.time() + while True: + run = client.runs.retrieve(run_id) + if run.status == "completed": + return run + if run.status == "failed": + print(run) + raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") + time.sleep(interval) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_async_greeting_with_assistant_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message as an asynchronous job using the synchronous client. + Waits for job completion and asserts that the result messages are as expected. + """ + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + run = client.agents.messages.send_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + run = wait_for_run_completion(client, run.id, timeout=60.0) + + messages_page = client.runs.messages.list(run_id=run.id) + messages = messages_page.items + usage = client.runs.usage.retrieve(run_id=run.id) + + # TODO: add results API test later + assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) # TODO: remove from_db=True later + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_with_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + # NOTE: deprecated in preparation of letta_v1_agent + # @pytest.mark.parametrize( + # "llm_config", + # TESTED_LLM_CONFIGS, + # ids=[c.model for c in TESTED_LLM_CONFIGS], + # ) + # def test_async_greeting_without_assistant_message( + # disable_e2b_api_key: Any, + # client: Letta, + # agent_state: AgentState, + # llm_config: LLMConfig, + # ) -> None: + # """ + # Tests sending a message as an asynchronous job using the synchronous client. + # Waits for job completion and asserts that the result messages are as expected. + # """ + # last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + # client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + # + # run = client.agents.messages.send_async( + # agent_id=agent_state.id, + # messages=USER_MESSAGE_FORCE_REPLY, + # use_assistant_message=False, + # ) + # run = wait_for_run_completion(client, run.id, timeout=60.0) + # + # messages_page = client.runs.messages.list(run_id=run.id) + messages = messages_page.items + # assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) + # + # messages_page = client.runs.messages.list(run_id=run.id) + messages = messages_page.items + # assert_greeting_without_assistant_message_response(messages, llm_config=llm_config) + # messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None, use_assistant_message=False) + messages_from_db = messages_from_db_page.items + + +# assert_greeting_without_assistant_message_response(messages_from_db, from_db=True, llm_config=llm_config) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_async_tool_call( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message as an asynchronous job using the synchronous client. + Waits for job completion and asserts that the result messages are as expected. + """ + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + # Use the thinking prompt for Anthropic models with extended reasoning to ensure second reasoning step + if llm_config.model_endpoint_type == "anthropic" and llm_config.enable_reasoner: + messages_to_send = USER_MESSAGE_ROLL_DICE_LONG_THINKING + else: + messages_to_send = USER_MESSAGE_ROLL_DICE + run = client.agents.messages.send_async( + agent_id=agent_state.id, + messages=messages_to_send, + ) + run = wait_for_run_completion(client, run.id, timeout=60.0) + messages_page = client.runs.messages.list(run_id=run.id) + messages = messages_page.items + # TODO: add test for response api + assert_tool_call_response(messages, from_db=True, llm_config=llm_config) # NOTE: skip first message which is the user message + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config) + + +class CallbackServer: + """Mock HTTP server for testing callback functionality.""" + + def __init__(self): + self.received_callbacks = [] + self.server = None + self.thread = None + self.port = None + + def start(self): + """Start the mock server on an available port.""" + + class CallbackHandler(BaseHTTPRequestHandler): + def __init__(self, callback_server, *args, **kwargs): + self.callback_server = callback_server + super().__init__(*args, **kwargs) + + def do_POST(self): + content_length = int(self.headers["Content-Length"]) + post_data = self.rfile.read(content_length) + try: + callback_data = json.loads(post_data.decode("utf-8")) + self.callback_server.received_callbacks.append( + {"data": callback_data, "headers": dict(self.headers), "timestamp": time.time()} + ) + # Respond with success + self.send_response(200) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"status": "received"}).encode()) + except Exception as e: + # Respond with error + self.send_response(400) + self.send_header("Content-type", "application/json") + self.end_headers() + self.wfile.write(json.dumps({"error": str(e)}).encode()) + + def log_message(self, format, *args): + # Suppress log messages during tests + pass + + # Bind to available port + self.server = HTTPServer(("localhost", 0), lambda *args: CallbackHandler(self, *args)) + self.port = self.server.server_address[1] + + # Start server in background thread + self.thread = threading.Thread(target=self.server.serve_forever) + self.thread.daemon = True + self.thread.start() + + def stop(self): + """Stop the mock server.""" + if self.server: + self.server.shutdown() + self.server.server_close() + if self.thread: + self.thread.join(timeout=1) + + @property + def url(self): + """Get the callback URL for this server.""" + return f"http://localhost:{self.port}/callback" + + def wait_for_callback(self, timeout=10): + """Wait for at least one callback to be received.""" + start_time = time.time() + while time.time() - start_time < timeout: + if self.received_callbacks: + return True + time.sleep(0.1) + return False + + +@contextmanager +def callback_server(): + """Context manager for callback server.""" + server = CallbackServer() + try: + server.start() + yield server + finally: + server.stop() + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_async_greeting_with_callback_url( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Tests sending a message as an asynchronous job with callback URL functionality. + Validates that callbacks are properly sent with correct payload structure. + """ + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + with callback_server() as server: + # Create async job with callback URL + run = client.agents.messages.send_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + callback_url=server.url, + ) + + # Wait for job completion + run = wait_for_run_completion(client, run.id, timeout=60.0) + + # Validate job completed successfully + messages_page = client.runs.messages.list(run_id=run.id) + messages = messages_page.items + assert_greeting_with_assistant_message_response(messages, from_db=True, llm_config=llm_config) + + # Validate callback was received + assert server.wait_for_callback(timeout=15), "Callback was not received within timeout" + assert len(server.received_callbacks) == 1, f"Expected 1 callback, got {len(server.received_callbacks)}" + + # Validate callback payload structure + callback = server.received_callbacks[0] + callback_data = callback["data"] + + # Check required fields + assert "run_id" in callback_data, "Callback missing 'run_id' field" + assert "status" in callback_data, "Callback missing 'status' field" + assert "completed_at" in callback_data, "Callback missing 'completed_at' field" + assert "metadata" in callback_data, "Callback missing 'metadata' field" + + # Validate field values + assert callback_data["run_id"] == run.id, f"Job ID mismatch: {callback_data['run_id']} != {run.id}" + assert callback_data["status"] == "completed", f"Expected status 'completed', got {callback_data['status']}" + assert callback_data["completed_at"] is not None, "completed_at should not be None" + assert callback_data["metadata"] is not None, "metadata should not be None" + + # Validate that callback metadata contains the result + assert "result" in callback_data["metadata"], "Callback metadata missing 'result' field" + callback_result = callback_data["metadata"]["result"] + callback_messages = cast_message_dict_to_messages(callback_result["messages"]) + assert callback_messages == messages, "Callback result doesn't match job result" + + # Validate HTTP headers + headers = callback["headers"] + assert headers.get("Content-Type") == "application/json", "Callback should have JSON content type" + + +@pytest.mark.flaky(max_runs=2) +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_auto_summarize(disable_e2b_api_key: Any, client: Letta, llm_config: LLMConfig): + """Test that summarization is automatically triggered.""" + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model (runs too slow) + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + # pydantic prevents us for overriding the context window paramter in the passed LLMConfig + new_llm_config = llm_config.model_dump() + new_llm_config["context_window"] = 3000 + pinned_context_window_llm_config = LLMConfig(**new_llm_config) + print("::LLM::", llm_config, new_llm_config) + send_message_tool = client.tools.list(name="send_message")[0] + temp_agent_state = client.agents.create( + include_base_tools=False, + tool_ids=[send_message_tool.id], + llm_config=pinned_context_window_llm_config, + embedding="letta/letta-free", + tags=["supervisor"], + ) + + philosophical_question_path = os.path.join(os.path.dirname(__file__), "..", "..", "data", "philosophical_question.txt") + with open(philosophical_question_path, "r", encoding="utf-8") as f: + philosophical_question = f.read().strip() + + MAX_ATTEMPTS = 10 + prev_length = None + + for attempt in range(MAX_ATTEMPTS): + try: + client.agents.messages.send( + agent_id=temp_agent_state.id, + messages=[MessageCreate(role="user", content=philosophical_question)], + ) + except Exception as e: + # if "flash" in llm_config.model and "FinishReason.MALFORMED_FUNCTION_CALL" in str(e): + # pytest.skip("Skipping test for flash model due to malformed function call from llm") + raise e + + temp_agent_state = client.agents.retrieve(agent_id=temp_agent_state.id) + message_ids = temp_agent_state.message_ids + current_length = len(message_ids) + + print("LENGTH OF IN_CONTEXT_MESSAGES:", current_length) + + if prev_length is not None and current_length <= prev_length: + # TODO: Add more stringent checks here + print(f"Summarization was triggered, detected current_length {current_length} is at least prev_length {prev_length}.") + break + + prev_length = current_length + else: + raise AssertionError("Summarization was not triggered after 10 messages") + + +# ============================ +# Job Cancellation Tests +# ============================ + + +def wait_for_run_status(client: Letta, run_id: str, target_status: str, timeout: float = 30.0, interval: float = 0.1) -> Run: + """Wait for a run to reach a specific status""" + start = time.time() + while True: + run = client.runs.retrieve(run_id) + if run.status == target_status: + return run + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not reach status '{target_status}' within {timeout} seconds (last status: {run.status})") + time.sleep(interval) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_job_creation_for_send_message( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + """ + Test that send_message endpoint creates a job and the job completes successfully. + """ + previous_runs = client.runs.list(agent_ids=[agent_state.id]) + client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + # Send a simple message and verify a job was created + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + + # The response should be successful + assert response.messages is not None + assert len(response.messages) > 0 + + runs = client.runs.list(agent_ids=[agent_state.id]) + new_runs = set(r.id for r in runs) - set(r.id for r in previous_runs) + assert len(new_runs) == 1 + + for run in runs: + if run.id == list(new_runs)[0]: + assert run.status == "completed" + + +# TODO (cliandy): MERGE BACK IN POST +# # @pytest.mark.parametrize( +# # "llm_config", +# # TESTED_LLM_CONFIGS, +# # ids=[c.model for c in TESTED_LLM_CONFIGS], +# # ) +# # def test_async_job_cancellation( +# # disable_e2b_api_key: Any, +# # client: Letta, +# # agent_state: AgentState, +# # llm_config: LLMConfig, +# # ) -> None: +# """ +# Test that an async job can be cancelled and the cancellation is reflected in the job status. +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # client.runs.cancel +# # Start an async job +# run = client.agents.messages.send_async( +# agent_id=agent_state.id, +# messages=USER_MESSAGE_FORCE_REPLY, +# ) +# +# # Verify the job was created +# assert run.id is not None +# assert run.status in ["created", "running"] +# +# # Cancel the job quickly (before it potentially completes) +# cancelled_run = client.jobs.cancel(run.id) +# +# # Verify the job was cancelled +# assert cancelled_run.status == "cancelled" +# +# # Wait a bit and verify it stays cancelled (no invalid state transitions) +# time.sleep(1) +# final_run = client.runs.retrieve(run.id) +# assert final_run.status == "cancelled" +# +# # Verify the job metadata indicates cancellation +# if final_run.metadata: +# assert final_run.metadata.get("cancelled") is True or "stop_reason" in final_run.metadata +# +# +# def test_job_cancellation_endpoint_validation( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# ) -> None: +# """ +# Test job cancellation endpoint validation (trying to cancel completed/failed jobs). +# """ +# # Test cancelling a non-existent job +# with pytest.raises(APIError) as exc_info: +# client.jobs.cancel("non-existent-job-id") +# assert exc_info.value.status_code == 404 +# +# +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_completed_job_cannot_be_cancelled( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Test that completed jobs cannot be cancelled. +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # Start an async job and wait for it to complete +# run = client.agents.messages.send_async( +# agent_id=agent_state.id, +# messages=USER_MESSAGE_FORCE_REPLY, +# ) +# +# # Wait for completion +# completed_run = wait_for_run_completion(client, run.id) +# assert completed_run.status == "completed" +# +# # Try to cancel the completed job - should fail +# with pytest.raises(APIError) as exc_info: +# client.jobs.cancel(run.id) +# assert exc_info.value.status_code == 400 +# assert "Cannot cancel job with status 'completed'" in str(exc_info.value) +# +# +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# def test_streaming_job_independence_from_client_disconnect( +# disable_e2b_api_key: Any, +# client: Letta, +# agent_state: AgentState, +# llm_config: LLMConfig, +# ) -> None: +# """ +# Test that streaming jobs are independent of client connection state. +# This verifies that jobs continue even if the client "disconnects" (simulated by not consuming the stream). +# """ +# client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) +# +# # Create a streaming request +# import threading +# +# import httpx +# +# # Get the base URL and create a raw HTTP request to simulate partial consumption +# base_url = client._client_wrapper._base_url +# +# def start_stream_and_abandon(): +# """Start a streaming request but abandon it (simulating client disconnect)""" +# try: +# response = httpx.post( +# f"{base_url}/agents/{agent_state.id}/messages/stream", +# json={"messages": [{"role": "user", "text": "Hello, how are you?"}], "stream_tokens": False}, +# headers={"user_id": "test-user"}, +# timeout=30.0, +# ) +# +# # Read just a few chunks then "disconnect" by not reading the rest +# chunk_count = 0 +# for chunk in response.iter_lines(): +# chunk_count += 1 +# if chunk_count > 3: # Read a few chunks then stop +# break +# # Connection is now "abandoned" but the job should continue +# +# except Exception: +# pass # Ignore connection errors +# +# # Start the stream in a separate thread to simulate abandonment +# thread = threading.Thread(target=start_stream_and_abandon) +# thread.start() +# thread.join(timeout=5.0) # Wait up to 5 seconds for the "disconnect" +# +# # The important thing is that this test validates our architecture: +# # 1. Jobs are created before streaming starts (verified by our other tests) +# # 2. Jobs track execution independent of client connection (handled by our wrapper) +# # 3. Only explicit cancellation terminates jobs (tested by other tests) +# +# # This test primarily validates that the implementation doesn't break under simulated disconnection +# assert True # If we get here without errors, the architecture is sound + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_inner_thoughts_false_non_reasoner_models( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + # skip if this is a reasoning model + if not config_filename or config_filename in reasoning_configs: + pytest.skip(f"Skipping test for reasoning model {llm_config.model}") + + # create a new config with all reasoning fields turned off + new_llm_config = llm_config.model_dump() + new_llm_config["put_inner_thoughts_in_kwargs"] = False + new_llm_config["enable_reasoner"] = False + new_llm_config["max_reasoning_tokens"] = 0 + adjusted_llm_config = LLMConfig(**new_llm_config) + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config) + response = client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + assert_greeting_no_reasoning_response(response.messages) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_no_reasoning_response(messages_from_db, from_db=True) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_inner_thoughts_false_non_reasoner_models_streaming( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a limited model + if not config_filename or config_filename in limited_configs: + pytest.skip(f"Skipping test for limited model {llm_config.model}") + + # skip if this is a reasoning model + if not config_filename or config_filename in reasoning_configs: + pytest.skip(f"Skipping test for reasoning model {llm_config.model}") + + # create a new config with all reasoning fields turned off + new_llm_config = llm_config.model_dump() + new_llm_config["put_inner_thoughts_in_kwargs"] = False + new_llm_config["enable_reasoner"] = False + new_llm_config["max_reasoning_tokens"] = 0 + adjusted_llm_config = LLMConfig(**new_llm_config) + + last_message_page = client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config) + response = client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + messages = accumulate_chunks(list(response)) + assert_greeting_no_reasoning_response(messages, streaming=True) + messages_from_db_page = client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_no_reasoning_response(messages_from_db, from_db=True) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +def test_inner_thoughts_toggle_interleaved( + disable_e2b_api_key: Any, + client: Letta, + agent_state: AgentState, + llm_config: LLMConfig, +) -> None: + # get the config filename + config_filename = None + for filename in filenames: + config = get_llm_config(filename) + if config.model_dump() == llm_config.model_dump(): + config_filename = filename + break + + # skip if this is a reasoning model + if not config_filename or config_filename in reasoning_configs: + pytest.skip(f"Skipping test for reasoning model {llm_config.model}") + + # Only run on OpenAI, Anthropic, and Google models + if llm_config.model_endpoint_type not in ["openai", "anthropic", "google_ai", "google_vertex"]: + pytest.skip(f"Skipping `test_inner_thoughts_toggle_interleaved` for model endpoint type {llm_config.model_endpoint_type}") + + assert not is_reasoning_completely_disabled(llm_config), "Reasoning should be enabled" + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + # Send a message with inner thoughts + client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_GREETING, + ) + + # create a new config with all reasoning fields turned off + new_llm_config = llm_config.model_dump() + new_llm_config["put_inner_thoughts_in_kwargs"] = False + new_llm_config["enable_reasoner"] = False + new_llm_config["max_reasoning_tokens"] = 0 + adjusted_llm_config = LLMConfig(**new_llm_config) + agent_state = client.agents.modify(agent_id=agent_state.id, llm_config=adjusted_llm_config) + + # Preview the message payload of the next message + # response = client.agents.messages.preview_raw_payload( + # agent_id=agent_state.id, + # request=LettaRequest(messages=USER_MESSAGE_FORCE_REPLY), + # ) + + # Test our helper functions + assert is_reasoning_completely_disabled(adjusted_llm_config), "Reasoning should be completely disabled" + + # Verify that assistant messages with tool calls have been scrubbed of inner thoughts + # Branch assertions based on model endpoint type + # if llm_config.model_endpoint_type == "openai": + # messages = response["messages"] + # validate_openai_format_scrubbing(messages) + # elif llm_config.model_endpoint_type == "anthropic": + # messages = response["messages"] + # validate_anthropic_format_scrubbing(messages, llm_config.enable_reasoner) + # elif llm_config.model_endpoint_type in ["google_ai", "google_vertex"]: + # # Google uses 'contents' instead of 'messages' + # contents = response.get("contents", response.get("messages", [])) + # validate_google_format_scrubbing(contents) diff --git a/tests/sdk_v1/integration/integration_test_send_message_v2.py b/tests/sdk_v1/integration/integration_test_send_message_v2.py new file mode 100644 index 00000000..279f4cc4 --- /dev/null +++ b/tests/sdk_v1/integration/integration_test_send_message_v2.py @@ -0,0 +1,761 @@ +import asyncio +import itertools +import json +import os +import threading +import time +import uuid +from typing import Any, List, Tuple + +import pytest +import requests +from dotenv import load_dotenv +from letta_client import AsyncLetta +from letta_client.types import ToolReturnMessage +from letta_client.types.agents import AssistantMessage, ReasoningMessage, Run, ToolCallMessage, UserMessage +from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics + +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import AgentType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import MessageCreate + +logger = get_logger(__name__) + + +# ------------------------------ +# Helper Functions and Constants +# ------------------------------ + + +all_configs = [ + "openai-gpt-4o-mini.json", + "openai-o3.json", + "openai-gpt-5.json", + "claude-4-5-sonnet.json", + "claude-4-1-opus.json", + "gemini-2.5-flash.json", +] + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + llm_config = LLMConfig(**config_data) + return llm_config + + +requested = os.getenv("LLM_CONFIG_FILE") +filenames = [requested] if requested else all_configs +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames] +# Filter out deprecated Claude 3.5 Sonnet model that is no longer available +TESTED_LLM_CONFIGS = [ + cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "anthropic" and cfg.model == "claude-3-5-sonnet-20241022") +] +# Filter out Bedrock models that require aioboto3 dependency (not available in CI) +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model_endpoint_type == "bedrock")] +# Filter out Gemini models that have Google Cloud permission issues +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if cfg.model_endpoint_type not in ["google_vertex", "google_ai"]] +# Filter out qwen2.5:7b model that has server issues +TESTED_LLM_CONFIGS = [cfg for cfg in TESTED_LLM_CONFIGS if not (cfg.model == "qwen2.5:7b")] + + +def roll_dice(num_sides: int) -> int: + """ + Returns a random number between 1 and num_sides. + Args: + num_sides (int): The number of sides on the die. + Returns: + int: A random integer between 1 and num_sides, representing the die roll. + """ + import random + + return random.randint(1, num_sides) + + +USER_MESSAGE_OTID = str(uuid.uuid4()) +USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work" +USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ + MessageCreate( + role="user", + content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [ + MessageCreate( + role="user", + content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.", + otid=USER_MESSAGE_OTID, + ) +] +USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreate] = [ + MessageCreate( + role="user", + content=("This is an automated test message. Please call the roll_dice tool three times in parallel."), + otid=USER_MESSAGE_OTID, + ) +] + + +def assert_greeting_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + expected_message_count_min, expected_message_count_max = get_expected_message_count_range( + llm_config, streaming=streaming, from_db=from_db + ) + assert expected_message_count_min <= len(messages) <= expected_message_count_max + + # User message if loaded from db + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Reasoning message if reasoning enabled + otid_suffix = 0 + try: + if is_reasoner_model(llm_config): + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + except: + # Reasoning is non-deterministic, so don't throw if missing + pass + + # Assistant message + assert isinstance(messages[index], AssistantMessage) + if not token_streaming: + assert "teamwork" in messages[index].content.lower() + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Stop reason and usage statistics if streaming + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +def assert_tool_call_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + from_db: bool = False, + with_cancellation: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + if not with_cancellation: + expected_message_count_min, expected_message_count_max = get_expected_message_count_range( + llm_config, tool_call=True, streaming=streaming, from_db=from_db + ) + assert expected_message_count_min <= len(messages) <= expected_message_count_max + + # User message if loaded from db + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Reasoning message if reasoning enabled + otid_suffix = 0 + try: + if is_reasoner_model(llm_config): + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + except: + # Reasoning is non-deterministic, so don't throw if missing + pass + + # Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call + if ( + (llm_config.model == "claude-sonnet-4-5-20250929" or llm_config.model.startswith("claude-opus-4-1")) + and index < len(messages) + and isinstance(messages[index], AssistantMessage) + ): + # Skip the extra AssistantMessage and move to the next message + index += 1 + otid_suffix += 1 + + # Tool call message (may be skipped if cancelled early) + if with_cancellation and isinstance(messages[index], AssistantMessage): + # If cancelled early, model might respond with text instead of making tool call + assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower() + return # Skip tool call assertions for early cancellation + + assert isinstance(messages[index], ToolCallMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + + # Tool return message + assert isinstance(messages[index], ToolReturnMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + + # Messages from second agent step if request has not been cancelled + if not with_cancellation: + # Reasoning message if reasoning enabled + try: + if is_reasoner_model(llm_config): + assert isinstance(messages[index], ReasoningMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + except: + # Reasoning is non-deterministic, so don't throw if missing + pass + + # Assistant message + assert isinstance(messages[index], AssistantMessage) + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + + # Stop reason and usage statistics if streaming + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn") + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +async def accumulate_chunks(chunks, verify_token_streaming: bool = False) -> List[Any]: + """ + Accumulates chunks into a list of messages. + Handles both async iterators and raw SSE strings. + """ + messages = [] + current_message = None + prev_message_type = None + + # Handle raw SSE string from runs.messages.stream() + if isinstance(chunks, str): + import json + + for line in chunks.strip().split("\n"): + if line.startswith("data: ") and line != "data: [DONE]": + try: + data = json.loads(line[6:]) # Remove 'data: ' prefix + if "message_type" in data: + # Create proper message type objects + message_type = data.get("message_type") + if message_type == "assistant_message": + from letta_client.types.agents import AssistantMessage + + chunk = AssistantMessage(**data) + elif message_type == "reasoning_message": + from letta_client.types.agents import ReasoningMessage + + chunk = ReasoningMessage(**data) + elif message_type == "tool_call_message": + from letta_client.types.agents import ToolCallMessage + + chunk = ToolCallMessage(**data) + elif message_type == "tool_return_message": + from letta_client.types import ToolReturnMessage + + chunk = ToolReturnMessage(**data) + elif message_type == "user_message": + from letta_client.types.agents import UserMessage + + chunk = UserMessage(**data) + elif message_type == "stop_reason": + from letta_client.types.agents.letta_streaming_response import LettaStopReason + + chunk = LettaStopReason(**data) + elif message_type == "usage_statistics": + from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics + + chunk = LettaUsageStatistics(**data) + else: + chunk = type("Chunk", (), data)() # Fallback for unknown types + + current_message_type = chunk.message_type + + if prev_message_type != current_message_type: + if current_message is not None: + messages.append(current_message) + current_message = chunk + else: + # Accumulate content for same message type + if hasattr(current_message, "content") and hasattr(chunk, "content"): + current_message.content += chunk.content + + prev_message_type = current_message_type + except json.JSONDecodeError: + continue + + if current_message is not None: + messages.append(current_message) + else: + # Handle async iterator from agents.messages.stream() + async for chunk in chunks: + current_message_type = chunk.message_type + + if prev_message_type != current_message_type: + if current_message is not None: + messages.append(current_message) + current_message = chunk + else: + # Accumulate content for same message type + if hasattr(current_message, "content") and hasattr(chunk, "content"): + current_message.content += chunk.content + + prev_message_type = current_message_type + + if current_message is not None: + messages.append(current_message) + + return messages + + +async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5): + await asyncio.sleep(delay) + await client.agents.messages.cancel(agent_id=agent_id) + + +async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: + start = time.time() + while True: + run = await client.runs.retrieve(run_id) + if run.status == "completed": + return run + if run.status == "cancelled": + time.sleep(5) + return run + if run.status == "failed": + raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") + if time.time() - start > timeout: + raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") + time.sleep(interval) + + +def get_expected_message_count_range( + llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False +) -> Tuple[int, int]: + """ + Returns the expected range of number of messages for a given LLM configuration. Uses range to account for possible variations in the number of reasoning messages. + + Greeting: + ------------------------------------------------------------------------------------------------------------------------------------------------------------------ + | gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking | + | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | + | AssistantMessage | AssistantMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage | + | | | AssistantMessage | | AssistantMessage | AssistantMessage | + + + Tool Call: + --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- + | gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | sonnet-4.5/opus-4.1 | flash-2.5-thinking | + | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | + | ToolCallMessage | ToolCallMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage | ReasoningMessage | + | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | AssistantMessage | ToolCallMessage | + | AssistantMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | ToolReturnMessage | + | | | ReasoningMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ReasoningMessage | + | | | AssistantMessage | | AssistantMessage | ReasoningMessage | AssistantMessage | + | | | | | | AssistantMessage | | + + """ + # assistant message + expected_message_count = 1 + expected_range = 0 + + if is_reasoner_model(llm_config): + # reasoning message + expected_range += 1 + if tool_call: + # check for sonnet 4.5 or opus 4.1 specifically + is_sonnet_4_5_or_opus_4_1 = ( + llm_config.model_endpoint_type == "anthropic" + and llm_config.enable_reasoner + and (llm_config.model.startswith("claude-sonnet-4-5") or llm_config.model.startswith("claude-opus-4-1")) + ) + if is_sonnet_4_5_or_opus_4_1 or not LLMConfig.is_anthropic_reasoning_model(llm_config): + # sonnet 4.5 and opus 4.1 return a reasoning message before the final assistant message + # so do the other native reasoning models + expected_range += 1 + + # opus 4.1 generates an extra AssistantMessage before the tool call + if llm_config.model.startswith("claude-opus-4-1"): + expected_range += 1 + + if tool_call: + # tool call and tool return messages + expected_message_count += 2 + + if from_db: + # user message + expected_message_count += 1 + + if streaming: + # stop reason and usage statistics + expected_message_count += 2 + + return expected_message_count, expected_message_count + expected_range + + +def is_reasoner_model(llm_config: LLMConfig) -> bool: + return ( + (LLMConfig.is_openai_reasoning_model(llm_config) and llm_config.reasoning_effort == "high") + or LLMConfig.is_anthropic_reasoning_model(llm_config) + or LLMConfig.is_google_vertex_reasoning_model(llm_config) + or LLMConfig.is_google_ai_reasoning_model(llm_config) + ) + + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="function") +async def client(server_url: str) -> AsyncLetta: + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + client_instance = AsyncLetta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +async def agent_state(client: AsyncLetta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. + """ + dice_tool = await client.tools.upsert_from_function(func=roll_dice) + + agent_state_instance = await client.agents.create( + agent_type=AgentType.letta_v1_agent, + name="test_agent", + include_base_tools=False, + tool_ids=[dice_tool.id], + model="openai/gpt-4o", + embedding="openai/text-embedding-3-small", + tags=["test"], + ) + yield agent_state_instance + + await client.agents.delete(agent_state_instance.id) + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) +@pytest.mark.asyncio(loop_scope="function") +async def test_greeting( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, +) -> None: + last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if send_type == "step": + response = await client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + messages = response.messages + run_id = messages[0].run_id + elif send_type == "async": + run = await client.agents.messages.send_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + run = await wait_for_run_completion(client, run.id, timeout=60.0) + messages_page = await client.runs.messages.list(run_id=run.id) + messages = [m for m in messages_page.items if m.message_type != "user_message"] + run_id = run.id + else: + response = await client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + stream_tokens=(send_type == "stream_tokens"), + background=(send_type == "stream_tokens_background"), + ) + messages = await accumulate_chunks(response) + run_id = messages[0].run_id + + assert_greeting_response( + messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config + ) + + if "background" in send_type: + response = await client.runs.messages.stream(run_id=run_id, starting_after=0) + messages = await accumulate_chunks(response) + assert_greeting_response( + messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config + ) + + messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config) + + assert run_id is not None + run = await client.runs.retrieve(run_id=run_id) + assert run.status == "completed" + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) +@pytest.mark.asyncio(loop_scope="function") +async def test_parallel_tool_call_anthropic( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, +) -> None: + if llm_config.model_endpoint_type != "anthropic": + pytest.skip("Parallel tool calling test only applies to Anthropic models.") + + # change llm_config to support parallel tool calling + llm_config.parallel_tool_calls = True + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if send_type == "step": + await client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + ) + elif send_type == "async": + run = await client.agents.messages.send_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + ) + await wait_for_run_completion(client, run.id, timeout=60.0) + else: + response = await client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_PARALLEL_TOOL_CALL, + stream_tokens=(send_type == "stream_tokens"), + background=(send_type == "stream_tokens_background"), + ) + await accumulate_chunks(response) + + # validate parallel tool call behavior in preserved messages + preserved_messages_page = await client.agents.messages.list(agent_id=agent_state.id) + preserved_messages = preserved_messages_page.items + + # find the tool call message in preserved messages + tool_call_msg = None + tool_return_msg = None + for msg in preserved_messages: + if isinstance(msg, ToolCallMessage): + tool_call_msg = msg + elif isinstance(msg, ToolReturnMessage): + tool_return_msg = msg + + # assert parallel tool calls were made + assert tool_call_msg is not None, "ToolCallMessage not found in preserved messages" + assert hasattr(tool_call_msg, "tool_calls"), "tool_calls field not found in ToolCallMessage" + assert len(tool_call_msg.tool_calls) == 3, f"Expected 3 parallel tool calls, got {len(tool_call_msg.tool_calls)}" + + # verify each tool call and collect num_sides values + num_sides_values = [] + for tc in tool_call_msg.tool_calls: + assert tc.name == "roll_dice" + assert tc.tool_call_id.startswith("toolu_") + assert "num_sides" in tc.arguments + # Parse the num_sides value from the arguments + import json + + args = json.loads(tc.arguments) + num_sides = int(args["num_sides"]) + num_sides_values.append(num_sides) + + # assert tool returns match the tool calls + assert tool_return_msg is not None, "ToolReturnMessage not found in preserved messages" + assert hasattr(tool_return_msg, "tool_returns"), "tool_returns field not found in ToolReturnMessage" + assert len(tool_return_msg.tool_returns) == 3, f"Expected 3 tool returns, got {len(tool_return_msg.tool_returns)}" + + # verify each tool return matches the corresponding tool call's num_sides + tool_call_ids = {tc.tool_call_id for tc in tool_call_msg.tool_calls} + for i, tr in enumerate(tool_return_msg.tool_returns): + assert tr.type == "tool" + assert tr.status == "success" + assert tr.tool_call_id in tool_call_ids, f"tool_call_id {tr.tool_call_id} not found in tool calls" + # Check that the tool return value is within the range of the corresponding num_sides + expected_max = num_sides_values[i] if i < len(num_sides_values) else max(num_sides_values) + assert int(tr.tool_return) >= 1 and int(tr.tool_return) <= expected_max, ( + f"Tool return {tr.tool_return} is not within range 1-{expected_max}" + ) + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize( + ["send_type", "cancellation"], + list( + itertools.product( + ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] + ) + ), + ids=[ + f"{s}-{c}" + for s, c in itertools.product( + ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] + ) + ], +) +@pytest.mark.asyncio(loop_scope="function") +async def test_tool_call( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, + cancellation: str, +) -> None: + # Skip models with OTID mismatch issues between ToolCallMessage and ToolReturnMessage + if llm_config.model == "gpt-5" or llm_config.model == "claude-sonnet-4-5-20250929" or llm_config.model.startswith("claude-opus-4-1"): + pytest.skip(f"Skipping {llm_config.model} due to OTID chain issue - messages receive incorrect OTID suffixes") + + last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1) + last_message = last_message_page.items[0] if last_message_page.items else None + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if cancellation == "with_cancellation": + delay = 5 if llm_config.model == "gpt-5" else 0.5 # increase delay for responses api + _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay)) + + if send_type == "step": + response = await client.agents.messages.send( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + ) + messages = response.messages + run_id = messages[0].run_id + elif send_type == "async": + run = await client.agents.messages.send_async( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + ) + run = await wait_for_run_completion(client, run.id, timeout=60.0) + messages_page = await client.runs.messages.list(run_id=run.id) + messages = [m for m in messages_page.items if m.message_type != "user_message"] + run_id = run.id + else: + response = await client.agents.messages.stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_ROLL_DICE, + stream_tokens=(send_type == "stream_tokens"), + background=(send_type == "stream_tokens_background"), + ) + messages = await accumulate_chunks(response) + run_id = messages[0].run_id + + assert_tool_call_response( + messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) + + if "background" in send_type: + response = await client.runs.messages.stream(run_id=run_id, starting_after=0) + messages = await accumulate_chunks(response) + assert_tool_call_response( + messages, streaming=("stream" in send_type), llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) + + messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) + messages_from_db = messages_from_db_page.items + assert_tool_call_response( + messages_from_db, from_db=True, llm_config=llm_config, with_cancellation=(cancellation == "with_cancellation") + ) + + assert run_id is not None + run = await client.runs.retrieve(run_id=run_id) + assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed")