Files
letta-server/tests/integration_test_cancellation.py

201 lines
6.1 KiB
Python

import asyncio
import json
import os
import threading
from typing import Any, List
import pytest
from dotenv import load_dotenv
from letta_client import AsyncLetta
from letta_client.types import MessageCreateParam
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import AgentType, JobStatus
from letta.schemas.llm_config import LLMConfig
logger = get_logger(__name__)
def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model_configs") -> LLMConfig:
filename = os.path.join(llm_config_dir, filename)
with open(filename, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
return llm_config
all_configs = [
"openai-gpt-4o-mini.json",
"openai-o3.json",
"openai-gpt-5.json",
"claude-4-5-sonnet.json",
"claude-4-1-opus.json",
"gemini-2.5-flash.json",
]
requested = os.getenv("LLM_CONFIG_FILE")
filenames = [requested] if requested else all_configs
TESTED_LLM_CONFIGS: List[LLMConfig] = [get_llm_config(fn) for fn in filenames]
def roll_dice(num_sides: int) -> int:
"""
Returns a random number between 1 and num_sides.
Args:
num_sides (int): The number of sides on the die.
Returns:
int: A random integer between 1 and num_sides, representing the die roll.
"""
import random
return random.randint(1, num_sides)
USER_MESSAGE_ROLL_DICE: List[MessageCreateParam] = [
MessageCreateParam(
role="user",
content="This is an automated test message. Call the roll_dice tool with 16 sides and reply back to me with the outcome.",
)
]
async def accumulate_chunks(chunks: Any) -> List[Any]:
"""
Accumulates chunks into a list of messages.
"""
messages = []
current_message = None
prev_message_type = None
async for chunk in chunks:
current_message_type = chunk.message_type
if prev_message_type != current_message_type:
messages.append(current_message)
current_message = chunk
else:
current_message = chunk
prev_message_type = current_message_type
messages.append(current_message)
return [m for m in messages if m is not None]
async def cancel_run_after_delay(client: AsyncLetta, agent_id: str, delay: float = 0.5):
await asyncio.sleep(delay)
await client.agents.messages.cancel(agent_id=agent_id)
@pytest.fixture(scope="module")
def server_url() -> str:
"""
Provides the URL for the Letta server.
If LETTA_SERVER_URL is not set, starts the server in a background thread
and polls until it's accepting connections.
"""
def _run_server() -> None:
load_dotenv()
from letta.server.rest_api.app import start_server
start_server(debug=True)
url: str = os.getenv("LETTA_SERVER_URL", "http://localhost:8283")
if not os.getenv("LETTA_SERVER_URL"):
if os.getenv("LETTA_REDIS_HOST"):
print(f"Redis is configured at {os.getenv('LETTA_REDIS_HOST')}:{os.getenv('LETTA_REDIS_PORT', '6379')}")
thread = threading.Thread(target=_run_server, daemon=True)
thread.start()
timeout_seconds = 30
import time
import httpx
start_time = time.time()
while time.time() - start_time < timeout_seconds:
try:
response = httpx.get(url + "/v1/health", timeout=1.0, follow_redirects=True)
if response.status_code < 500:
break
except Exception:
pass
time.sleep(0.5)
else:
raise TimeoutError(f"Server at {url} did not become ready in {timeout_seconds}s")
return url
@pytest.fixture(scope="function")
async def client(server_url: str) -> AsyncLetta:
"""
Creates and returns an asynchronous Letta REST client for testing.
"""
client_instance = AsyncLetta(base_url=server_url)
yield client_instance
@pytest.fixture(scope="function")
async def agent_state(client: AsyncLetta) -> AgentState:
"""
Creates and returns an agent state for testing with a pre-configured agent.
The agent is configured with the roll_dice tool.
"""
dice_tool = await client.tools.upsert_from_function(func=roll_dice)
agent_state_instance = await client.agents.create(
agent_type=AgentType.letta_v1_agent,
name="test_agent",
include_base_tools=False,
tool_ids=[dice_tool.id],
model="openai/gpt-4o",
embedding="openai/text-embedding-3-small",
tags=["test"],
)
yield agent_state_instance
await client.agents.delete(agent_state_instance.id)
@pytest.mark.skipif(not os.getenv("LETTA_REDIS_HOST"), reason="Redis is required for background streaming (set LETTA_REDIS_HOST to enable)")
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
@pytest.mark.asyncio(loop_scope="function")
async def test_background_streaming_cancellation(
disable_e2b_api_key: Any,
client: AsyncLetta,
agent_state: AgentState,
llm_config: LLMConfig,
) -> None:
agent_state = await client.agents.update(agent_id=agent_state.id, llm_config=llm_config)
delay = 1.5
_cancellation_task = asyncio.create_task(cancel_run_after_delay(client, agent_state.id, delay=delay))
response = await client.agents.messages.stream(
agent_id=agent_state.id,
messages=USER_MESSAGE_ROLL_DICE,
stream_tokens=True,
background=True,
)
messages = await accumulate_chunks(response)
run_id = messages[0].run_id if hasattr(messages[0], "run_id") else None
await _cancellation_task
if run_id:
run = await client.runs.retrieve(run_id=run_id)
assert run.status == JobStatus.cancelled
else:
runs = await client.runs.list(agent_id=agent_state.id, stop_reason="cancelled", limit=1)
assert len(list(runs)) == 1
run_id = runs.items[0].id
response = await client.runs.messages.stream(run_id=run_id, starting_after=0)
messages_from_stream = await accumulate_chunks(response)
assert len(messages_from_stream) > 0