chore: Clean up upserting base tools (#2274)
This commit is contained in:
@@ -233,7 +233,7 @@ class AbstractClient(object):
|
||||
def get_tool_id(self, name: str) -> Optional[str]:
|
||||
raise NotImplementedError
|
||||
|
||||
def add_base_tools(self) -> List[Tool]:
|
||||
def upsert_base_tools(self) -> List[Tool]:
|
||||
raise NotImplementedError
|
||||
|
||||
def load_data(self, connector: DataConnector, source_name: str):
|
||||
@@ -1466,7 +1466,7 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return response.json()
|
||||
|
||||
def add_base_tools(self) -> List[Tool]:
|
||||
def upsert_base_tools(self) -> List[Tool]:
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/tools/add-base-tools/", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to add base tools: {response.text}")
|
||||
|
||||
@@ -61,60 +61,6 @@ def conversation_search(self: "Agent", query: str, page: Optional[int] = 0) -> O
|
||||
return results_str
|
||||
|
||||
|
||||
def conversation_search_date(self: "Agent", start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using a date range.
|
||||
|
||||
Args:
|
||||
start_date (str): The start of the date range to search, in the format 'YYYY-MM-DD'.
|
||||
end_date (str): The end of the date range to search, in the format 'YYYY-MM-DD'.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
import math
|
||||
from datetime import datetime
|
||||
|
||||
from letta.constants import RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.utils import json_dumps
|
||||
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
if page < 0:
|
||||
raise ValueError
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
|
||||
# Convert date strings to datetime objects
|
||||
try:
|
||||
start_datetime = datetime.strptime(start_date, "%Y-%m-%d").replace(hour=0, minute=0, second=0, microsecond=0)
|
||||
end_datetime = datetime.strptime(end_date, "%Y-%m-%d").replace(hour=23, minute=59, second=59, microsecond=999999)
|
||||
except ValueError:
|
||||
raise ValueError("Dates must be in the format 'YYYY-MM-DD'")
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
results = self.message_manager.list_user_messages_for_agent(
|
||||
# TODO: add paging by page number. currently cursor only works with strings.
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
start_date=start_datetime,
|
||||
end_date=end_datetime,
|
||||
limit=count,
|
||||
)
|
||||
total = len(results)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
if len(results) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(results)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [f"timestamp: {d['timestamp']}, {d['message']['role']} - {d['message']['content']}" for d in results]
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
return results_str
|
||||
|
||||
|
||||
def archival_memory_insert(self: "Agent", content: str) -> Optional[str]:
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
@@ -152,30 +152,15 @@ def update_tool(
|
||||
|
||||
|
||||
@router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools")
|
||||
def add_base_tools(
|
||||
def upsert_base_tools(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""
|
||||
Add base tools
|
||||
Upsert base tools
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.tool_manager.add_base_tools(actor=actor)
|
||||
|
||||
|
||||
# NOTE: can re-enable if needed
|
||||
# @router.post("/{tool_id}/run", response_model=FunctionReturn, operation_id="run_tool")
|
||||
# def run_tool(
|
||||
# server: SyncServer = Depends(get_letta_server),
|
||||
# request: ToolRun = Body(...),
|
||||
# user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
# ):
|
||||
# """
|
||||
# Run an existing tool on provided arguments
|
||||
# """
|
||||
# actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# return server.run_tool(tool_id=request.tool_id, tool_args=request.tool_args, user_id=actor.id)
|
||||
return server.tool_manager.upsert_base_tools(actor=actor)
|
||||
|
||||
|
||||
@router.post("/run", response_model=FunctionReturn, operation_id="run_tool_from_source")
|
||||
|
||||
@@ -301,7 +301,7 @@ class SyncServer(Server):
|
||||
self.default_org = self.organization_manager.create_default_organization()
|
||||
self.default_user = self.user_manager.create_default_user()
|
||||
self.block_manager.add_default_blocks(actor=self.default_user)
|
||||
self.tool_manager.add_base_tools(actor=self.default_user)
|
||||
self.tool_manager.upsert_base_tools(actor=self.default_user)
|
||||
|
||||
# If there is a default org/user
|
||||
# This logic may have to change in the future
|
||||
|
||||
@@ -1,18 +1,18 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select, union_all, literal, func, Select
|
||||
import numpy as np
|
||||
from sqlalchemy import Select, func, literal, select, union_all
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import AgentPassage
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import SourcePassage, SourcesAgents
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm import AgentPassage, SourcePassage
|
||||
from letta.orm import SourcesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
@@ -77,6 +77,8 @@ class AgentManager:
|
||||
tool_names.extend(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
if agent_create.tools:
|
||||
tool_names.extend(agent_create.tools)
|
||||
# Remove duplicates
|
||||
tool_names = list(set(tool_names))
|
||||
|
||||
tool_ids = agent_create.tool_ids or []
|
||||
for tool_name in tool_names:
|
||||
@@ -431,7 +433,7 @@ class AgentManager:
|
||||
agent_only: bool = False,
|
||||
) -> Select:
|
||||
"""Helper function to build the base passage query with all filters applied.
|
||||
|
||||
|
||||
Returns the query before any limit or count operations are applied.
|
||||
"""
|
||||
embedded_text = None
|
||||
@@ -448,21 +450,14 @@ class AgentManager:
|
||||
if not agent_only: # Include source passages
|
||||
if agent_id is not None:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
select(SourcePassage, literal(None).label("agent_id"))
|
||||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||
.where(SourcesAgents.agent_id == agent_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
|
||||
SourcePassage.organization_id == actor.organization_id
|
||||
)
|
||||
|
||||
if source_id:
|
||||
@@ -486,9 +481,9 @@ class AgentManager:
|
||||
AgentPassage._created_by_id,
|
||||
AgentPassage._last_updated_by_id,
|
||||
AgentPassage.organization_id,
|
||||
literal(None).label('file_id'),
|
||||
literal(None).label('source_id'),
|
||||
AgentPassage.agent_id
|
||||
literal(None).label("file_id"),
|
||||
literal(None).label("source_id"),
|
||||
AgentPassage.agent_id,
|
||||
)
|
||||
.where(AgentPassage.agent_id == agent_id)
|
||||
.where(AgentPassage.organization_id == actor.organization_id)
|
||||
@@ -496,11 +491,11 @@ class AgentManager:
|
||||
|
||||
# Combine queries
|
||||
if source_passages is not None and agent_passages is not None:
|
||||
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
||||
combined_query = union_all(source_passages, agent_passages).cte("combined_passages")
|
||||
elif agent_passages is not None:
|
||||
combined_query = agent_passages.cte('combined_passages')
|
||||
combined_query = agent_passages.cte("combined_passages")
|
||||
elif source_passages is not None:
|
||||
combined_query = source_passages.cte('combined_passages')
|
||||
combined_query = source_passages.cte("combined_passages")
|
||||
else:
|
||||
raise ValueError("No passages found")
|
||||
|
||||
@@ -521,9 +516,7 @@ class AgentManager:
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
||||
)
|
||||
main_query = main_query.order_by(combined_query.c.embedding.cosine_distance(embedded_text).asc())
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
@@ -531,13 +524,13 @@ class AgentManager:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc()
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc()
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
if query_text:
|
||||
@@ -545,18 +538,12 @@ class AgentManager:
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
cursor_query = select(combined_query.c.created_at).where(
|
||||
combined_query.c.id == cursor
|
||||
).scalar_subquery()
|
||||
|
||||
cursor_query = select(combined_query.c.created_at).where(combined_query.c.id == cursor).scalar_subquery()
|
||||
|
||||
if ascending:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at > cursor_query
|
||||
)
|
||||
main_query = main_query.where(combined_query.c.created_at > cursor_query)
|
||||
else:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at < cursor_query
|
||||
)
|
||||
main_query = main_query.where(combined_query.c.created_at < cursor_query)
|
||||
|
||||
# Add ordering if not already ordered by similarity
|
||||
if not embed_query:
|
||||
@@ -588,7 +575,7 @@ class AgentManager:
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
agent_only: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
with self.session_maker() as session:
|
||||
@@ -617,19 +604,18 @@ class AgentManager:
|
||||
passages = []
|
||||
for row in results:
|
||||
data = dict(row._mapping)
|
||||
if data['agent_id'] is not None:
|
||||
if data["agent_id"] is not None:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop('source_id', None)
|
||||
data.pop('file_id', None)
|
||||
data.pop("source_id", None)
|
||||
data.pop("file_id", None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
data.pop('agent_id', None)
|
||||
data.pop("agent_id", None)
|
||||
passage = SourcePassage(**data)
|
||||
passages.append(passage)
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
@enforce_types
|
||||
def passage_size(
|
||||
@@ -645,7 +631,7 @@ class AgentManager:
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
agent_only: bool = False,
|
||||
) -> int:
|
||||
"""Returns the count of passages matching the given criteria."""
|
||||
with self.session_maker() as session:
|
||||
@@ -663,7 +649,7 @@ class AgentManager:
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
|
||||
# Convert to count query
|
||||
count_query = select(func.count()).select_from(main_query.subquery())
|
||||
return session.scalar(count_query) or 0
|
||||
|
||||
@@ -3,6 +3,7 @@ import inspect
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.functions.functions import derive_openai_json_schema, load_function_set
|
||||
|
||||
# TODO: Remove this once we translate all of these to the ORM
|
||||
@@ -20,7 +21,6 @@ class ToolManager:
|
||||
BASE_TOOL_NAMES = [
|
||||
"send_message",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
"archival_memory_insert",
|
||||
"archival_memory_search",
|
||||
]
|
||||
@@ -133,7 +133,7 @@ class ToolManager:
|
||||
raise ValueError(f"Tool with id {tool_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
def add_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
def upsert_base_tools(self, actor: PydanticUser) -> List[PydanticTool]:
|
||||
"""Add default tools in base.py"""
|
||||
module_name = "base"
|
||||
full_module_name = f"letta.functions.function_sets.{module_name}"
|
||||
@@ -154,7 +154,7 @@ class ToolManager:
|
||||
# create tool in db
|
||||
tools = []
|
||||
for name, schema in functions_to_schema.items():
|
||||
if name in self.BASE_TOOL_NAMES + self.BASE_MEMORY_TOOL_NAMES:
|
||||
if name in BASE_TOOLS + BASE_MEMORY_TOOLS:
|
||||
# print([str(inspect.getsource(line)) for line in schema["imports"]])
|
||||
source_code = inspect.getsource(schema["python_function"])
|
||||
tags = [module_name]
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.schemas.tool import Tool
|
||||
from tqdm import tqdm
|
||||
|
||||
from letta.schemas.user import User
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -10,33 +10,8 @@ def deprecated_tool():
|
||||
|
||||
|
||||
orgs = OrganizationManager().list_organizations(cursor=None, limit=5000)
|
||||
for org in orgs:
|
||||
for org in tqdm(orgs):
|
||||
if org.name != "default":
|
||||
fake_user = User(id="user-00000000-0000-4000-8000-000000000000", name="fake", organization_id=org.id)
|
||||
|
||||
ToolManager().add_base_tools(actor=fake_user)
|
||||
|
||||
source_code = parse_source_code(deprecated_tool)
|
||||
source_type = "python"
|
||||
description = "deprecated"
|
||||
tags = ["deprecated"]
|
||||
|
||||
ToolManager().create_or_update_tool(
|
||||
Tool(
|
||||
name="core_memory_append",
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
description=description,
|
||||
),
|
||||
actor=fake_user,
|
||||
)
|
||||
|
||||
ToolManager().create_or_update_tool(
|
||||
Tool(
|
||||
name="core_memory_replace",
|
||||
source_code=source_code,
|
||||
source_type=source_type,
|
||||
description=description,
|
||||
),
|
||||
actor=fake_user,
|
||||
)
|
||||
ToolManager().upsert_base_tools(actor=fake_user)
|
||||
|
||||
@@ -11,7 +11,7 @@ from sqlalchemy import delete
|
||||
|
||||
from letta import create_client
|
||||
from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, DEFAULT_PRESET
|
||||
from letta.orm import FileMetadata, Source
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -30,7 +30,6 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.services.user_manager import UserManager
|
||||
from letta.settings import model_settings
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
@@ -336,9 +335,9 @@ def test_list_tools_pagination(client: Union[LocalClient, RESTClient]):
|
||||
|
||||
|
||||
def test_list_tools(client: Union[LocalClient, RESTClient]):
|
||||
tools = client.add_base_tools()
|
||||
tools = client.upsert_base_tools()
|
||||
tool_names = [t.name for t in tools]
|
||||
expected = ToolManager.BASE_TOOL_NAMES + ToolManager.BASE_MEMORY_TOOL_NAMES
|
||||
expected = BASE_TOOLS + BASE_MEMORY_TOOLS
|
||||
assert sorted(tool_names) == sorted(expected)
|
||||
|
||||
|
||||
|
||||
@@ -2,28 +2,27 @@ import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
from httpx._transports import default
|
||||
from numpy import source
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.functions.functions import derive_openai_json_schema, parse_source_code
|
||||
from letta.orm import (
|
||||
Agent,
|
||||
AgentPassage,
|
||||
Block,
|
||||
BlocksAgents,
|
||||
FileMetadata,
|
||||
Job,
|
||||
Message,
|
||||
Organization,
|
||||
AgentPassage,
|
||||
SourcePassage,
|
||||
SandboxConfig,
|
||||
SandboxEnvironmentVariable,
|
||||
Source,
|
||||
SourcePassage,
|
||||
SourcesAgents,
|
||||
Tool,
|
||||
ToolsAgents,
|
||||
@@ -202,9 +201,9 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
yield passage
|
||||
|
||||
@@ -220,9 +219,9 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
yield passage
|
||||
|
||||
@@ -240,9 +239,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
@@ -258,9 +257,9 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata_={"type": "test"}
|
||||
metadata_={"type": "test"},
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
if USING_SQLITE:
|
||||
@@ -452,7 +451,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
actor=actor,
|
||||
)
|
||||
source_passages.append(passage)
|
||||
|
||||
@@ -467,7 +466,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=actor
|
||||
actor=actor,
|
||||
)
|
||||
agent_passages.append(passage)
|
||||
|
||||
@@ -476,6 +475,7 @@ def agent_passages_setup(server, default_source, default_user, sarah_agent):
|
||||
# Cleanup
|
||||
server.source_manager.delete_source(default_source.id, actor=actor)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Basic
|
||||
# ======================================================================================================================
|
||||
@@ -940,32 +940,33 @@ def test_get_block_with_label(server: SyncServer, sarah_agent, default_block, de
|
||||
# Agent Manager - Passages Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_agent_list_passages_basic(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test basic listing functionality of agent passages"""
|
||||
|
||||
|
||||
all_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert len(all_passages) == 5 # 3 source + 2 agent passages
|
||||
|
||||
|
||||
def test_agent_list_passages_ordering(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test ordering of agent passages"""
|
||||
"""Test ordering of agent passages"""
|
||||
|
||||
# Test ascending order
|
||||
asc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=True)
|
||||
assert len(asc_passages) == 5
|
||||
for i in range(1, len(asc_passages)):
|
||||
assert asc_passages[i-1].created_at <= asc_passages[i].created_at
|
||||
assert asc_passages[i - 1].created_at <= asc_passages[i].created_at
|
||||
|
||||
# Test descending order
|
||||
desc_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, ascending=False)
|
||||
assert len(desc_passages) == 5
|
||||
for i in range(1, len(desc_passages)):
|
||||
assert desc_passages[i-1].created_at >= desc_passages[i].created_at
|
||||
assert desc_passages[i - 1].created_at >= desc_passages[i].created_at
|
||||
|
||||
|
||||
def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test pagination of agent passages"""
|
||||
|
||||
|
||||
# Test limit
|
||||
limited_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=3)
|
||||
assert len(limited_passages) == 3
|
||||
@@ -973,13 +974,9 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
|
||||
# Test cursor-based pagination
|
||||
first_page = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, limit=2, ascending=True)
|
||||
assert len(first_page) == 2
|
||||
|
||||
|
||||
second_page = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
cursor=first_page[-1].id,
|
||||
limit=2,
|
||||
ascending=True
|
||||
actor=default_user, agent_id=sarah_agent.id, cursor=first_page[-1].id, limit=2, ascending=True
|
||||
)
|
||||
assert len(second_page) == 2
|
||||
assert first_page[-1].id != second_page[0].id
|
||||
@@ -988,57 +985,38 @@ def test_agent_list_passages_pagination(server, default_user, sarah_agent, agent
|
||||
|
||||
def test_agent_list_passages_text_search(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
|
||||
# Test text search for source passages
|
||||
source_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Source passage"
|
||||
)
|
||||
source_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Source passage")
|
||||
assert len(source_text_passages) == 3
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
query_text="Agent passage"
|
||||
)
|
||||
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, query_text="Agent passage")
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_agent_only(server, default_user, sarah_agent, agent_passages_setup):
|
||||
"""Test text search functionality of agent passages"""
|
||||
|
||||
|
||||
# Test text search for agent passages
|
||||
agent_text_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
agent_only=True
|
||||
)
|
||||
agent_text_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agent_text_passages) == 2
|
||||
|
||||
|
||||
def test_agent_list_passages_filtering(server, default_user, sarah_agent, default_source, agent_passages_setup):
|
||||
"""Test filtering functionality of agent passages"""
|
||||
|
||||
|
||||
# Test source filtering
|
||||
source_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
source_id=default_source.id
|
||||
)
|
||||
source_filtered = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, source_id=default_source.id)
|
||||
assert len(source_filtered) == 3
|
||||
|
||||
# Test date filtering
|
||||
now = datetime.utcnow()
|
||||
future_date = now + timedelta(days=1)
|
||||
past_date = now - timedelta(days=1)
|
||||
|
||||
|
||||
date_filtered = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
agent_id=sarah_agent.id,
|
||||
start_date=past_date,
|
||||
end_date=future_date
|
||||
actor=default_user, agent_id=sarah_agent.id, start_date=past_date, end_date=future_date
|
||||
)
|
||||
assert len(date_filtered) == 5
|
||||
|
||||
@@ -1049,7 +1027,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
|
||||
# Create passages with known embeddings
|
||||
passages = []
|
||||
|
||||
|
||||
# Create passages with different embeddings
|
||||
test_passages = [
|
||||
"I like red",
|
||||
@@ -1058,7 +1036,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
]
|
||||
|
||||
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
|
||||
|
||||
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if i % 2 == 0:
|
||||
@@ -1067,7 +1045,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
organization_id=default_user.organization_id,
|
||||
agent_id=sarah_agent.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
embedding=embedding,
|
||||
)
|
||||
else:
|
||||
passage = PydanticPassage(
|
||||
@@ -1075,14 +1053,14 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=default_source.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = server.passage_manager.create_passage(passage, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
query_key = "What's my favorite color?"
|
||||
|
||||
|
||||
# Test vector search with all passages
|
||||
results = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
@@ -1091,7 +1069,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify results are ordered by similarity
|
||||
assert len(results) == 3
|
||||
assert results[0].text == "I like red"
|
||||
@@ -1105,9 +1083,9 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
query_text=query_key,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embed_query=True,
|
||||
agent_only=True
|
||||
agent_only=True,
|
||||
)
|
||||
|
||||
|
||||
# Verify agent-only results
|
||||
assert len(agent_only_results) == 2
|
||||
assert agent_only_results[0].text == "I like red"
|
||||
@@ -1116,7 +1094,7 @@ def test_agent_list_passages_vector_search(server, default_user, sarah_agent, de
|
||||
|
||||
def test_list_source_passages_only(server: SyncServer, default_user, default_source, agent_passages_setup):
|
||||
"""Test listing passages from a source without specifying an agent."""
|
||||
|
||||
|
||||
# List passages by source_id without agent_id
|
||||
source_passages = server.agent_manager.list_passages(
|
||||
actor=default_user,
|
||||
@@ -1180,6 +1158,7 @@ def test_list_organizations_pagination(server: SyncServer):
|
||||
# Passage Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_passage_create_agentic(server: SyncServer, agent_passage_fixture, default_user):
|
||||
"""Test creating a passage using agent_passage_fixture fixture"""
|
||||
assert agent_passage_fixture.id is not None
|
||||
@@ -1214,7 +1193,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
|
||||
"""Test creating an agent passage."""
|
||||
assert agent_passage_fixture is not None
|
||||
assert agent_passage_fixture.text == "Hello, I am an agent passage"
|
||||
|
||||
|
||||
# Try to create an invalid passage (with both agent_id and source_id)
|
||||
with pytest.raises(AssertionError):
|
||||
server.passage_manager.create_passage(
|
||||
@@ -1226,7 +1205,7 @@ def test_passage_create_invalid(server: SyncServer, agent_passage_fixture, defau
|
||||
embedding=[0.1] * 1024,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
@@ -1243,19 +1222,21 @@ def test_passage_get_by_id(server: SyncServer, agent_passage_fixture, source_pas
|
||||
assert retrieved.text == source_passage_fixture.text
|
||||
|
||||
|
||||
def test_passage_cascade_deletion(server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent):
|
||||
def test_passage_cascade_deletion(
|
||||
server: SyncServer, agent_passage_fixture, source_passage_fixture, default_user, default_source, sarah_agent
|
||||
):
|
||||
"""Test that passages are deleted when their parent (agent or source) is deleted."""
|
||||
# Verify passages exist
|
||||
agent_passage = server.passage_manager.get_passage_by_id(agent_passage_fixture.id, default_user)
|
||||
source_passage = server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
assert agent_passage is not None
|
||||
assert source_passage is not None
|
||||
|
||||
|
||||
# Delete agent and verify its passages are deleted
|
||||
server.agent_manager.delete_agent(sarah_agent.id, default_user)
|
||||
agentic_passages = server.agent_manager.list_passages(actor=default_user, agent_id=sarah_agent.id, agent_only=True)
|
||||
assert len(agentic_passages) == 0
|
||||
|
||||
|
||||
# Delete source and verify its passages are deleted
|
||||
server.source_manager.delete_source(default_source.id, default_user)
|
||||
with pytest.raises(NoResultFound):
|
||||
@@ -1320,7 +1301,6 @@ def test_create_tool(server: SyncServer, print_tool, default_user, default_organ
|
||||
assert print_tool.organization_id == default_organization.id
|
||||
|
||||
|
||||
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Test not applicable when using SQLite.")
|
||||
def test_create_tool_duplicate_name(server: SyncServer, print_tool, default_user, default_organization):
|
||||
data = print_tool.model_dump(exclude=["id"])
|
||||
@@ -1481,6 +1461,16 @@ def test_delete_tool_by_id(server: SyncServer, print_tool, default_user):
|
||||
assert len(tools) == 0
|
||||
|
||||
|
||||
def test_upsert_base_tools(server: SyncServer, default_user):
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
expected_tool_names = sorted(BASE_TOOLS + BASE_MEMORY_TOOLS)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
# Call it again to make sure it doesn't create duplicates
|
||||
tools = server.tool_manager.upsert_base_tools(actor=default_user)
|
||||
assert sorted([t.name for t in tools]) == expected_tool_names
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Message Manager Tests
|
||||
# ======================================================================================================================
|
||||
@@ -1889,6 +1879,7 @@ def test_update_source_no_changes(server: SyncServer, default_user):
|
||||
# Source Manager Tests - Files
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_get_file_by_id(server: SyncServer, default_user, default_source):
|
||||
"""Test retrieving a file by ID."""
|
||||
file_metadata = PydanticFileMetadata(
|
||||
@@ -1960,6 +1951,7 @@ def test_delete_file(server: SyncServer, default_user, default_source):
|
||||
# SandboxConfigManager Tests - Sandbox Configs
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_or_update_sandbox_config(server: SyncServer, default_user):
|
||||
sandbox_config_create = SandboxConfigCreate(
|
||||
config=E2BSandboxConfig(),
|
||||
@@ -2039,6 +2031,7 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
# SandboxConfigManager Tests - Environment Variables
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_sandbox_env_var(server: SyncServer, sandbox_config_fixture, default_user):
|
||||
env_var_create = SandboxEnvironmentVariableCreate(key="TEST_VAR", value="test_value", description="A test environment variable.")
|
||||
created_env_var = server.sandbox_config_manager.create_sandbox_env_var(
|
||||
@@ -2111,6 +2104,7 @@ def test_get_sandbox_env_var_by_key(server: SyncServer, sandbox_env_var_fixture,
|
||||
# JobManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_create_job(server: SyncServer, default_user):
|
||||
"""Test creating a job."""
|
||||
job_data = PydanticJob(
|
||||
|
||||
@@ -272,15 +272,15 @@ def test_update_tool(client, mock_sync_server, update_integers_tool, add_integer
|
||||
)
|
||||
|
||||
|
||||
def test_add_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.add_base_tools.return_value = [add_integers_tool]
|
||||
def test_upsert_base_tools(client, mock_sync_server, add_integers_tool):
|
||||
mock_sync_server.tool_manager.upsert_base_tools.return_value = [add_integers_tool]
|
||||
|
||||
response = client.post("/v1/tools/add-base-tools", headers={"user_id": "test_user"})
|
||||
|
||||
assert response.status_code == 200
|
||||
assert len(response.json()) == 1
|
||||
assert response.json()[0]["id"] == add_integers_tool.id
|
||||
mock_sync_server.tool_manager.add_base_tools.assert_called_once_with(
|
||||
mock_sync_server.tool_manager.upsert_base_tools.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user