From 318a7c769b2f43f9078d39c4937d1eb9e6a1bb09 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 6 Jun 2025 15:34:03 -0700 Subject: [PATCH] feat: Search files returns citations of the filenames that were searched (#2689) --- .gitignore | 2 + ...433aef_add_file_name_to_source_passages.py | 40 ++ letta/agent.py | 4 +- letta/agents/base_agent.py | 2 +- letta/agents/letta_agent.py | 2 +- letta/agents/voice_agent.py | 2 +- letta/orm/passage.py | 2 + letta/schemas/passage.py | 1 + letta/services/agent_manager.py | 6 +- .../context_window_calculator.py | 2 +- .../services/file_processor/file_processor.py | 4 +- .../services/helpers/agent_manager_helper.py | 39 +- letta/services/passage_manager.py | 621 +++++++++++++++++- .../tool_executor/files_tool_executor.py | 11 +- tests/test_managers.py | 437 +++++++++++- 15 files changed, 1132 insertions(+), 43 deletions(-) create mode 100644 alembic/versions/c96263433aef_add_file_name_to_source_passages.py diff --git a/.gitignore b/.gitignore index baaeabfb..920d0eb9 100644 --- a/.gitignore +++ b/.gitignore @@ -5,6 +5,8 @@ openapi_letta.json openapi_openai.json +CLAUDE.md + ### Eclipse ### .metadata bin/ diff --git a/alembic/versions/c96263433aef_add_file_name_to_source_passages.py b/alembic/versions/c96263433aef_add_file_name_to_source_passages.py new file mode 100644 index 00000000..9f24dcfb --- /dev/null +++ b/alembic/versions/c96263433aef_add_file_name_to_source_passages.py @@ -0,0 +1,40 @@ +"""Add file name to source passages + +Revision ID: c96263433aef +Revises: 9792f94e961d +Create Date: 2025-06-06 12:06:57.328127 +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "c96263433aef" +down_revision: Union[str, None] = "9792f94e961d" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # Add the new column + op.add_column("source_passages", sa.Column("file_name", sa.String(), nullable=True)) + + # Backfill file_name using SQL UPDATE JOIN + op.execute( + """ + UPDATE source_passages + SET file_name = files.file_name + FROM files + WHERE source_passages.file_id = files.id + """ + ) + + # Enforce non-null constraint after backfill + op.alter_column("source_passages", "file_name", nullable=False) + + +def downgrade() -> None: + op.drop_column("source_passages", "file_name") diff --git a/letta/agent.py b/letta/agent.py index f81e0f69..153f6f16 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -1292,7 +1292,7 @@ class Agent(BaseAgent): # conversion of messages to OpenAI dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), - self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id), self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), ) in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages] @@ -1414,7 +1414,7 @@ class Agent(BaseAgent): # conversion of messages to anthropic dict format, which is passed to the token counter (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user), - self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id), + self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id), self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id), ) in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages] diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index dbc5b2fa..95cb00df 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -104,7 +104,7 @@ class BaseAgent(ABC): if num_messages is None: num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) if num_archival_memories is None: - num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) new_system_message_str = compile_system_message( system_prompt=agent_state.system, diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index b259373f..93a575c1 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -763,7 +763,7 @@ class LettaAgent(BaseAgent): else asyncio.sleep(0, result=self.num_messages) ), ( - self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) if self.num_archival_memories is None else asyncio.sleep(0, result=self.num_archival_memories) ), diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index b2479eed..4f904350 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -305,7 +305,7 @@ class VoiceAgent(BaseAgent): else asyncio.sleep(0, result=self.num_messages) ), ( - self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) if self.num_archival_memories is None else asyncio.sleep(0, result=self.num_archival_memories) ), diff --git a/letta/orm/passage.py b/letta/orm/passage.py index eb38b691..d3431e63 100644 --- a/letta/orm/passage.py +++ b/letta/orm/passage.py @@ -47,6 +47,8 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin): __tablename__ = "source_passages" + file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from") + @declared_attr def file(cls) -> Mapped["FileMetadata"]: """Relationship to file""" diff --git a/letta/schemas/passage.py b/letta/schemas/passage.py index becdd3c3..da87dd0f 100644 --- a/letta/schemas/passage.py +++ b/letta/schemas/passage.py @@ -23,6 +23,7 @@ class PassageBase(OrmMetadataBase): # file association file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.") + file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).") metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8f9ec195..a126ff25 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1483,7 +1483,7 @@ class AgentManager: memory_edit_timestamp = curr_system_message.created_at num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id) - num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id) + num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id) # update memory (TODO: potentially update recall/archival stats separately) new_system_message_str = compile_system_message( @@ -2075,6 +2075,7 @@ class AgentManager: # This is an AgentPassage - remove source fields data.pop("source_id", None) data.pop("file_id", None) + data.pop("file_name", None) passage = AgentPassage(**data) else: # This is a SourcePassage - remove agent field @@ -2135,6 +2136,7 @@ class AgentManager: # This is an AgentPassage - remove source fields data.pop("source_id", None) data.pop("file_id", None) + data.pop("file_name", None) passage = AgentPassage(**data) else: # This is a SourcePassage - remove agent field @@ -2198,14 +2200,12 @@ class AgentManager: self, actor: PydanticUser, agent_id: Optional[str] = None, - file_id: Optional[str] = None, limit: Optional[int] = 50, query_text: Optional[str] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, before: Optional[str] = None, after: Optional[str] = None, - source_id: Optional[str] = None, embed_query: bool = False, ascending: bool = True, embedding_config: Optional[EmbeddingConfig] = None, diff --git a/letta/services/context_window_calculator/context_window_calculator.py b/letta/services/context_window_calculator/context_window_calculator.py index 4f5c7a82..47a9aacd 100644 --- a/letta/services/context_window_calculator/context_window_calculator.py +++ b/letta/services/context_window_calculator/context_window_calculator.py @@ -63,7 +63,7 @@ class ContextWindowCalculator: # Fetch data concurrently (in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather( message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor), - passage_manager.size_async(actor=actor, agent_id=agent_state.id), + passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id), message_manager.size_async(actor=actor, agent_id=agent_state.id), ) diff --git a/letta/services/file_processor/file_processor.py b/letta/services/file_processor/file_processor.py index 533723f1..c82786cc 100644 --- a/letta/services/file_processor/file_processor.py +++ b/letta/services/file_processor/file_processor.py @@ -111,7 +111,9 @@ class FileProcessor: ) all_passages.extend(passages) - all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor) + all_passages = await self.passage_manager.create_many_source_passages_async( + passages=all_passages, file_metadata=file_metadata, actor=self.actor + ) logger.info(f"Successfully processed {filename}: {len(all_passages)} passages") diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 7d83b1ab..b4935ef8 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -607,15 +607,45 @@ def build_passage_query( if not agent_only: # Include source passages if agent_id is not None: source_passages = ( - select(SourcePassage, literal(None).label("agent_id")) + select( + SourcePassage.file_name, + SourcePassage.id, + SourcePassage.text, + SourcePassage.embedding_config, + SourcePassage.metadata_, + SourcePassage.embedding, + SourcePassage.created_at, + SourcePassage.updated_at, + SourcePassage.is_deleted, + SourcePassage._created_by_id, + SourcePassage._last_updated_by_id, + SourcePassage.organization_id, + SourcePassage.file_id, + SourcePassage.source_id, + literal(None).label("agent_id"), + ) .join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id) .where(SourcesAgents.agent_id == agent_id) .where(SourcePassage.organization_id == actor.organization_id) ) else: - source_passages = select(SourcePassage, literal(None).label("agent_id")).where( - SourcePassage.organization_id == actor.organization_id - ) + source_passages = select( + SourcePassage.file_name, + SourcePassage.id, + SourcePassage.text, + SourcePassage.embedding_config, + SourcePassage.metadata_, + SourcePassage.embedding, + SourcePassage.created_at, + SourcePassage.updated_at, + SourcePassage.is_deleted, + SourcePassage._created_by_id, + SourcePassage._last_updated_by_id, + SourcePassage.organization_id, + SourcePassage.file_id, + SourcePassage.source_id, + literal(None).label("agent_id"), + ).where(SourcePassage.organization_id == actor.organization_id) if source_id: source_passages = source_passages.where(SourcePassage.source_id == source_id) @@ -627,6 +657,7 @@ def build_passage_query( if agent_id is not None: agent_passages = ( select( + literal(None).label("file_name"), AgentPassage.id, AgentPassage.text, AgentPassage.embedding_config, diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index b60e2e1b..25b87a11 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -13,6 +13,7 @@ from letta.orm.errors import NoResultFound from letta.orm.passage import AgentPassage, SourcePassage from letta.otel.tracing import trace_method from letta.schemas.agent import AgentState +from letta.schemas.file import FileMetadata as PydanticFileMetadata from letta.schemas.passage import Passage as PydanticPassage from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry @@ -42,10 +43,65 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> Li class PassageManager: """Manager class to handle business logic related to Passages.""" + # AGENT PASSAGE METHODS + @enforce_types + @trace_method + def get_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: + """Fetch an agent passage by ID.""" + with db_registry.session() as session: + try: + passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") + + @enforce_types + @trace_method + async def get_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: + """Fetch an agent passage by ID.""" + async with db_registry.async_session() as session: + try: + passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + raise NoResultFound(f"Agent passage with id {passage_id} not found in database.") + + # SOURCE PASSAGE METHODS + @enforce_types + @trace_method + def get_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: + """Fetch a source passage by ID.""" + with db_registry.session() as session: + try: + passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + raise NoResultFound(f"Source passage with id {passage_id} not found in database.") + + @enforce_types + @trace_method + async def get_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: + """Fetch a source passage by ID.""" + async with db_registry.async_session() as session: + try: + passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor) + return passage.to_pydantic() + except NoResultFound: + raise NoResultFound(f"Source passage with id {passage_id} not found in database.") + + # DEPRECATED - Use specific methods above @enforce_types @trace_method def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: - """Fetch a passage by ID.""" + """DEPRECATED: Use get_agent_passage_by_id() or get_source_passage_by_id() instead.""" + import warnings + + warnings.warn( + "get_passage_by_id is deprecated. Use get_agent_passage_by_id() or get_source_passage_by_id() instead.", + DeprecationWarning, + stacklevel=2, + ) + with db_registry.session() as session: # Try source passages first try: @@ -62,7 +118,15 @@ class PassageManager: @enforce_types @trace_method async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]: - """Fetch a passage by ID.""" + """DEPRECATED: Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead.""" + import warnings + + warnings.warn( + "get_passage_by_id_async is deprecated. Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead.", + DeprecationWarning, + stacklevel=2, + ) + async with db_registry.async_session() as session: # Try source passages first try: @@ -76,10 +140,137 @@ class PassageManager: except NoResultFound: raise NoResultFound(f"Passage with id {passage_id} not found in database.") + @enforce_types + @trace_method + def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: + """Create a new agent passage.""" + if not pydantic_passage.agent_id: + raise ValueError("Agent passage must have agent_id") + if pydantic_passage.source_id: + raise ValueError("Agent passage cannot have source_id") + + data = pydantic_passage.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + agent_fields = {"agent_id": data["agent_id"]} + passage = AgentPassage(**common_fields, **agent_fields) + + with db_registry.session() as session: + passage.create(session, actor=actor) + return passage.to_pydantic() + + @enforce_types + @trace_method + async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: + """Create a new agent passage.""" + if not pydantic_passage.agent_id: + raise ValueError("Agent passage must have agent_id") + if pydantic_passage.source_id: + raise ValueError("Agent passage cannot have source_id") + + data = pydantic_passage.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + agent_fields = {"agent_id": data["agent_id"]} + passage = AgentPassage(**common_fields, **agent_fields) + + async with db_registry.async_session() as session: + passage = await passage.create_async(session, actor=actor) + return passage.to_pydantic() + + @enforce_types + @trace_method + def create_source_passage( + self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser + ) -> PydanticPassage: + """Create a new source passage.""" + if not pydantic_passage.source_id: + raise ValueError("Source passage must have source_id") + if pydantic_passage.agent_id: + raise ValueError("Source passage cannot have agent_id") + + data = pydantic_passage.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + source_fields = { + "source_id": data["source_id"], + "file_id": data.get("file_id"), + "file_name": file_metadata.file_name, + } + passage = SourcePassage(**common_fields, **source_fields) + + with db_registry.session() as session: + passage.create(session, actor=actor) + return passage.to_pydantic() + + @enforce_types + @trace_method + async def create_source_passage_async( + self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser + ) -> PydanticPassage: + """Create a new source passage.""" + if not pydantic_passage.source_id: + raise ValueError("Source passage must have source_id") + if pydantic_passage.agent_id: + raise ValueError("Source passage cannot have agent_id") + + data = pydantic_passage.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + source_fields = { + "source_id": data["source_id"], + "file_id": data.get("file_id"), + "file_name": file_metadata.file_name, + } + passage = SourcePassage(**common_fields, **source_fields) + + async with db_registry.async_session() as session: + passage = await passage.create_async(session, actor=actor) + return passage.to_pydantic() + + # DEPRECATED - Use specific methods above @enforce_types @trace_method def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: - """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" + """DEPRECATED: Use create_agent_passage() or create_source_passage() instead.""" + import warnings + + warnings.warn( + "create_passage is deprecated. Use create_agent_passage() or create_source_passage() instead.", DeprecationWarning, stacklevel=2 + ) + passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage) with db_registry.session() as session: @@ -89,7 +280,15 @@ class PassageManager: @enforce_types @trace_method async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage: - """Create a new passage in the appropriate table based on whether it has agent_id or source_id.""" + """DEPRECATED: Use create_agent_passage_async() or create_source_passage_async() instead.""" + import warnings + + warnings.warn( + "create_passage_async is deprecated. Use create_agent_passage_async() or create_source_passage_async() instead.", + DeprecationWarning, + stacklevel=2, + ) + # Common fields for both passage types passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage) async with db_registry.async_session() as session: @@ -128,16 +327,110 @@ class PassageManager: return passage + @enforce_types + @trace_method + def create_many_agent_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: + """Create multiple agent passages.""" + return [self.create_agent_passage(p, actor) for p in passages] + + @enforce_types + @trace_method + async def create_many_agent_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: + """Create multiple agent passages.""" + agent_passages = [] + for p in passages: + if not p.agent_id: + raise ValueError("Agent passage must have agent_id") + if p.source_id: + raise ValueError("Agent passage cannot have source_id") + + data = p.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + agent_fields = {"agent_id": data["agent_id"]} + agent_passages.append(AgentPassage(**common_fields, **agent_fields)) + + async with db_registry.async_session() as session: + agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor) + return [p.to_pydantic() for p in agent_created] + + @enforce_types + @trace_method + def create_many_source_passages( + self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser + ) -> List[PydanticPassage]: + """Create multiple source passages.""" + return [self.create_source_passage(p, file_metadata, actor) for p in passages] + + @enforce_types + @trace_method + async def create_many_source_passages_async( + self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser + ) -> List[PydanticPassage]: + """Create multiple source passages.""" + source_passages = [] + for p in passages: + if not p.source_id: + raise ValueError("Source passage must have source_id") + if p.agent_id: + raise ValueError("Source passage cannot have agent_id") + + data = p.model_dump(to_orm=True) + common_fields = { + "id": data.get("id"), + "text": data["text"], + "embedding": data["embedding"], + "embedding_config": data["embedding_config"], + "organization_id": data["organization_id"], + "metadata_": data.get("metadata", {}), + "is_deleted": data.get("is_deleted", False), + "created_at": data.get("created_at", datetime.now(timezone.utc)), + } + source_fields = { + "source_id": data["source_id"], + "file_id": data.get("file_id"), + "file_name": file_metadata.file_name, + } + source_passages.append(SourcePassage(**common_fields, **source_fields)) + + async with db_registry.async_session() as session: + source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor) + return [p.to_pydantic() for p in source_created] + + # DEPRECATED - Use specific methods above @enforce_types @trace_method def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: - """Create multiple passages.""" + """DEPRECATED: Use create_many_agent_passages() or create_many_source_passages() instead.""" + import warnings + + warnings.warn( + "create_many_passages is deprecated. Use create_many_agent_passages() or create_many_source_passages() instead.", + DeprecationWarning, + stacklevel=2, + ) return [self.create_passage(p, actor) for p in passages] @enforce_types @trace_method async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]: - """Create multiple passages.""" + """DEPRECATED: Use create_many_agent_passages_async() or create_many_source_passages_async() instead.""" + import warnings + + warnings.warn( + "create_many_passages_async is deprecated. Use create_many_agent_passages_async() or create_many_source_passages_async() instead.", + DeprecationWarning, + stacklevel=2, + ) + async with db_registry.async_session() as session: agent_passages = [] source_passages = [] @@ -203,7 +496,7 @@ class PassageManager: raise TypeError( f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}" ) - passage = self.create_passage( + passage = self.create_agent_passage( PydanticPassage( organization_id=actor.organization_id, agent_id=agent_id, @@ -251,7 +544,7 @@ class PassageManager: for chunk_text, embedding in zip(text_chunks, embeddings) ] - passages = await self.create_many_passages_async(passages=passages, actor=actor) + passages = await self.create_many_agent_passages_async(passages=passages, actor=actor) return passages @@ -292,10 +585,191 @@ class PassageManager: return processed_embeddings + @enforce_types + @trace_method + def update_agent_passage_by_id( + self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs + ) -> Optional[PydanticPassage]: + """Update an agent passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with db_registry.session() as session: + try: + curr_passage = AgentPassage.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + raise ValueError(f"Agent passage with id {passage_id} does not exist.") + + # Update the database record with values from the provided record + update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) + + # Commit changes + curr_passage.update(session, actor=actor) + return curr_passage.to_pydantic() + + @enforce_types + @trace_method + async def update_agent_passage_by_id_async( + self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs + ) -> Optional[PydanticPassage]: + """Update an agent passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + async with db_registry.async_session() as session: + try: + curr_passage = await AgentPassage.read_async( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + raise ValueError(f"Agent passage with id {passage_id} does not exist.") + + # Update the database record with values from the provided record + update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) + + # Commit changes + await curr_passage.update_async(session, actor=actor) + return curr_passage.to_pydantic() + + @enforce_types + @trace_method + def update_source_passage_by_id( + self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs + ) -> Optional[PydanticPassage]: + """Update a source passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with db_registry.session() as session: + try: + curr_passage = SourcePassage.read( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + raise ValueError(f"Source passage with id {passage_id} does not exist.") + + # Update the database record with values from the provided record + update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) + + # Commit changes + curr_passage.update(session, actor=actor) + return curr_passage.to_pydantic() + + @enforce_types + @trace_method + async def update_source_passage_by_id_async( + self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs + ) -> Optional[PydanticPassage]: + """Update a source passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + async with db_registry.async_session() as session: + try: + curr_passage = await SourcePassage.read_async( + db_session=session, + identifier=passage_id, + actor=actor, + ) + except NoResultFound: + raise ValueError(f"Source passage with id {passage_id} does not exist.") + + # Update the database record with values from the provided record + update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(curr_passage, key, value) + + # Commit changes + await curr_passage.update_async(session, actor=actor) + return curr_passage.to_pydantic() + + @enforce_types + @trace_method + def delete_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: + """Delete an agent passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with db_registry.session() as session: + try: + passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor) + passage.hard_delete(session, actor=actor) + return True + except NoResultFound: + raise NoResultFound(f"Agent passage with id {passage_id} not found.") + + @enforce_types + @trace_method + async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool: + """Delete an agent passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + async with db_registry.async_session() as session: + try: + passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor) + await passage.hard_delete_async(session, actor=actor) + return True + except NoResultFound: + raise NoResultFound(f"Agent passage with id {passage_id} not found.") + + @enforce_types + @trace_method + def delete_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: + """Delete a source passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + with db_registry.session() as session: + try: + passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor) + passage.hard_delete(session, actor=actor) + return True + except NoResultFound: + raise NoResultFound(f"Source passage with id {passage_id} not found.") + + @enforce_types + @trace_method + async def delete_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool: + """Delete a source passage.""" + if not passage_id: + raise ValueError("Passage ID must be provided.") + + async with db_registry.async_session() as session: + try: + passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor) + await passage.hard_delete_async(session, actor=actor) + return True + except NoResultFound: + raise NoResultFound(f"Source passage with id {passage_id} not found.") + + # DEPRECATED - Use specific methods above @enforce_types @trace_method def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]: - """Update a passage.""" + """DEPRECATED: Use update_agent_passage_by_id() or update_source_passage_by_id() instead.""" + import warnings + + warnings.warn( + "update_passage_by_id is deprecated. Use update_agent_passage_by_id() or update_source_passage_by_id() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not passage_id: raise ValueError("Passage ID must be provided.") @@ -330,7 +804,15 @@ class PassageManager: @enforce_types @trace_method def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool: - """Delete a passage from either source or archival passages.""" + """DEPRECATED: Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.""" + import warnings + + warnings.warn( + "delete_passage_by_id is deprecated. Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not passage_id: raise ValueError("Passage ID must be provided.") @@ -352,7 +834,15 @@ class PassageManager: @enforce_types @trace_method async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool: - """Delete a passage from either source or archival passages.""" + """DEPRECATED: Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead.""" + import warnings + + warnings.warn( + "delete_passage_by_id_async is deprecated. Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead.", + DeprecationWarning, + stacklevel=2, + ) + if not passage_id: raise ValueError("Passage ID must be provided.") @@ -373,15 +863,42 @@ class PassageManager: @enforce_types @trace_method - def delete_passages( + def delete_agent_passages( self, actor: PydanticUser, passages: List[PydanticPassage], ) -> bool: + """Delete multiple agent passages.""" # TODO: This is very inefficient # TODO: We should have a base `delete_all_matching_filters`-esque function for passage in passages: - self.delete_passage_by_id(passage_id=passage.id, actor=actor) + self.delete_agent_passage_by_id(passage_id=passage.id, actor=actor) + return True + + @enforce_types + @trace_method + async def delete_agent_passages_async( + self, + actor: PydanticUser, + passages: List[PydanticPassage], + ) -> bool: + """Delete multiple agent passages.""" + async with db_registry.async_session() as session: + await AgentPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor) + return True + + @enforce_types + @trace_method + def delete_source_passages( + self, + actor: PydanticUser, + passages: List[PydanticPassage], + ) -> bool: + """Delete multiple source passages.""" + # TODO: This is very inefficient + # TODO: We should have a base `delete_all_matching_filters`-esque function + for passage in passages: + self.delete_source_passage_by_id(passage_id=passage.id, actor=actor) return True @enforce_types @@ -395,14 +912,36 @@ class PassageManager: await SourcePassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor) return True + # DEPRECATED - Use specific methods above @enforce_types @trace_method - def size( + def delete_passages( + self, + actor: PydanticUser, + passages: List[PydanticPassage], + ) -> bool: + """DEPRECATED: Use delete_agent_passages() or delete_source_passages() instead.""" + import warnings + + warnings.warn( + "delete_passages is deprecated. Use delete_agent_passages() or delete_source_passages() instead.", + DeprecationWarning, + stacklevel=2, + ) + # TODO: This is very inefficient + # TODO: We should have a base `delete_all_matching_filters`-esque function + for passage in passages: + self.delete_passage_by_id(passage_id=passage.id, actor=actor) + return True + + @enforce_types + @trace_method + def agent_passage_size( self, actor: PydanticUser, agent_id: Optional[str] = None, ) -> int: - """Get the total count of messages with optional filters. + """Get the total count of agent passages with optional filters. Args: actor: The user requesting the count @@ -411,14 +950,29 @@ class PassageManager: with db_registry.session() as session: return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) + # DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway @enforce_types @trace_method - async def size_async( + def size( self, actor: PydanticUser, agent_id: Optional[str] = None, ) -> int: - """Get the total count of messages with optional filters. + """DEPRECATED: Use agent_passage_size() instead (this only counted agent passages anyway).""" + import warnings + + warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2) + with db_registry.session() as session: + return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) + + @enforce_types + @trace_method + async def agent_passage_size_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + ) -> int: + """Get the total count of agent passages with optional filters. Args: actor: The user requesting the count agent_id: The agent ID of the messages @@ -426,6 +980,37 @@ class PassageManager: async with db_registry.async_session() as session: return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id) + @enforce_types + @trace_method + def source_passage_size( + self, + actor: PydanticUser, + source_id: Optional[str] = None, + ) -> int: + """Get the total count of source passages with optional filters. + + Args: + actor: The user requesting the count + source_id: The source ID of the passages + """ + with db_registry.session() as session: + return SourcePassage.size(db_session=session, actor=actor, source_id=source_id) + + @enforce_types + @trace_method + async def source_passage_size_async( + self, + actor: PydanticUser, + source_id: Optional[str] = None, + ) -> int: + """Get the total count of source passages with optional filters. + Args: + actor: The user requesting the count + source_id: The source ID of the passages + """ + async with db_registry.async_session() as session: + return await SourcePassage.size_async(db_session=session, actor=actor, source_id=source_id) + @enforce_types @trace_method async def estimate_embeddings_size_async( @@ -448,7 +1033,7 @@ class PassageManager: raise ValueError(f"Invalid storage unit: {storage_unit}. Must be one of {list(BYTES_PER_STORAGE_UNIT.keys())}.") BYTES_PER_EMBEDDING_DIM = 4 GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM - return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING + return await self.agent_passage_size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING @enforce_types @trace_method diff --git a/letta/services/tool_executor/files_tool_executor.py b/letta/services/tool_executor/files_tool_executor.py index 15d8f00a..9017a21e 100644 --- a/letta/services/tool_executor/files_tool_executor.py +++ b/letta/services/tool_executor/files_tool_executor.py @@ -126,6 +126,13 @@ class LettaFileToolExecutor(ToolExecutor): # TODO: Make this paginated? async def search_files(self, agent_state: AgentState, query: str) -> List[str]: - """Stub for search_files tool.""" + """Search for text within attached files and return passages with their source filenames.""" passages = await self.agent_manager.list_source_passages_async(actor=self.actor, agent_id=agent_state.id, query_text=query) - return [p.text for p in passages] + formatted_results = [] + for p in passages: + if p.file_name: + formatted_result = f"[{p.file_name}]:\n{p.text}" + else: + formatted_result = p.text + formatted_results.append(formatted_result) + return formatted_results diff --git a/tests/test_managers.py b/tests/test_managers.py index f467bdad..12acdce5 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -12,6 +12,7 @@ import httpx # tests/test_file_content_flow.py import pytest +from _pytest.python_api import approx from anthropic.types.beta import BetaMessage from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall @@ -280,7 +281,7 @@ async def default_run(server: SyncServer, default_user): @pytest.fixture def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): """Fixture to create an agent passage.""" - passage = server.passage_manager.create_passage( + passage = server.passage_manager.create_agent_passage( PydanticPassage( text="Hello, I am an agent passage", agent_id=sarah_agent.id, @@ -297,7 +298,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent): @pytest.fixture def source_passage_fixture(server: SyncServer, default_user, default_file, default_source): """Fixture to create a source passage.""" - passage = server.passage_manager.create_passage( + passage = server.passage_manager.create_source_passage( PydanticPassage( text="Hello, I am a source passage", source_id=default_source.id, @@ -307,6 +308,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test"}, ), + file_metadata=default_file, actor=default_user, ) yield passage @@ -318,7 +320,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a # Create agent passages passages = [] for i in range(5): - passage = server.passage_manager.create_passage( + passage = server.passage_manager.create_agent_passage( PydanticPassage( text=f"Agent passage {i}", agent_id=sarah_agent.id, @@ -335,7 +337,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a # Create source passages for i in range(5): - passage = server.passage_manager.create_passage( + passage = server.passage_manager.create_source_passage( PydanticPassage( text=f"Source passage {i}", source_id=default_source.id, @@ -345,6 +347,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a embedding_config=DEFAULT_EMBEDDING_CONFIG, metadata={"type": "test"}, ), + file_metadata=default_file, actor=default_user, ) passages.append(passage) @@ -525,7 +528,7 @@ def server(): @pytest.fixture @pytest.mark.asyncio -async def agent_passages_setup(server, default_source, default_user, sarah_agent, event_loop): +async def agent_passages_setup(server, default_source, default_file, default_user, sarah_agent, event_loop): """Setup fixture for agent passages tests""" agent_id = sarah_agent.id actor = default_user @@ -535,14 +538,16 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent # Create some source passages source_passages = [] for i in range(3): - passage = await server.passage_manager.create_passage_async( + passage = await server.passage_manager.create_source_passage_async( PydanticPassage( organization_id=actor.organization_id, source_id=default_source.id, + file_id=default_file.id, text=f"Source passage {i}", embedding=[0.1], # Default OpenAI embedding size embedding_config=DEFAULT_EMBEDDING_CONFIG, ), + file_metadata=default_file, actor=actor, ) source_passages.append(passage) @@ -550,7 +555,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent # Create some agent passages agent_passages = [] for i in range(2): - passage = await server.passage_manager.create_passage_async( + passage = await server.passage_manager.create_agent_passage_async( PydanticPassage( organization_id=actor.organization_id, agent_id=agent_id, @@ -2022,7 +2027,7 @@ async def test_agent_list_passages_filtering(server, default_user, sarah_agent, @pytest.mark.asyncio -async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop): +async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, event_loop): """Test vector search functionality of agent passages""" embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG) @@ -2041,6 +2046,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age for i, text in enumerate(test_passages): embedding = embed_model.get_text_embedding(text) if i % 2 == 0: + # Create agent passage passage = PydanticPassage( text=text, organization_id=default_user.organization_id, @@ -2048,15 +2054,18 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age embedding_config=DEFAULT_EMBEDDING_CONFIG, embedding=embedding, ) + created_passage = await server.passage_manager.create_agent_passage_async(passage, default_user) else: + # Create source passage passage = PydanticPassage( text=text, organization_id=default_user.organization_id, source_id=default_source.id, + file_id=default_file.id, embedding_config=DEFAULT_EMBEDDING_CONFIG, embedding=embedding, ) - created_passage = await server.passage_manager.create_passage_async(passage, default_user) + created_passage = await server.passage_manager.create_source_passage_async(passage, default_file, default_user) passages.append(created_passage) # Query vector similar to "red" embedding @@ -2261,6 +2270,416 @@ async def test_passage_cascade_deletion( server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user) +def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent): + """Test creating an agent passage using the new agent-specific method.""" + passage = server.passage_manager.create_agent_passage( + PydanticPassage( + text="Test agent passage via specific method", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata={"type": "test_specific"}, + ), + actor=default_user, + ) + + assert passage.id is not None + assert passage.text == "Test agent passage via specific method" + assert passage.agent_id == sarah_agent.id + assert passage.source_id is None + + +def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source): + """Test creating a source passage using the new source-specific method.""" + passage = server.passage_manager.create_source_passage( + PydanticPassage( + text="Test source passage via specific method", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + metadata={"type": "test_specific"}, + ), + file_metadata=default_file, + actor=default_user, + ) + + assert passage.id is not None + assert passage.text == "Test source passage via specific method" + assert passage.source_id == default_source.id + assert passage.agent_id is None + + +def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent): + """Test that agent passage creation validates inputs correctly.""" + # Should fail if agent_id is missing + with pytest.raises(ValueError, match="Agent passage must have agent_id"): + server.passage_manager.create_agent_passage( + PydanticPassage( + text="Invalid agent passage", + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Should fail if source_id is present + with pytest.raises(ValueError, match="Agent passage cannot have source_id"): + server.passage_manager.create_agent_passage( + PydanticPassage( + text="Invalid agent passage", + agent_id=sarah_agent.id, + source_id=default_source.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + +def test_create_source_passage_validation(server: SyncServer, default_user, default_file, default_source, sarah_agent): + """Test that source passage creation validates inputs correctly.""" + # Should fail if source_id is missing + with pytest.raises(ValueError, match="Source passage must have source_id"): + server.passage_manager.create_source_passage( + PydanticPassage( + text="Invalid source passage", + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + # Should fail if agent_id is present + with pytest.raises(ValueError, match="Source passage cannot have agent_id"): + server.passage_manager.create_source_passage( + PydanticPassage( + text="Invalid source passage", + source_id=default_source.id, + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + +def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent): + """Test retrieving an agent passage using the new agent-specific method.""" + # Create an agent passage + passage = server.passage_manager.create_agent_passage( + PydanticPassage( + text="Agent passage for retrieval test", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Retrieve it using the specific method + retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user) + assert retrieved is not None + assert retrieved.id == passage.id + assert retrieved.text == passage.text + assert retrieved.agent_id == sarah_agent.id + + +def test_get_source_passage_by_id_specific(server: SyncServer, default_user, default_file, default_source): + """Test retrieving a source passage using the new source-specific method.""" + # Create a source passage + passage = server.passage_manager.create_source_passage( + PydanticPassage( + text="Source passage for retrieval test", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + # Retrieve it using the specific method + retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user) + assert retrieved is not None + assert retrieved.id == passage.id + assert retrieved.text == passage.text + assert retrieved.source_id == default_source.id + + +def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_agent, default_file, default_source): + """Test that trying to get the wrong passage type with specific methods fails.""" + # Create an agent passage + agent_passage = server.passage_manager.create_agent_passage( + PydanticPassage( + text="Agent passage", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Create a source passage + source_passage = server.passage_manager.create_source_passage( + PydanticPassage( + text="Source passage", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + # Trying to get agent passage with source method should fail + with pytest.raises(NoResultFound): + server.passage_manager.get_source_passage_by_id(agent_passage.id, actor=default_user) + + # Trying to get source passage with agent method should fail + with pytest.raises(NoResultFound): + server.passage_manager.get_agent_passage_by_id(source_passage.id, actor=default_user) + + +def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent): + """Test updating an agent passage using the new agent-specific method.""" + # Create an agent passage + passage = server.passage_manager.create_agent_passage( + PydanticPassage( + text="Original agent passage text", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Update it + updated_passage = server.passage_manager.update_agent_passage_by_id( + passage.id, + PydanticPassage( + text="Updated agent passage text", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.2], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + assert updated_passage.text == "Updated agent passage text" + assert updated_passage.embedding[0] == approx(0.2) + assert updated_passage.id == passage.id + + +def test_update_source_passage_specific(server: SyncServer, default_user, default_file, default_source): + """Test updating a source passage using the new source-specific method.""" + # Create a source passage + passage = server.passage_manager.create_source_passage( + PydanticPassage( + text="Original source passage text", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + # Update it + updated_passage = server.passage_manager.update_source_passage_by_id( + passage.id, + PydanticPassage( + text="Updated source passage text", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.2], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + assert updated_passage.text == "Updated source passage text" + assert updated_passage.embedding[0] == approx(0.2) + assert updated_passage.id == passage.id + + +def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent): + """Test deleting an agent passage using the new agent-specific method.""" + # Create an agent passage + passage = server.passage_manager.create_agent_passage( + PydanticPassage( + text="Agent passage to delete", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Verify it exists + retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user) + assert retrieved is not None + + # Delete it + result = server.passage_manager.delete_agent_passage_by_id(passage.id, actor=default_user) + assert result is True + + # Verify it's gone + with pytest.raises(NoResultFound): + server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user) + + +def test_delete_source_passage_specific(server: SyncServer, default_user, default_file, default_source): + """Test deleting a source passage using the new source-specific method.""" + # Create a source passage + passage = server.passage_manager.create_source_passage( + PydanticPassage( + text="Source passage to delete", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + file_metadata=default_file, + actor=default_user, + ) + + # Verify it exists + retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user) + assert retrieved is not None + + # Delete it + result = server.passage_manager.delete_source_passage_by_id(passage.id, actor=default_user) + assert result is True + + # Verify it's gone + with pytest.raises(NoResultFound): + server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent, event_loop): + """Test creating multiple agent passages using the new batch method.""" + passages = [ + PydanticPassage( + text=f"Batch agent passage {i}", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1 * i], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + for i in range(3) + ] + + created_passages = await server.passage_manager.create_many_agent_passages_async(passages, actor=default_user) + + assert len(created_passages) == 3 + for i, passage in enumerate(created_passages): + assert passage.text == f"Batch agent passage {i}" + assert passage.agent_id == sarah_agent.id + assert passage.source_id is None + + +@pytest.mark.asyncio +async def test_create_many_source_passages_async(server: SyncServer, default_user, default_file, default_source, event_loop): + """Test creating multiple source passages using the new batch method.""" + passages = [ + PydanticPassage( + text=f"Batch source passage {i}", + source_id=default_source.id, + file_id=default_file.id, + organization_id=default_user.organization_id, + embedding=[0.1 * i], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ) + for i in range(3) + ] + + created_passages = await server.passage_manager.create_many_source_passages_async( + passages, file_metadata=default_file, actor=default_user + ) + + assert len(created_passages) == 3 + for i, passage in enumerate(created_passages): + assert passage.text == f"Batch source passage {i}" + assert passage.source_id == default_source.id + assert passage.agent_id is None + + +def test_agent_passage_size(server: SyncServer, default_user, sarah_agent): + """Test counting agent passages using the new agent-specific size method.""" + initial_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id) + + # Create some agent passages + for i in range(3): + server.passage_manager.create_agent_passage( + PydanticPassage( + text=f"Agent passage {i} for size test", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + final_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id) + assert final_size == initial_size + 3 + + +def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sarah_agent): + """Test that deprecated methods show deprecation warnings.""" + import warnings + + with warnings.catch_warnings(record=True) as w: + warnings.simplefilter("always") + + # Test deprecated create_passage + passage = server.passage_manager.create_passage( + PydanticPassage( + text="Test deprecated method", + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + # Test deprecated get_passage_by_id + server.passage_manager.get_passage_by_id(passage.id, actor=default_user) + + # Test deprecated size + server.passage_manager.size(actor=default_user, agent_id=sarah_agent.id) + + # Check that deprecation warnings were issued + assert len(w) >= 3 + assert any("create_passage is deprecated" in str(warning.message) for warning in w) + assert any("get_passage_by_id is deprecated" in str(warning.message) for warning in w) + assert any("size is deprecated" in str(warning.message) for warning in w) + + # ====================================================================================================================== # User Manager Tests # ======================================================================================================================