chore: Clean up upserting base tools (#2274)

This commit is contained in:
Matthew Zhou
2024-12-18 14:33:29 -08:00
committed by GitHub
parent 8644f2016a
commit b1ce8b4e8a
10 changed files with 113 additions and 228 deletions

View File

@@ -11,7 +11,7 @@ from sqlalchemy import delete
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -30,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
@@ -336,9 +335,9 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
def test_list_tools(client: Union[LocalClient, RESTClient]):
tools = client.add_base_tools()
tools = client.upsert_base_tools()
tool_names = [t.name for t in tools]
expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES
expected = BASE_TOOLS + BASE_MEMORY_TOOLS
assert sorted(tool_names) == sorted(expected)

View File

@@ -2,28 +2,27 @@ import os
import time
from datetime import datetime, timedelta
from httpx._transports import default
from numpy import source
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError
from letta.config import LettaConfig
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.orm import (
Agent,
AgentPassage,
Block,
BlocksAgents,
FileMetadata,
Job,
Message,
Organization,
AgentPassage,
SourcePassage,
SandboxConfig,
SandboxEnvironmentVariable,
Source,
SourcePassage,
SourcesAgents,
Tool,
ToolsAgents,
@@ -202,9 +201,9 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
yield passage
@@ -220,9 +219,9 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
yield passage
@@ -240,9 +239,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
passages.append(passage)
if USING_SQLITE:
@@ -258,9 +257,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
passages.append(passage)
if USING_SQLITE:
@@ -452,7 +451,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
actor=actor,
)
source_passages.append(passage)
@@ -467,7 +466,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
actor=actor,
)
agent_passages.append(passage)
@@ -476,6 +475,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
# Cleanup
server.source_manager.delete_source(default_source.id, actor=actor)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -940,32 +940,33 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
# Agent Manager - Passages Tests
# ======================================================================================================================
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
"""Test basic listing functionality of agent passages"""
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
"""Test ordering of agent passages"""
"""Test ordering of agent passages"""
# Test ascending order
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
assert len(asc_passages) == 5
for i in range(1, len(asc_passages)):
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
assert asc_passages[i - 1].created_at <= asc_passages[i].created_at
# Test descending order
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
assert len(desc_passages) == 5
for i in range(1, len(desc_passages)):
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
assert desc_passages[i - 1].created_at >= desc_passages[i].created_at
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
"""Test pagination of agent passages"""
# Test limit
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
assert len(limited_passages) == 3
@@ -973,13 +974,9 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
# Test cursor-based pagination
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
assert len(first_page) == 2
second_page = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
cursor=first_page[-1].id,
limit=2,
ascending=True
actor=default_user, agent_id=sarah_agent.id, cursor=first_page[-1].id, limit=2, ascending=True
)
assert len(second_page) == 2
assert first_page[-1].id != second_page[0].id
@@ -988,57 +985,38 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for source passages
source_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Source passage"
)
source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage")
assert len(source_text_passages) == 3
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Agent passage"
)
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage")
assert len(agent_text_passages) == 2
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
agent_only=True
)
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agent_text_passages) == 2
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
"""Test filtering functionality of agent passages"""
# Test source filtering
source_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
source_id=default_source.id
)
source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id)
assert len(source_filtered) == 3
# Test date filtering
now = datetime.utcnow()
future_date = now + timedelta(days=1)
past_date = now - timedelta(days=1)
date_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
start_date=past_date,
end_date=future_date
actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date
)
assert len(date_filtered) == 5
@@ -1049,7 +1027,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
@@ -1058,7 +1036,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
]
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
for i, text in enumerate(test_passages):
embedding = embed_model.get_text_embedding(text)
if i % 2 == 0:
@@ -1067,7 +1045,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
embedding=embedding,
)
else:
passage = PydanticPassage(
@@ -1075,14 +1053,14 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
organization_id=default_user.organization_id,
source_id=default_source.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
embedding=embedding,
)
created_passage = server.passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
query_key = "What's my favorite color?"
# Test vector search with all passages
results = server.agent_manager.list_passages(
actor=default_user,
@@ -1091,7 +1069,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
@@ -1105,9 +1083,9 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
agent_only=True
agent_only=True,
)
# Verify agent-only results
assert len(agent_only_results) == 2
assert agent_only_results[0].text == "I like red"
@@ -1116,7 +1094,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
"""Test listing passages from a source without specifying an agent."""
# List passages by source_id without agent_id
source_passages = server.agent_manager.list_passages(
actor=default_user,
@@ -1180,6 +1158,7 @@ def test_list_organizations_pagination(server: SyncServer):
# Passage Manager Tests
# ======================================================================================================================
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating a passage using agent_passage_fixture fixture"""
assert agent_passage_fixture.id is not None
@@ -1214,7 +1193,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
"""Test creating an agent passage."""
assert agent_passage_fixture is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Try to create an invalid passage (with both agent_id and source_id)
with pytest.raises(AssertionError):
server.passage_manager.create_passage(
@@ -1226,7 +1205,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
embedding=[0.1] * 1024,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user
actor=default_user,
)
@@ -1243,19 +1222,21 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas
assert retrieved.text == source_passage_fixture.text
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
def test_passage_cascade_deletion(
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent
):
"""Test that passages are deleted when their parent (agent or source) is deleted."""
# Verify passages exist
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
assert agent_passage is not None
assert source_passage is not None
# Delete agent and verify its passages are deleted
server.agent_manager.delete_agent(sarah_agent.id, default_user)
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
@@ -1320,7 +1301,6 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
assert print_tool.organization_id == default_organization.id
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
@@ -1481,6 +1461,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
assert len(tools) == 0
def test_upsert_base_tools(server: SyncServer, default_user):
tools = server.tool_manager.upsert_base_tools(actor=default_user)
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS)
assert sorted([t.name for t in tools]) == expected_tool_names
# Call it again to make sure it doesn't create duplicates
tools = server.tool_manager.upsert_base_tools(actor=default_user)
assert sorted([t.name for t in tools]) == expected_tool_names
# ======================================================================================================================
# Message Manager Tests
# ======================================================================================================================
@@ -1889,6 +1879,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# Source Manager Tests - Files
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
@@ -1960,6 +1951,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
# SandboxConfigManager Tests - Sandbox Configs
# ======================================================================================================================
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
sandbox_config_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
@@ -2039,6 +2031,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
# SandboxConfigManager Tests - Environment Variables
# ======================================================================================================================
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
@@ -2111,6 +2104,7 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# JobManager Tests
# ======================================================================================================================
def test_create_job(server: SyncServer, default_user):
"""Test creating a job."""
job_data = PydanticJob(

View File

@@ -272,15 +272,15 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
)
def test_add_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool]
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value
)