chore: Clean up upserting base tools (#2274)
This commit is contained in:
@@ -11,7 +11,7 @@ from sqlalchemy import delete
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
|
||||
from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -30,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
@@ -336,9 +335,9 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
|
||||
def test_list_tools(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.add_base_tools()
|
||||
tools = client.upsert_base_tools()
|
||||
tool_names = [t.name for t in tools]
|
||||
expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES
|
||||
expected = BASE_TOOLS + BASE_MEMORY_TOOLS
|
||||
assert sorted(tool_names) == sorted(expected)
|
||||
|
||||
|
||||
|
||||
@@ -2,28 +2,27 @@ import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from httpx._transports import default
|
||||
from numpy import source
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.orm import (
|
||||
Agent,
|
||||
AgentPassage,
|
||||
Block,
|
||||
BlocksAgents,
|
||||
FileMetadata,
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
AgentPassage,
|
||||
SourcePassage,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
SourcePassage,
|
||||
SourcesAgents,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
@@ -202,9 +201,9 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
yield passage
|
||||
|
||||
@@ -220,9 +219,9 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
yield passage
|
||||
|
||||
@@ -240,9 +239,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
@@ -258,9 +257,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
@@ -452,7 +451,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
actor=actor,
|
||||
)
|
||||
source_passages.append(passage)
|
||||
|
||||
@@ -467,7 +466,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
actor=actor,
|
||||
)
|
||||
agent_passages.append(passage)
|
||||
|
||||
@@ -476,6 +475,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
# Cleanup
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -940,32 +940,33 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
||||
# Agent Manager - Passages Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test basic listing functionality of agent passages"""
|
||||
|
||||
|
||||
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||
|
||||
|
||||
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test ordering of agent passages"""
|
||||
"""Test ordering of agent passages"""
|
||||
|
||||
# Test ascending order
|
||||
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
|
||||
assert len(asc_passages) == 5
|
||||
for i in range(1, len(asc_passages)):
|
||||
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
|
||||
assert asc_passages[i - 1].created_at <= asc_passages[i].created_at
|
||||
|
||||
# Test descending order
|
||||
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
|
||||
assert len(desc_passages) == 5
|
||||
for i in range(1, len(desc_passages)):
|
||||
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
|
||||
assert desc_passages[i - 1].created_at >= desc_passages[i].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test pagination of agent passages"""
|
||||
|
||||
|
||||
# Test limit
|
||||
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
|
||||
assert len(limited_passages) == 3
|
||||
@@ -973,13 +974,9 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
|
||||
# Test cursor-based pagination
|
||||
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
|
||||
assert len(first_page) == 2
|
||||
|
||||
|
||||
second_page = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
cursor=first_page[-1].id,
|
||||
limit=2,
|
||||
ascending=True
|
||||
actor=default_user, agent_id=sarah_agent.id, cursor=first_page[-1].id, limit=2, ascending=True
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
assert first_page[-1].id != second_page[0].id
|
||||
@@ -988,57 +985,38 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
|
||||
|
||||
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
|
||||
# Test text search for source passages
|
||||
source_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Source passage"
|
||||
)
|
||||
source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage")
|
||||
assert len(source_text_passages) == 3
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Agent passage"
|
||||
)
|
||||
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage")
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
agent_only=True
|
||||
)
|
||||
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||
"""Test filtering functionality of agent passages"""
|
||||
|
||||
|
||||
# Test source filtering
|
||||
source_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
source_id=default_source.id
|
||||
)
|
||||
source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id)
|
||||
assert len(source_filtered) == 3
|
||||
|
||||
# Test date filtering
|
||||
now = datetime.utcnow()
|
||||
future_date = now + timedelta(days=1)
|
||||
past_date = now - timedelta(days=1)
|
||||
|
||||
|
||||
date_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
start_date=past_date,
|
||||
end_date=future_date
|
||||
actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date
|
||||
)
|
||||
assert len(date_filtered) == 5
|
||||
|
||||
@@ -1049,7 +1027,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
@@ -1058,7 +1036,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
]
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if i % 2 == 0:
|
||||
@@ -1067,7 +1045,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
embedding=embedding,
|
||||
)
|
||||
else:
|
||||
passage = PydanticPassage(
|
||||
@@ -1075,14 +1053,14 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=default_source.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
|
||||
# Test vector search with all passages
|
||||
results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
@@ -1091,7 +1069,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
@@ -1105,9 +1083,9 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
agent_only=True
|
||||
agent_only=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify agent-only results
|
||||
assert len(agent_only_results) == 2
|
||||
assert agent_only_results[0].text == "I like red"
|
||||
@@ -1116,7 +1094,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
|
||||
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
|
||||
"""Test listing passages from a source without specifying an agent."""
|
||||
|
||||
|
||||
# List passages by source_id without agent_id
|
||||
source_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
@@ -1180,6 +1158,7 @@ def test_list_organizations_pagination(server: SyncServer):
|
||||
# Passage Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating a passage using agent_passage_fixture fixture"""
|
||||
assert agent_passage_fixture.id is not None
|
||||
@@ -1214,7 +1193,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
|
||||
"""Test creating an agent passage."""
|
||||
assert agent_passage_fixture is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
|
||||
# Try to create an invalid passage (with both agent_id and source_id)
|
||||
with pytest.raises(AssertionError):
|
||||
server.passage_manager.create_passage(
|
||||
@@ -1226,7 +1205,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
|
||||
embedding=[0.1] * 1024,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1243,19 +1222,21 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
|
||||
def test_passage_cascade_deletion(
|
||||
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent
|
||||
):
|
||||
"""Test that passages are deleted when their parent (agent or source) is deleted."""
|
||||
# Verify passages exist
|
||||
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
|
||||
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
assert agent_passage is not None
|
||||
assert source_passage is not None
|
||||
|
||||
|
||||
# Delete agent and verify its passages are deleted
|
||||
server.agent_manager.delete_agent(sarah_agent.id, default_user)
|
||||
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agentic_passages) == 0
|
||||
|
||||
|
||||
# Delete source and verify its passages are deleted
|
||||
server.source_manager.delete_source(default_source.id, default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
@@ -1320,7 +1301,6 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
@@ -1481,6 +1461,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Call it again to make sure it doesn't create duplicates
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
@@ -1889,6 +1879,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
|
||||
# Source Manager Tests - Files
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by ID."""
|
||||
file_metadata = PydanticFileMetadata(
|
||||
@@ -1960,6 +1951,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
||||
# SandboxConfigManager Tests - Sandbox Configs
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
||||
sandbox_config_create = SandboxConfigCreate(
|
||||
config=E2BSandboxConfig(),
|
||||
@@ -2039,6 +2031,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
# SandboxConfigManager Tests - Environment Variables
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
||||
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
||||
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
||||
@@ -2111,6 +2104,7 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_job(server: SyncServer, default_user):
|
||||
"""Test creating a job."""
|
||||
job_data = PydanticJob(
|
||||
|
||||
@@ -272,15 +272,15 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
|
||||
)
|
||||
|
||||
|
||||
def test_add_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool]
|
||||
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
|
||||
|
||||
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
|
||||
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user