feat: Search files returns citations of the filenames that were searched (#2689)

This commit is contained in:
Matthew Zhou
2025-06-06 15:34:03 -07:00
committed by GitHub
parent 22c1f6e70a
commit 318a7c769b
15 changed files with 1132 additions and 43 deletions

2
.gitignore vendored
View File

@@ -5,6 +5,8 @@
openapi_letta.json
openapi_openai.json
CLAUDE.md
### Eclipse ###
.metadata
bin/

View File

@@ -0,0 +1,40 @@
"""Add file name to source passages
Revision ID: c96263433aef
Revises: 9792f94e961d
Create Date: 2025-06-06 12:06:57.328127
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "c96263433aef"
down_revision: Union[str, None] = "9792f94e961d"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# Add the new column
op.add_column("source_passages", sa.Column("file_name", sa.String(), nullable=True))
# Backfill file_name using SQL UPDATE JOIN
op.execute(
"""
UPDATE source_passages
SET file_name = files.file_name
FROM files
WHERE source_passages.file_id = files.id
"""
)
# Enforce non-null constraint after backfill
op.alter_column("source_passages", "file_name", nullable=False)
def downgrade() -> None:
op.drop_column("source_passages", "file_name")

View File

@@ -1292,7 +1292,7 @@ class Agent(BaseAgent):
# conversion of messages to OpenAI dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
@@ -1414,7 +1414,7 @@ class Agent(BaseAgent):
# conversion of messages to anthropic dict format, which is passed to the token counter
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
)
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]

View File

@@ -104,7 +104,7 @@ class BaseAgent(ABC):
if num_messages is None:
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
if num_archival_memories is None:
num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
system_prompt=agent_state.system,

View File

@@ -763,7 +763,7 @@ class LettaAgent(BaseAgent):
else asyncio.sleep(0, result=self.num_messages)
),
(
self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_archival_memories is None
else asyncio.sleep(0, result=self.num_archival_memories)
),

View File

@@ -305,7 +305,7 @@ class VoiceAgent(BaseAgent):
else asyncio.sleep(0, result=self.num_messages)
),
(
self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
if self.num_archival_memories is None
else asyncio.sleep(0, result=self.num_archival_memories)
),

View File

@@ -47,6 +47,8 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
__tablename__ = "source_passages"
file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from")
@declared_attr
def file(cls) -> Mapped["FileMetadata"]:
"""Relationship to file"""

View File

@@ -23,6 +23,7 @@ class PassageBase(OrmMetadataBase):
# file association
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).")
metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.")

View File

@@ -1483,7 +1483,7 @@ class AgentManager:
memory_edit_timestamp = curr_system_message.created_at
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id)
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
@@ -2075,6 +2075,7 @@ class AgentManager:
# This is an AgentPassage - remove source fields
data.pop("source_id", None)
data.pop("file_id", None)
data.pop("file_name", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
@@ -2135,6 +2136,7 @@ class AgentManager:
# This is an AgentPassage - remove source fields
data.pop("source_id", None)
data.pop("file_id", None)
data.pop("file_name", None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
@@ -2198,14 +2200,12 @@ class AgentManager:
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
before: Optional[str] = None,
after: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,

View File

@@ -63,7 +63,7 @@ class ContextWindowCalculator:
# Fetch data concurrently
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
passage_manager.size_async(actor=actor, agent_id=agent_state.id),
passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id),
message_manager.size_async(actor=actor, agent_id=agent_state.id),
)

View File

@@ -111,7 +111,9 @@ class FileProcessor:
)
all_passages.extend(passages)
all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor)
all_passages = await self.passage_manager.create_many_source_passages_async(
passages=all_passages, file_metadata=file_metadata, actor=self.actor
)
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")

View File

@@ -607,15 +607,45 @@ def build_passage_query(
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(SourcePassage, literal(None).label("agent_id"))
select(
SourcePassage.file_name,
SourcePassage.id,
SourcePassage.text,
SourcePassage.embedding_config,
SourcePassage.metadata_,
SourcePassage.embedding,
SourcePassage.created_at,
SourcePassage.updated_at,
SourcePassage.is_deleted,
SourcePassage._created_by_id,
SourcePassage._last_updated_by_id,
SourcePassage.organization_id,
SourcePassage.file_id,
SourcePassage.source_id,
literal(None).label("agent_id"),
)
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
SourcePassage.organization_id == actor.organization_id
)
source_passages = select(
SourcePassage.file_name,
SourcePassage.id,
SourcePassage.text,
SourcePassage.embedding_config,
SourcePassage.metadata_,
SourcePassage.embedding,
SourcePassage.created_at,
SourcePassage.updated_at,
SourcePassage.is_deleted,
SourcePassage._created_by_id,
SourcePassage._last_updated_by_id,
SourcePassage.organization_id,
SourcePassage.file_id,
SourcePassage.source_id,
literal(None).label("agent_id"),
).where(SourcePassage.organization_id == actor.organization_id)
if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
@@ -627,6 +657,7 @@ def build_passage_query(
if agent_id is not None:
agent_passages = (
select(
literal(None).label("file_name"),
AgentPassage.id,
AgentPassage.text,
AgentPassage.embedding_config,

View File

@@ -13,6 +13,7 @@ from letta.orm.errors import NoResultFound
from letta.orm.passage import AgentPassage, SourcePassage
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.file import FileMetadata as PydanticFileMetadata
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
@@ -42,10 +43,65 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> Li
class PassageManager:
"""Manager class to handle business logic related to Passages."""
# AGENT PASSAGE METHODS
@enforce_types
@trace_method
def get_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch an agent passage by ID."""
with db_registry.session() as session:
try:
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
@enforce_types
@trace_method
async def get_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch an agent passage by ID."""
async with db_registry.async_session() as session:
try:
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
# SOURCE PASSAGE METHODS
@enforce_types
@trace_method
def get_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a source passage by ID."""
with db_registry.session() as session:
try:
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
raise NoResultFound(f"Source passage with id {passage_id} not found in database.")
@enforce_types
@trace_method
async def get_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a source passage by ID."""
async with db_registry.async_session() as session:
try:
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
return passage.to_pydantic()
except NoResultFound:
raise NoResultFound(f"Source passage with id {passage_id} not found in database.")
# DEPRECATED - Use specific methods above
@enforce_types
@trace_method
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a passage by ID."""
"""DEPRECATED: Use get_agent_passage_by_id() or get_source_passage_by_id() instead."""
import warnings
warnings.warn(
"get_passage_by_id is deprecated. Use get_agent_passage_by_id() or get_source_passage_by_id() instead.",
DeprecationWarning,
stacklevel=2,
)
with db_registry.session() as session:
# Try source passages first
try:
@@ -62,7 +118,15 @@ class PassageManager:
@enforce_types
@trace_method
async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
"""Fetch a passage by ID."""
"""DEPRECATED: Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead."""
import warnings
warnings.warn(
"get_passage_by_id_async is deprecated. Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead.",
DeprecationWarning,
stacklevel=2,
)
async with db_registry.async_session() as session:
# Try source passages first
try:
@@ -76,10 +140,137 @@ class PassageManager:
except NoResultFound:
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
@enforce_types
@trace_method
def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new agent passage."""
if not pydantic_passage.agent_id:
raise ValueError("Agent passage must have agent_id")
if pydantic_passage.source_id:
raise ValueError("Agent passage cannot have source_id")
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
agent_fields = {"agent_id": data["agent_id"]}
passage = AgentPassage(**common_fields, **agent_fields)
with db_registry.session() as session:
passage.create(session, actor=actor)
return passage.to_pydantic()
@enforce_types
@trace_method
async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new agent passage."""
if not pydantic_passage.agent_id:
raise ValueError("Agent passage must have agent_id")
if pydantic_passage.source_id:
raise ValueError("Agent passage cannot have source_id")
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
agent_fields = {"agent_id": data["agent_id"]}
passage = AgentPassage(**common_fields, **agent_fields)
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
return passage.to_pydantic()
@enforce_types
@trace_method
def create_source_passage(
self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser
) -> PydanticPassage:
"""Create a new source passage."""
if not pydantic_passage.source_id:
raise ValueError("Source passage must have source_id")
if pydantic_passage.agent_id:
raise ValueError("Source passage cannot have agent_id")
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
"file_name": file_metadata.file_name,
}
passage = SourcePassage(**common_fields, **source_fields)
with db_registry.session() as session:
passage.create(session, actor=actor)
return passage.to_pydantic()
@enforce_types
@trace_method
async def create_source_passage_async(
self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser
) -> PydanticPassage:
"""Create a new source passage."""
if not pydantic_passage.source_id:
raise ValueError("Source passage must have source_id")
if pydantic_passage.agent_id:
raise ValueError("Source passage cannot have agent_id")
data = pydantic_passage.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
"file_name": file_metadata.file_name,
}
passage = SourcePassage(**common_fields, **source_fields)
async with db_registry.async_session() as session:
passage = await passage.create_async(session, actor=actor)
return passage.to_pydantic()
# DEPRECATED - Use specific methods above
@enforce_types
@trace_method
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
"""DEPRECATED: Use create_agent_passage() or create_source_passage() instead."""
import warnings
warnings.warn(
"create_passage is deprecated. Use create_agent_passage() or create_source_passage() instead.", DeprecationWarning, stacklevel=2
)
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
with db_registry.session() as session:
@@ -89,7 +280,15 @@ class PassageManager:
@enforce_types
@trace_method
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
"""DEPRECATED: Use create_agent_passage_async() or create_source_passage_async() instead."""
import warnings
warnings.warn(
"create_passage_async is deprecated. Use create_agent_passage_async() or create_source_passage_async() instead.",
DeprecationWarning,
stacklevel=2,
)
# Common fields for both passage types
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
async with db_registry.async_session() as session:
@@ -128,16 +327,110 @@ class PassageManager:
return passage
@enforce_types
@trace_method
def create_many_agent_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple agent passages."""
return [self.create_agent_passage(p, actor) for p in passages]
@enforce_types
@trace_method
async def create_many_agent_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple agent passages."""
agent_passages = []
for p in passages:
if not p.agent_id:
raise ValueError("Agent passage must have agent_id")
if p.source_id:
raise ValueError("Agent passage cannot have source_id")
data = p.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
agent_fields = {"agent_id": data["agent_id"]}
agent_passages.append(AgentPassage(**common_fields, **agent_fields))
async with db_registry.async_session() as session:
agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
return [p.to_pydantic() for p in agent_created]
@enforce_types
@trace_method
def create_many_source_passages(
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
) -> List[PydanticPassage]:
"""Create multiple source passages."""
return [self.create_source_passage(p, file_metadata, actor) for p in passages]
@enforce_types
@trace_method
async def create_many_source_passages_async(
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
) -> List[PydanticPassage]:
"""Create multiple source passages."""
source_passages = []
for p in passages:
if not p.source_id:
raise ValueError("Source passage must have source_id")
if p.agent_id:
raise ValueError("Source passage cannot have agent_id")
data = p.model_dump(to_orm=True)
common_fields = {
"id": data.get("id"),
"text": data["text"],
"embedding": data["embedding"],
"embedding_config": data["embedding_config"],
"organization_id": data["organization_id"],
"metadata_": data.get("metadata", {}),
"is_deleted": data.get("is_deleted", False),
"created_at": data.get("created_at", datetime.now(timezone.utc)),
}
source_fields = {
"source_id": data["source_id"],
"file_id": data.get("file_id"),
"file_name": file_metadata.file_name,
}
source_passages.append(SourcePassage(**common_fields, **source_fields))
async with db_registry.async_session() as session:
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
return [p.to_pydantic() for p in source_created]
# DEPRECATED - Use specific methods above
@enforce_types
@trace_method
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple passages."""
"""DEPRECATED: Use create_many_agent_passages() or create_many_source_passages() instead."""
import warnings
warnings.warn(
"create_many_passages is deprecated. Use create_many_agent_passages() or create_many_source_passages() instead.",
DeprecationWarning,
stacklevel=2,
)
return [self.create_passage(p, actor) for p in passages]
@enforce_types
@trace_method
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
"""Create multiple passages."""
"""DEPRECATED: Use create_many_agent_passages_async() or create_many_source_passages_async() instead."""
import warnings
warnings.warn(
"create_many_passages_async is deprecated. Use create_many_agent_passages_async() or create_many_source_passages_async() instead.",
DeprecationWarning,
stacklevel=2,
)
async with db_registry.async_session() as session:
agent_passages = []
source_passages = []
@@ -203,7 +496,7 @@ class PassageManager:
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passage = self.create_passage(
passage = self.create_agent_passage(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
@@ -251,7 +544,7 @@ class PassageManager:
for chunk_text, embedding in zip(text_chunks, embeddings)
]
passages = await self.create_many_passages_async(passages=passages, actor=actor)
passages = await self.create_many_agent_passages_async(passages=passages, actor=actor)
return passages
@@ -292,10 +585,191 @@ class PassageManager:
return processed_embeddings
@enforce_types
@trace_method
def update_agent_passage_by_id(
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
) -> Optional[PydanticPassage]:
"""Update an agent passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
with db_registry.session() as session:
try:
curr_passage = AgentPassage.read(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
raise ValueError(f"Agent passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(curr_passage, key, value)
# Commit changes
curr_passage.update(session, actor=actor)
return curr_passage.to_pydantic()
@enforce_types
@trace_method
async def update_agent_passage_by_id_async(
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
) -> Optional[PydanticPassage]:
"""Update an agent passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
async with db_registry.async_session() as session:
try:
curr_passage = await AgentPassage.read_async(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
raise ValueError(f"Agent passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(curr_passage, key, value)
# Commit changes
await curr_passage.update_async(session, actor=actor)
return curr_passage.to_pydantic()
@enforce_types
@trace_method
def update_source_passage_by_id(
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
) -> Optional[PydanticPassage]:
"""Update a source passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
with db_registry.session() as session:
try:
curr_passage = SourcePassage.read(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
raise ValueError(f"Source passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(curr_passage, key, value)
# Commit changes
curr_passage.update(session, actor=actor)
return curr_passage.to_pydantic()
@enforce_types
@trace_method
async def update_source_passage_by_id_async(
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
) -> Optional[PydanticPassage]:
"""Update a source passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
async with db_registry.async_session() as session:
try:
curr_passage = await SourcePassage.read_async(
db_session=session,
identifier=passage_id,
actor=actor,
)
except NoResultFound:
raise ValueError(f"Source passage with id {passage_id} does not exist.")
# Update the database record with values from the provided record
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(curr_passage, key, value)
# Commit changes
await curr_passage.update_async(session, actor=actor)
return curr_passage.to_pydantic()
@enforce_types
@trace_method
def delete_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete an agent passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
with db_registry.session() as session:
try:
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
passage.hard_delete(session, actor=actor)
return True
except NoResultFound:
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
@enforce_types
@trace_method
async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete an agent passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
async with db_registry.async_session() as session:
try:
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
await passage.hard_delete_async(session, actor=actor)
return True
except NoResultFound:
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
@enforce_types
@trace_method
def delete_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a source passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
with db_registry.session() as session:
try:
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
passage.hard_delete(session, actor=actor)
return True
except NoResultFound:
raise NoResultFound(f"Source passage with id {passage_id} not found.")
@enforce_types
@trace_method
async def delete_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a source passage."""
if not passage_id:
raise ValueError("Passage ID must be provided.")
async with db_registry.async_session() as session:
try:
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
await passage.hard_delete_async(session, actor=actor)
return True
except NoResultFound:
raise NoResultFound(f"Source passage with id {passage_id} not found.")
# DEPRECATED - Use specific methods above
@enforce_types
@trace_method
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
"""Update a passage."""
"""DEPRECATED: Use update_agent_passage_by_id() or update_source_passage_by_id() instead."""
import warnings
warnings.warn(
"update_passage_by_id is deprecated. Use update_agent_passage_by_id() or update_source_passage_by_id() instead.",
DeprecationWarning,
stacklevel=2,
)
if not passage_id:
raise ValueError("Passage ID must be provided.")
@@ -330,7 +804,15 @@ class PassageManager:
@enforce_types
@trace_method
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a passage from either source or archival passages."""
"""DEPRECATED: Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead."""
import warnings
warnings.warn(
"delete_passage_by_id is deprecated. Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.",
DeprecationWarning,
stacklevel=2,
)
if not passage_id:
raise ValueError("Passage ID must be provided.")
@@ -352,7 +834,15 @@ class PassageManager:
@enforce_types
@trace_method
async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
"""Delete a passage from either source or archival passages."""
"""DEPRECATED: Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead."""
import warnings
warnings.warn(
"delete_passage_by_id_async is deprecated. Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead.",
DeprecationWarning,
stacklevel=2,
)
if not passage_id:
raise ValueError("Passage ID must be provided.")
@@ -373,15 +863,42 @@ class PassageManager:
@enforce_types
@trace_method
def delete_passages(
def delete_agent_passages(
self,
actor: PydanticUser,
passages: List[PydanticPassage],
) -> bool:
"""Delete multiple agent passages."""
# TODO: This is very inefficient
# TODO: We should have a base `delete_all_matching_filters`-esque function
for passage in passages:
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
self.delete_agent_passage_by_id(passage_id=passage.id, actor=actor)
return True
@enforce_types
@trace_method
async def delete_agent_passages_async(
self,
actor: PydanticUser,
passages: List[PydanticPassage],
) -> bool:
"""Delete multiple agent passages."""
async with db_registry.async_session() as session:
await AgentPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
return True
@enforce_types
@trace_method
def delete_source_passages(
self,
actor: PydanticUser,
passages: List[PydanticPassage],
) -> bool:
"""Delete multiple source passages."""
# TODO: This is very inefficient
# TODO: We should have a base `delete_all_matching_filters`-esque function
for passage in passages:
self.delete_source_passage_by_id(passage_id=passage.id, actor=actor)
return True
@enforce_types
@@ -395,14 +912,36 @@ class PassageManager:
await SourcePassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
return True
# DEPRECATED - Use specific methods above
@enforce_types
@trace_method
def size(
def delete_passages(
self,
actor: PydanticUser,
passages: List[PydanticPassage],
) -> bool:
"""DEPRECATED: Use delete_agent_passages() or delete_source_passages() instead."""
import warnings
warnings.warn(
"delete_passages is deprecated. Use delete_agent_passages() or delete_source_passages() instead.",
DeprecationWarning,
stacklevel=2,
)
# TODO: This is very inefficient
# TODO: We should have a base `delete_all_matching_filters`-esque function
for passage in passages:
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
return True
@enforce_types
@trace_method
def agent_passage_size(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
"""Get the total count of agent passages with optional filters.
Args:
actor: The user requesting the count
@@ -411,14 +950,29 @@ class PassageManager:
with db_registry.session() as session:
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
# DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway
@enforce_types
@trace_method
async def size_async(
def size(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of messages with optional filters.
"""DEPRECATED: Use agent_passage_size() instead (this only counted agent passages anyway)."""
import warnings
warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2)
with db_registry.session() as session:
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
@enforce_types
@trace_method
async def agent_passage_size_async(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> int:
"""Get the total count of agent passages with optional filters.
Args:
actor: The user requesting the count
agent_id: The agent ID of the messages
@@ -426,6 +980,37 @@ class PassageManager:
async with db_registry.async_session() as session:
return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
@enforce_types
@trace_method
def source_passage_size(
self,
actor: PydanticUser,
source_id: Optional[str] = None,
) -> int:
"""Get the total count of source passages with optional filters.
Args:
actor: The user requesting the count
source_id: The source ID of the passages
"""
with db_registry.session() as session:
return SourcePassage.size(db_session=session, actor=actor, source_id=source_id)
@enforce_types
@trace_method
async def source_passage_size_async(
self,
actor: PydanticUser,
source_id: Optional[str] = None,
) -> int:
"""Get the total count of source passages with optional filters.
Args:
actor: The user requesting the count
source_id: The source ID of the passages
"""
async with db_registry.async_session() as session:
return await SourcePassage.size_async(db_session=session, actor=actor, source_id=source_id)
@enforce_types
@trace_method
async def estimate_embeddings_size_async(
@@ -448,7 +1033,7 @@ class PassageManager:
raise ValueError(f"Invalid storage unit: {storage_unit}. Must be one of {list(BYTES_PER_STORAGE_UNIT.keys())}.")
BYTES_PER_EMBEDDING_DIM = 4
GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
return await self.agent_passage_size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
@enforce_types
@trace_method

View File

@@ -126,6 +126,13 @@ class LettaFileToolExecutor(ToolExecutor):
# TODO: Make this paginated?
async def search_files(self, agent_state: AgentState, query: str) -> List[str]:
"""Stub for search_files tool."""
"""Search for text within attached files and return passages with their source filenames."""
passages = await self.agent_manager.list_source_passages_async(actor=self.actor, agent_id=agent_state.id, query_text=query)
return [p.text for p in passages]
formatted_results = []
for p in passages:
if p.file_name:
formatted_result = f"[{p.file_name}]:\n{p.text}"
else:
formatted_result = p.text
formatted_results.append(formatted_result)
return formatted_results

View File

@@ -12,6 +12,7 @@ import httpx
# tests/test_file_content_flow.py
import pytest
from _pytest.python_api import approx
from anthropic.types.beta import BetaMessage
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
@@ -280,7 +281,7 @@ async def default_run(server: SyncServer, default_user):
@pytest.fixture
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
"""Fixture to create an agent passage."""
passage = server.passage_manager.create_passage(
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Hello, I am an agent passage",
agent_id=sarah_agent.id,
@@ -297,7 +298,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
@pytest.fixture
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
"""Fixture to create a source passage."""
passage = server.passage_manager.create_passage(
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Hello, I am a source passage",
source_id=default_source.id,
@@ -307,6 +308,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test"},
),
file_metadata=default_file,
actor=default_user,
)
yield passage
@@ -318,7 +320,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
# Create agent passages
passages = []
for i in range(5):
passage = server.passage_manager.create_passage(
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text=f"Agent passage {i}",
agent_id=sarah_agent.id,
@@ -335,7 +337,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
# Create source passages
for i in range(5):
passage = server.passage_manager.create_passage(
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text=f"Source passage {i}",
source_id=default_source.id,
@@ -345,6 +347,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test"},
),
file_metadata=default_file,
actor=default_user,
)
passages.append(passage)
@@ -525,7 +528,7 @@ def server():
@pytest.fixture
@pytest.mark.asyncio
async def agent_passages_setup(server, default_source, default_user, sarah_agent, event_loop):
async def agent_passages_setup(server, default_source, default_file, default_user, sarah_agent, event_loop):
"""Setup fixture for agent passages tests"""
agent_id = sarah_agent.id
actor = default_user
@@ -535,14 +538,16 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
# Create some source passages
source_passages = []
for i in range(3):
passage = await server.passage_manager.create_passage_async(
passage = await server.passage_manager.create_source_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
source_id=default_source.id,
file_id=default_file.id,
text=f"Source passage {i}",
embedding=[0.1], # Default OpenAI embedding size
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=actor,
)
source_passages.append(passage)
@@ -550,7 +555,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
# Create some agent passages
agent_passages = []
for i in range(2):
passage = await server.passage_manager.create_passage_async(
passage = await server.passage_manager.create_agent_passage_async(
PydanticPassage(
organization_id=actor.organization_id,
agent_id=agent_id,
@@ -2022,7 +2027,7 @@ async def test_agent_list_passages_filtering(server, default_user, sarah_agent,
@pytest.mark.asyncio
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop):
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, event_loop):
"""Test vector search functionality of agent passages"""
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
@@ -2041,6 +2046,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
for i, text in enumerate(test_passages):
embedding = embed_model.get_text_embedding(text)
if i % 2 == 0:
# Create agent passage
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
@@ -2048,15 +2054,18 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding,
)
created_passage = await server.passage_manager.create_agent_passage_async(passage, default_user)
else:
# Create source passage
passage = PydanticPassage(
text=text,
organization_id=default_user.organization_id,
source_id=default_source.id,
file_id=default_file.id,
embedding_config=DEFAULT_EMBEDDING_CONFIG,
embedding=embedding,
)
created_passage = await server.passage_manager.create_passage_async(passage, default_user)
created_passage = await server.passage_manager.create_source_passage_async(passage, default_file, default_user)
passages.append(created_passage)
# Query vector similar to "red" embedding
@@ -2261,6 +2270,416 @@ async def test_passage_cascade_deletion(
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test creating an agent passage using the new agent-specific method."""
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Test agent passage via specific method",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test_specific"},
),
actor=default_user,
)
assert passage.id is not None
assert passage.text == "Test agent passage via specific method"
assert passage.agent_id == sarah_agent.id
assert passage.source_id is None
def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
"""Test creating a source passage using the new source-specific method."""
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Test source passage via specific method",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
metadata={"type": "test_specific"},
),
file_metadata=default_file,
actor=default_user,
)
assert passage.id is not None
assert passage.text == "Test source passage via specific method"
assert passage.source_id == default_source.id
assert passage.agent_id is None
def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent):
"""Test that agent passage creation validates inputs correctly."""
# Should fail if agent_id is missing
with pytest.raises(ValueError, match="Agent passage must have agent_id"):
server.passage_manager.create_agent_passage(
PydanticPassage(
text="Invalid agent passage",
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Should fail if source_id is present
with pytest.raises(ValueError, match="Agent passage cannot have source_id"):
server.passage_manager.create_agent_passage(
PydanticPassage(
text="Invalid agent passage",
agent_id=sarah_agent.id,
source_id=default_source.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
def test_create_source_passage_validation(server: SyncServer, default_user, default_file, default_source, sarah_agent):
"""Test that source passage creation validates inputs correctly."""
# Should fail if source_id is missing
with pytest.raises(ValueError, match="Source passage must have source_id"):
server.passage_manager.create_source_passage(
PydanticPassage(
text="Invalid source passage",
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
# Should fail if agent_id is present
with pytest.raises(ValueError, match="Source passage cannot have agent_id"):
server.passage_manager.create_source_passage(
PydanticPassage(
text="Invalid source passage",
source_id=default_source.id,
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent):
"""Test retrieving an agent passage using the new agent-specific method."""
# Create an agent passage
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Agent passage for retrieval test",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Retrieve it using the specific method
retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == passage.id
assert retrieved.text == passage.text
assert retrieved.agent_id == sarah_agent.id
def test_get_source_passage_by_id_specific(server: SyncServer, default_user, default_file, default_source):
"""Test retrieving a source passage using the new source-specific method."""
# Create a source passage
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Source passage for retrieval test",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
# Retrieve it using the specific method
retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
assert retrieved is not None
assert retrieved.id == passage.id
assert retrieved.text == passage.text
assert retrieved.source_id == default_source.id
def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_agent, default_file, default_source):
"""Test that trying to get the wrong passage type with specific methods fails."""
# Create an agent passage
agent_passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Agent passage",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Create a source passage
source_passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Source passage",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
# Trying to get agent passage with source method should fail
with pytest.raises(NoResultFound):
server.passage_manager.get_source_passage_by_id(agent_passage.id, actor=default_user)
# Trying to get source passage with agent method should fail
with pytest.raises(NoResultFound):
server.passage_manager.get_agent_passage_by_id(source_passage.id, actor=default_user)
def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test updating an agent passage using the new agent-specific method."""
# Create an agent passage
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Original agent passage text",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Update it
updated_passage = server.passage_manager.update_agent_passage_by_id(
passage.id,
PydanticPassage(
text="Updated agent passage text",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.2],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
assert updated_passage.text == "Updated agent passage text"
assert updated_passage.embedding[0] == approx(0.2)
assert updated_passage.id == passage.id
def test_update_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
"""Test updating a source passage using the new source-specific method."""
# Create a source passage
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Original source passage text",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
# Update it
updated_passage = server.passage_manager.update_source_passage_by_id(
passage.id,
PydanticPassage(
text="Updated source passage text",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.2],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
assert updated_passage.text == "Updated source passage text"
assert updated_passage.embedding[0] == approx(0.2)
assert updated_passage.id == passage.id
def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
"""Test deleting an agent passage using the new agent-specific method."""
# Create an agent passage
passage = server.passage_manager.create_agent_passage(
PydanticPassage(
text="Agent passage to delete",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Verify it exists
retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
assert retrieved is not None
# Delete it
result = server.passage_manager.delete_agent_passage_by_id(passage.id, actor=default_user)
assert result is True
# Verify it's gone
with pytest.raises(NoResultFound):
server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
def test_delete_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
"""Test deleting a source passage using the new source-specific method."""
# Create a source passage
passage = server.passage_manager.create_source_passage(
PydanticPassage(
text="Source passage to delete",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
file_metadata=default_file,
actor=default_user,
)
# Verify it exists
retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
assert retrieved is not None
# Delete it
result = server.passage_manager.delete_source_passage_by_id(passage.id, actor=default_user)
assert result is True
# Verify it's gone
with pytest.raises(NoResultFound):
server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
@pytest.mark.asyncio
async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent, event_loop):
"""Test creating multiple agent passages using the new batch method."""
passages = [
PydanticPassage(
text=f"Batch agent passage {i}",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1 * i],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
for i in range(3)
]
created_passages = await server.passage_manager.create_many_agent_passages_async(passages, actor=default_user)
assert len(created_passages) == 3
for i, passage in enumerate(created_passages):
assert passage.text == f"Batch agent passage {i}"
assert passage.agent_id == sarah_agent.id
assert passage.source_id is None
@pytest.mark.asyncio
async def test_create_many_source_passages_async(server: SyncServer, default_user, default_file, default_source, event_loop):
"""Test creating multiple source passages using the new batch method."""
passages = [
PydanticPassage(
text=f"Batch source passage {i}",
source_id=default_source.id,
file_id=default_file.id,
organization_id=default_user.organization_id,
embedding=[0.1 * i],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
)
for i in range(3)
]
created_passages = await server.passage_manager.create_many_source_passages_async(
passages, file_metadata=default_file, actor=default_user
)
assert len(created_passages) == 3
for i, passage in enumerate(created_passages):
assert passage.text == f"Batch source passage {i}"
assert passage.source_id == default_source.id
assert passage.agent_id is None
def test_agent_passage_size(server: SyncServer, default_user, sarah_agent):
"""Test counting agent passages using the new agent-specific size method."""
initial_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id)
# Create some agent passages
for i in range(3):
server.passage_manager.create_agent_passage(
PydanticPassage(
text=f"Agent passage {i} for size test",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
final_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id)
assert final_size == initial_size + 3
def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sarah_agent):
"""Test that deprecated methods show deprecation warnings."""
import warnings
with warnings.catch_warnings(record=True) as w:
warnings.simplefilter("always")
# Test deprecated create_passage
passage = server.passage_manager.create_passage(
PydanticPassage(
text="Test deprecated method",
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
embedding=[0.1],
embedding_config=DEFAULT_EMBEDDING_CONFIG,
),
actor=default_user,
)
# Test deprecated get_passage_by_id
server.passage_manager.get_passage_by_id(passage.id, actor=default_user)
# Test deprecated size
server.passage_manager.size(actor=default_user, agent_id=sarah_agent.id)
# Check that deprecation warnings were issued
assert len(w) >= 3
assert any("create_passage is deprecated" in str(warning.message) for warning in w)
assert any("get_passage_by_id is deprecated" in str(warning.message) for warning in w)
assert any("size is deprecated" in str(warning.message) for warning in w)
# ======================================================================================================================
# User Manager Tests
# ======================================================================================================================