From 68a9e31eb15fcc4dbe3a9b4a0a2fcfabead9a58f Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 22 Apr 2025 15:31:45 -0700 Subject: [PATCH] chore: Add send_message sdk tests (#1842) --- tests/integration_test_send_message.py | 81 +++++++++----------------- 1 file changed, 27 insertions(+), 54 deletions(-) diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index 23f7af4a..67d57261 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -4,9 +4,10 @@ import time from typing import Any, Dict, List import pytest +import requests from dotenv import load_dotenv -from letta_client import AsyncLetta, Letta, Run, Tool -from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage +from letta_client import AsyncLetta, Letta, Run +from letta_client.types import AssistantMessage, ReasoningMessage from letta.schemas.agent import AgentState @@ -19,25 +20,35 @@ from letta.schemas.agent import AgentState def server_url() -> str: """ Provides the URL for the Letta server. - If the environment variable 'LETTA_SERVER_URL' is not set, this fixture - will start the Letta server in a background thread and return the default URL. + If LETTA_SERVER_URL is not set, starts the server in a background thread + and polls until it’s accepting connections. """ def _run_server() -> None: - """Starts the Letta server in a background thread.""" - load_dotenv() # Load environment variables from .env file + load_dotenv() from letta.server.rest_api.app import start_server start_server(debug=True) - # Retrieve server URL from environment, or default to localhost url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - # If no environment variable is set, start the server in a background thread if not os.getenv("LETTA_SERVER_URL"): thread = threading.Thread(target=_run_server, daemon=True) thread.start() - time.sleep(5) # Allow time for the server to start + + # Poll until the server is up (or timeout) + timeout_seconds = 30 + deadline = time.time() + timeout_seconds + while time.time() < deadline: + try: + resp = requests.get(url + "/v1/health") + if resp.status_code < 500: + break + except requests.exceptions.RequestException: + pass + time.sleep(0.1) + else: + raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s") return url @@ -61,29 +72,7 @@ def async_client(server_url: str) -> AsyncLetta: @pytest.fixture -def roll_dice_tool(client: Letta) -> Tool: - """ - Registers a simple roll dice tool with the provided client. - - The tool simulates rolling a six-sided die but returns a fixed result. - """ - - def roll_dice() -> str: - """ - Simulates rolling a die. - - Returns: - str: The roll result. - """ - # Note: The result here is intentionally incorrect for demonstration purposes. - return "Rolled a 10!" - - tool = client.tools.upsert_from_function(func=roll_dice) - yield tool - - -@pytest.fixture -def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: +def agent_state(client: Letta) -> AgentState: """ Creates and returns an agent state for testing with a pre-configured agent. The agent is named 'supervisor' and is configured with base tools and the roll_dice tool. @@ -91,7 +80,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: agent_state_instance = client.agents.create( name="supervisor", include_base_tools=True, - tool_ids=[roll_dice_tool.id], model="openai/gpt-4o", embedding="letta/letta-free", tags=["supervisor"], @@ -103,8 +91,8 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState: # Helper Functions and Constants # ------------------------------ -USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}] -TESTED_MODELS: List[str] = ["openai/gpt-4o"] +USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}] +TESTED_MODELS: List[str] = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022"] def assert_tool_response_messages(messages: List[Any]) -> None: @@ -114,10 +102,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None: ReasoningMessage -> AssistantMessage. """ assert isinstance(messages[0], ReasoningMessage) - assert isinstance(messages[1], ToolCallMessage) - assert isinstance(messages[2], ToolReturnMessage) - assert isinstance(messages[3], ReasoningMessage) - assert isinstance(messages[4], AssistantMessage) + assert isinstance(messages[1], AssistantMessage) def assert_streaming_tool_response_messages(chunks: List[Any]) -> None: @@ -130,16 +115,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None: return [c for c in chunks if isinstance(c, msg_type)] reasoning_msgs = msg_groups(ReasoningMessage) - tool_calls = msg_groups(ToolCallMessage) - tool_returns = msg_groups(ToolReturnMessage) assistant_msgs = msg_groups(AssistantMessage) - usage_stats = msg_groups(LettaUsageStatistics) - assert len(reasoning_msgs) >= 1 - assert len(tool_calls) == 1 - assert len(tool_returns) == 1 + assert len(reasoning_msgs) == 1 assert len(assistant_msgs) == 1 - assert len(usage_stats) == 1 def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run: @@ -161,7 +140,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i """ start = time.time() while True: - run = client.runs.retrieve_run(run_id) + run = client.runs.retrieve(run_id) if run.status == "completed": return run if run.status == "failed": @@ -184,13 +163,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None: """ assert isinstance(messages, list) assert messages[0]["message_type"] == "reasoning_message" - assert messages[1]["message_type"] == "tool_call_message" - assert messages[2]["message_type"] == "tool_return_message" - assert messages[3]["message_type"] == "reasoning_message" - assert messages[4]["message_type"] == "assistant_message" - - tool_return = messages[2] - assert tool_return["status"] == "success" + assert messages[1]["message_type"] == "assistant_message" # ------------------------------