From 2af3130be1268584b38ceed138ba291bb58d0910 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 30 Sep 2025 13:03:46 -0700 Subject: [PATCH] feat: add integration test for new agent loop (#5020) --- tests/integration_test_send_message_v2.py | 294 ++++++++++++++++++++++ 1 file changed, 294 insertions(+) create mode 100644 tests/integration_test_send_message_v2.py diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py new file mode 100644 index 00000000..37ed101f --- /dev/null +++ b/tests/integration_test_send_message_v2.py @@ -0,0 +1,294 @@ +import base64 +import json +import os +import threading +import time +import uuid +from contextlib import contextmanager +from http.server import BaseHTTPRequestHandler, HTTPServer +from typing import Any, Dict, List +from unittest.mock import patch + +import httpx +import pytest +import requests +from dotenv import load_dotenv +from letta_client import AsyncLetta, Letta, LettaRequest, MessageCreate, Run +from letta_client.core.api_error import ApiError +from letta_client.types import ( + AssistantMessage, + Base64Image, + HiddenReasoningMessage, + ImageContent, + LettaMessageUnion, + LettaStopReason, + LettaUsageStatistics, + ReasoningMessage, + TextContent, + ToolCallMessage, + ToolReturnMessage, + UrlImage, + UserMessage, +) + +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import AgentType +from letta.schemas.letta_ping import LettaPing +from letta.schemas.llm_config import LLMConfig + +logger = get_logger(__name__) + + +# ------------------------------ +# Helper Functions and Constants +# ------------------------------ + + +all_configs = [ + "claude-3-5-sonnet.json", + "claude-3-7-sonnet-extended.json", +] + + +def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig: + filename = os.path.join(llm_config_dir, filename) + with open(filename, "r") as f: + config_data = json.load(f) + llm_config = LLMConfig(**config_data) + return llm_config + + +TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in all_configs] + + +def roll_dice(num_sides: int) -> int: + """ + Returns a random number between 1 and num_sides. + Args: + num_sides (int): The number of sides on the die. + Returns: + int: A random integer between 1 and num_sides, representing the die roll. + """ + import random + + return random.randint(1, num_sides) + + +USER_MESSAGE_OTID = str(uuid.uuid4()) +USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work" +USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [ + MessageCreate( + role="user", + content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.", + otid=USER_MESSAGE_OTID, + ) +] + + +def assert_greeting_response( + messages: List[Any], + llm_config: LLMConfig, + streaming: bool = False, + token_streaming: bool = False, + from_db: bool = False, +) -> None: + """ + Asserts that the messages list follows the expected sequence: + ReasoningMessage -> AssistantMessage. + """ + # Filter out LettaPing messages which are keep-alive messages for SSE streams + messages = [ + msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping")) + ] + + is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config) + expected_message_count = 3 if streaming else 2 if from_db else 1 + assert len(messages) == expected_message_count + (1 if is_reasoner_model else 0) + + # User message if loaded from db + index = 0 + if from_db: + assert isinstance(messages[index], UserMessage) + assert messages[index].otid == USER_MESSAGE_OTID + index += 1 + + # Reasoning message if reasoning enabled + otid_suffix = 0 + if is_reasoner_model: + if LLMConfig.is_openai_reasoning_model(llm_config): + assert isinstance(messages[index], HiddenReasoningMessage) + else: + assert isinstance(messages[index], ReasoningMessage) + + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Assistant message + assert isinstance(messages[index], AssistantMessage) + if not token_streaming: + assert "teamwork" in messages[index].content.lower() + assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix) + index += 1 + otid_suffix += 1 + + # Stop reason and usage statistics if streaming + if streaming: + assert isinstance(messages[index], LettaStopReason) + assert messages[index].stop_reason == "end_turn" + index += 1 + assert isinstance(messages[index], LettaUsageStatistics) + assert messages[index].prompt_tokens > 0 + assert messages[index].completion_tokens > 0 + assert messages[index].total_tokens > 0 + assert messages[index].step_count > 0 + + +async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]: + """ + Accumulates chunks into a list of messages. + """ + messages = [] + current_message = None + prev_message_type = None + chunk_count = 0 + async for chunk in chunks: + current_message_type = chunk.message_type + if prev_message_type != current_message_type: + messages.append(current_message) + if ( + prev_message_type + and verify_token_streaming + and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"] + ): + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}. Messages: {messages}" + current_message = None + chunk_count = 0 + if current_message is None: + current_message = chunk + else: + pass # TODO: actually accumulate the chunks. For now we only care about the count + prev_message_type = current_message_type + chunk_count += 1 + messages.append(current_message) + if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]: + assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}" + + return [m for m in messages if m is not None] + + +# ------------------------------ +# Fixtures +# ------------------------------ + + +@pytest.fixture(scope="module") +def server_url() -> str: + """ + Provides the URL for the Letta server. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it's accepting connections. + """ + + def _run_server() -> None: + load_dotenv() + from letta.server.rest_api.app import start_server + + start_server(debug=True) + + url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") + + if not os.getenv("LETTA_SERVER_URL"): + thread = threading.Thread(target=_run_server, daemon=True) + thread.start() + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") + + return url + + +@pytest.fixture(scope="function") +async def client(server_url: str) -> AsyncLetta: + """ + Creates and returns an asynchronous Letta REST client for testing. + """ + client_instance = AsyncLetta(base_url=server_url) + yield client_instance + + +@pytest.fixture(scope="function") +async def agent_state(client: AsyncLetta) -> AgentState: + """ + Creates and returns an agent state for testing with a pre-configured agent. + The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. + """ + dice_tool = await client.tools.upsert_from_function(func=roll_dice) + + agent_state_instance = await client.agents.create( + agent_type=AgentType.letta_v1_agent, + name="test_agent", + include_base_tools=False, + tool_ids=[dice_tool.id], + model="openai/gpt-4o", + embedding="openai/text-embedding-3-small", + tags=["test"], + ) + yield agent_state_instance + + await client.agents.delete(agent_state_instance.id) + + +# ------------------------------ +# Test Cases +# ------------------------------ + + +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens"]) +@pytest.mark.asyncio(scope="function") +async def test_greeting( + disable_e2b_api_key: Any, + client: AsyncLetta, + agent_state: AgentState, + llm_config: LLMConfig, + send_type: str, +) -> None: + last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1) + agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config) + + if send_type == "step": + response = await client.agents.messages.create( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + ) + messages = response.messages + else: + response = client.agents.messages.create_stream( + agent_id=agent_state.id, + messages=USER_MESSAGE_FORCE_REPLY, + stream_tokens=(send_type == "stream_tokens"), + ) + messages = await accumulate_chunks(response) + + assert_greeting_response( + messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config + ) + messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id) + assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config)