chore: Clean up upserting base tools (#2274)

This commit is contained in:
Matthew Zhou
2024-12-18 14:33:29 -08:00
committed by GitHub
parent 8644f2016a
commit b1ce8b4e8a
10 changed files with 113 additions and 228 deletions

View File

@@ -233,7 +233,7 @@ class AbstractClient(object):
def get_tool_id(self, name: str) -> Optional[str]:
raise NotImplementedError
def add_base_tools(self) -> List[Tool]:
def upsert_base_tools(self) -> List[Tool]:
raise NotImplementedError
def load_data(self, connector: DataConnector, source_name: str):
@@ -1466,7 +1466,7 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to get tool: {response.text}")
return response.json()
def add_base_tools(self) -> List[Tool]:
def upsert_base_tools(self) -> List[Tool]:
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to add base tools: {response.text}")

View File

@@ -61,60 +61,6 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
return results_str
def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using a date range.
Args:
start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'.
end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'.
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
Returns:
str: Query result string
"""
import math
from datetime import datetime
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
from letta.utils import json_dumps
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
page = 0
try:
page = int(page)
if page < 0:
raise ValueError
except:
raise ValueError(f"'page' argument must be an integer")
# Convert date strings to datetime objects
try:
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0)
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999)
except ValueError:
raise ValueError("Dates must be in the format 'YYYY-MM-DD'")
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
results = self.message_manager.list_user_messages_for_agent(
# TODO: add paging by page number. currently cursor only works with strings.
agent_id=self.agent_state.id,
actor=self.user,
start_date=start_datetime,
end_date=end_datetime,
limit=count,
)
total = len(results)
num_pages = math.ceil(total / count) - 1 # 0 index
if len(results) == 0:
results_str = f"No results found."
else:
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
results_str = f"{results_pref} {json_dumps(results_formatted)}"
return results_str
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
"""
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.

View File

@@ -152,30 +152,15 @@ def update_tool(
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
def add_base_tools(
def upsert_base_tools(
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Add base tools
Upsert base tools
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
return server.tool_manager.add_base_tools(actor=actor)
# NOTE: can re-enable if needed
# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
# def run_tool(
# server: SyncServer = Depends(get_letta_server),
# request: ToolRun = Body(...),
# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
# ):
# """
# Run an existing tool on provided arguments
# """
# actor = server.user_manager.get_user_or_default(user_id=user_id)
# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
return server.tool_manager.upsert_base_tools(actor=actor)
@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")

View File

@@ -301,7 +301,7 @@ class SyncServer(Server):
self.default_org = self.organization_manager.create_default_organization()
self.default_user = self.user_manager.create_default_user()
self.block_manager.add_default_blocks(actor=self.default_user)
self.tool_manager.add_base_tools(actor=self.default_user)
self.tool_manager.upsert_base_tools(actor=self.default_user)
# If there is a default org/user
# This logic may have to change in the future

View File

@@ -1,18 +1,18 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np
from typing import Dict, List, Optional
from sqlalchemy import select, union_all, literal, func, Select
import numpy as np
from sqlalchemy import Select, func, literal, select, union_all
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import AgentPassage
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import SourcePassage, SourcesAgents
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
@@ -77,6 +77,8 @@ class AgentManager:
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
if agent_create.tools:
tool_names.extend(agent_create.tools)
# Remove duplicates
tool_names = list(set(tool_names))
tool_ids = agent_create.tool_ids or []
for tool_name in tool_names:
@@ -431,7 +433,7 @@ class AgentManager:
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
@@ -448,21 +450,14 @@ class AgentManager:
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
select(SourcePassage, literal(None).label("agent_id"))
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.where(SourcePassage.organization_id == actor.organization_id)
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
SourcePassage.organization_id == actor.organization_id
)
if source_id:
@@ -486,9 +481,9 @@ class AgentManager:
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
literal(None).label("file_id"),
literal(None).label("source_id"),
AgentPassage.agent_id,
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
@@ -496,11 +491,11 @@ class AgentManager:
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
elif agent_passages is not None:
combined_query = agent_passages.cte('combined_passages')
combined_query = agent_passages.cte("combined_passages")
elif source_passages is not None:
combined_query = source_passages.cte('combined_passages')
combined_query = source_passages.cte("combined_passages")
else:
raise ValueError("No passages found")
@@ -521,9 +516,7 @@ class AgentManager:
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
@@ -531,13 +524,13 @@ class AgentManager:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
combined_query.c.id.asc(),
)
else:
if query_text:
@@ -545,18 +538,12 @@ class AgentManager:
# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()
cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()
if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
main_query = main_query.where(combined_query.c.created_at > cursor_query)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)
main_query = main_query.where(combined_query.c.created_at < cursor_query)
# Add ordering if not already ordered by similarity
if not embed_query:
@@ -588,7 +575,7 @@ class AgentManager:
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
@@ -617,19 +604,18 @@ class AgentManager:
passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
if data["agent_id"] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
data.pop("source_id", None)
data.pop("file_id", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
data.pop("agent_id", None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
@@ -645,7 +631,7 @@ class AgentManager:
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
agent_only: bool = False,
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
@@ -663,7 +649,7 @@ class AgentManager:
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0

View File

@@ -3,6 +3,7 @@ import inspect
import warnings
from typing import List, Optional
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.functions.functions import derive_openai_json_schema, load_function_set
# TODO: Remove this once we translate all of these to the ORM
@@ -20,7 +21,6 @@ class ToolManager:
BASE_TOOL_NAMES = [
"send_message",
"conversation_search",
"conversation_search_date",
"archival_memory_insert",
"archival_memory_search",
]
@@ -133,7 +133,7 @@ class ToolManager:
raise ValueError(f"Tool with id {tool_id} not found.")
@enforce_types
def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
"""Add default tools in base.py"""
module_name = "base"
full_module_name = f"letta.functions.function_sets.{module_name}"
@@ -154,7 +154,7 @@ class ToolManager:
# create tool in db
tools = []
for name, schema in functions_to_schema.items():
if name in self.BASE_TOOL_NAMES + self.BASE_MEMORY_TOOL_NAMES:
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
# print([str(inspect.getsource(line)) for line in schema["imports"]])
source_code = inspect.getsource(schema["python_function"])
tags = [module_name]

View File

@@ -1,5 +1,5 @@
from letta.functions.functions import parse_source_code
from letta.schemas.tool import Tool
from tqdm import tqdm
from letta.schemas.user import User
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
@@ -10,33 +10,8 @@ def deprecated_tool():
orgs = OrganizationManager().list_organizations(cursor=None, limit=5000)
for org in orgs:
for org in tqdm(orgs):
if org.name != "default":
fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id)
ToolManager().add_base_tools(actor=fake_user)
source_code = parse_source_code(deprecated_tool)
source_type = "python"
description = "deprecated"
tags = ["deprecated"]
ToolManager().create_or_update_tool(
Tool(
name="core_memory_append",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
ToolManager().create_or_update_tool(
Tool(
name="core_memory_replace",
source_code=source_code,
source_type=source_type,
description=description,
),
actor=fake_user,
)
ToolManager().upsert_base_tools(actor=fake_user)

View File

@@ -11,7 +11,7 @@ from sqlalchemy import delete
from letta import create_client
from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
from letta.orm import FileMetadata, Source
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
@@ -30,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import MessageCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.services.organization_manager import OrganizationManager
from letta.services.tool_manager import ToolManager
from letta.services.user_manager import UserManager
from letta.settings import model_settings
from tests.helpers.client_helper import upload_file_using_client
@@ -336,9 +335,9 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
def test_list_tools(client: Union[LocalClient, RESTClient]):
tools = client.add_base_tools()
tools = client.upsert_base_tools()
tool_names = [t.name for t in tools]
expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES
expected = BASE_TOOLS + BASE_MEMORY_TOOLS
assert sorted(tool_names) == sorted(expected)

View File

@@ -2,28 +2,27 @@ import os
import time
from datetime import datetime, timedelta
from httpx._transports import default
from numpy import source
import pytest
from sqlalchemy import delete
from sqlalchemy.exc import IntegrityError
from letta.config import LettaConfig
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from letta.embeddings import embedding_model
from letta.functions.functions import derive_openai_json_schema, parse_source_code
from letta.orm import (
Agent,
AgentPassage,
Block,
BlocksAgents,
FileMetadata,
Job,
Message,
Organization,
AgentPassage,
SourcePassage,
SandboxConfig,
SandboxEnvironmentVariable,
Source,
SourcePassage,
SourcesAgents,
Tool,
ToolsAgents,
@@ -202,9 +201,9 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
yield passage
@@ -220,9 +219,9 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
yield passage
@@ -240,9 +239,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
passages.append(passage)
if USING_SQLITE:
@@ -258,9 +257,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata_={"type": "test"}
metadata_={"type": "test"},
),
actor=default_user
actor=default_user,
)
passages.append(passage)
if USING_SQLITE:
@@ -452,7 +451,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
actor=actor,
)
source_passages.append(passage)
@@ -467,7 +466,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=actor
actor=actor,
)
agent_passages.append(passage)
@@ -476,6 +475,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
# Cleanup
server.source_manager.delete_source(default_source.id, actor=actor)
# ======================================================================================================================
# AgentManager Tests - Basic
# ======================================================================================================================
@@ -940,32 +940,33 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
# Agent Manager - Passages Tests
# ======================================================================================================================
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
"""Test basic listing functionality of agent passages"""
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
assert len(all_passages) == 5 # 3 source + 2 agent passages
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
"""Test ordering of agent passages"""
"""Test ordering of agent passages"""
# Test ascending order
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
assert len(asc_passages) == 5
for i in range(1, len(asc_passages)):
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
assert asc_passages[i - 1].created_at <= asc_passages[i].created_at
# Test descending order
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
assert len(desc_passages) == 5
for i in range(1, len(desc_passages)):
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
assert desc_passages[i - 1].created_at >= desc_passages[i].created_at
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
"""Test pagination of agent passages"""
# Test limit
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
assert len(limited_passages) == 3
@@ -973,13 +974,9 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
# Test cursor-based pagination
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
assert len(first_page) == 2
second_page = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
cursor=first_page[-1].id,
limit=2,
ascending=True
actor=default_user, agent_id=sarah_agent.id, cursor=first_page[-1].id, limit=2, ascending=True
)
assert len(second_page) == 2
assert first_page[-1].id != second_page[0].id
@@ -988,57 +985,38 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for source passages
source_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Source passage"
)
source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage")
assert len(source_text_passages) == 3
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
query_text="Agent passage"
)
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage")
assert len(agent_text_passages) == 2
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
"""Test text search functionality of agent passages"""
# Test text search for agent passages
agent_text_passages = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
agent_only=True
)
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agent_text_passages) == 2
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
"""Test filtering functionality of agent passages"""
# Test source filtering
source_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
source_id=default_source.id
)
source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id)
assert len(source_filtered) == 3
# Test date filtering
now = datetime.utcnow()
future_date = now + timedelta(days=1)
past_date = now - timedelta(days=1)
date_filtered = server.agent_manager.list_passages(
actor=default_user,
agent_id=sarah_agent.id,
start_date=past_date,
end_date=future_date
actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date
)
assert len(date_filtered) == 5
@@ -1049,7 +1027,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
# Create passages with known embeddings
passages = []
# Create passages with different embeddings
test_passages = [
"I like red",
@@ -1058,7 +1036,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
]
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
for i, text in enumerate(test_passages):
embedding = embed_model.get_text_embedding(text)
if i % 2 == 0:
@@ -1067,7 +1045,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
organization_id=default_user.organization_id,
agent_id=sarah_agent.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
embedding=embedding,
)
else:
passage = PydanticPassage(
@@ -1075,14 +1053,14 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
organization_id=default_user.organization_id,
source_id=default_source.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding
embedding=embedding,
)
created_passage = server.passage_manager.create_passage(passage, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
query_key = "What's my favorite color?"
# Test vector search with all passages
results = server.agent_manager.list_passages(
actor=default_user,
@@ -1091,7 +1069,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
)
# Verify results are ordered by similarity
assert len(results) == 3
assert results[0].text == "I like red"
@@ -1105,9 +1083,9 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
query_text=query_key,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embed_query=True,
agent_only=True
agent_only=True,
)
# Verify agent-only results
assert len(agent_only_results) == 2
assert agent_only_results[0].text == "I like red"
@@ -1116,7 +1094,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
"""Test listing passages from a source without specifying an agent."""
# List passages by source_id without agent_id
source_passages = server.agent_manager.list_passages(
actor=default_user,
@@ -1180,6 +1158,7 @@ def test_list_organizations_pagination(server: SyncServer):
# Passage Manager Tests
# ======================================================================================================================
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
"""Test creating a passage using agent_passage_fixture fixture"""
assert agent_passage_fixture.id is not None
@@ -1214,7 +1193,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
"""Test creating an agent passage."""
assert agent_passage_fixture is not None
assert agent_passage_fixture.text == "Hello, I am an agent passage"
# Try to create an invalid passage (with both agent_id and source_id)
with pytest.raises(AssertionError):
server.passage_manager.create_passage(
@@ -1226,7 +1205,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
embedding=[0.1] * 1024,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user
actor=default_user,
)
@@ -1243,19 +1222,21 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas
assert retrieved.text == source_passage_fixture.text
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
def test_passage_cascade_deletion(
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent
):
"""Test that passages are deleted when their parent (agent or source) is deleted."""
# Verify passages exist
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
assert agent_passage is not None
assert source_passage is not None
# Delete agent and verify its passages are deleted
server.agent_manager.delete_agent(sarah_agent.id, default_user)
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
@@ -1320,7 +1301,6 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
assert print_tool.organization_id == default_organization.id
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
data = print_tool.model_dump(exclude=["id"])
@@ -1481,6 +1461,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
assert len(tools) == 0
def test_upsert_base_tools(server: SyncServer, default_user):
tools = server.tool_manager.upsert_base_tools(actor=default_user)
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS)
assert sorted([t.name for t in tools]) == expected_tool_names
# Call it again to make sure it doesn't create duplicates
tools = server.tool_manager.upsert_base_tools(actor=default_user)
assert sorted([t.name for t in tools]) == expected_tool_names
# ======================================================================================================================
# Message Manager Tests
# ======================================================================================================================
@@ -1889,6 +1879,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# Source Manager Tests - Files
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
@@ -1960,6 +1951,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
# SandboxConfigManager Tests - Sandbox Configs
# ======================================================================================================================
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
sandbox_config_create = SandboxConfigCreate(
config=E2BSandboxConfig(),
@@ -2039,6 +2031,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
# SandboxConfigManager Tests - Environment Variables
# ======================================================================================================================
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
@@ -2111,6 +2104,7 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
# JobManager Tests
# ======================================================================================================================
def test_create_job(server: SyncServer, default_user):
"""Test creating a job."""
job_data = PydanticJob(

View File

@@ -272,15 +272,15 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
)
def test_add_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool]
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
assert response.status_code == 200
assert len(response.json()) == 1
assert response.json()[0]["id"] == add_integers_tool.id
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value
)