chore: Add send_message sdk tests (#1842)
This commit is contained in:
@@ -4,9 +4,10 @@ import time
|
||||
from typing import Any, Dict, List
|
||||
|
||||
import pytest
|
||||
import requests
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import AsyncLetta, Letta, Run, Tool
|
||||
from letta_client.types import AssistantMessage, LettaUsageStatistics, ReasoningMessage, ToolCallMessage, ToolReturnMessage
|
||||
from letta_client import AsyncLetta, Letta, Run
|
||||
from letta_client.types import AssistantMessage, ReasoningMessage
|
||||
|
||||
from letta.schemas.agent import AgentState
|
||||
|
||||
@@ -19,25 +20,35 @@ from letta.schemas.agent import AgentState
|
||||
def server_url() -> str:
|
||||
"""
|
||||
Provides the URL for the Letta server.
|
||||
If the environment variable 'LETTA_SERVER_URL' is not set, this fixture
|
||||
will start the Letta server in a background thread and return the default URL.
|
||||
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
||||
and polls until it’s accepting connections.
|
||||
"""
|
||||
|
||||
def _run_server() -> None:
|
||||
"""Starts the Letta server in a background thread."""
|
||||
load_dotenv() # Load environment variables from .env file
|
||||
load_dotenv()
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
start_server(debug=True)
|
||||
|
||||
# Retrieve server URL from environment, or default to localhost
|
||||
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
||||
|
||||
# If no environment variable is set, start the server in a background thread
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
thread = threading.Thread(target=_run_server, daemon=True)
|
||||
thread.start()
|
||||
time.sleep(5) # Allow time for the server to start
|
||||
|
||||
# Poll until the server is up (or timeout)
|
||||
timeout_seconds = 30
|
||||
deadline = time.time() + timeout_seconds
|
||||
while time.time() < deadline:
|
||||
try:
|
||||
resp = requests.get(url + "/v1/health")
|
||||
if resp.status_code < 500:
|
||||
break
|
||||
except requests.exceptions.RequestException:
|
||||
pass
|
||||
time.sleep(0.1)
|
||||
else:
|
||||
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
||||
|
||||
return url
|
||||
|
||||
@@ -61,29 +72,7 @@ def async_client(server_url: str) -> AsyncLetta:
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def roll_dice_tool(client: Letta) -> Tool:
|
||||
"""
|
||||
Registers a simple roll dice tool with the provided client.
|
||||
|
||||
The tool simulates rolling a six-sided die but returns a fixed result.
|
||||
"""
|
||||
|
||||
def roll_dice() -> str:
|
||||
"""
|
||||
Simulates rolling a die.
|
||||
|
||||
Returns:
|
||||
str: The roll result.
|
||||
"""
|
||||
# Note: The result here is intentionally incorrect for demonstration purposes.
|
||||
return "Rolled a 10!"
|
||||
|
||||
tool = client.tools.upsert_from_function(func=roll_dice)
|
||||
yield tool
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
def agent_state(client: Letta) -> AgentState:
|
||||
"""
|
||||
Creates and returns an agent state for testing with a pre-configured agent.
|
||||
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
||||
@@ -91,7 +80,6 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
agent_state_instance = client.agents.create(
|
||||
name="supervisor",
|
||||
include_base_tools=True,
|
||||
tool_ids=[roll_dice_tool.id],
|
||||
model="openai/gpt-4o",
|
||||
embedding="letta/letta-free",
|
||||
tags=["supervisor"],
|
||||
@@ -103,8 +91,8 @@ def agent_state(client: Letta, roll_dice_tool: Tool) -> AgentState:
|
||||
# Helper Functions and Constants
|
||||
# ------------------------------
|
||||
|
||||
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Roll the dice."}]
|
||||
TESTED_MODELS: List[str] = ["openai/gpt-4o"]
|
||||
USER_MESSAGE: List[Dict[str, str]] = [{"role": "user", "content": "Hi there."}]
|
||||
TESTED_MODELS: List[str] = ["openai/gpt-4o", "anthropic/claude-3-5-sonnet-20241022"]
|
||||
|
||||
|
||||
def assert_tool_response_messages(messages: List[Any]) -> None:
|
||||
@@ -114,10 +102,7 @@ def assert_tool_response_messages(messages: List[Any]) -> None:
|
||||
ReasoningMessage -> AssistantMessage.
|
||||
"""
|
||||
assert isinstance(messages[0], ReasoningMessage)
|
||||
assert isinstance(messages[1], ToolCallMessage)
|
||||
assert isinstance(messages[2], ToolReturnMessage)
|
||||
assert isinstance(messages[3], ReasoningMessage)
|
||||
assert isinstance(messages[4], AssistantMessage)
|
||||
assert isinstance(messages[1], AssistantMessage)
|
||||
|
||||
|
||||
def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
|
||||
@@ -130,16 +115,10 @@ def assert_streaming_tool_response_messages(chunks: List[Any]) -> None:
|
||||
return [c for c in chunks if isinstance(c, msg_type)]
|
||||
|
||||
reasoning_msgs = msg_groups(ReasoningMessage)
|
||||
tool_calls = msg_groups(ToolCallMessage)
|
||||
tool_returns = msg_groups(ToolReturnMessage)
|
||||
assistant_msgs = msg_groups(AssistantMessage)
|
||||
usage_stats = msg_groups(LettaUsageStatistics)
|
||||
|
||||
assert len(reasoning_msgs) >= 1
|
||||
assert len(tool_calls) == 1
|
||||
assert len(tool_returns) == 1
|
||||
assert len(reasoning_msgs) == 1
|
||||
assert len(assistant_msgs) == 1
|
||||
assert len(usage_stats) == 1
|
||||
|
||||
|
||||
def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
||||
@@ -161,7 +140,7 @@ def wait_for_run_completion(client: Letta, run_id: str, timeout: float = 30.0, i
|
||||
"""
|
||||
start = time.time()
|
||||
while True:
|
||||
run = client.runs.retrieve_run(run_id)
|
||||
run = client.runs.retrieve(run_id)
|
||||
if run.status == "completed":
|
||||
return run
|
||||
if run.status == "failed":
|
||||
@@ -184,13 +163,7 @@ def assert_tool_response_dict_messages(messages: List[Dict[str, Any]]) -> None:
|
||||
"""
|
||||
assert isinstance(messages, list)
|
||||
assert messages[0]["message_type"] == "reasoning_message"
|
||||
assert messages[1]["message_type"] == "tool_call_message"
|
||||
assert messages[2]["message_type"] == "tool_return_message"
|
||||
assert messages[3]["message_type"] == "reasoning_message"
|
||||
assert messages[4]["message_type"] == "assistant_message"
|
||||
|
||||
tool_return = messages[2]
|
||||
assert tool_return["status"] == "success"
|
||||
assert messages[1]["message_type"] == "assistant_message"
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
Reference in New Issue
Block a user