import asyncio import itertools import json import logging import os import threading import time import uuid from typing import Any, List, Tuple import pytest import requests from dotenv import load_dotenv from letta_client import AsyncLetta from letta_client.types import ( AgentState, AnthropicModelSettings, JsonSchemaResponseFormat, MessageCreateParam, OpenAIModelSettings, ToolReturnMessage, ) from letta_client.types.agents import AssistantMessage, ReasoningMessage, Run, ToolCallMessage, UserMessage from letta_client.types.agents.letta_streaming_response import LettaPing, LettaStopReason, LettaUsageStatistics logger = logging.getLogger(__name__) # ------------------------------ # Helper Functions and Constants # ------------------------------ all_configs = [ "openai-gpt-4o-mini.json", "openai-gpt-4.1.json", "openai-gpt-5.json", "claude-4-5-sonnet.json", "gemini-2.5-pro.json", "zai-glm-4.6.json", ] def get_model_config(filename: str, model_settings_dir: str = "tests/model_settings") -> Tuple[str, dict]: """Load a model_settings file and return the handle and settings dict.""" filename = os.path.join(model_settings_dir, filename) with open(filename, "r") as f: config_data = json.load(f) return config_data["handle"], config_data.get("model_settings", {}) requested = os.getenv("LLM_CONFIG_FILE") filenames = [requested] if requested else all_configs TESTED_MODEL_CONFIGS: List[Tuple[str, dict]] = [get_model_config(fn) for fn in filenames] def roll_dice(num_sides: int) -> int: """ Returns a random number between 1 and num_sides. Args: num_sides (int): The number of sides on the die. Returns: int: A random integer between 1 and num_sides, representing the die roll. """ import random return random.randint(1, num_sides) USER_MESSAGE_OTID = str(uuid.uuid4()) USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work" USER_MESSAGE_FORCE_REPLY: List[MessageCreateParam] = [ MessageCreateParam( role="user", content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.", otid=USER_MESSAGE_OTID, ) ] USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [ MessageCreateParam( role="user", content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.", otid=USER_MESSAGE_OTID, ) ] USER_MESSAGE_PARALLEL_TOOL_CALL: List[MessageCreateParam] = [ MessageCreateParam( role="user", content=( "This is an automated test message. Please call the roll_dice tool EXACTLY three times in parallel - no more, no less. " "Call it with num_sides=6, num_sides=12, and num_sides=20. Make all three calls at the same time in a single response." ), otid=USER_MESSAGE_OTID, ) ] def assert_greeting_response( messages: List[Any], model_handle: str, model_settings: dict, streaming: bool = False, token_streaming: bool = False, from_db: bool = False, ) -> None: """ Asserts that the messages list follows the expected sequence: ReasoningMessage -> AssistantMessage. """ # Filter out LettaPing messages which are keep-alive messages for SSE streams messages = [ msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) ] expected_message_count_min, expected_message_count_max = get_expected_message_count_range( model_handle, model_settings, streaming=streaming, from_db=from_db ) assert expected_message_count_min <= len(messages) <= expected_message_count_max # User message if loaded from db index = 0 if from_db: assert isinstance(messages[index], UserMessage) assert messages[index].otid == USER_MESSAGE_OTID index += 1 # Reasoning message if reasoning enabled otid_suffix = 0 try: if is_reasoner_model(model_handle, model_settings): assert isinstance(messages[index], ReasoningMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 otid_suffix += 1 except: # Reasoning is non-deterministic, so don't throw if missing pass # Assistant message assert isinstance(messages[index], AssistantMessage) if not token_streaming: assert "teamwork" in messages[index].content.lower() assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 otid_suffix += 1 # Stop reason and usage statistics if streaming if streaming: assert isinstance(messages[index], LettaStopReason) assert messages[index].stop_reason == "end_turn" index += 1 assert isinstance(messages[index], LettaUsageStatistics) assert messages[index].prompt_tokens > 0 assert messages[index].completion_tokens > 0 assert messages[index].total_tokens > 0 assert messages[index].step_count > 0 def assert_tool_call_response( messages: List[Any], model_handle: str, model_settings: dict, streaming: bool = False, from_db: bool = False, with_cancellation: bool = False, ) -> None: """ Asserts that the messages list follows the expected sequence: ReasoningMessage -> ToolCallMessage -> ToolReturnMessage -> ReasoningMessage -> AssistantMessage. """ # Filter out LettaPing messages which are keep-alive messages for SSE streams messages = [ msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) ] # If cancellation happened and no messages were persisted (early cancellation), return early if with_cancellation and len(messages) == 0: return if not with_cancellation: expected_message_count_min, expected_message_count_max = get_expected_message_count_range( model_handle, model_settings, tool_call=True, streaming=streaming, from_db=from_db ) assert expected_message_count_min <= len(messages) <= expected_message_count_max # User message if loaded from db index = 0 if from_db: assert isinstance(messages[index], UserMessage) assert messages[index].otid == USER_MESSAGE_OTID index += 1 # If cancellation happened after user message but before any response, return early if with_cancellation and index >= len(messages): return # Reasoning message if reasoning enabled otid_suffix = 0 try: if is_reasoner_model(model_handle, model_settings): assert isinstance(messages[index], ReasoningMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 otid_suffix += 1 except: # Reasoning is non-deterministic, so don't throw if missing pass # Special case for claude-sonnet-4-5-20250929, opus-4.1, and zai which can generate an extra AssistantMessage before tool call if ( ("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle or model_settings.get("provider_type") == "zai") and index < len(messages) and isinstance(messages[index], AssistantMessage) ): # Skip the extra AssistantMessage and move to the next message index += 1 otid_suffix += 1 # Tool call message (may be skipped if cancelled early) if with_cancellation and index < len(messages) and isinstance(messages[index], AssistantMessage): # If cancelled early, model might respond with text instead of making tool call assert "roll" in messages[index].content.lower() or "die" in messages[index].content.lower() return # Skip tool call assertions for early cancellation # If cancellation happens before tool call, we might get LettaStopReason directly if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason): assert messages[index].stop_reason == "cancelled" return # Skip remaining assertions for very early cancellation assert isinstance(messages[index], ToolCallMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 # If cancellation happens before tool return, we might get LettaStopReason directly if with_cancellation and index < len(messages) and isinstance(messages[index], LettaStopReason): assert messages[index].stop_reason == "cancelled" return # Skip remaining assertions for very early cancellation # Tool return message otid_suffix = 0 assert isinstance(messages[index], ToolReturnMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 # Messages from second agent step if request has not been cancelled if not with_cancellation: # Reasoning message if reasoning enabled otid_suffix = 0 try: if is_reasoner_model(model_handle, model_settings): assert isinstance(messages[index], ReasoningMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 otid_suffix += 1 except: # Reasoning is non-deterministic, so don't throw if missing pass # Assistant message assert isinstance(messages[index], AssistantMessage) assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) index += 1 # Stop reason and usage statistics if streaming if streaming: assert isinstance(messages[index], LettaStopReason) assert messages[index].stop_reason == ("cancelled" if with_cancellation else "end_turn") index += 1 assert isinstance(messages[index], LettaUsageStatistics) assert messages[index].prompt_tokens > 0 assert messages[index].completion_tokens > 0 assert messages[index].total_tokens > 0 assert messages[index].step_count > 0 async def accumulate_chunks(chunks, verify_token_streaming: bool = False) -> List[Any]: """ Accumulates chunks into a list of messages. Handles both async iterators and raw SSE strings. """ messages = [] current_message = None prev_message_type = None # Handle raw SSE string from runs.messages.stream() if isinstance(chunks, str): import json for line in chunks.strip().split("\n"): if line.startswith("data: ") and line != "data: [DONE]": try: data = json.loads(line[6:]) # Remove 'data: ' prefix if "message_type" in data: # Create proper message type objects message_type = data.get("message_type") if message_type == "assistant_message": from letta_client.types.agents import AssistantMessage chunk = AssistantMessage(**data) elif message_type == "reasoning_message": from letta_client.types.agents import ReasoningMessage chunk = ReasoningMessage(**data) elif message_type == "tool_call_message": from letta_client.types.agents import ToolCallMessage chunk = ToolCallMessage(**data) elif message_type == "tool_return_message": from letta_client.types import ToolReturnMessage chunk = ToolReturnMessage(**data) elif message_type == "user_message": from letta_client.types.agents import UserMessage chunk = UserMessage(**data) elif message_type == "stop_reason": from letta_client.types.agents.letta_streaming_response import LettaStopReason chunk = LettaStopReason(**data) elif message_type == "usage_statistics": from letta_client.types.agents.letta_streaming_response import LettaUsageStatistics chunk = LettaUsageStatistics(**data) else: chunk = type("Chunk", (), data)() # Fallback for unknown types current_message_type = chunk.message_type if prev_message_type != current_message_type: if current_message is not None: messages.append(current_message) current_message = chunk else: # Accumulate content for same message type if hasattr(current_message, "content") and hasattr(chunk, "content"): current_message.content += chunk.content prev_message_type = current_message_type except json.JSONDecodeError: continue if current_message is not None: messages.append(current_message) else: # Handle async iterator from agents.messages.stream() async for chunk in chunks: current_message_type = chunk.message_type if prev_message_type != current_message_type: if current_message is not None: messages.append(current_message) current_message = chunk else: # Accumulate content for same message type if hasattr(current_message, "content") and hasattr(chunk, "content"): current_message.content += chunk.content prev_message_type = current_message_type if current_message is not None: messages.append(current_message) return messages async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5): await asyncio.sleep(delay) await client.agents.messages.cancel(agent_id=agent_id) async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: start = time.time() while True: run = await client.runs.retrieve(run_id) if run.status == "completed": return run if run.status == "cancelled": time.sleep(5) return run if run.status == "failed": raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}") if time.time() - start > timeout: raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})") time.sleep(interval) def get_expected_message_count_range( model_handle: str, model_settings: dict, tool_call: bool = False, streaming: bool = False, from_db: bool = False ) -> Tuple[int, int]: """ Returns the expected range of number of messages for a given LLM configuration. Uses range to account for possible variations in the number of reasoning messages. Greeting: ------------------------------------------------------------------------------------------------------------------------------------------------------------------ | gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking | | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | | AssistantMessage | AssistantMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage | | | | AssistantMessage | | AssistantMessage | AssistantMessage | Tool Call: --------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------------- | gpt-4o | gpt-o3 (med effort) | gpt-5 (high effort) | sonnet-3-5 | sonnet-3.7-thinking | sonnet-4.5/opus-4.1 | flash-2.5-thinking | | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ | | ToolCallMessage | ToolCallMessage | ReasoningMessage | AssistantMessage | ReasoningMessage | ReasoningMessage | ReasoningMessage | | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | AssistantMessage | ToolCallMessage | | AssistantMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | ToolCallMessage | ToolReturnMessage | | | | ReasoningMessage | AssistantMessage | ToolReturnMessage | ToolReturnMessage | ReasoningMessage | | | | AssistantMessage | | AssistantMessage | ReasoningMessage | AssistantMessage | | | | | | | AssistantMessage | | """ # assistant message expected_message_count = 1 expected_range = 0 if is_reasoner_model(model_handle, model_settings): # reasoning message expected_range += 1 if tool_call: # check for sonnet 4.5 or opus 4.1 specifically is_sonnet_4_5_or_opus_4_1 = ( model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled" and ("claude-sonnet-4-5" in model_handle or "claude-opus-4-1" in model_handle) ) is_anthropic_reasoning = ( model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled" ) if is_sonnet_4_5_or_opus_4_1 or not is_anthropic_reasoning: # sonnet 4.5 and opus 4.1 return a reasoning message before the final assistant message # so do the other native reasoning models expected_range += 1 # opus 4.1 generates an extra AssistantMessage before the tool call if "claude-opus-4-1" in model_handle: expected_range += 1 # Z.ai models output an AssistantMessage with each ReasoningMessage (not just the final one) if model_settings.get("provider_type") == "zai": expected_range += 1 if tool_call: # tool call and tool return messages expected_message_count += 2 if from_db: # user message expected_message_count += 1 if streaming: # stop reason and usage statistics expected_message_count += 2 return expected_message_count, expected_message_count + expected_range def is_reasoner_model(model_handle: str, model_settings: dict) -> bool: """Check if the model is a reasoning model based on its handle and settings.""" # OpenAI reasoning models with high reasoning effort is_openai_reasoning = ( model_settings.get("provider_type") == "openai" and ( "gpt-5" in model_handle or "o1" in model_handle or "o3" in model_handle or "o4-mini" in model_handle or "gpt-4.1" in model_handle ) and model_settings.get("reasoning", {}).get("reasoning_effort") == "high" ) # Anthropic models with thinking enabled is_anthropic_reasoning = ( model_settings.get("provider_type") == "anthropic" and model_settings.get("thinking", {}).get("type") == "enabled" ) # Google Vertex models with thinking config is_google_vertex_reasoning = ( model_settings.get("provider_type") == "google_vertex" and model_settings.get("thinking_config", {}).get("include_thoughts") is True ) # Google AI models with thinking config is_google_ai_reasoning = ( model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True ) # Z.ai models output reasoning by default is_zai_reasoning = model_settings.get("provider_type") == "zai" return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning # ------------------------------ # Fixtures # ------------------------------ @pytest.fixture(scope="module") def server_url() -> str: """ Provides the URL for the Letta server. If LETTA_SERVER_URL is not set, starts the server in a background thread and polls until it's accepting connections. """ def _run_server() -> None: load_dotenv() from letta.server.rest_api.app import start_server start_server(debug=True) url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") if not os.getenv("LETTA_SERVER_URL"): thread = threading.Thread(target=_run_server, daemon=True) thread.start() # Poll until the server is up (or timeout) timeout_seconds = 60 deadline = time.time() + timeout_seconds while time.time() < deadline: try: resp = requests.get(url + "/v1/health") if resp.status_code < 500: break except requests.exceptions.RequestException: pass time.sleep(0.1) else: raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") return url @pytest.fixture(scope="function") async def client(server_url: str) -> AsyncLetta: """ Creates and returns an asynchronous Letta REST client for testing. """ client_instance = AsyncLetta(base_url=server_url) yield client_instance @pytest.fixture(scope="function") async def agent_state(client: AsyncLetta) -> AgentState: """ Creates and returns an agent state for testing with a pre-configured agent. The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. """ dice_tool = await client.tools.upsert_from_function(func=roll_dice) agent_state_instance = await client.agents.create( agent_type="letta_v1_agent", name="test_agent", include_base_tools=False, tool_ids=[dice_tool.id], model="openai/gpt-4o", embedding="openai/text-embedding-3-small", tags=["test"], ) yield agent_state_instance await client.agents.delete(agent_state_instance.id) # ------------------------------ # Test Cases # ------------------------------ @pytest.mark.parametrize( "model_config", TESTED_MODEL_CONFIGS, ids=[handle for handle, _ in TESTED_MODEL_CONFIGS], ) @pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) @pytest.mark.asyncio(loop_scope="function") async def test_greeting( disable_e2b_api_key: Any, client: AsyncLetta, agent_state: AgentState, model_config: Tuple[str, dict], send_type: str, ) -> None: model_handle, model_settings = model_config last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1) last_message = last_message_page.items[0] if last_message_page.items else None agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings) if send_type == "step": response = await client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, ) messages = response.messages run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) elif send_type == "async": run = await client.agents.messages.create_async( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, ) run = await wait_for_run_completion(client, run.id, timeout=60.0) messages_page = await client.runs.messages.list(run_id=run.id) messages = [m for m in messages_page.items if m.message_type != "user_message"] run_id = run.id else: response = await client.agents.messages.stream( agent_id=agent_state.id, messages=USER_MESSAGE_FORCE_REPLY, stream_tokens=(send_type == "stream_tokens"), background=(send_type == "stream_tokens_background"), ) messages = await accumulate_chunks(response) run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) # If run_id is not in messages (e.g., due to early cancellation), get the most recent run if run_id is None: runs = await client.runs.list(agent_ids=[agent_state.id]) run_id = runs.items[0].id if runs.items else None assert_greeting_response( messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens") ) if "background" in send_type: response = await client.runs.messages.stream(run_id=run_id, starting_after=0) messages = await accumulate_chunks(response) assert_greeting_response( messages, model_handle, model_settings, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens") ) messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) messages_from_db = messages_from_db_page.items assert_greeting_response(messages_from_db, model_handle, model_settings, from_db=True) assert run_id is not None run = await client.runs.retrieve(run_id=run_id) assert run.status == "completed" @pytest.mark.parametrize( "model_config", TESTED_MODEL_CONFIGS, ids=[handle for handle, _ in TESTED_MODEL_CONFIGS], ) @pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"]) @pytest.mark.asyncio(loop_scope="function") async def test_parallel_tool_calls( disable_e2b_api_key: Any, client: AsyncLetta, agent_state: AgentState, model_config: Tuple[str, dict], send_type: str, ) -> None: model_handle, model_settings = model_config provider_type = model_settings.get("provider_type", "") if provider_type not in ["anthropic", "openai", "google_ai", "google_vertex"]: pytest.skip("Parallel tool calling test only applies to Anthropic, OpenAI, and Gemini models.") if "gpt-5" in model_handle or "o3" in model_handle: pytest.skip("GPT-5 takes too long to test, o3 is bad at this task.") # Skip Gemini models due to issues with parallel tool calling if provider_type in ["google_ai", "google_vertex"]: pytest.skip("Gemini models are flaky for this test so we disable them for now") # Update model_settings to enable parallel tool calling modified_model_settings = model_settings.copy() modified_model_settings["parallel_tool_calls"] = True agent_state = await client.agents.update( agent_id=agent_state.id, model=model_handle, model_settings=modified_model_settings, ) if send_type == "step": await client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_PARALLEL_TOOL_CALL, ) elif send_type == "async": run = await client.agents.messages.create_async( agent_id=agent_state.id, messages=USER_MESSAGE_PARALLEL_TOOL_CALL, ) await wait_for_run_completion(client, run.id, timeout=60.0) else: response = await client.agents.messages.stream( agent_id=agent_state.id, messages=USER_MESSAGE_PARALLEL_TOOL_CALL, stream_tokens=(send_type == "stream_tokens"), background=(send_type == "stream_tokens_background"), ) await accumulate_chunks(response) # validate parallel tool call behavior in preserved messages preserved_messages_page = await client.agents.messages.list(agent_id=agent_state.id) preserved_messages = preserved_messages_page.items # collect all ToolCallMessage and ToolReturnMessage instances tool_call_messages = [] tool_return_messages = [] for msg in preserved_messages: if isinstance(msg, ToolCallMessage): tool_call_messages.append(msg) elif isinstance(msg, ToolReturnMessage): tool_return_messages.append(msg) # Check if tool calls are grouped in a single message (parallel) or separate messages (sequential) total_tool_calls = 0 for i, tcm in enumerate(tool_call_messages): if hasattr(tcm, "tool_calls") and tcm.tool_calls: num_calls = len(tcm.tool_calls) if isinstance(tcm.tool_calls, list) else 1 total_tool_calls += num_calls elif hasattr(tcm, "tool_call"): total_tool_calls += 1 # Check tool returns structure total_tool_returns = 0 for i, trm in enumerate(tool_return_messages): if hasattr(trm, "tool_returns") and trm.tool_returns: num_returns = len(trm.tool_returns) if isinstance(trm.tool_returns, list) else 1 total_tool_returns += num_returns elif hasattr(trm, "tool_return"): total_tool_returns += 1 # CRITICAL: For TRUE parallel tool calling with letta_v1_agent, there should be exactly ONE ToolCallMessage # containing multiple tool calls, not multiple ToolCallMessages # Verify we have exactly 3 tool calls total assert total_tool_calls == 3, f"Expected exactly 3 tool calls total, got {total_tool_calls}" assert total_tool_returns == 3, f"Expected exactly 3 tool returns total, got {total_tool_returns}" # Check if we have true parallel tool calling is_parallel = False if len(tool_call_messages) == 1: # Check if the single message contains multiple tool calls tcm = tool_call_messages[0] if hasattr(tcm, "tool_calls") and isinstance(tcm.tool_calls, list) and len(tcm.tool_calls) == 3: is_parallel = True # IMPORTANT: Assert that parallel tool calling is actually working # This test should FAIL if parallel tool calling is not working properly assert is_parallel, ( f"Parallel tool calling is NOT working for {provider_type}! " f"Got {len(tool_call_messages)} ToolCallMessage(s) instead of 1 with 3 parallel calls. " f"When using letta_v1_agent with parallel_tool_calls=True, all tool calls should be in a single message." ) # Collect all tool calls and their details for validation all_tool_calls = [] tool_call_ids = set() num_sides_by_id = {} for tcm in tool_call_messages: if hasattr(tcm, "tool_calls") and tcm.tool_calls and isinstance(tcm.tool_calls, list): # Message has multiple tool calls for tc in tcm.tool_calls: all_tool_calls.append(tc) tool_call_ids.add(tc.tool_call_id) # Parse arguments import json args = json.loads(tc.arguments) num_sides_by_id[tc.tool_call_id] = int(args["num_sides"]) elif hasattr(tcm, "tool_call") and tcm.tool_call: # Message has single tool call tc = tcm.tool_call all_tool_calls.append(tc) tool_call_ids.add(tc.tool_call_id) # Parse arguments import json args = json.loads(tc.arguments) num_sides_by_id[tc.tool_call_id] = int(args["num_sides"]) # Verify each tool call for tc in all_tool_calls: assert tc.name == "roll_dice", f"Expected tool call name 'roll_dice', got '{tc.name}'" # Support Anthropic (toolu_), OpenAI (call_), and Gemini (UUID) tool call ID formats # Gemini uses UUID format which could start with any alphanumeric character valid_id_format = ( tc.tool_call_id.startswith("toolu_") or tc.tool_call_id.startswith("call_") or (len(tc.tool_call_id) > 0 and tc.tool_call_id[0].isalnum()) # UUID format for Gemini ) assert valid_id_format, f"Unexpected tool call ID format: {tc.tool_call_id}" # Collect all tool returns for validation all_tool_returns = [] for trm in tool_return_messages: if hasattr(trm, "tool_returns") and trm.tool_returns and isinstance(trm.tool_returns, list): # Message has multiple tool returns all_tool_returns.extend(trm.tool_returns) elif hasattr(trm, "tool_return") and trm.tool_return: # Message has single tool return (create a mock object if needed) # Since ToolReturnMessage might not have individual tool_return, check the structure pass # If all_tool_returns is empty, it means returns are structured differently # Let's check the actual structure if not all_tool_returns: print("Note: Tool returns may be structured differently than expected") # For now, just verify we got the right number of messages assert len(tool_return_messages) > 0, "No tool return messages found" # Verify tool returns if we have them in the expected format for tr in all_tool_returns: assert tr.type == "tool", f"Tool return type should be 'tool', got '{tr.type}'" assert tr.status == "success", f"Tool return status should be 'success', got '{tr.status}'" assert tr.tool_call_id in tool_call_ids, f"Tool return ID '{tr.tool_call_id}' not found in tool call IDs: {tool_call_ids}" # Verify the dice roll result is within the valid range dice_result = int(tr.tool_return) expected_max = num_sides_by_id[tr.tool_call_id] assert 1 <= dice_result <= expected_max, ( f"Dice roll result {dice_result} is not within valid range 1-{expected_max} for tool call {tr.tool_call_id}" ) @pytest.mark.parametrize( "model_config", TESTED_MODEL_CONFIGS, ids=[handle for handle, _ in TESTED_MODEL_CONFIGS], ) @pytest.mark.parametrize( ["send_type", "cancellation"], list( itertools.product( ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] ) ), ids=[ f"{s}-{c}" for s, c in itertools.product( ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"], ["with_cancellation", "no_cancellation"] ) ], ) @pytest.mark.asyncio(loop_scope="function") async def test_tool_call( disable_e2b_api_key: Any, client: AsyncLetta, agent_state: AgentState, model_config: Tuple[str, dict], send_type: str, cancellation: str, ) -> None: model_handle, model_settings = model_config # Skip models with OTID mismatch issues between ToolCallMessage and ToolReturnMessage if "gpt-5" in model_handle or "claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle: pytest.skip(f"Skipping {model_handle} due to OTID chain issue - messages receive incorrect OTID suffixes") last_message_page = await client.agents.messages.list(agent_id=agent_state.id, limit=1) last_message = last_message_page.items[0] if last_message_page.items else None agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings) if cancellation == "with_cancellation": delay = 5 if "gpt-5" in model_handle else 0.5 # increase delay for responses api _cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay)) if send_type == "step": response = await client.agents.messages.create( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, ) messages = response.messages run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) elif send_type == "async": run = await client.agents.messages.create_async( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, ) run = await wait_for_run_completion(client, run.id, timeout=60.0) messages_page = await client.runs.messages.list(run_id=run.id) messages = [m for m in messages_page.items if m.message_type != "user_message"] run_id = run.id else: response = await client.agents.messages.stream( agent_id=agent_state.id, messages=USER_MESSAGE_ROLL_DICE, stream_tokens=(send_type == "stream_tokens"), background=(send_type == "stream_tokens_background"), ) messages = await accumulate_chunks(response) run_id = next((msg.run_id for msg in messages if hasattr(msg, "run_id")), None) # If run_id is not in messages (e.g., due to early cancellation), get the most recent run if run_id is None: runs = await client.runs.list(agent_ids=[agent_state.id]) run_id = runs.items[0].id if runs.items else None assert_tool_call_response( messages, model_handle, model_settings, streaming=("stream" in send_type), with_cancellation=(cancellation == "with_cancellation") ) if "background" in send_type: response = await client.runs.messages.stream(run_id=run_id, starting_after=0) messages = await accumulate_chunks(response) assert_tool_call_response( messages, model_handle, model_settings, streaming=("stream" in send_type), with_cancellation=(cancellation == "with_cancellation"), ) messages_from_db_page = await client.agents.messages.list(agent_id=agent_state.id, after=last_message.id if last_message else None) messages_from_db = messages_from_db_page.items assert_tool_call_response( messages_from_db, model_handle, model_settings, from_db=True, with_cancellation=(cancellation == "with_cancellation") ) assert run_id is not None run = await client.runs.retrieve(run_id=run_id) assert run.status == ("cancelled" if cancellation == "with_cancellation" else "completed") @pytest.mark.parametrize( "model_config", TESTED_MODEL_CONFIGS, ids=[handle for handle, _ in TESTED_MODEL_CONFIGS], ) @pytest.mark.asyncio(loop_scope="function") async def test_conversation_streaming_raw_http( disable_e2b_api_key: Any, client: AsyncLetta, server_url: str, agent_state: AgentState, model_config: Tuple[str, dict], ) -> None: """ Test conversation-based streaming functionality using raw HTTP requests. This test verifies that: 1. A conversation can be created for an agent 2. Messages can be sent to the conversation via streaming 3. The streaming response contains the expected message types 4. Messages are properly persisted in the conversation Uses raw HTTP requests instead of SDK until SDK is regenerated with conversations support. """ import httpx model_handle, model_settings = model_config agent_state = await client.agents.update(agent_id=agent_state.id, model=model_handle, model_settings=model_settings) async with httpx.AsyncClient(base_url=server_url, timeout=60.0) as http_client: # Create a conversation for the agent create_response = await http_client.post( "/v1/conversations/", params={"agent_id": agent_state.id}, json={}, ) assert create_response.status_code == 200, f"Failed to create conversation: {create_response.text}" conversation = create_response.json() assert conversation["id"] is not None assert conversation["agent_id"] == agent_state.id # Send a message to the conversation using streaming stream_response = await http_client.post( f"/v1/conversations/{conversation['id']}/messages", json={ "messages": [{"role": "user", "content": f"Reply with the message '{USER_MESSAGE_RESPONSE}'."}], "stream_tokens": True, }, ) assert stream_response.status_code == 200, f"Failed to send message: {stream_response.text}" # Parse SSE response and accumulate messages messages = await accumulate_chunks(stream_response.text) print("MESSAGES:", messages) # Verify the response contains expected message types assert_greeting_response(messages, model_handle, model_settings, streaming=True, token_streaming=True) # Verify the conversation can be retrieved retrieve_response = await http_client.get(f"/v1/conversations/{conversation['id']}") assert retrieve_response.status_code == 200, f"Failed to retrieve conversation: {retrieve_response.text}" retrieved_conversation = retrieve_response.json() assert retrieved_conversation["id"] == conversation["id"] print("RETRIEVED CONVERSATION:", retrieved_conversation) # Verify conversations can be listed for the agent list_response = await http_client.get("/v1/conversations/", params={"agent_id": agent_state.id}) assert list_response.status_code == 200, f"Failed to list conversations: {list_response.text}" conversations_list = list_response.json() assert any(c["id"] == conversation["id"] for c in conversations_list) # Verify messages can be listed from the conversation messages_response = await http_client.get(f"/v1/conversations/{conversation['id']}/messages") assert messages_response.status_code == 200, f"Failed to list conversation messages: {messages_response.text}" conversation_messages = messages_response.json() print("CONVERSATION MESSAGES:", conversation_messages) # Verify we have at least the user message and assistant message assert len(conversation_messages) >= 2, f"Expected at least 2 messages, got {len(conversation_messages)}" # Check message types are present message_types = [msg.get("message_type") for msg in conversation_messages] assert "user_message" in message_types, f"Expected user_message in {message_types}" assert "assistant_message" in message_types, f"Expected assistant_message in {message_types}" @pytest.mark.parametrize( "model_handle,provider_type", [ ("openai/gpt-4o", "openai"), ("openai/gpt-5", "openai"), # ("anthropic/claude-sonnet-4-5-20250929", "anthropic"), ], ) @pytest.mark.asyncio(loop_scope="function") async def test_json_schema_response_format( disable_e2b_api_key: Any, client: AsyncLetta, model_handle: str, provider_type: str, ) -> None: """ Test JsonSchemaResponseFormat with OpenAI and Anthropic models. This test verifies that: 1. Agents can be created with json_schema response_format via model_settings 2. The schema is properly stored in the agent's model_settings 3. Messages sent to the agent produce responses conforming to the schema 4. Both OpenAI and Anthropic handle structured outputs correctly """ # Define the structured output schema response_schema = { "name": "capital_response", "strict": True, "schema": { "type": "object", "properties": { "response": {"type": "string", "description": "The answer to the question"}, "justification": {"type": "string", "description": "Why this is the answer"}, }, "required": ["response", "justification"], "additionalProperties": False, }, } # Create model settings with json_schema response format based on provider if provider_type == "openai": model_settings = OpenAIModelSettings( provider_type="openai", response_format=JsonSchemaResponseFormat(type="json_schema", json_schema=response_schema) ) else: model_settings = AnthropicModelSettings( provider_type="anthropic", response_format=JsonSchemaResponseFormat(type="json_schema", json_schema=response_schema) ) # Create agent with structured output configuration agent_state = await client.agents.create( name=f"test_structured_agent_{model_handle.replace('/', '_')}", model=model_handle, model_settings=model_settings, embedding="openai/text-embedding-3-small", agent_type="letta_v1_agent", ) try: # Send a message to the agent message_response = await client.agents.messages.create( agent_id=agent_state.id, messages=[MessageCreateParam(role="user", content="What is the capital of France?")] ) # Verify we got a response assert len(message_response.messages) > 0, "Should have received at least one message" # Find the assistant message and verify it contains valid JSON matching the schema assistant_message = None for msg in message_response.messages: if isinstance(msg, AssistantMessage): assistant_message = msg break assert assistant_message is not None, "Should have received an AssistantMessage" # Parse the content as JSON parsed_content = json.loads(assistant_message.content) # Verify the JSON has the required fields from our schema assert "response" in parsed_content, "JSON should contain 'response' field" assert "justification" in parsed_content, "JSON should contain 'justification' field" assert isinstance(parsed_content["response"], str), "'response' field should be a string" assert isinstance(parsed_content["justification"], str), "'justification' field should be a string" assert len(parsed_content["response"]) > 0, "'response' field should not be empty" assert len(parsed_content["justification"]) > 0, "'justification' field should not be empty" finally: # Cleanup await client.agents.delete(agent_state.id)