Files
letta-server/tests/test_server.py
2025-11-13 15:36:00 -08:00

298 lines
12 KiB
Python

import json
import os
import shutil
import uuid
import warnings
from typing import List, Tuple
from unittest.mock import patch
import pytest
from sqlalchemy import delete
import letta.utils as utils
from letta.agents.agent_loop import AgentLoop
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
from letta.orm import Provider, Step
from letta.schemas.block import CreateBlock
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers import Provider as PydanticProvider, ProviderCreate
from letta.schemas.sandbox_config import SandboxType
from letta.schemas.user import User
utils.DEBUG = True
from letta.config import LettaConfig
from letta.orm.errors import NoResultFound
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.message import Message, MessageCreate
from letta.schemas.run import Run as PydanticRun
from letta.schemas.source import Source as PydanticSource
from letta.server.server import SyncServer
from letta.system import unpack_message
from .utils import DummyDataConnector
@pytest.fixture
async def server():
config = LettaConfig.load()
config.save()
server = SyncServer(init_with_default_org_and_user=True)
await server.init_async()
await server.tool_manager.upsert_base_tools_async(actor=server.default_user)
yield server
@pytest.fixture
async def org_id(server):
# create org
org = await server.organization_manager.create_default_organization_async()
yield org.id
# cleanup
await server.organization_manager.delete_organization_by_id_async(org.id)
@pytest.fixture
async def user(server, org_id):
user = await server.user_manager.create_default_actor_async(org_id=org_id)
yield user
@pytest.fixture
def user_id(server, user):
# create user
yield user.id
provider_name = "custom-anthropic29"
@pytest.fixture
async def custom_anthropic_provider(server: SyncServer, user_id: str):
actor = await server.user_manager.get_actor_or_default_async()
# check if provider already exists
existing_providers = await server.provider_manager.list_providers_async(actor=actor)
for provider in existing_providers:
if provider.name == provider_name:
# delete provider
await server.provider_manager.delete_provider_by_id_async(provider.id, actor=actor)
provider = await server.provider_manager.create_provider_async(
ProviderCreate(
name=provider_name,
provider_type=ProviderType.anthropic,
api_key=os.getenv("ANTHROPIC_API_KEY"),
),
actor=actor,
)
yield provider
# Try to delete provider if it still exists (test may have already deleted it)
try:
await server.provider_manager.delete_provider_by_id_async(provider.id, actor=actor)
except NoResultFound:
pass # Provider was already deleted in the test
@pytest.fixture
async def agent(server: SyncServer, user: User):
actor = await server.user_manager.get_actor_or_default_async()
agent = await server.create_agent_async(
CreateAgent(
agent_type="memgpt_v2_agent",
),
)
return agent
@pytest.mark.asyncio
async def test_messages_with_provider_override(server: SyncServer, custom_anthropic_provider: PydanticProvider, user):
# list the models
models = await server.list_llm_models_async(actor=user)
for model in models:
if model.provider_name == provider_name:
print(model.model)
actor = await server.user_manager.get_actor_or_default_async()
agent = await server.create_agent_async(
CreateAgent(
agent_type="letta_v1_agent",
memory_blocks=[],
model=f"{provider_name}/claude-sonnet-4-5-20250929",
context_window_limit=100000,
embedding="openai/text-embedding-ada-002",
include_base_tools=False,
),
actor=actor,
)
existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
# send a message
run = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=agent.id,
background=False,
),
actor=actor,
)
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
response = await agent_loop.step(
input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
run_id=run.id,
)
usage = response.usage
messages = response.messages
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
# usage = await server.message_manager.create_message(user_id=actor.id, agent_id=agent.id, message="Test message")
# assert usage, "Sending message failed"
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
# assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
assert step, "Step was not logged correctly"
# assert step.provider_id == custom_anthropic_provider.id
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens
# await server.provider_manager.delete_provider_by_id_async(custom_anthropic_provider.id, actor=actor)
# existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
## with pytest.raises(NoResultFound):
# agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
# response = await agent_loop.step(
# input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
# run_id=run.id,
# )
# print("RESULT", response)
# usage = await server.message_manager.create_user_message_async(user_id=actor.id, agent_id=agent.id, message="Test message")
# assert usage, "Sending message failed"
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
# assert len(get_messages_response) > 0, "Retrieving messages failed"
# step_ids = set([msg.step_id for msg in get_messages_response])
# completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
# for step_id in step_ids:
# step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
# assert step, "Step was not logged correctly"
# assert step.provider_id == None
# assert step.provider_name == agent.llm_config.model_endpoint_type
# assert step.model == agent.llm_config.model
# assert step.context_window_limit == agent.llm_config.context_window
# completion_tokens += int(step.completion_tokens)
# prompt_tokens += int(step.prompt_tokens)
# total_tokens += int(step.total_tokens)
# assert completion_tokens == usage.completion_tokens
# assert prompt_tokens == usage.prompt_tokens
# assert total_tokens == usage.total_tokens
@pytest.mark.asyncio
async def test_messages_with_provider_override_legacy_agent(server: SyncServer, custom_anthropic_provider: PydanticProvider, user):
# list the models
models = await server.list_llm_models_async(actor=user)
for model in models:
if model.provider_name == provider_name:
print(model.model)
actor = await server.user_manager.get_actor_or_default_async()
agent = await server.create_agent_async(
CreateAgent(
agent_type="memgpt_v2_agent",
memory_blocks=[],
model=f"{provider_name}/claude-sonnet-4-5-20250929",
context_window_limit=100000,
embedding="openai/text-embedding-ada-002",
),
actor=actor,
)
existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
# send a message
run = await server.run_manager.create_run(
pydantic_run=PydanticRun(
agent_id=agent.id,
background=False,
),
actor=actor,
)
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
response = await agent_loop.step(
input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
run_id=run.id,
)
usage = response.usage
messages = response.messages
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
# usage = await server.message_manager.create_message(user_id=actor.id, agent_id=agent.id, message="Test message")
# assert usage, "Sending message failed"
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
# assert len(get_messages_response) > 0, "Retrieving messages failed"
step_ids = set([msg.step_id for msg in get_messages_response])
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
for step_id in step_ids:
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
assert step, "Step was not logged correctly"
# assert step.provider_id == custom_anthropic_provider.id
assert step.provider_name == agent.llm_config.model_endpoint_type
assert step.model == agent.llm_config.model
assert step.context_window_limit == agent.llm_config.context_window
completion_tokens += int(step.completion_tokens)
prompt_tokens += int(step.prompt_tokens)
total_tokens += int(step.total_tokens)
assert completion_tokens == usage.completion_tokens
assert prompt_tokens == usage.prompt_tokens
assert total_tokens == usage.total_tokens
# await server.provider_manager.delete_provider_by_id_async(custom_anthropic_provider.id, actor=actor)
# existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
## with pytest.raises(NoResultFound):
# agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
# response = await agent_loop.step(
# input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
# run_id=run.id,
# )
# print("RESULT", response)
@pytest.mark.asyncio
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
models = await server.list_llm_models_async(actor=user)
model_handles = [model.handle for model in models]
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
embeddings = await server.list_embedding_models_async(actor=user)
embedding_handles = [embedding.handle for embedding in embeddings]
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"