534 lines
20 KiB
Python
534 lines
20 KiB
Python
import base64
|
|
import json
|
|
import os
|
|
import threading
|
|
import time
|
|
import uuid
|
|
from contextlib import contextmanager
|
|
from http.server import BaseHTTPRequestHandler, HTTPServer
|
|
from typing import Any, Dict, List
|
|
from unittest.mock import patch
|
|
|
|
import httpx
|
|
import pytest
|
|
import requests
|
|
from dotenv import load_dotenv
|
|
from letta_client import AsyncLetta, Letta, LettaRequest, MessageCreate, Run
|
|
from letta_client.core.api_error import ApiError
|
|
from letta_client.types import (
|
|
AssistantMessage,
|
|
Base64Image,
|
|
HiddenReasoningMessage,
|
|
ImageContent,
|
|
LettaMessageUnion,
|
|
LettaStopReason,
|
|
LettaUsageStatistics,
|
|
ReasoningMessage,
|
|
TextContent,
|
|
ToolCallMessage,
|
|
ToolReturnMessage,
|
|
UrlImage,
|
|
UserMessage,
|
|
)
|
|
|
|
from letta.log import get_logger
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.enums import AgentType, JobStatus
|
|
from letta.schemas.letta_ping import LettaPing
|
|
from letta.schemas.llm_config import LLMConfig
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
# ------------------------------
|
|
# Helper Functions and Constants
|
|
# ------------------------------
|
|
|
|
|
|
all_configs = [
|
|
"openai-gpt-4o-mini.json",
|
|
"openai-o3.json",
|
|
"claude-3-5-sonnet.json",
|
|
"claude-3-7-sonnet-extended.json",
|
|
"gemini-2.5-flash.json",
|
|
]
|
|
|
|
|
|
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
|
|
filename = os.path.join(llm_config_dir, filename)
|
|
with open(filename, "r") as f:
|
|
config_data = json.load(f)
|
|
llm_config = LLMConfig(**config_data)
|
|
return llm_config
|
|
|
|
|
|
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in all_configs]
|
|
|
|
|
|
def roll_dice(num_sides: int) -> int:
|
|
"""
|
|
Returns a random number between 1 and num_sides.
|
|
Args:
|
|
num_sides (int): The number of sides on the die.
|
|
Returns:
|
|
int: A random integer between 1 and num_sides, representing the die roll.
|
|
"""
|
|
import random
|
|
|
|
return random.randint(1, num_sides)
|
|
|
|
|
|
USER_MESSAGE_OTID = str(uuid.uuid4())
|
|
USER_MESSAGE_RESPONSE: str = "Teamwork makes the dream work"
|
|
USER_MESSAGE_FORCE_REPLY: List[MessageCreate] = [
|
|
MessageCreate(
|
|
role="user",
|
|
content=f"This is an automated test message. Reply with the message '{USER_MESSAGE_RESPONSE}'.",
|
|
otid=USER_MESSAGE_OTID,
|
|
)
|
|
]
|
|
USER_MESSAGE_ROLL_DICE: List[MessageCreate] = [
|
|
MessageCreate(
|
|
role="user",
|
|
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
|
|
otid=USER_MESSAGE_OTID,
|
|
)
|
|
]
|
|
|
|
|
|
def assert_greeting_response(
|
|
messages: List[Any],
|
|
llm_config: LLMConfig,
|
|
streaming: bool = False,
|
|
token_streaming: bool = False,
|
|
from_db: bool = False,
|
|
) -> None:
|
|
"""
|
|
Asserts that the messages list follows the expected sequence:
|
|
ReasoningMessage -> AssistantMessage.
|
|
"""
|
|
# Filter out LettaPing messages which are keep-alive messages for SSE streams
|
|
messages = [
|
|
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
|
]
|
|
|
|
expected_message_count = get_expected_message_count(llm_config, streaming=streaming, from_db=from_db)
|
|
assert len(messages) == expected_message_count
|
|
|
|
# User message if loaded from db
|
|
index = 0
|
|
if from_db:
|
|
assert isinstance(messages[index], UserMessage)
|
|
assert messages[index].otid == USER_MESSAGE_OTID
|
|
index += 1
|
|
|
|
# Reasoning message if reasoning enabled
|
|
otid_suffix = 0
|
|
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
|
|
if LLMConfig.is_openai_reasoning_model(llm_config):
|
|
assert isinstance(messages[index], HiddenReasoningMessage)
|
|
else:
|
|
assert isinstance(messages[index], ReasoningMessage)
|
|
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Assistant message
|
|
assert isinstance(messages[index], AssistantMessage)
|
|
if not token_streaming:
|
|
assert "teamwork" in messages[index].content.lower()
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Stop reason and usage statistics if streaming
|
|
if streaming:
|
|
assert isinstance(messages[index], LettaStopReason)
|
|
assert messages[index].stop_reason == "end_turn"
|
|
index += 1
|
|
assert isinstance(messages[index], LettaUsageStatistics)
|
|
assert messages[index].prompt_tokens > 0
|
|
assert messages[index].completion_tokens > 0
|
|
assert messages[index].total_tokens > 0
|
|
assert messages[index].step_count > 0
|
|
|
|
|
|
def assert_tool_call_response(
|
|
messages: List[Any],
|
|
llm_config: LLMConfig,
|
|
streaming: bool = False,
|
|
from_db: bool = False,
|
|
) -> None:
|
|
"""
|
|
Asserts that the messages list follows the expected sequence:
|
|
ReasoningMessage -> ToolCallMessage -> ToolReturnMessage ->
|
|
ReasoningMessage -> AssistantMessage.
|
|
"""
|
|
# Filter out LettaPing messages which are keep-alive messages for SSE streams
|
|
messages = [
|
|
msg for msg in messages if not (isinstance(msg, LettaPing) or (hasattr(msg, "message_type") and msg.message_type == "ping"))
|
|
]
|
|
|
|
expected_message_count = get_expected_message_count(llm_config, tool_call=True, streaming=streaming, from_db=from_db)
|
|
assert len(messages) == expected_message_count
|
|
|
|
# User message if loaded from db
|
|
index = 0
|
|
if from_db:
|
|
assert isinstance(messages[index], UserMessage)
|
|
assert messages[index].otid == USER_MESSAGE_OTID
|
|
index += 1
|
|
|
|
# Reasoning message if reasoning enabled
|
|
otid_suffix = 0
|
|
if LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config):
|
|
if LLMConfig.is_openai_reasoning_model(llm_config):
|
|
assert isinstance(messages[index], HiddenReasoningMessage)
|
|
else:
|
|
assert isinstance(messages[index], ReasoningMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Assistant message
|
|
if llm_config.model_endpoint_type == "anthropic":
|
|
assert isinstance(messages[index], AssistantMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Tool call message
|
|
assert isinstance(messages[index], ToolCallMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
|
|
# Tool return message
|
|
otid_suffix = 0
|
|
assert isinstance(messages[index], ToolReturnMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
|
|
# Reasoning message if reasoning enabled for openai models
|
|
otid_suffix = 0
|
|
if LLMConfig.is_openai_reasoning_model(llm_config):
|
|
assert isinstance(messages[index], HiddenReasoningMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Assistant message
|
|
assert isinstance(messages[index], AssistantMessage)
|
|
assert messages[index].otid and messages[index].otid[-1] == str(otid_suffix)
|
|
index += 1
|
|
otid_suffix += 1
|
|
|
|
# Stop reason and usage statistics if streaming
|
|
if streaming:
|
|
assert isinstance(messages[index], LettaStopReason)
|
|
assert messages[index].stop_reason == "end_turn"
|
|
index += 1
|
|
assert isinstance(messages[index], LettaUsageStatistics)
|
|
assert messages[index].prompt_tokens > 0
|
|
assert messages[index].completion_tokens > 0
|
|
assert messages[index].total_tokens > 0
|
|
assert messages[index].step_count > 0
|
|
|
|
|
|
async def accumulate_chunks(chunks: List[Any], verify_token_streaming: bool = False) -> List[Any]:
|
|
"""
|
|
Accumulates chunks into a list of messages.
|
|
"""
|
|
messages = []
|
|
current_message = None
|
|
prev_message_type = None
|
|
chunk_count = 0
|
|
async for chunk in chunks:
|
|
current_message_type = chunk.message_type
|
|
if prev_message_type != current_message_type:
|
|
messages.append(current_message)
|
|
if (
|
|
prev_message_type
|
|
and verify_token_streaming
|
|
and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]
|
|
):
|
|
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}. Messages: {messages}"
|
|
current_message = None
|
|
chunk_count = 0
|
|
if current_message is None:
|
|
current_message = chunk
|
|
else:
|
|
pass # TODO: actually accumulate the chunks. For now we only care about the count
|
|
prev_message_type = current_message_type
|
|
chunk_count += 1
|
|
messages.append(current_message)
|
|
if verify_token_streaming and current_message.message_type in ["reasoning_message", "assistant_message", "tool_call_message"]:
|
|
assert chunk_count > 1, f"Expected more than one chunk for {current_message.message_type}"
|
|
|
|
return [m for m in messages if m is not None]
|
|
|
|
|
|
async def wait_for_run_completion(client: AsyncLetta, run_id: str, timeout: float = 30.0, interval: float = 0.5) -> Run:
|
|
start = time.time()
|
|
while True:
|
|
run = await client.runs.retrieve(run_id)
|
|
if run.status == "completed":
|
|
return run
|
|
if run.status == "failed":
|
|
print(run)
|
|
raise RuntimeError(f"Run {run_id} did not complete: status = {run.status}")
|
|
if time.time() - start > timeout:
|
|
raise TimeoutError(f"Run {run_id} did not complete within {timeout} seconds (last status: {run.status})")
|
|
time.sleep(interval)
|
|
|
|
|
|
def get_expected_message_count(llm_config: LLMConfig, tool_call: bool = False, streaming: bool = False, from_db: bool = False) -> int:
|
|
"""
|
|
Returns the expected number of messages for a given LLM configuration.
|
|
|
|
Greeting:
|
|
---------------------------------------------------------------------------------------------------------------------------------------
|
|
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
|
|
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
|
|
| AssistantMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | AssistantMessage |
|
|
| | AssistantMessage | | AssistantMessage | |
|
|
|
|
|
|
Tool Call:
|
|
---------------------------------------------------------------------------------------------------------------------------------------
|
|
| gpt-4o | gpt-o3 | sonnet-3-5 | sonnet-3.7-thinking | flash-2.5-thinking |
|
|
| ------------------------ | ------------------------ | ------------------------ | ------------------------ | ------------------------ |
|
|
| ToolCallMessage | HiddenReasoningMessage | AssistantMessage | ReasoningMessage | ToolCallMessage |
|
|
| ToolReturnMessage | ToolCallMessage | ToolCallMessage | AssistantMessage | ToolReturnMessage |
|
|
| AssistantMessage | ToolReturnMessage | ToolReturnMessage | ToolCallMessage | AssistantMessage |
|
|
| | HiddenReasoningMessage | AssistantMessage | ToolReturnMessage | |
|
|
| | AssistantMessage | | AssistantMessage | |
|
|
|
|
"""
|
|
is_reasoner_model = LLMConfig.is_openai_reasoning_model(llm_config) or LLMConfig.is_anthropic_reasoning_model(llm_config)
|
|
|
|
# assistant message
|
|
expected_message_count = 1
|
|
|
|
if is_reasoner_model:
|
|
# reasoning message
|
|
expected_message_count += 1
|
|
|
|
if tool_call:
|
|
# tool call and tool return messages
|
|
expected_message_count += 2
|
|
if llm_config.model_endpoint_type == "anthropic":
|
|
# anthropic models return an assistant message first before the tool call message
|
|
expected_message_count += 1
|
|
if LLMConfig.is_openai_reasoning_model(llm_config):
|
|
# openai reasoning models return an additional reasoning message before final assistant message
|
|
expected_message_count += 1
|
|
|
|
if from_db:
|
|
# user message
|
|
expected_message_count += 1
|
|
|
|
if streaming:
|
|
# stop reason and usage statistics
|
|
expected_message_count += 2
|
|
|
|
return expected_message_count
|
|
|
|
|
|
# ------------------------------
|
|
# Fixtures
|
|
# ------------------------------
|
|
|
|
|
|
@pytest.fixture(scope="module")
|
|
def server_url() -> str:
|
|
"""
|
|
Provides the URL for the Letta server.
|
|
If LETTA_SERVER_URL is not set, starts the server in a background thread
|
|
and polls until it's accepting connections.
|
|
"""
|
|
|
|
def _run_server() -> None:
|
|
load_dotenv()
|
|
from letta.server.rest_api.app import start_server
|
|
|
|
start_server(debug=True)
|
|
|
|
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
|
|
|
|
if not os.getenv("LETTA_SERVER_URL"):
|
|
thread = threading.Thread(target=_run_server, daemon=True)
|
|
thread.start()
|
|
|
|
# Poll until the server is up (or timeout)
|
|
timeout_seconds = 30
|
|
deadline = time.time() + timeout_seconds
|
|
while time.time() < deadline:
|
|
try:
|
|
resp = requests.get(url + "/v1/health")
|
|
if resp.status_code < 500:
|
|
break
|
|
except requests.exceptions.RequestException:
|
|
pass
|
|
time.sleep(0.1)
|
|
else:
|
|
raise RuntimeError(f"Could not reach {url} within {timeout_seconds}s")
|
|
|
|
return url
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def client(server_url: str) -> AsyncLetta:
|
|
"""
|
|
Creates and returns an asynchronous Letta REST client for testing.
|
|
"""
|
|
client_instance = AsyncLetta(base_url=server_url)
|
|
yield client_instance
|
|
|
|
|
|
@pytest.fixture(scope="function")
|
|
async def agent_state(client: AsyncLetta) -> AgentState:
|
|
"""
|
|
Creates and returns an agent state for testing with a pre-configured agent.
|
|
The agent is named 'supervisor' and is configured with base tools and the roll_dice tool.
|
|
"""
|
|
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
|
|
|
|
agent_state_instance = await client.agents.create(
|
|
agent_type=AgentType.letta_v1_agent,
|
|
name="test_agent",
|
|
include_base_tools=False,
|
|
tool_ids=[dice_tool.id],
|
|
model="openai/gpt-4o",
|
|
embedding="openai/text-embedding-3-small",
|
|
tags=["test"],
|
|
)
|
|
yield agent_state_instance
|
|
|
|
await client.agents.delete(agent_state_instance.id)
|
|
|
|
|
|
# ------------------------------
|
|
# Test Cases
|
|
# ------------------------------
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"llm_config",
|
|
TESTED_LLM_CONFIGS,
|
|
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
|
)
|
|
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
|
@pytest.mark.asyncio(loop_scope="function")
|
|
async def test_greeting(
|
|
disable_e2b_api_key: Any,
|
|
client: AsyncLetta,
|
|
agent_state: AgentState,
|
|
llm_config: LLMConfig,
|
|
send_type: str,
|
|
) -> None:
|
|
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
|
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
|
|
|
if send_type == "step":
|
|
response = await client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_FORCE_REPLY,
|
|
)
|
|
messages = response.messages
|
|
run_id = messages[0].run_id
|
|
elif send_type == "async":
|
|
run = await client.agents.messages.create_async(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_FORCE_REPLY,
|
|
)
|
|
run = await wait_for_run_completion(client, run.id)
|
|
messages = await client.runs.messages.list(run_id=run.id)
|
|
messages = [m for m in messages if m.message_type != "user_message"]
|
|
run_id = run.id
|
|
else:
|
|
response = client.agents.messages.create_stream(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_FORCE_REPLY,
|
|
stream_tokens=(send_type == "stream_tokens"),
|
|
background=(send_type == "stream_tokens_background"),
|
|
)
|
|
messages = await accumulate_chunks(response)
|
|
run_id = messages[0].run_id
|
|
|
|
assert_greeting_response(
|
|
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
|
|
)
|
|
|
|
if "background" in send_type:
|
|
response = client.runs.stream(run_id=run_id, starting_after=0)
|
|
messages = await accumulate_chunks(response)
|
|
assert_greeting_response(
|
|
messages, streaming=("stream" in send_type), token_streaming=(send_type == "stream_tokens"), llm_config=llm_config
|
|
)
|
|
|
|
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
|
assert_greeting_response(messages_from_db, from_db=True, llm_config=llm_config)
|
|
|
|
assert run_id is not None
|
|
run = await client.runs.retrieve(run_id=run_id)
|
|
assert run.status == JobStatus.completed
|
|
|
|
|
|
@pytest.mark.parametrize(
|
|
"llm_config",
|
|
TESTED_LLM_CONFIGS,
|
|
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
|
)
|
|
@pytest.mark.parametrize("send_type", ["step", "stream_steps", "stream_tokens", "stream_tokens_background", "async"])
|
|
@pytest.mark.asyncio(loop_scope="function")
|
|
async def test_tool_call(
|
|
disable_e2b_api_key: Any,
|
|
client: AsyncLetta,
|
|
agent_state: AgentState,
|
|
llm_config: LLMConfig,
|
|
send_type: str,
|
|
) -> None:
|
|
last_message = await client.agents.messages.list(agent_id=agent_state.id, limit=1)
|
|
agent_state = await client.agents.modify(agent_id=agent_state.id, llm_config=llm_config)
|
|
|
|
if send_type == "step":
|
|
response = await client.agents.messages.create(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_ROLL_DICE,
|
|
)
|
|
messages = response.messages
|
|
run_id = messages[0].run_id
|
|
elif send_type == "async":
|
|
run = await client.agents.messages.create_async(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_ROLL_DICE,
|
|
)
|
|
run = await wait_for_run_completion(client, run.id)
|
|
messages = await client.runs.messages.list(run_id=run.id)
|
|
messages = [m for m in messages if m.message_type != "user_message"]
|
|
run_id = run.id
|
|
else:
|
|
response = client.agents.messages.create_stream(
|
|
agent_id=agent_state.id,
|
|
messages=USER_MESSAGE_ROLL_DICE,
|
|
stream_tokens=(send_type == "stream_tokens"),
|
|
background=(send_type == "stream_tokens_background"),
|
|
)
|
|
messages = await accumulate_chunks(response)
|
|
run_id = messages[0].run_id
|
|
|
|
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
|
|
|
|
if "background" in send_type:
|
|
response = client.runs.stream(run_id=run_id, starting_after=0)
|
|
messages = await accumulate_chunks(response)
|
|
assert_tool_call_response(messages, streaming=("stream" in send_type), llm_config=llm_config)
|
|
|
|
messages_from_db = await client.agents.messages.list(agent_id=agent_state.id, after=last_message[0].id)
|
|
assert_tool_call_response(messages_from_db, from_db=True, llm_config=llm_config)
|
|
|
|
assert run_id is not None
|
|
run = await client.runs.retrieve(run_id=run_id)
|
|
assert run.status == JobStatus.completed
|