298 lines
12 KiB
Python
298 lines
12 KiB
Python
import json
|
|
import os
|
|
import shutil
|
|
import uuid
|
|
import warnings
|
|
from typing import List, Tuple
|
|
from unittest.mock import patch
|
|
|
|
import pytest
|
|
from sqlalchemy import delete
|
|
|
|
import letta.utils as utils
|
|
from letta.agents.agent_loop import AgentLoop
|
|
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, LETTA_DIR, LETTA_TOOL_EXECUTION_DIR
|
|
from letta.orm import Provider, Step
|
|
from letta.schemas.block import CreateBlock
|
|
from letta.schemas.enums import MessageRole, ProviderType
|
|
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, SystemMessage, ToolCallMessage, ToolReturnMessage, UserMessage
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.providers import Provider as PydanticProvider, ProviderCreate
|
|
from letta.schemas.sandbox_config import SandboxType
|
|
from letta.schemas.user import User
|
|
|
|
utils.DEBUG = True
|
|
from letta.config import LettaConfig
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.schemas.agent import CreateAgent, UpdateAgent
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.job import Job as PydanticJob
|
|
from letta.schemas.message import Message, MessageCreate
|
|
from letta.schemas.run import Run as PydanticRun
|
|
from letta.schemas.source import Source as PydanticSource
|
|
from letta.server.server import SyncServer
|
|
from letta.system import unpack_message
|
|
|
|
from .utils import DummyDataConnector
|
|
|
|
|
|
@pytest.fixture
|
|
async def server():
|
|
config = LettaConfig.load()
|
|
config.save()
|
|
server = SyncServer(init_with_default_org_and_user=True)
|
|
await server.init_async()
|
|
await server.tool_manager.upsert_base_tools_async(actor=server.default_user)
|
|
|
|
yield server
|
|
|
|
|
|
@pytest.fixture
|
|
async def org_id(server):
|
|
# create org
|
|
org = await server.organization_manager.create_default_organization_async()
|
|
|
|
yield org.id
|
|
|
|
# cleanup
|
|
await server.organization_manager.delete_organization_by_id_async(org.id)
|
|
|
|
|
|
@pytest.fixture
|
|
async def user(server, org_id):
|
|
user = await server.user_manager.create_default_actor_async(org_id=org_id)
|
|
yield user
|
|
|
|
|
|
@pytest.fixture
|
|
def user_id(server, user):
|
|
# create user
|
|
yield user.id
|
|
|
|
|
|
provider_name = "custom-anthropic29"
|
|
|
|
|
|
@pytest.fixture
|
|
async def custom_anthropic_provider(server: SyncServer, user_id: str):
|
|
actor = await server.user_manager.get_actor_or_default_async()
|
|
|
|
# check if provider already exists
|
|
existing_providers = await server.provider_manager.list_providers_async(actor=actor)
|
|
for provider in existing_providers:
|
|
if provider.name == provider_name:
|
|
# delete provider
|
|
await server.provider_manager.delete_provider_by_id_async(provider.id, actor=actor)
|
|
|
|
provider = await server.provider_manager.create_provider_async(
|
|
ProviderCreate(
|
|
name=provider_name,
|
|
provider_type=ProviderType.anthropic,
|
|
api_key=os.getenv("ANTHROPIC_API_KEY"),
|
|
),
|
|
actor=actor,
|
|
)
|
|
yield provider
|
|
# Try to delete provider if it still exists (test may have already deleted it)
|
|
try:
|
|
await server.provider_manager.delete_provider_by_id_async(provider.id, actor=actor)
|
|
except NoResultFound:
|
|
pass # Provider was already deleted in the test
|
|
|
|
|
|
@pytest.fixture
|
|
async def agent(server: SyncServer, user: User):
|
|
actor = await server.user_manager.get_actor_or_default_async()
|
|
agent = await server.create_agent_async(
|
|
CreateAgent(
|
|
agent_type="memgpt_v2_agent",
|
|
),
|
|
)
|
|
return agent
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messages_with_provider_override(server: SyncServer, custom_anthropic_provider: PydanticProvider, user):
|
|
# list the models
|
|
models = await server.list_llm_models_async(actor=user)
|
|
for model in models:
|
|
if model.provider_name == provider_name:
|
|
print(model.model)
|
|
|
|
actor = await server.user_manager.get_actor_or_default_async()
|
|
agent = await server.create_agent_async(
|
|
CreateAgent(
|
|
agent_type="letta_v1_agent",
|
|
memory_blocks=[],
|
|
model=f"{provider_name}/claude-sonnet-4-5-20250929",
|
|
context_window_limit=100000,
|
|
embedding="openai/text-embedding-ada-002",
|
|
include_base_tools=False,
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
|
|
|
|
# send a message
|
|
run = await server.run_manager.create_run(
|
|
pydantic_run=PydanticRun(
|
|
agent_id=agent.id,
|
|
background=False,
|
|
),
|
|
actor=actor,
|
|
)
|
|
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
|
response = await agent_loop.step(
|
|
input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
|
|
run_id=run.id,
|
|
)
|
|
usage = response.usage
|
|
messages = response.messages
|
|
|
|
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
|
|
|
# usage = await server.message_manager.create_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
|
# assert usage, "Sending message failed"
|
|
|
|
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
|
# assert len(get_messages_response) > 0, "Retrieving messages failed"
|
|
|
|
step_ids = set([msg.step_id for msg in get_messages_response])
|
|
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
|
for step_id in step_ids:
|
|
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
|
|
assert step, "Step was not logged correctly"
|
|
# assert step.provider_id == custom_anthropic_provider.id
|
|
assert step.provider_name == agent.llm_config.model_endpoint_type
|
|
assert step.model == agent.llm_config.model
|
|
assert step.context_window_limit == agent.llm_config.context_window
|
|
completion_tokens += int(step.completion_tokens)
|
|
prompt_tokens += int(step.prompt_tokens)
|
|
total_tokens += int(step.total_tokens)
|
|
|
|
assert completion_tokens == usage.completion_tokens
|
|
assert prompt_tokens == usage.prompt_tokens
|
|
assert total_tokens == usage.total_tokens
|
|
|
|
# await server.provider_manager.delete_provider_by_id_async(custom_anthropic_provider.id, actor=actor)
|
|
|
|
# existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
|
|
|
|
## with pytest.raises(NoResultFound):
|
|
# agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
|
# response = await agent_loop.step(
|
|
# input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
|
|
# run_id=run.id,
|
|
# )
|
|
# print("RESULT", response)
|
|
|
|
# usage = await server.message_manager.create_user_message_async(user_id=actor.id, agent_id=agent.id, message="Test message")
|
|
# assert usage, "Sending message failed"
|
|
|
|
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
|
# assert len(get_messages_response) > 0, "Retrieving messages failed"
|
|
|
|
# step_ids = set([msg.step_id for msg in get_messages_response])
|
|
# completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
|
# for step_id in step_ids:
|
|
# step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
|
|
# assert step, "Step was not logged correctly"
|
|
# assert step.provider_id == None
|
|
# assert step.provider_name == agent.llm_config.model_endpoint_type
|
|
# assert step.model == agent.llm_config.model
|
|
# assert step.context_window_limit == agent.llm_config.context_window
|
|
# completion_tokens += int(step.completion_tokens)
|
|
# prompt_tokens += int(step.prompt_tokens)
|
|
# total_tokens += int(step.total_tokens)
|
|
|
|
# assert completion_tokens == usage.completion_tokens
|
|
# assert prompt_tokens == usage.prompt_tokens
|
|
# assert total_tokens == usage.total_tokens
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_messages_with_provider_override_legacy_agent(server: SyncServer, custom_anthropic_provider: PydanticProvider, user):
|
|
# list the models
|
|
models = await server.list_llm_models_async(actor=user)
|
|
for model in models:
|
|
if model.provider_name == provider_name:
|
|
print(model.model)
|
|
|
|
actor = await server.user_manager.get_actor_or_default_async()
|
|
agent = await server.create_agent_async(
|
|
CreateAgent(
|
|
agent_type="memgpt_v2_agent",
|
|
memory_blocks=[],
|
|
model=f"{provider_name}/claude-sonnet-4-5-20250929",
|
|
context_window_limit=100000,
|
|
embedding="openai/text-embedding-ada-002",
|
|
),
|
|
actor=actor,
|
|
)
|
|
|
|
existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
|
|
|
|
# send a message
|
|
run = await server.run_manager.create_run(
|
|
pydantic_run=PydanticRun(
|
|
agent_id=agent.id,
|
|
background=False,
|
|
),
|
|
actor=actor,
|
|
)
|
|
agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
|
response = await agent_loop.step(
|
|
input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
|
|
run_id=run.id,
|
|
)
|
|
usage = response.usage
|
|
messages = response.messages
|
|
|
|
get_messages_response = await server.message_manager.list_messages(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
|
|
|
# usage = await server.message_manager.create_message(user_id=actor.id, agent_id=agent.id, message="Test message")
|
|
# assert usage, "Sending message failed"
|
|
|
|
# get_messages_response = await server.message_manager.list_messages_for_agent_async(agent_id=agent.id, actor=actor, after=existing_messages[-1].id)
|
|
# assert len(get_messages_response) > 0, "Retrieving messages failed"
|
|
|
|
step_ids = set([msg.step_id for msg in get_messages_response])
|
|
completion_tokens, prompt_tokens, total_tokens = 0, 0, 0
|
|
for step_id in step_ids:
|
|
step = await server.step_manager.get_step_async(step_id=step_id, actor=actor)
|
|
assert step, "Step was not logged correctly"
|
|
# assert step.provider_id == custom_anthropic_provider.id
|
|
assert step.provider_name == agent.llm_config.model_endpoint_type
|
|
assert step.model == agent.llm_config.model
|
|
assert step.context_window_limit == agent.llm_config.context_window
|
|
completion_tokens += int(step.completion_tokens)
|
|
prompt_tokens += int(step.prompt_tokens)
|
|
total_tokens += int(step.total_tokens)
|
|
|
|
assert completion_tokens == usage.completion_tokens
|
|
assert prompt_tokens == usage.prompt_tokens
|
|
assert total_tokens == usage.total_tokens
|
|
|
|
# await server.provider_manager.delete_provider_by_id_async(custom_anthropic_provider.id, actor=actor)
|
|
|
|
# existing_messages = await server.message_manager.list_messages(agent_id=agent.id, actor=actor)
|
|
|
|
## with pytest.raises(NoResultFound):
|
|
# agent_loop = AgentLoop.load(agent_state=agent, actor=actor)
|
|
# response = await agent_loop.step(
|
|
# input_messages=[MessageCreate(role=MessageRole.user, content="Test message")],
|
|
# run_id=run.id,
|
|
# )
|
|
# print("RESULT", response)
|
|
|
|
|
|
@pytest.mark.asyncio
|
|
async def test_unique_handles_for_provider_configs(server: SyncServer, user: User):
|
|
models = await server.list_llm_models_async(actor=user)
|
|
model_handles = [model.handle for model in models]
|
|
assert sorted(model_handles) == sorted(list(set(model_handles))), "All models should have unique handles"
|
|
embeddings = await server.list_embedding_models_async(actor=user)
|
|
embedding_handles = [embedding.handle for embedding in embeddings]
|
|
assert sorted(embedding_handles) == sorted(list(set(embedding_handles))), "All embeddings should have unique handles"
|