chore: Add send_message sdk tests (#1842)

This commit is contained in:
Matthew Zhou
2025-04-22 15:31:45 -07:00
committed by GitHub
parent cc8e47779f
commit 68a9e31eb1

View File

@@ -4,9 +4,10 @@ import time
from typing import Any, Dict, List
import pytest
import requests
from dotenv import load_dotenv
from letta_client import AsyncLetta, Letta, Run, Tool
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta_client import AsyncLetta, Letta, Run
from letta_client.types import AssistantMessage, ReasoningMessage
from letta.schemas.agent import AgentState
@@ -19,25 +20,35 @@ from letta.schemas.agent import AgentState
def server_url() -> str:
"""
Provides the URL for the Letta server.
If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
will start the Letta server in a background thread and return the default URL.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until its accepting connections.
"""
def _run_server() -> None:
"""Starts the Letta server in a background thread."""
load_dotenv() # Load environment variables from .env file
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
# Retrieve server URL from environment, or default to localhost
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
# If no environment variable is set, start the server in a background thread
if not os.getenv("LETTA_SERVER_URL"):
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
time.sleep(5) # Allow time for the server to start
# Poll until the server is up (or timeout)
timeout_seconds = 30
deadline = time.time() + timeout_seconds
while time.time() < deadline:
try:
resp = requests.get(url + "/v1/health")
if resp.status_code < 500:
break
except requests.exceptions.RequestException:
pass
time.sleep(0.1)
else:
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
return url
@@ -61,29 +72,7 @@ def async_client(server_url: str) -> AsyncLetta:
@pytest.fixture
def roll_dice_tool(client: Letta) -> Tool:
"""
Registers a simple roll dice tool with the provided client.
The tool simulates rolling a six-sided die but returns a fixed result.
"""
def roll_dice() -> str:
"""
Simulates rolling a die.
Returns:
str: The roll result.
"""
# Note: The result here is intentionally incorrect for demonstration purposes.
return "Rolled a 10!"
tool = client.tools.upsert_from_function(func=roll_dice)
yield tool
@pytest.fixture
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
def agent_state(client: Letta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
@@ -91,7 +80,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
agent_state_instance = client.agents.create(
name="supervisor",
include_base_tools=True,
tool_ids=[roll_dice_tool.id],
model="openai/gpt-4o",
embedding="letta/letta-free",
tags=["supervisor"],
@@ -103,8 +91,8 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
# Helper Functions and Constants
# ------------------------------
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}]
TESTED_MODELS: List[str] = ["openai/gpt-4o"]
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
TESTED_MODELS: List[str] = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022"]
def assert_tool_response_messages(messages: List[Any]) -> None:
@@ -114,10 +102,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None:
ReasoningMessage -> AssistantMessage.
"""
assert isinstance(messages[0], ReasoningMessage)
assert isinstance(messages[1], ToolCallMessage)
assert isinstance(messages[2], ToolReturnMessage)
assert isinstance(messages[3], ReasoningMessage)
assert isinstance(messages[4], AssistantMessage)
assert isinstance(messages[1], AssistantMessage)
def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
@@ -130,16 +115,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
return [c for c in chunks if isinstance(c, msg_type)]
reasoning_msgs = msg_groups(ReasoningMessage)
tool_calls = msg_groups(ToolCallMessage)
tool_returns = msg_groups(ToolReturnMessage)
assistant_msgs = msg_groups(AssistantMessage)
usage_stats = msg_groups(LettaUsageStatistics)
assert len(reasoning_msgs) >= 1
assert len(tool_calls) == 1
assert len(tool_returns) == 1
assert len(reasoning_msgs) == 1
assert len(assistant_msgs) == 1
assert len(usage_stats) == 1
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
@@ -161,7 +140,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
"""
start = time.time()
while True:
run = client.runs.retrieve_run(run_id)
run = client.runs.retrieve(run_id)
if run.status == "completed":
return run
if run.status == "failed":
@@ -184,13 +163,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
"""
assert isinstance(messages, list)
assert messages[0]["message_type"] == "reasoning_message"
assert messages[1]["message_type"] == "tool_call_message"
assert messages[2]["message_type"] == "tool_return_message"
assert messages[3]["message_type"] == "reasoning_message"
assert messages[4]["message_type"] == "assistant_message"
tool_return = messages[2]
assert tool_return["status"] == "success"
assert messages[1]["message_type"] == "assistant_message"
# ------------------------------