Files
letta-server/tests/managers/test_run_manager.py
Matthew Zhou 126e51d584 feat: Support returning multiple tool returns [LET-4678] (#5202)
* Support returning multiple tool returns

* Fix test run managers
2025-10-07 17:50:50 -07:00

1149 lines
45 KiB
Python

import json
import logging
import os
import random
import re
import string
import time
import uuid
from datetime import datetime, timedelta, timezone
from typing import List
from unittest.mock import AsyncMock, Mock, patch
import pytest
from _pytest.python_api import approx
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
# Import shared fixtures and constants from conftest
from conftest import (
CREATE_DELAY_SQLITE,
DEFAULT_EMBEDDING_CONFIG,
USING_SQLITE,
)
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall, Function as OpenAIFunction
from sqlalchemy import func, select
from sqlalchemy.exc import IntegrityError, InvalidRequestError
from sqlalchemy.orm.exc import StaleDataError
from letta.config import LettaConfig
from letta.constants import (
BASE_MEMORY_TOOLS,
BASE_SLEEPTIME_TOOLS,
BASE_TOOLS,
BASE_VOICE_SLEEPTIME_CHAT_TOOLS,
BASE_VOICE_SLEEPTIME_TOOLS,
BUILTIN_TOOLS,
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
FILES_TOOLS,
LETTA_TOOL_EXECUTION_DIR,
LETTA_TOOL_SET,
LOCAL_ONLY_MULTI_AGENT_TOOLS,
MCP_TOOL_TAG_NAME_PREFIX,
MULTI_AGENT_TOOLS,
)
from letta.data_sources.redis_client import NoopAsyncRedisClient, get_redis_client
from letta.errors import LettaAgentNotFoundError
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.functions.mcp_client.types import MCPTool
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import AsyncTimer
from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo
from letta.orm import Base, Block
from letta.orm.block_history import BlockHistory
from letta.orm.errors import NoResultFound, UniqueConstraintViolationError
from letta.orm.file import FileContent as FileContentModel, FileMetadata as FileMetadataModel
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import (
ActorType,
AgentStepStatus,
FileProcessingStatus,
JobStatus,
JobType,
MessageRole,
ProviderType,
RunStatus,
SandboxType,
StepStatus,
TagMatchMode,
ToolType,
VectorDBProvider,
)
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate, SandboxEnvironmentVariableUpdate
from letta.schemas.file import FileMetadata, FileMetadata as PydanticFileMetadata
from letta.schemas.identity import IdentityCreate, IdentityProperty, IdentityPropertyType, IdentityType, IdentityUpdate, IdentityUpsert
from letta.schemas.job import Job as PydanticJob, LettaRequestConfig
from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, Message as PydanticMessage, MessageCreate, MessageUpdate, ToolReturn
from letta.schemas.openai.chat_completion_response import UsageStatistics
from letta.schemas.organization import Organization, Organization as PydanticOrganization, OrganizationUpdate
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.pip_requirement import PipRequirement
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.sandbox_config import E2BSandboxConfig, LocalSandboxConfig, SandboxConfigCreate, SandboxConfigUpdate
from letta.schemas.source import Source as PydanticSource, SourceUpdate
from letta.schemas.tool import Tool as PydanticTool, ToolCreate, ToolUpdate
from letta.schemas.tool_rule import InitToolRule
from letta.schemas.user import User as PydanticUser, UserUpdate
from letta.server.db import db_registry
from letta.server.server import SyncServer
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import calculate_base_tools, calculate_multi_agent_tools, validate_agent_exists_async
from letta.services.step_manager import FeedbackType
from letta.settings import settings, tool_settings
from letta.utils import calculate_file_defaults_based_on_context_window
from tests.helpers.utils import comprehensive_agent_checks, validate_context_window_overview
from tests.utils import random_string
# ======================================================================================================================
# RunManager Tests
# ======================================================================================================================
@pytest.mark.asyncio
async def test_create_run(server: SyncServer, sarah_agent, default_user):
"""Test creating a run."""
run_data = PydanticRun(
metadata={"type": "test"},
agent_id=sarah_agent.id,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
# Assertions to ensure the created run matches the expected values
assert created_run.agent_id == sarah_agent.id
assert created_run.created_at
assert created_run.status == RunStatus.created
assert created_run.metadata == {"type": "test"}
@pytest.mark.asyncio
async def test_get_run_by_id(server: SyncServer, sarah_agent, default_user):
"""Test fetching a run by ID."""
# Create a run
run_data = PydanticRun(
metadata={"type": "test"},
agent_id=sarah_agent.id,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
# Fetch the run by ID
fetched_run = await server.run_manager.get_run_by_id(created_run.id, actor=default_user)
# Assertions to ensure the fetched run matches the created run
assert fetched_run.id == created_run.id
assert fetched_run.status == RunStatus.created
assert fetched_run.metadata == {"type": "test"}
@pytest.mark.asyncio
async def test_list_runs(server: SyncServer, sarah_agent, default_user):
"""Test listing runs."""
# Create multiple runs
for i in range(3):
run_data = PydanticRun(
metadata={"type": f"test-{i}"},
agent_id=sarah_agent.id,
)
await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
# List runs
runs = await server.run_manager.list_runs(actor=default_user)
# Assertions to check that the created runs are listed
assert len(runs) == 3
assert all(run.agent_id == sarah_agent.id for run in runs)
assert all(run.metadata["type"].startswith("test") for run in runs)
@pytest.mark.asyncio
async def test_list_runs_with_metadata(server: SyncServer, sarah_agent, default_user):
for i in range(3):
run_data = PydanticRun(agent_id=sarah_agent.id)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
if i == 1:
await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user)
runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.completed])
assert len(runs) == 1
assert runs[0].status == RunStatus.completed
runs = await server.run_manager.list_runs(actor=default_user)
assert len(runs) == 3
@pytest.mark.asyncio
async def test_update_run_by_id(server: SyncServer, sarah_agent, default_user):
"""Test updating a run by its ID."""
# Create a run
run_data = PydanticRun(
metadata={"type": "test"},
agent_id=sarah_agent.id,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
# Update the run
updated_run = await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user)
# Assertions to ensure the run was updated
assert updated_run.status == RunStatus.completed
@pytest.mark.asyncio
async def test_delete_run_by_id(server: SyncServer, sarah_agent, default_user):
"""Test deleting a run by its ID."""
# Create a run
run_data = PydanticRun(
metadata={"type": "test"},
agent_id=sarah_agent.id,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
print("created_run to delete", created_run.id)
# Delete the run
await server.run_manager.delete_run(created_run.id, actor=default_user)
# Fetch the run by ID
with pytest.raises(NoResultFound):
await server.run_manager.get_run_by_id(created_run.id, actor=default_user)
# List runs to ensure the run was deleted
runs = await server.run_manager.list_runs(actor=default_user)
assert len(runs) == 0
@pytest.mark.asyncio
async def test_update_run_auto_complete(server: SyncServer, default_user, sarah_agent):
"""Test that updating a run's status to 'completed' automatically sets completed_at."""
# Create a run
run_data = PydanticRun(
metadata={"type": "test"},
agent_id=sarah_agent.id,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
assert created_run.completed_at is None
# Update the run to completed status
updated_run = await server.run_manager.update_run_by_id_async(created_run.id, RunUpdate(status=RunStatus.completed), actor=default_user)
# Check that completed_at was automatically set
assert updated_run.completed_at is not None
assert isinstance(updated_run.completed_at, datetime)
@pytest.mark.asyncio
async def test_get_run_not_found(server: SyncServer, default_user):
"""Test fetching a non-existent run."""
non_existent_run_id = "nonexistent-id"
with pytest.raises(NoResultFound):
await server.run_manager.get_run_by_id(non_existent_run_id, actor=default_user)
@pytest.mark.asyncio
async def test_delete_run_not_found(server: SyncServer, default_user):
"""Test deleting a non-existent run."""
non_existent_run_id = "nonexistent-id"
with pytest.raises(NoResultFound):
await server.run_manager.delete_run(non_existent_run_id, actor=default_user)
@pytest.mark.asyncio
async def test_list_runs_pagination(server: SyncServer, sarah_agent, default_user):
"""Test listing runs with pagination."""
# Create multiple runs
for i in range(10):
run_data = PydanticRun(agent_id=sarah_agent.id)
await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
# List runs with a limit
runs = await server.run_manager.list_runs(actor=default_user, limit=5)
assert len(runs) == 5
assert all(run.agent_id == sarah_agent.id for run in runs)
# Test cursor-based pagination
first_page = await server.run_manager.list_runs(actor=default_user, limit=3, ascending=True) # [J0, J1, J2]
assert len(first_page) == 3
assert first_page[0].created_at <= first_page[1].created_at <= first_page[2].created_at
last_page = await server.run_manager.list_runs(actor=default_user, limit=3, ascending=False) # [J9, J8, J7]
assert len(last_page) == 3
assert last_page[0].created_at >= last_page[1].created_at >= last_page[2].created_at
first_page_ids = set(run.id for run in first_page)
last_page_ids = set(run.id for run in last_page)
assert first_page_ids.isdisjoint(last_page_ids)
# Test middle page using both before and after
middle_page = await server.run_manager.list_runs(
actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=True
) # [J3, J4, J5, J6]
assert len(middle_page) == 4 # Should include jobs between first and second page
head_tail_jobs = first_page_ids.union(last_page_ids)
assert all(job.id not in head_tail_jobs for job in middle_page)
# NOTE: made some changes about assumptions ofr ascending
# Test descending order
middle_page_desc = await server.run_manager.list_runs(
# actor=default_user, before=last_page[-1].id, after=first_page[-1].id, ascending=False
actor=default_user,
before=first_page[-1].id,
after=last_page[-1].id,
ascending=False,
) # [J6, J5, J4, J3]
assert len(middle_page_desc) == 4
assert middle_page_desc[0].id == middle_page[-1].id
assert middle_page_desc[1].id == middle_page[-2].id
assert middle_page_desc[2].id == middle_page[-3].id
assert middle_page_desc[3].id == middle_page[-4].id
# BONUS
run_7 = last_page[-1].id
# earliest_runs = await server.run_manager.list_runs(actor=default_user, ascending=False, before=run_7)
earliest_runs = await server.run_manager.list_runs(actor=default_user, ascending=True, before=run_7)
assert len(earliest_runs) == 7
assert all(j.id not in last_page_ids for j in earliest_runs)
# assert all(earliest_runs[i].created_at >= earliest_runs[i + 1].created_at for i in range(len(earliest_runs) - 1))
assert all(earliest_runs[i].created_at <= earliest_runs[i + 1].created_at for i in range(len(earliest_runs) - 1))
@pytest.mark.asyncio
async def test_list_runs_by_status(server: SyncServer, default_user, sarah_agent):
"""Test listing runs filtered by status."""
# Create multiple runs with different statuses
run_data_created = PydanticRun(
status=RunStatus.created,
metadata={"type": "test-created"},
agent_id=sarah_agent.id,
)
run_data_in_progress = PydanticRun(
status=RunStatus.running,
metadata={"type": "test-running"},
agent_id=sarah_agent.id,
)
run_data_completed = PydanticRun(
status=RunStatus.completed,
metadata={"type": "test-completed"},
agent_id=sarah_agent.id,
)
await server.run_manager.create_run(pydantic_run=run_data_created, actor=default_user)
await server.run_manager.create_run(pydantic_run=run_data_in_progress, actor=default_user)
await server.run_manager.create_run(pydantic_run=run_data_completed, actor=default_user)
# List runs filtered by status
created_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.created])
in_progress_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.running])
completed_runs = await server.run_manager.list_runs(actor=default_user, statuses=[RunStatus.completed])
# Assertions
assert len(created_runs) == 1
assert created_runs[0].metadata["type"] == run_data_created.metadata["type"]
assert len(in_progress_runs) == 1
assert in_progress_runs[0].metadata["type"] == run_data_in_progress.metadata["type"]
assert len(completed_runs) == 1
assert completed_runs[0].metadata["type"] == run_data_completed.metadata["type"]
@pytest.mark.asyncio
async def test_list_runs_by_stop_reason(server: SyncServer, sarah_agent, default_user):
"""Test listing runs by stop reason."""
run_pydantic = PydanticRun(
agent_id=sarah_agent.id,
stop_reason=StopReasonType.requires_approval,
background=True,
)
run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user)
assert run.stop_reason == StopReasonType.requires_approval
assert run.background == True
assert run.agent_id == sarah_agent.id
# list runs by stop reason
runs = await server.run_manager.list_runs(actor=default_user, stop_reason=StopReasonType.requires_approval)
assert len(runs) == 1
assert runs[0].id == run.id
# list runs by background
runs = await server.run_manager.list_runs(actor=default_user, background=True)
assert len(runs) == 1
assert runs[0].id == run.id
# list runs by agent_id
runs = await server.run_manager.list_runs(actor=default_user, agent_ids=[sarah_agent.id])
assert len(runs) == 1
assert runs[0].id == run.id
async def test_e2e_run_callback(monkeypatch, server: SyncServer, default_user, sarah_agent):
"""Test that run callbacks are properly dispatched when a run is completed."""
captured = {}
# Create a simple mock for the async HTTP client
class MockAsyncResponse:
status_code = 202
async def mock_post(url, json, timeout):
captured["url"] = url
captured["json"] = json
return MockAsyncResponse()
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, timeout):
return await mock_post(url, json, timeout)
# Patch the AsyncClient
import letta.services.run_manager as run_manager_module
monkeypatch.setattr(run_manager_module, "AsyncClient", MockAsyncClient)
run_in = PydanticRun(
status=RunStatus.created, metadata={"foo": "bar"}, agent_id=sarah_agent.id, callback_url="http://example.test/webhook/runs"
)
created = await server.run_manager.create_run(pydantic_run=run_in, actor=default_user)
assert created.callback_url == "http://example.test/webhook/runs"
# Update the run status to completed, which should trigger the callback
updated = await server.run_manager.update_run_by_id_async(
created.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
)
# Verify the callback was triggered with the correct parameters
assert captured["url"] == created.callback_url, "Callback URL doesn't match"
assert captured["json"]["run_id"] == created.id, "Run ID in callback doesn't match"
assert captured["json"]["status"] == RunStatus.completed.value, "Run status in callback doesn't match"
# Verify the completed_at timestamp is reasonable
actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None)
# Remove timezone from updated.completed_at for comparison (it comes from DB as timezone-aware)
assert abs((actual_dt - updated.completed_at).total_seconds()) < 1, "Timestamp difference is too large"
assert isinstance(updated.callback_sent_at, datetime)
assert updated.callback_status_code == 202
@pytest.mark.asyncio
async def test_run_callback_only_on_terminal_status(server: SyncServer, sarah_agent, default_user, monkeypatch):
"""
Regression: ensure a non-terminal update (running) does NOT set completed_at or trigger callback,
and that a subsequent terminal update (completed) does trigger the callback exactly once.
"""
# Capture callback invocations
captured = {"count": 0, "url": None, "json": None}
class MockAsyncResponse:
status_code = 202
async def mock_post(url, json, timeout):
captured["count"] += 1
captured["url"] = url
captured["json"] = json
return MockAsyncResponse()
class MockAsyncClient:
async def __aenter__(self):
return self
async def __aexit__(self, *args):
pass
async def post(self, url, json, timeout):
return await mock_post(url, json, timeout)
# Patch the AsyncClient in run_manager module
import letta.services.run_manager as run_manager_module
monkeypatch.setattr(run_manager_module, "AsyncClient", MockAsyncClient)
# Create run with a callback URL
run_in = PydanticRun(
status=RunStatus.created,
metadata={"foo": "bar"},
agent_id=sarah_agent.id,
callback_url="http://example.test/webhook/runs",
)
created = await server.run_manager.create_run(pydantic_run=run_in, actor=default_user)
assert created.callback_url == "http://example.test/webhook/runs"
# 1) Non-terminal update: running
updated_running = await server.run_manager.update_run_by_id_async(created.id, RunUpdate(status=RunStatus.running), actor=default_user)
# Should not set completed_at or trigger callback
assert updated_running.completed_at is None
assert captured["count"] == 0
# 2) Terminal update: completed
updated_completed = await server.run_manager.update_run_by_id_async(
created.id, RunUpdate(status=RunStatus.completed, stop_reason=StopReasonType.end_turn), actor=default_user
)
# Should trigger exactly one callback with expected payload
assert captured["count"] == 1
assert captured["url"] == created.callback_url
assert captured["json"]["run_id"] == created.id
assert captured["json"]["status"] == RunStatus.completed.value
# completed_at should be set and align closely with callback payload
assert updated_completed.completed_at is not None
actual_dt = datetime.fromisoformat(captured["json"]["completed_at"]).replace(tzinfo=None)
assert abs((actual_dt - updated_completed.completed_at).total_seconds()) < 1
assert isinstance(updated_completed.callback_sent_at, datetime)
assert updated_completed.callback_status_code == 202
# ======================================================================================================================
# RunManager Tests - Messages
# ======================================================================================================================
@pytest.mark.asyncio
async def test_run_messages_pagination(server: SyncServer, default_run, default_user, sarah_agent):
"""Test pagination of run messages."""
# create the run
run_pydantic = PydanticRun(
agent_id=sarah_agent.id,
status=RunStatus.created,
metadata={"foo": "bar"},
)
run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user)
assert run.status == RunStatus.created
# Create multiple messages
message_ids = []
for i in range(5):
message = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text=f"Test message {i}")],
run_id=run.id,
)
msg = await server.message_manager.create_many_messages_async([message], actor=default_user)
message_ids.append(msg[0].id)
# Test pagination with limit
messages = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
limit=2,
)
assert len(messages) == 2
assert messages[0].id == message_ids[0]
assert messages[1].id == message_ids[1]
# Test pagination with cursor
first_page = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
limit=2,
ascending=True, # [M0, M1]
)
assert len(first_page) == 2
assert first_page[0].id == message_ids[0]
assert first_page[1].id == message_ids[1]
assert first_page[0].created_at <= first_page[1].created_at
last_page = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
limit=2,
ascending=False, # [M4, M3]
)
assert len(last_page) == 2
assert last_page[0].id == message_ids[4]
assert last_page[1].id == message_ids[3]
assert last_page[0].created_at >= last_page[1].created_at
first_page_ids = set(msg.id for msg in first_page)
last_page_ids = set(msg.id for msg in last_page)
assert first_page_ids.isdisjoint(last_page_ids)
# Test middle page using both before and after
middle_page = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
before=last_page[-1].id, # M3
after=first_page[0].id, # M0
ascending=True, # [M1, M2]
)
assert len(middle_page) == 2 # Should include message between first and last pages
assert middle_page[0].id == message_ids[1]
assert middle_page[1].id == message_ids[2]
head_tail_msgs = first_page_ids.union(last_page_ids)
assert middle_page[1].id not in head_tail_msgs
assert middle_page[0].id in first_page_ids
# Test descending order for middle page
middle_page = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
before=last_page[-1].id, # M3
after=first_page[0].id, # M0
ascending=False, # [M2, M1]
)
assert len(middle_page) == 2 # Should include message between first and last pages
assert middle_page[0].id == message_ids[2]
assert middle_page[1].id == message_ids[1]
# Test getting earliest messages
msg_3 = last_page[-1].id
earliest_msgs = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
ascending=False,
before=msg_3, # Get messages after M3 in descending order
)
assert len(earliest_msgs) == 3 # Should get M2, M1, M0
assert all(m.id not in last_page_ids for m in earliest_msgs)
assert earliest_msgs[0].created_at > earliest_msgs[1].created_at > earliest_msgs[2].created_at
# Test getting earliest messages with ascending order
earliest_msgs_ascending = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
ascending=True,
before=msg_3, # Get messages before M3 in ascending order
)
assert len(earliest_msgs_ascending) == 3 # Should get M0, M1, M2
assert all(m.id not in last_page_ids for m in earliest_msgs_ascending)
assert earliest_msgs_ascending[0].created_at < earliest_msgs_ascending[1].created_at < earliest_msgs_ascending[2].created_at
@pytest.mark.asyncio
async def test_run_messages_ordering(server: SyncServer, default_run, default_user, sarah_agent):
"""Test that messages are ordered by created_at."""
# Create messages with different timestamps
base_time = datetime.now(timezone.utc)
message_times = [
base_time - timedelta(minutes=2),
base_time - timedelta(minutes=1),
base_time,
]
# create the run
run_pydantic = PydanticRun(
agent_id=sarah_agent.id,
)
run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user)
assert run.status == RunStatus.created
for i, created_at in enumerate(message_times):
message = PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Test message")],
agent_id=sarah_agent.id,
created_at=created_at,
run_id=run.id,
)
msg = await server.message_manager.create_many_messages_async([message], actor=default_user)
# Verify messages are returned in chronological order
returned_messages = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
)
assert len(returned_messages) == 3
assert returned_messages[0].created_at < returned_messages[1].created_at
assert returned_messages[1].created_at < returned_messages[2].created_at
# Verify messages are returned in descending order
returned_messages = await server.message_manager.list_messages(
run_id=run.id,
actor=default_user,
ascending=False,
)
assert len(returned_messages) == 3
assert returned_messages[0].created_at > returned_messages[1].created_at
assert returned_messages[1].created_at > returned_messages[2].created_at
@pytest.mark.asyncio
async def test_job_messages_empty(server: SyncServer, default_run, default_user):
"""Test getting messages for a job with no messages."""
messages = await server.message_manager.list_messages(
run_id=default_run.id,
actor=default_user,
)
assert len(messages) == 0
@pytest.mark.asyncio
async def test_job_messages_filter(server: SyncServer, default_run, default_user, sarah_agent):
"""Test getting messages associated with a job."""
# Create the run
run_pydantic = PydanticRun(
agent_id=sarah_agent.id,
)
run = await server.run_manager.create_run(pydantic_run=run_pydantic, actor=default_user)
assert run.status == RunStatus.created
# Create test messages with different roles and tool calls
messages = [
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Hello")],
agent_id=sarah_agent.id,
run_id=default_run.id,
),
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(text="Hi there!")],
agent_id=sarah_agent.id,
run_id=default_run.id,
),
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(text="Let me help you with that")],
agent_id=sarah_agent.id,
tool_calls=[
OpenAIToolCall(
id="call_1",
type="function",
function=OpenAIFunction(
name="test_tool",
arguments='{"arg1": "value1"}',
),
)
],
run_id=default_run.id,
),
]
await server.message_manager.create_many_messages_async(messages, actor=default_user)
# Test getting all messages
all_messages = await server.message_manager.list_messages(
run_id=default_run.id,
actor=default_user,
)
assert len(all_messages) == 3
# Test filtering by role
user_messages = await server.message_manager.list_messages(run_id=default_run.id, actor=default_user, roles=[MessageRole.user])
assert len(user_messages) == 1
assert user_messages[0].role == MessageRole.user
# Test limit
limited_messages = await server.message_manager.list_messages(run_id=default_run.id, actor=default_user, limit=2)
assert len(limited_messages) == 2
@pytest.mark.asyncio
async def test_get_run_messages(server: SyncServer, default_user: PydanticUser, sarah_agent):
"""Test getting messages for a run with request config."""
# Create a run with custom request config
run = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=sarah_agent.id,
status=RunStatus.created,
request_config=LettaRequestConfig(
use_assistant_message=False, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
),
),
actor=default_user,
)
# Add some messages
messages = []
for i in range(4):
if i % 2 == 0:
# tool return message
messages.append(
PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.tool,
content=[TextContent(text='{"status": "OK"}')],
tool_call_id=f"call_{i // 2}",
tool_returns=[
ToolReturn(
tool_call_id=f"call_{i // 2}",
status="success",
func_response='{"status": "OK", "message": "Tool executed successfully"}',
)
],
run_id=run.id,
)
)
else:
# assistant message with tool call
messages.append(
PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text=f"Test message {i}")],
tool_calls=[
{
"type": "function",
"id": f"call_{i // 2}",
"function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'},
}
],
run_id=run.id,
)
)
created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user)
# Get messages and verify they're converted correctly
result = await server.message_manager.list_messages(run_id=run.id, actor=default_user)
result = Message.to_letta_messages_from_list(result)
# Verify correct number of messages. Assistant messages should be parsed
assert len(result) == 6
# Verify assistant messages are parsed according to request config
tool_call_messages = [msg for msg in result if msg.message_type == "tool_call_message"]
reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"]
assert len(tool_call_messages) == 2
assert len(reasoning_messages) == 2
for msg in tool_call_messages:
assert msg.tool_call is not None
assert msg.tool_call.name == "custom_tool"
@pytest.mark.asyncio
async def test_get_run_messages_with_assistant_message(server: SyncServer, default_user: PydanticUser, sarah_agent):
"""Test getting messages for a run with request config."""
# Create a run with custom request config
run = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=sarah_agent.id,
status=RunStatus.created,
request_config=LettaRequestConfig(
use_assistant_message=True, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
),
),
actor=default_user,
)
# Add some messages
messages = []
for i in range(4):
if i % 2 == 0:
# tool return message
messages.append(
PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.tool,
content=[TextContent(text='{"status": "OK"}')],
tool_call_id=f"call_{i // 2}",
tool_returns=[
ToolReturn(
tool_call_id=f"call_{i // 2}",
status="success",
func_response='{"status": "OK", "message": "Tool executed successfully"}',
)
],
run_id=run.id,
)
)
else:
# assistant message with tool call
messages.append(
PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.assistant,
content=[TextContent(text=f"Test message {i}")],
tool_calls=[
{
"type": "function",
"id": f"call_{i // 2}",
"function": {"name": "custom_tool", "arguments": '{"custom_arg": "test"}'},
}
],
run_id=run.id,
)
)
created_msg = await server.message_manager.create_many_messages_async(messages, actor=default_user)
# Get messages and verify they're converted correctly
result = await server.message_manager.list_messages(run_id=run.id, actor=default_user)
result = Message.to_letta_messages_from_list(
result, assistant_message_tool_name="custom_tool", assistant_message_tool_kwarg="custom_arg"
)
# Verify correct number of messages. Assistant messages should be parsed
assert len(result) == 4
# Verify assistant messages are parsed according to request config
assistant_messages = [msg for msg in result if msg.message_type == "assistant_message"]
reasoning_messages = [msg for msg in result if msg.message_type == "reasoning_message"]
assert len(assistant_messages) == 2
assert len(reasoning_messages) == 2
for msg in assistant_messages:
assert msg.content == "test"
for msg in reasoning_messages:
assert "Test message" in msg.reasoning
# ======================================================================================================================
# RunManager Tests - Usage Statistics -
# ======================================================================================================================
@pytest.mark.asyncio
async def test_run_usage_stats_add_and_get(server: SyncServer, sarah_agent, default_run, default_user):
"""Test adding and retrieving run usage statistics."""
run_manager = server.run_manager
step_manager = server.step_manager
# Add usage statistics
await step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
run_id=default_run.id,
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
),
actor=default_user,
project_id=sarah_agent.project_id,
)
# Get usage statistics
usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user)
# Verify the statistics
assert usage_stats.completion_tokens == 100
assert usage_stats.prompt_tokens == 50
assert usage_stats.total_tokens == 150
# get steps
steps = await step_manager.list_steps_async(run_id=default_run.id, actor=default_user)
assert len(steps) == 1
@pytest.mark.asyncio
async def test_run_usage_stats_get_no_stats(server: SyncServer, default_run, default_user):
"""Test getting usage statistics for a job with no stats."""
run_manager = server.run_manager
# Get usage statistics for a job with no stats
usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user)
# Verify default values
assert usage_stats.completion_tokens == 0
assert usage_stats.prompt_tokens == 0
assert usage_stats.total_tokens == 0
# get steps
steps = await server.step_manager.list_steps_async(run_id=default_run.id, actor=default_user)
assert len(steps) == 0
@pytest.mark.asyncio
async def test_run_usage_stats_add_multiple(server: SyncServer, sarah_agent, default_run, default_user):
"""Test adding multiple usage statistics entries for a job."""
run_manager = server.run_manager
step_manager = server.step_manager
# Add first usage statistics entry
await step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
usage=UsageStatistics(
completion_tokens=100,
prompt_tokens=50,
total_tokens=150,
),
actor=default_user,
project_id=sarah_agent.project_id,
run_id=default_run.id,
)
# Add second usage statistics entry
await step_manager.log_step_async(
agent_id=sarah_agent.id,
provider_name="openai",
provider_category="base",
model="gpt-4o-mini",
model_endpoint="https://api.openai.com/v1",
context_window_limit=8192,
usage=UsageStatistics(
completion_tokens=200,
prompt_tokens=100,
total_tokens=300,
),
actor=default_user,
project_id=sarah_agent.project_id,
run_id=default_run.id,
)
# Get usage statistics (should return the latest entry)
usage_stats = await run_manager.get_run_usage(run_id=default_run.id, actor=default_user)
# Verify we get the most recent statistics
assert usage_stats.completion_tokens == 300
assert usage_stats.prompt_tokens == 150
assert usage_stats.total_tokens == 450
assert usage_stats.step_count == 2
# get steps
steps = await step_manager.list_steps_async(run_id=default_run.id, actor=default_user)
assert len(steps) == 2
# get agent steps
steps = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user)
assert len(steps) == 2
# add step feedback
step_manager = server.step_manager
# Add feedback to first step
await step_manager.add_feedback_async(step_id=steps[0].id, feedback=FeedbackType.POSITIVE, actor=default_user)
# Test has_feedback filtering
steps_with_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, has_feedback=True, actor=default_user)
assert len(steps_with_feedback) == 1
steps_without_feedback = await step_manager.list_steps_async(agent_id=sarah_agent.id, actor=default_user)
assert len(steps_without_feedback) == 2
@pytest.mark.asyncio
async def test_run_usage_stats_get_nonexistent_run(server: SyncServer, default_user):
"""Test getting usage statistics for a nonexistent run."""
run_manager = server.run_manager
with pytest.raises(NoResultFound):
await run_manager.get_run_usage(run_id="nonexistent_run", actor=default_user)
@pytest.mark.asyncio
async def test_get_run_request_config(server: SyncServer, sarah_agent, default_user):
"""Test getting request config from a run."""
request_config = LettaRequestConfig(
use_assistant_message=True, assistant_message_tool_name="send_message", assistant_message_tool_kwarg="message"
)
run_data = PydanticRun(
agent_id=sarah_agent.id,
request_config=request_config,
)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
retrieved_config = await server.run_manager.get_run_request_config(created_run.id, actor=default_user)
assert retrieved_config is not None
assert retrieved_config.use_assistant_message == request_config.use_assistant_message
assert retrieved_config.assistant_message_tool_name == request_config.assistant_message_tool_name
assert retrieved_config.assistant_message_tool_kwarg == request_config.assistant_message_tool_kwarg
@pytest.mark.asyncio
async def test_get_run_request_config_none(server: SyncServer, sarah_agent, default_user):
"""Test getting request config from a run with no config."""
run_data = PydanticRun(agent_id=sarah_agent.id)
created_run = await server.run_manager.create_run(pydantic_run=run_data, actor=default_user)
retrieved_config = await server.run_manager.get_run_request_config(created_run.id, actor=default_user)
assert retrieved_config is None
@pytest.mark.asyncio
async def test_get_run_request_config_nonexistent_run(server: SyncServer, default_user):
"""Test getting request config for a nonexistent run."""
with pytest.raises(NoResultFound):
await server.run_manager.get_run_request_config("nonexistent_run", actor=default_user)
# TODO: add back once metrics are added
# @pytest.mark.asyncio
# async def test_record_ttft(server: SyncServer, default_user):
# """Test recording time to first token for a job."""
# # Create a job
# job_data = PydanticJob(
# status=RunStatus.created,
# metadata={"type": "test_timing"},
# )
# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
#
# # Record TTFT
# ttft_ns = 1_500_000_000 # 1.5 seconds in nanoseconds
# await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user)
#
# # Fetch the job and verify TTFT was recorded
# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
# assert updated_job.ttft_ns == ttft_ns
#
#
# @pytest.mark.asyncio
# async def test_record_response_duration(server: SyncServer, default_user):
# """Test recording total response duration for a job."""
# # Create a job
# job_data = PydanticJob(
# status=RunStatus.created,
# metadata={"type": "test_timing"},
# )
# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
#
# # Record response duration
# duration_ns = 5_000_000_000 # 5 seconds in nanoseconds
# await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user)
#
# # Fetch the job and verify duration was recorded
# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
# assert updated_job.total_duration_ns == duration_ns
#
#
# @pytest.mark.asyncio
# async def test_record_timing_metrics_together(server: SyncServer, default_user):
# """Test recording both TTFT and response duration for a job."""
# # Create a job
# job_data = PydanticJob(
# status=RunStatus.created,
# metadata={"type": "test_timing_combined"},
# )
# created_job = await server.job_manager.create_job_async(pydantic_job=job_data, actor=default_user)
#
# # Record both metrics
# ttft_ns = 2_000_000_000 # 2 seconds in nanoseconds
# duration_ns = 8_500_000_000 # 8.5 seconds in nanoseconds
#
# await server.job_manager.record_ttft(created_job.id, ttft_ns, default_user)
# await server.job_manager.record_response_duration(created_job.id, duration_ns, default_user)
#
# # Fetch the job and verify both metrics were recorded
# updated_job = await server.job_manager.get_job_by_id_async(created_job.id, default_user)
# assert updated_job.ttft_ns == ttft_ns
# assert updated_job.total_duration_ns == duration_ns
#
#
# @pytest.mark.asyncio
# async def test_record_timing_invalid_job(server: SyncServer, default_user):
# """Test recording timing metrics for non-existent job fails gracefully."""
# # Try to record TTFT for non-existent job - should not raise exception but log warning
# await server.job_manager.record_ttft("nonexistent_job_id", 1_000_000_000, default_user)
#
# # Try to record response duration for non-existent job - should not raise exception but log warning
# await server.job_manager.record_response_duration("nonexistent_job_id", 2_000_000_000, default_user)
#