feat(asyncify): migrate sources to async (#2332)

This commit is contained in:
Sarah Wooders
2025-05-22 19:39:44 -07:00
committed by GitHub
parent 658594c4c8
commit 1d2f8d86e9
11 changed files with 350 additions and 456 deletions

View File

@@ -37,7 +37,9 @@ class DataConnector:
"""
def load_data(connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User"):
async def load_data(
connector: DataConnector, source: Source, passage_manager: PassageManager, source_manager: SourceManager, actor: "User"
):
"""Load data from a connector (generates file and passages) into a specified source_id, associated with a user_id."""
embedding_config = source.embedding_config
@@ -51,7 +53,7 @@ def load_data(connector: DataConnector, source: Source, passage_manager: Passage
file_count = 0
for file_metadata in connector.find_files(source):
file_count += 1
source_manager.create_file(file_metadata, actor)
await source_manager.create_file(file_metadata, actor)
# generate passages
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):

View File

@@ -13,7 +13,6 @@ from sqlalchemy.orm import sessionmaker
from letta.config import LettaConfig
from letta.log import get_logger
from letta.settings import settings
from letta.tracing import trace_method
logger = get_logger(__name__)
@@ -203,7 +202,6 @@ class DatabaseRegistry:
self.initialize_async()
return self._async_session_factories.get(name)
@trace_method
@contextmanager
def session(self, name: str = "default") -> Generator[Any, None, None]:
"""Context manager for database sessions."""
@@ -217,7 +215,6 @@ class DatabaseRegistry:
finally:
session.close()
@trace_method
@asynccontextmanager
async def async_session(self, name: str = "default") -> AsyncGenerator[AsyncSession, None]:
"""Async context manager for database sessions."""

View File

@@ -297,7 +297,7 @@ def detach_tool(
@router.patch("/{agent_id}/sources/attach/{source_id}", response_model=AgentState, operation_id="attach_source_to_agent")
def attach_source(
async def attach_source(
agent_id: str,
source_id: str,
background_tasks: BackgroundTasks,
@@ -310,7 +310,7 @@ def attach_source(
actor = server.user_manager.get_user_or_default(user_id=actor_id)
agent = server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor)
if agent.enable_sleeptime:
source = server.source_manager.get_source_by_id(source_id=source_id)
source = await server.source_manager.get_source_by_id_async(source_id=source_id)
background_tasks.add_task(server.sleeptime_document_ingest, agent, source, actor)
return agent

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import tempfile
from typing import List, Optional
@@ -21,18 +22,18 @@ router = APIRouter(prefix="/sources", tags=["sources"])
@router.get("/count", response_model=int, operation_id="count_sources")
def count_sources(
async def count_sources(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
"""
Count all data sources created by a user.
"""
return server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
return await server.source_manager.size(actor=server.user_manager.get_user_or_default(user_id=actor_id))
@router.get("/{source_id}", response_model=Source, operation_id="retrieve_source")
def retrieve_source(
async def retrieve_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -42,14 +43,14 @@ def retrieve_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with id={source_id} not found.")
return source
@router.get("/name/{source_name}", response_model=str, operation_id="get_source_id_by_name")
def get_source_id_by_name(
async def get_source_id_by_name(
source_name: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -59,14 +60,14 @@ def get_source_id_by_name(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
source = await server.source_manager.get_source_by_name(source_name=source_name, actor=actor)
if not source:
raise HTTPException(status_code=404, detail=f"Source with name={source_name} not found.")
return source.id
@router.get("/", response_model=List[Source], operation_id="list_sources")
def list_sources(
async def list_sources(
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -74,12 +75,11 @@ def list_sources(
List all data sources created by a user.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.list_all_sources(actor=actor)
return await server.source_manager.list_sources(actor=actor)
@router.post("/", response_model=Source, operation_id="create_source")
def create_source(
async def create_source(
source_create: SourceCreate,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -88,6 +88,8 @@ def create_source(
Create a new data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
# TODO: need to asyncify this
if not source_create.embedding_config:
if not source_create.embedding:
# TODO: modify error type
@@ -104,11 +106,11 @@ def create_source(
instructions=source_create.instructions,
metadata=source_create.metadata,
)
return server.source_manager.create_source(source=source, actor=actor)
return await server.source_manager.create_source(source=source, actor=actor)
@router.patch("/{source_id}", response_model=Source, operation_id="modify_source")
def modify_source(
async def modify_source(
source_id: str,
source: SourceUpdate,
server: "SyncServer" = Depends(get_letta_server),
@@ -119,13 +121,13 @@ def modify_source(
"""
# TODO: allow updating the handle/embedding config
actor = server.user_manager.get_user_or_default(user_id=actor_id)
if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
if not await server.source_manager.get_source_by_id(source_id=source_id, actor=actor):
raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.")
return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
return await server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor)
@router.delete("/{source_id}", response_model=None, operation_id="delete_source")
def delete_source(
async def delete_source(
source_id: str,
server: "SyncServer" = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
@@ -134,20 +136,21 @@ def delete_source(
Delete a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id)
agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:
try:
# TODO: make async
block = server.agent_manager.get_block_with_label(agent_id=agent.id, block_label=source.name, actor=actor)
server.block_manager.delete_block(block.id, actor)
except:
pass
server.delete_source(source_id=source_id, actor=actor)
await server.delete_source(source_id=source_id, actor=actor)
@router.post("/{source_id}/upload", response_model=Job, operation_id="upload_file_to_source")
def upload_file_to_source(
async def upload_file_to_source(
file: UploadFile,
source_id: str,
background_tasks: BackgroundTasks,
@@ -159,7 +162,7 @@ def upload_file_to_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
source = await server.source_manager.get_source_by_id(source_id=source_id, actor=actor)
assert source is not None, f"Source with id={source_id} not found."
bytes = file.file.read()
@@ -173,8 +176,8 @@ def upload_file_to_source(
server.job_manager.create_job(job, actor=actor)
# create background tasks
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor)
background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor)
asyncio.create_task(load_file_to_source_async(server, source_id=source.id, file=file, job_id=job.id, bytes=bytes, actor=actor))
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor))
# return job information
# Is this necessary? Can we just return the job from create_job?
@@ -184,8 +187,11 @@ def upload_file_to_source(
@router.get("/{source_id}/passages", response_model=List[Passage], operation_id="list_source_passages")
def list_source_passages(
async def list_source_passages(
source_id: str,
after: Optional[str] = Query(None, description="Message after which to retrieve the returned messages."),
before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."),
limit: int = Query(100, description="Maximum number of messages to retrieve."),
server: SyncServer = Depends(get_letta_server),
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
):
@@ -193,12 +199,17 @@ def list_source_passages(
List all passages associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id)
return passages
return await server.agent_manager.list_passages_async(
actor=actor,
source_id=source_id,
after=after,
before=before,
limit=limit,
)
@router.get("/{source_id}/files", response_model=List[FileMetadata], operation_id="list_source_files")
def list_source_files(
async def list_source_files(
source_id: str,
limit: int = Query(1000, description="Number of files to return"),
after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"),
@@ -209,13 +220,13 @@ def list_source_files(
List paginated files associated with a data source.
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
return await server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor)
# it's redundant to include /delete in the URL path. The HTTP verb DELETE already implies that action.
# it's still good practice to return a status indicating the success or failure of the deletion
@router.delete("/{source_id}/{file_id}", status_code=204, operation_id="delete_file_from_source")
def delete_file_from_source(
async def delete_file_from_source(
source_id: str,
file_id: str,
background_tasks: BackgroundTasks,
@@ -227,13 +238,15 @@ def delete_file_from_source(
"""
actor = server.user_manager.get_user_or_default(user_id=actor_id)
deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor)
background_tasks.add_task(sleeptime_document_ingest_async, server, source_id, actor, clear_history=True)
deleted_file = await server.source_manager.delete_file(file_id=file_id, actor=actor)
# TODO: make async
asyncio.create_task(sleeptime_document_ingest_async(server, source_id, actor, clear_history=True))
if deleted_file is None:
raise HTTPException(status_code=404, detail=f"File with id={file_id} not found.")
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
async def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes, actor: User):
# Create a temporary directory (deleted after the context manager exits)
with tempfile.TemporaryDirectory() as tmpdirname:
# Sanitize the filename
@@ -245,12 +258,12 @@ def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, f
buffer.write(bytes)
# Pass the file to load_file_to_source
server.load_file_to_source(source_id, file_path, job_id, actor)
await server.load_file_to_source(source_id, file_path, job_id, actor)
def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = server.source_manager.get_source_by_id(source_id=source_id)
agents = server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
async def sleeptime_document_ingest_async(server: SyncServer, source_id: str, actor: User, clear_history: bool = False):
source = await server.source_manager.get_source_by_id(source_id=source_id)
agents = await server.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent in agents:
if agent.enable_sleeptime:
server.sleeptime_document_ingest(agent, source, actor, clear_history)
server.sleeptime_document_ingest(agent, source, actor, clear_history) # TODO: make async

View File

@@ -1169,17 +1169,20 @@ class SyncServer(Server):
# rebuild system prompt for agent, potentially changed
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
def delete_source(self, source_id: str, actor: User):
async def delete_source(self, source_id: str, actor: User):
"""Delete a data source"""
self.source_manager.delete_source(source_id=source_id, actor=actor)
await self.source_manager.delete_source(source_id=source_id, actor=actor)
# delete data from passage store
# TODO: make async
passages_to_be_deleted = self.agent_manager.list_passages(actor=actor, source_id=source_id, limit=None)
# TODO: make this async
self.passage_manager.delete_passages(actor=actor, passages=passages_to_be_deleted)
# TODO: delete data from agent passage stores (?)
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
async def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
# update job
job = self.job_manager.get_job_by_id(job_id, actor=actor)
@@ -1189,21 +1192,22 @@ class SyncServer(Server):
# try:
from letta.data_sources.connectors import DirectoryConnector
source = self.source_manager.get_source_by_id(source_id=source_id)
# TODO: move this into a thread
source = await self.source_manager.get_source_by_id(source_id=source_id)
if source is None:
raise ValueError(f"Source {source_id} does not exist")
connector = DirectoryConnector(input_files=[file_path])
num_passages, num_documents = self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
num_passages, num_documents = await self.load_data(user_id=source.created_by_id, source_name=source.name, connector=connector)
# update all agents who have this source attached
agent_states = self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
agent_states = await self.source_manager.list_attached_agents(source_id=source_id, actor=actor)
for agent_state in agent_states:
agent_id = agent_state.id
# Attach source to agent
curr_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
curr_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
agent_state = self.agent_manager.attach_source(agent_id=agent_state.id, source_id=source_id, actor=actor)
new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id)
new_passage_size = await self.agent_manager.passage_size_async(actor=actor, agent_id=agent_id)
assert new_passage_size >= curr_passage_size # in case empty files are added
# rebuild system prompt and force
@@ -1266,7 +1270,7 @@ class SyncServer(Server):
actor=actor,
)
def load_data(
async def load_data(
self,
user_id: str,
connector: DataConnector,
@@ -1277,12 +1281,12 @@ class SyncServer(Server):
# load data from a data source into the document store
user = self.user_manager.get_user_by_id(user_id=user_id)
source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
source = await self.source_manager.get_source_by_name(source_name=source_name, actor=user)
if source is None:
raise ValueError(f"Data source {source_name} does not exist for user {user_id}")
# load data into the document store
passage_count, document_count = load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
passage_count, document_count = await load_data(connector, source, self.passage_manager, self.source_manager, actor=user)
return passage_count, document_count
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
@@ -1290,6 +1294,7 @@ class SyncServer(Server):
return self.agent_manager.list_passages(actor=self.user_manager.get_user_or_default(user_id=user_id), source_id=source_id)
def list_all_sources(self, actor: User) -> List[Source]:
# TODO: legacy: remove
"""List all sources (w/ extra metadata) belonging to a user"""
sources = self.source_manager.list_sources(actor=actor)

View File

@@ -2127,6 +2127,44 @@ class AgentManager:
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0
@enforce_types
async def passage_size_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> int:
async with db_registry.async_session() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
before=before,
after=after,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return (await session.execute(count_query)).scalar() or 0
# ======================================================================================================================
# Tool Management
# ======================================================================================================================

View File

@@ -1,3 +1,4 @@
import asyncio
from typing import List, Optional
from letta.orm.errors import NoResultFound
@@ -18,26 +19,26 @@ class SourceManager:
@enforce_types
@trace_method
def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
# Try getting the source first by id
db_source = self.get_source_by_id(source.id, actor=actor)
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source
else:
with db_registry.session() as session:
async with db_registry.async_session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
source.create(session, actor=actor)
await source.create_async(session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
@@ -57,19 +58,21 @@ class SourceManager:
@enforce_types
@trace_method
def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
with db_registry.session() as session:
source = SourceModel.read(db_session=session, identifier=source_id)
source.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id)
await source.hard_delete_async(db_session=session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
def list_sources(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs) -> List[PydanticSource]:
async def list_sources(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[PydanticSource]:
"""List all sources with optional pagination."""
with db_registry.session() as session:
sources = SourceModel.list(
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
after=after,
limit=limit,
@@ -80,19 +83,16 @@ class SourceManager:
@enforce_types
@trace_method
def size(
self,
actor: PydanticUser,
) -> int:
async def size(self, actor: PydanticUser) -> int:
"""
Get the total count of sources for the given user.
"""
with db_registry.session() as session:
return SourceModel.size(db_session=session, actor=actor)
async with db_registry.async_session() as session:
return await SourceModel.size_async(db_session=session, actor=actor)
@enforce_types
@trace_method
def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
"""
Lists all agents that have the specified source attached.
@@ -103,32 +103,33 @@ class SourceManager:
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
with db_registry.session() as session:
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# The agents relationship is already loaded due to lazy="selectin" in the Source model
# and will be properly filtered by organization_id due to the OrganizationMixin
return [agent.to_pydantic() for agent in source.agents]
agents_orm = source.agents
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
with db_registry.session() as session:
async with db_registry.async_session() as session:
try:
source = SourceModel.read(db_session=session, identifier=source_id, actor=actor)
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
return source.to_pydantic()
except NoResultFound:
return None
@enforce_types
@trace_method
def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its name."""
with db_registry.session() as session:
sources = SourceModel.list(
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
name=source_name,
organization_id=actor.organization_id,
@@ -141,47 +142,48 @@ class SourceManager:
@enforce_types
@trace_method
def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata:
async def create_file(self, file_metadata: PydanticFileMetadata, actor: PydanticUser) -> PydanticFileMetadata:
"""Create a new file based on the PydanticFileMetadata schema."""
db_file = self.get_file_by_id(file_metadata.id, actor=actor)
db_file = await self.get_file_by_id(file_metadata.id, actor=actor)
if db_file:
return db_file
else:
with db_registry.session() as session:
async with db_registry.async_session() as session:
file_metadata.organization_id = actor.organization_id
file_metadata = FileMetadataModel(**file_metadata.model_dump(to_orm=True, exclude_none=True))
file_metadata.create(session, actor=actor)
await file_metadata.create_async(session, actor=actor)
return file_metadata.to_pydantic()
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
async def get_file_by_id(self, file_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticFileMetadata]:
"""Retrieve a file by its ID."""
with db_registry.session() as session:
async with db_registry.async_session() as session:
try:
file = FileMetadataModel.read(db_session=session, identifier=file_id, actor=actor)
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id, actor=actor)
return file.to_pydantic()
except NoResultFound:
return None
@enforce_types
@trace_method
def list_files(
async def list_files(
self, source_id: str, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50
) -> List[PydanticFileMetadata]:
"""List all files with optional pagination."""
with db_registry.session() as session:
files = FileMetadataModel.list(
async with db_registry.async_session() as session:
files_all = await FileMetadataModel.list_async(db_session=session, organization_id=actor.organization_id, source_id=source_id)
files = await FileMetadataModel.list_async(
db_session=session, after=after, limit=limit, organization_id=actor.organization_id, source_id=source_id
)
return [file.to_pydantic() for file in files]
@enforce_types
@trace_method
def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
async def delete_file(self, file_id: str, actor: PydanticUser) -> PydanticFileMetadata:
"""Delete a file by its ID."""
with db_registry.session() as session:
file = FileMetadataModel.read(db_session=session, identifier=file_id)
file.hard_delete(db_session=session, actor=actor)
async with db_registry.async_session() as session:
file = await FileMetadataModel.read_async(db_session=session, identifier=file_id)
await file.hard_delete_async(db_session=session, actor=actor)
return file.to_pydantic()

View File

@@ -426,95 +426,6 @@ def test_load_file(client: RESTClient, agent: AgentState):
assert file.source_id == source.id
def test_sources(client: RESTClient, agent: AgentState):
# _reset_config()
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 0
# create a source
source = client.create_source(name="test_source")
# list sources
sources = client.list_sources()
print("listed sources", sources)
assert len(sources) == 1
# TODO: add back?
assert sources[0].metadata["num_passages"] == 0
assert sources[0].metadata["num_documents"] == 0
# update the source
original_id = source.id
original_name = source.name
new_name = original_name + "_new"
client.update_source(source_id=source.id, name=new_name)
# get the source name (check that it's been updated)
source = client.get_source(source_id=source.id)
assert source.name == new_name
assert source.id == original_id
# get the source id (make sure that it's the same)
assert str(original_id) == client.get_source_id(source_name=new_name)
# check agent archival memory size
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_job = upload_file_using_client(client, source, filename)
job = client.get_job(upload_job.id)
created_passages = job.metadata["num_passages"]
# TODO: add test for blocking job
# TODO: make sure things run in the right order
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0
# attach a source
client.attach_source(source_id=source.id, agent_id=agent.id)
# list attached sources
attached_sources = client.list_attached_sources(agent_id=agent.id)
print("attached sources", attached_sources)
assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}"
# list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id)
# print(archival_memories)
assert len(archival_memories) == created_passages, f"Mismatched length {len(archival_memories)} vs. {created_passages}"
# check number of passages
sources = client.list_sources()
# TODO: add back?
# assert sources.sources[0].metadata["num_passages"] > 0
# assert sources.sources[0].metadata["num_documents"] == 0 # TODO: fix this once document store added
print(sources)
# detach the source
assert len(client.get_archival_memory(agent_id=agent.id)) > 0, "No archival memory"
client.detach_source(source_id=source.id, agent_id=agent.id)
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0, f"Failed to detach source: {len(archival_memories)}"
assert source.id not in [s.id for s in client.list_attached_sources(agent.id)]
# delete the source
client.delete_source(source.id)
def test_organization(client: RESTClient):
# create an organization
org_name = "test-org"

View File

@@ -140,32 +140,32 @@ async def other_user_different_org(server: SyncServer, other_organization):
@pytest.fixture
def default_source(server: SyncServer, default_user):
async def default_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Test Source",
description="This is a test source.",
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
yield source
@pytest.fixture
def other_source(server: SyncServer, default_user):
async def other_source(server: SyncServer, default_user):
source_pydantic = PydanticSource(
name="Another Test Source",
description="This is yet another test source.",
metadata={"type": "another_test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
yield source
@pytest.fixture
def default_file(server: SyncServer, default_source, default_user, default_organization):
file = server.source_manager.create_file(
async def default_file(server: SyncServer, default_source, default_user, default_organization):
file = await server.source_manager.create_file(
PydanticFileMetadata(file_name="test_file", organization_id=default_organization.id, source_id=default_source.id),
actor=default_user,
)
@@ -1175,17 +1175,18 @@ async def test_list_attached_source_ids_nonexistent_agent(server: SyncServer, de
await server.agent_manager.list_attached_sources_async(agent_id="nonexistent-agent-id", actor=default_user)
def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user):
@pytest.mark.asyncio
async def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, default_source, default_user, event_loop):
"""Test listing agents that have a particular source attached."""
# Initially should have no attached agents
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 0
# Attach source to first agent
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
# Verify one agent is now attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 1
assert sarah_agent.id in [a.id for a in attached_agents]
@@ -1193,7 +1194,7 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de
server.agent_manager.attach_source(agent_id=charles_agent.id, source_id=default_source.id, actor=default_user)
# Verify both agents are now attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 2
attached_agent_ids = [a.id for a in attached_agents]
assert sarah_agent.id in attached_agent_ids
@@ -1203,15 +1204,16 @@ def test_list_attached_agents(server: SyncServer, sarah_agent, charles_agent, de
server.agent_manager.detach_source(agent_id=sarah_agent.id, source_id=default_source.id, actor=default_user)
# Verify only second agent remains attached
attached_agents = server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
attached_agents = await server.source_manager.list_attached_agents(source_id=default_source.id, actor=default_user)
assert len(attached_agents) == 1
assert charles_agent.id in [a.id for a in attached_agents]
def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_attached_agents_nonexistent_source(server: SyncServer, default_user):
"""Test listing agents for a nonexistent source."""
with pytest.raises(NoResultFound):
server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)
await server.source_manager.list_attached_agents(source_id="nonexistent-source-id", actor=default_user)
# ======================================================================================================================
@@ -2137,7 +2139,7 @@ async def test_passage_cascade_deletion(
assert len(agentic_passages) == 0
# Delete source and verify its passages are deleted
server.source_manager.delete_source(default_source.id, default_user)
await server.source_manager.delete_source(default_source.id, default_user)
with pytest.raises(NoResultFound):
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
@@ -3807,7 +3809,10 @@ async def test_upsert_properties(server: SyncServer, default_user, event_loop):
# ======================================================================================================================
# SourceManager Tests - Sources
# ======================================================================================================================
def test_create_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_source(server: SyncServer, default_user, event_loop):
"""Test creating a new source."""
source_pydantic = PydanticSource(
name="Test Source",
@@ -3815,7 +3820,7 @@ def test_create_source(server: SyncServer, default_user):
metadata={"type": "test"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Assertions to check the created source
assert source.name == source_pydantic.name
@@ -3824,7 +3829,8 @@ def test_create_source(server: SyncServer, default_user):
assert source.organization_id == default_user.organization_id
def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_create_sources_with_same_name_does_not_error(server: SyncServer, default_user):
"""Test creating a new source."""
name = "Test Source"
source_pydantic = PydanticSource(
@@ -3833,27 +3839,28 @@ def test_create_sources_with_same_name_does_not_error(server: SyncServer, defaul
metadata={"type": "medical"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
source_pydantic = PydanticSource(
name=name,
description="This is a different test source.",
metadata={"type": "legal"},
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
same_source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
same_source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
assert source.name == same_source.name
assert source.id != same_source.id
def test_update_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_source(server: SyncServer, default_user):
"""Test updating an existing source."""
source_pydantic = PydanticSource(name="Original Source", description="Original description", embedding_config=DEFAULT_EMBEDDING_CONFIG)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Update the source
update_data = SourceUpdate(name="Updated Source", description="Updated description", metadata={"type": "updated"})
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
# Assertions to verify update
assert updated_source.name == update_data.name
@@ -3861,21 +3868,22 @@ def test_update_source(server: SyncServer, default_user):
assert updated_source.metadata == update_data.metadata
def test_delete_source(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_delete_source(server: SyncServer, default_user):
"""Test deleting a source."""
source_pydantic = PydanticSource(
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Delete the source
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)
# Assertions to verify deletion
assert deleted_source.id == source.id
# Verify that the source no longer appears in list_sources
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 0
@@ -3885,18 +3893,18 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u
source_pydantic = PydanticSource(
name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user)
# Delete the source
deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user)
deleted_source = await server.source_manager.delete_source(source_id=source.id, actor=default_user)
# Assertions to verify deletion
assert deleted_source.id == source.id
# Verify that the source no longer appears in list_sources
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 0
# Verify that agent is not deleted
@@ -3904,37 +3912,43 @@ async def test_delete_attached_source(server: SyncServer, sarah_agent, default_u
assert agent is not None
def test_list_sources(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_list_sources(server: SyncServer, default_user):
"""Test listing sources with pagination."""
# Create multiple sources
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
await server.source_manager.create_source(
PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
await server.source_manager.create_source(
PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user
)
# List sources without pagination
sources = server.source_manager.list_sources(actor=default_user)
sources = await server.source_manager.list_sources(actor=default_user)
assert len(sources) == 2
# List sources with pagination
paginated_sources = server.source_manager.list_sources(actor=default_user, limit=1)
paginated_sources = await server.source_manager.list_sources(actor=default_user, limit=1)
assert len(paginated_sources) == 1
# Ensure cursor-based pagination works
next_page = server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1)
next_page = await server.source_manager.list_sources(actor=default_user, after=paginated_sources[-1].id, limit=1)
assert len(next_page) == 1
assert next_page[0].name != paginated_sources[0].name
def test_get_source_by_id(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_source_by_id(server: SyncServer, default_user):
"""Test retrieving a source by ID."""
source_pydantic = PydanticSource(
name="Retrieve by ID", description="Test source for ID retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Retrieve the source by ID
retrieved_source = server.source_manager.get_source_by_id(source_id=source.id, actor=default_user)
retrieved_source = await server.source_manager.get_source_by_id(source_id=source.id, actor=default_user)
# Assertions to verify the retrieved source matches the created one
assert retrieved_source.id == source.id
@@ -3942,29 +3956,31 @@ def test_get_source_by_id(server: SyncServer, default_user):
assert retrieved_source.description == source.description
def test_get_source_by_name(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_get_source_by_name(server: SyncServer, default_user):
"""Test retrieving a source by name."""
source_pydantic = PydanticSource(
name="Unique Source", description="Test source for name retrieval", embedding_config=DEFAULT_EMBEDDING_CONFIG
)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Retrieve the source by name
retrieved_source = server.source_manager.get_source_by_name(source_name=source.name, actor=default_user)
retrieved_source = await server.source_manager.get_source_by_name(source_name=source.name, actor=default_user)
# Assertions to verify the retrieved source matches the created one
assert retrieved_source.name == source.name
assert retrieved_source.description == source.description
def test_update_source_no_changes(server: SyncServer, default_user):
@pytest.mark.asyncio
async def test_update_source_no_changes(server: SyncServer, default_user):
"""Test update_source with no actual changes to verify logging and response."""
source_pydantic = PydanticSource(name="No Change Source", description="No changes", embedding_config=DEFAULT_EMBEDDING_CONFIG)
source = server.source_manager.create_source(source=source_pydantic, actor=default_user)
source = await server.source_manager.create_source(source=source_pydantic, actor=default_user)
# Attempt to update the source with identical data
update_data = SourceUpdate(name="No Change Source", description="No changes")
updated_source = server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
updated_source = await server.source_manager.update_source(source_id=source.id, source_update=update_data, actor=default_user)
# Assertions to ensure the update returned the source but made no modifications
assert updated_source.id == source.id
@@ -3977,7 +3993,8 @@ def test_update_source_no_changes(server: SyncServer, default_user):
# ======================================================================================================================
def test_get_file_by_id(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_get_file_by_id(server: SyncServer, default_user, default_source):
"""Test retrieving a file by ID."""
file_metadata = PydanticFileMetadata(
file_name="Retrieve File",
@@ -3986,10 +4003,10 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source):
file_size=2048,
source_id=default_source.id,
)
created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
# Retrieve the file by ID
retrieved_file = server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user)
retrieved_file = await server.source_manager.get_file_by_id(file_id=created_file.id, actor=default_user)
# Assertions to verify the retrieved file matches the created one
assert retrieved_file.id == created_file.id
@@ -3998,49 +4015,53 @@ def test_get_file_by_id(server: SyncServer, default_user, default_source):
assert retrieved_file.file_type == created_file.file_type
def test_list_files(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_list_files(server: SyncServer, default_user, default_source):
"""Test listing files with pagination."""
# Create multiple files
server.source_manager.create_file(
await server.source_manager.create_file(
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_file(
await server.source_manager.create_file(
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
# List files without pagination
files = server.source_manager.list_files(source_id=default_source.id, actor=default_user)
files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user)
assert len(files) == 2
# List files with pagination
paginated_files = server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1)
paginated_files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user, limit=1)
assert len(paginated_files) == 1
# Ensure cursor-based pagination works
next_page = server.source_manager.list_files(source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1)
next_page = await server.source_manager.list_files(
source_id=default_source.id, actor=default_user, after=paginated_files[-1].id, limit=1
)
assert len(next_page) == 1
assert next_page[0].file_name != paginated_files[0].file_name
def test_delete_file(server: SyncServer, default_user, default_source):
@pytest.mark.asyncio
async def test_delete_file(server: SyncServer, default_user, default_source):
"""Test deleting a file."""
file_metadata = PydanticFileMetadata(
file_name="Delete File", file_path="/path/to/delete_file.txt", file_type="text/plain", source_id=default_source.id
)
created_file = server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
created_file = await server.source_manager.create_file(file_metadata=file_metadata, actor=default_user)
# Delete the file
deleted_file = server.source_manager.delete_file(file_id=created_file.id, actor=default_user)
deleted_file = await server.source_manager.delete_file(file_id=created_file.id, actor=default_user)
# Assertions to verify deletion
assert deleted_file.id == created_file.id
# Verify that the file no longer appears in list_files
files = server.source_manager.list_files(source_id=default_source.id, actor=default_user)
files = await server.source_manager.list_files(source_id=default_source.id, actor=default_user)
assert len(files) == 0

View File

@@ -680,3 +680,72 @@ def test_many_blocks(client: LettaSDKClient):
client.agents.delete(agent1.id)
client.agents.delete(agent2.id)
def test_sources(client: LettaSDKClient, agent: AgentState):
# Clear existing sources
for source in client.sources.list():
client.sources.delete(source_id=source.id)
# Clear existing jobs
for job in client.jobs.list():
client.jobs.delete(job_id=job.id)
# Create a new source
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
assert len(client.sources.list()) == 1
# delete the source
client.sources.delete(source_id=source.id)
assert len(client.sources.list()) == 0
source = client.sources.create(name="test_source", embedding="openai/text-embedding-ada-002")
# Load files into the source
file_a_path = "tests/data/memgpt_paper.pdf"
file_b_path = "tests/data/test.txt"
# Upload the files
with open(file_a_path, "rb") as f:
job_a = client.sources.files.upload(source_id=source.id, file=f)
with open(file_b_path, "rb") as f:
job_b = client.sources.files.upload(source_id=source.id, file=f)
# Wait for the jobs to complete
while job_a.status != "completed" or job_b.status != "completed":
time.sleep(1)
job_a = client.jobs.retrieve(job_id=job_a.id)
job_b = client.jobs.retrieve(job_id=job_b.id)
print("Waiting for jobs to complete...", job_a.status, job_b.status)
# Get the first file with pagination
files_a = client.sources.files.list(source_id=source.id, limit=1)
assert len(files_a) == 1
assert files_a[0].source_id == source.id
# Use the cursor from files_a to get the remaining file
files_b = client.sources.files.list(source_id=source.id, limit=1, after=files_a[-1].id)
assert len(files_b) == 1
assert files_b[0].source_id == source.id
# Check files are different to ensure the cursor works
assert files_a[0].file_name != files_b[0].file_name
# Use the cursor from files_b to list files, should be empty
files = client.sources.files.list(source_id=source.id, limit=1, after=files_b[-1].id)
assert len(files) == 0 # Should be empty
# list passages
passages = client.sources.passages.list(source_id=source.id)
assert len(passages) > 0
# attach to an agent
assert len(client.agents.passages.list(agent_id=agent.id)) == 0
client.agents.sources.attach(source_id=source.id, agent_id=agent.id)
assert len(client.agents.passages.list(agent_id=agent.id)) > 0
assert len(client.agents.sources.list(agent_id=agent.id)) == 1
# detach from agent
client.agents.sources.detach(source_id=source.id, agent_id=agent.id)
assert len(client.agents.passages.list(agent_id=agent.id)) == 0

View File

@@ -24,15 +24,10 @@ from letta.server.db import db_registry
utils.DEBUG = True
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent, UpdateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.job import Job as PydanticJob
from letta.schemas.message import Message
from letta.schemas.source import Source as PydanticSource
from letta.server.server import SyncServer
from letta.system import unpack_message
from .utils import DummyDataConnector
WAR_AND_PEACE = """BOOK ONE: 1805
CHAPTER I
@@ -390,40 +385,6 @@ def test_user_message_memory(server, user, agent_id):
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
@pytest.mark.order(3)
def test_load_data(server, user, agent_id):
# create source
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000)
assert len(passages_before) == 0
source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)
# load data
archival_memories = [
"alpha",
"Cinderella wore a blue dress",
"Dog eat dog",
"ZZZ",
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user.id, connector, source.name)
# attach source
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
# check archival memory size
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, after=None, limit=10000)
assert len(passages_after) == 5
def test_save_archival_memory(server, user_id, agent_id):
# TODO: insert into archival memory
pass
@pytest.mark.order(4)
def test_user_message(server, user, agent_id):
# add data into recall memory
@@ -456,54 +417,54 @@ def test_get_recall_memory(server, org_id, user, agent_id):
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@pytest.mark.order(6)
def test_get_archival_memory(server, user, agent_id):
# test archival memory cursor pagination
actor = user
# List latest 2 passages
passages_1 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
limit=2,
)
assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
before=cursor1,
)
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.agent_manager.list_passages(
actor=actor,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
# @pytest.mark.order(6)
# def test_get_archival_memory(server, user, agent_id):
# # test archival memory cursor pagination
# actor = user
#
# # List latest 2 passages
# passages_1 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# limit=2,
# )
# assert len(passages_1) == 2, f"Returned {[p.text for p in passages_1]}, not equal to 2"
#
# # List next 3 passages (earliest 3)
# cursor1 = passages_1[-1].id
# passages_2 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# before=cursor1,
# )
#
# # List all 5
# cursor2 = passages_1[0].created_at
# passages_3 = server.agent_manager.list_passages(
# actor=actor,
# agent_id=agent_id,
# ascending=False,
# end_date=cursor2,
# limit=1000,
# )
# assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
# assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
#
# latest = passages_1[0]
# earliest = passages_2[-1]
#
# # test archival memory
# passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
# assert len(passage_1) == 1
# assert passage_1[0].text == "alpha"
# passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=earliest.id, limit=1000, ascending=True)
# assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
# assert all("alpha" not in passage.text for passage in passage_2)
# # test safe empty return
# passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, after=latest.id, limit=1000, ascending=True)
# assert len(passage_none) == 0
def test_get_context_window_overview(server: SyncServer, user, agent_id):
@@ -985,131 +946,6 @@ async def test_memory_rebuild_count(server, user, disable_e2b_api_key, base_tool
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)
existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
name="timber_source",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
created_by_id=user_id,
),
actor=actor,
)
assert source.created_by_id == user_id
# Create a test file with some content
test_file = tmp_path / "test.txt"
test_content = "We have a dog called Timber. He likes to sleep and eat chicken."
test_file.write_text(test_content)
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata={"type": "embedding", "filename": test_file.name, "source_id": source.id},
),
actor=actor,
)
# Load the first file to source
server.load_file_to_source(
source_id=source.id,
file_path=str(test_file),
job_id=job.id,
actor=actor,
)
# Verify job completed successfully
job = server.job_manager.get_job_by_id(job_id=job.id, actor=actor)
assert job.status == "completed"
assert job.metadata["num_passages"] == 1
assert job.metadata["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
test_file2 = tmp_path / "test2.txt"
test_file2.write_text(WAR_AND_PEACE)
# Create a job for loading the second file
job2 = server.job_manager.create_job(
PydanticJob(
user_id=user_id,
metadata={"type": "embedding", "filename": test_file2.name, "source_id": source.id},
),
actor=actor,
)
# Load the second file to source
server.load_file_to_source(
source_id=source.id,
file_path=str(test_file2),
job_id=job2.id,
actor=actor,
)
# Verify second job completed successfully
job2 = server.job_manager.get_job_by_id(job_id=job2.id, actor=actor)
assert job2.status == "completed"
assert job2.metadata["num_passages"] >= 10
assert job2.metadata["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.agent_manager.list_passages(
agent_id=agent_id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages) == final_passage_count
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0
# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages
# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)
def test_add_nonexisting_tool(server: SyncServer, user_id: str, base_tools):
actor = server.user_manager.get_user_or_default(user_id)