feat: Search files returns citations of the filenames that were searched (#2689)
This commit is contained in:
2
.gitignore
vendored
2
.gitignore
vendored
@@ -5,6 +5,8 @@
|
||||
openapi_letta.json
|
||||
openapi_openai.json
|
||||
|
||||
CLAUDE.md
|
||||
|
||||
### Eclipse ###
|
||||
.metadata
|
||||
bin/
|
||||
|
||||
@@ -0,0 +1,40 @@
|
||||
"""Add file name to source passages
|
||||
|
||||
Revision ID: c96263433aef
|
||||
Revises: 9792f94e961d
|
||||
Create Date: 2025-06-06 12:06:57.328127
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "c96263433aef"
|
||||
down_revision: Union[str, None] = "9792f94e961d"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# Add the new column
|
||||
op.add_column("source_passages", sa.Column("file_name", sa.String(), nullable=True))
|
||||
|
||||
# Backfill file_name using SQL UPDATE JOIN
|
||||
op.execute(
|
||||
"""
|
||||
UPDATE source_passages
|
||||
SET file_name = files.file_name
|
||||
FROM files
|
||||
WHERE source_passages.file_id = files.id
|
||||
"""
|
||||
)
|
||||
|
||||
# Enforce non-null constraint after backfill
|
||||
op.alter_column("source_passages", "file_name", nullable=False)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
op.drop_column("source_passages", "file_name")
|
||||
@@ -1292,7 +1292,7 @@ class Agent(BaseAgent):
|
||||
# conversion of messages to OpenAI dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_openai = [m.to_openai_dict() for m in in_context_messages]
|
||||
@@ -1414,7 +1414,7 @@ class Agent(BaseAgent):
|
||||
# conversion of messages to anthropic dict format, which is passed to the token counter
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
self.message_manager.get_messages_by_ids_async(message_ids=self.agent_state.message_ids, actor=self.user),
|
||||
self.passage_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.passage_manager.agent_passage_size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
self.message_manager.size_async(actor=self.user, agent_id=self.agent_state.id),
|
||||
)
|
||||
in_context_messages_anthropic = [m.to_anthropic_dict() for m in in_context_messages]
|
||||
|
||||
@@ -104,7 +104,7 @@ class BaseAgent(ABC):
|
||||
if num_messages is None:
|
||||
num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if num_archival_memories is None:
|
||||
num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
|
||||
@@ -763,7 +763,7 @@ class LettaAgent(BaseAgent):
|
||||
else asyncio.sleep(0, result=self.num_messages)
|
||||
),
|
||||
(
|
||||
self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_archival_memories is None
|
||||
else asyncio.sleep(0, result=self.num_archival_memories)
|
||||
),
|
||||
|
||||
@@ -305,7 +305,7 @@ class VoiceAgent(BaseAgent):
|
||||
else asyncio.sleep(0, result=self.num_messages)
|
||||
),
|
||||
(
|
||||
self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
if self.num_archival_memories is None
|
||||
else asyncio.sleep(0, result=self.num_archival_memories)
|
||||
),
|
||||
|
||||
@@ -47,6 +47,8 @@ class SourcePassage(BasePassage, FileMixin, SourceMixin):
|
||||
|
||||
__tablename__ = "source_passages"
|
||||
|
||||
file_name: Mapped[str] = mapped_column(doc="The name of the file that this passage was derived from")
|
||||
|
||||
@declared_attr
|
||||
def file(cls) -> Mapped["FileMetadata"]:
|
||||
"""Relationship to file"""
|
||||
|
||||
@@ -23,6 +23,7 @@ class PassageBase(OrmMetadataBase):
|
||||
|
||||
# file association
|
||||
file_id: Optional[str] = Field(None, description="The unique identifier of the file associated with the passage.")
|
||||
file_name: Optional[str] = Field(None, description="The name of the file (only for source passages).")
|
||||
metadata: Optional[Dict] = Field({}, validation_alias="metadata_", description="The metadata of the passage.")
|
||||
|
||||
|
||||
|
||||
@@ -1483,7 +1483,7 @@ class AgentManager:
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
num_messages = await self.message_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = await self.passage_manager.size_async(actor=actor, agent_id=agent_id)
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_id)
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
@@ -2075,6 +2075,7 @@ class AgentManager:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop("source_id", None)
|
||||
data.pop("file_id", None)
|
||||
data.pop("file_name", None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
@@ -2135,6 +2136,7 @@ class AgentManager:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop("source_id", None)
|
||||
data.pop("file_id", None)
|
||||
data.pop("file_name", None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
@@ -2198,14 +2200,12 @@ class AgentManager:
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
before: Optional[str] = None,
|
||||
after: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
|
||||
@@ -63,7 +63,7 @@ class ContextWindowCalculator:
|
||||
# Fetch data concurrently
|
||||
(in_context_messages, passage_manager_size, message_manager_size) = await asyncio.gather(
|
||||
message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor),
|
||||
passage_manager.size_async(actor=actor, agent_id=agent_state.id),
|
||||
passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id),
|
||||
message_manager.size_async(actor=actor, agent_id=agent_state.id),
|
||||
)
|
||||
|
||||
|
||||
@@ -111,7 +111,9 @@ class FileProcessor:
|
||||
)
|
||||
all_passages.extend(passages)
|
||||
|
||||
all_passages = await self.passage_manager.create_many_passages_async(all_passages, self.actor)
|
||||
all_passages = await self.passage_manager.create_many_source_passages_async(
|
||||
passages=all_passages, file_metadata=file_metadata, actor=self.actor
|
||||
)
|
||||
|
||||
logger.info(f"Successfully processed {filename}: {len(all_passages)} passages")
|
||||
|
||||
|
||||
@@ -607,15 +607,45 @@ def build_passage_query(
|
||||
if not agent_only: # Include source passages
|
||||
if agent_id is not None:
|
||||
source_passages = (
|
||||
select(SourcePassage, literal(None).label("agent_id"))
|
||||
select(
|
||||
SourcePassage.file_name,
|
||||
SourcePassage.id,
|
||||
SourcePassage.text,
|
||||
SourcePassage.embedding_config,
|
||||
SourcePassage.metadata_,
|
||||
SourcePassage.embedding,
|
||||
SourcePassage.created_at,
|
||||
SourcePassage.updated_at,
|
||||
SourcePassage.is_deleted,
|
||||
SourcePassage._created_by_id,
|
||||
SourcePassage._last_updated_by_id,
|
||||
SourcePassage.organization_id,
|
||||
SourcePassage.file_id,
|
||||
SourcePassage.source_id,
|
||||
literal(None).label("agent_id"),
|
||||
)
|
||||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||
.where(SourcesAgents.agent_id == agent_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
source_passages = select(SourcePassage, literal(None).label("agent_id")).where(
|
||||
SourcePassage.organization_id == actor.organization_id
|
||||
)
|
||||
source_passages = select(
|
||||
SourcePassage.file_name,
|
||||
SourcePassage.id,
|
||||
SourcePassage.text,
|
||||
SourcePassage.embedding_config,
|
||||
SourcePassage.metadata_,
|
||||
SourcePassage.embedding,
|
||||
SourcePassage.created_at,
|
||||
SourcePassage.updated_at,
|
||||
SourcePassage.is_deleted,
|
||||
SourcePassage._created_by_id,
|
||||
SourcePassage._last_updated_by_id,
|
||||
SourcePassage.organization_id,
|
||||
SourcePassage.file_id,
|
||||
SourcePassage.source_id,
|
||||
literal(None).label("agent_id"),
|
||||
).where(SourcePassage.organization_id == actor.organization_id)
|
||||
|
||||
if source_id:
|
||||
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
||||
@@ -627,6 +657,7 @@ def build_passage_query(
|
||||
if agent_id is not None:
|
||||
agent_passages = (
|
||||
select(
|
||||
literal(None).label("file_name"),
|
||||
AgentPassage.id,
|
||||
AgentPassage.text,
|
||||
AgentPassage.embedding_config,
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import AgentPassage, SourcePassage
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.file import FileMetadata as PydanticFileMetadata
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
@@ -42,10 +43,65 @@ async def get_openai_embedding_async(text: str, model: str, endpoint: str) -> Li
|
||||
class PassageManager:
|
||||
"""Manager class to handle business logic related to Passages."""
|
||||
|
||||
# AGENT PASSAGE METHODS
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch an agent passage by ID."""
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch an agent passage by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Agent passage with id {passage_id} not found in database.")
|
||||
|
||||
# SOURCE PASSAGE METHODS
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a source passage by ID."""
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Source passage with id {passage_id} not found in database.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a source passage by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Source passage with id {passage_id} not found in database.")
|
||||
|
||||
# DEPRECATED - Use specific methods above
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_passage_by_id(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a passage by ID."""
|
||||
"""DEPRECATED: Use get_agent_passage_by_id() or get_source_passage_by_id() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"get_passage_by_id is deprecated. Use get_agent_passage_by_id() or get_source_passage_by_id() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
with db_registry.session() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
@@ -62,7 +118,15 @@ class PassageManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def get_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> Optional[PydanticPassage]:
|
||||
"""Fetch a passage by ID."""
|
||||
"""DEPRECATED: Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"get_passage_by_id_async is deprecated. Use get_agent_passage_by_id_async() or get_source_passage_by_id_async() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# Try source passages first
|
||||
try:
|
||||
@@ -76,10 +140,137 @@ class PassageManager:
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Passage with id {passage_id} not found in database.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_agent_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new agent passage."""
|
||||
if not pydantic_passage.agent_id:
|
||||
raise ValueError("Agent passage must have agent_id")
|
||||
if pydantic_passage.source_id:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
agent_fields = {"agent_id": data["agent_id"]}
|
||||
passage = AgentPassage(**common_fields, **agent_fields)
|
||||
|
||||
with db_registry.session() as session:
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_agent_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new agent passage."""
|
||||
if not pydantic_passage.agent_id:
|
||||
raise ValueError("Agent passage must have agent_id")
|
||||
if pydantic_passage.source_id:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
agent_fields = {"agent_id": data["agent_id"]}
|
||||
passage = AgentPassage(**common_fields, **agent_fields)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_source_passage(
|
||||
self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser
|
||||
) -> PydanticPassage:
|
||||
"""Create a new source passage."""
|
||||
if not pydantic_passage.source_id:
|
||||
raise ValueError("Source passage must have source_id")
|
||||
if pydantic_passage.agent_id:
|
||||
raise ValueError("Source passage cannot have agent_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
"file_name": file_metadata.file_name,
|
||||
}
|
||||
passage = SourcePassage(**common_fields, **source_fields)
|
||||
|
||||
with db_registry.session() as session:
|
||||
passage.create(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_source_passage_async(
|
||||
self, pydantic_passage: PydanticPassage, file_metadata: PydanticFileMetadata, actor: PydanticUser
|
||||
) -> PydanticPassage:
|
||||
"""Create a new source passage."""
|
||||
if not pydantic_passage.source_id:
|
||||
raise ValueError("Source passage must have source_id")
|
||||
if pydantic_passage.agent_id:
|
||||
raise ValueError("Source passage cannot have agent_id")
|
||||
|
||||
data = pydantic_passage.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
"file_name": file_metadata.file_name,
|
||||
}
|
||||
passage = SourcePassage(**common_fields, **source_fields)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
passage = await passage.create_async(session, actor=actor)
|
||||
return passage.to_pydantic()
|
||||
|
||||
# DEPRECATED - Use specific methods above
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_passage(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
"""DEPRECATED: Use create_agent_passage() or create_source_passage() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"create_passage is deprecated. Use create_agent_passage() or create_source_passage() instead.", DeprecationWarning, stacklevel=2
|
||||
)
|
||||
|
||||
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
||||
|
||||
with db_registry.session() as session:
|
||||
@@ -89,7 +280,15 @@ class PassageManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_passage_async(self, pydantic_passage: PydanticPassage, actor: PydanticUser) -> PydanticPassage:
|
||||
"""Create a new passage in the appropriate table based on whether it has agent_id or source_id."""
|
||||
"""DEPRECATED: Use create_agent_passage_async() or create_source_passage_async() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"create_passage_async is deprecated. Use create_agent_passage_async() or create_source_passage_async() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
# Common fields for both passage types
|
||||
passage = self._preprocess_passage_for_creation(pydantic_passage=pydantic_passage)
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -128,16 +327,110 @@ class PassageManager:
|
||||
|
||||
return passage
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_many_agent_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple agent passages."""
|
||||
return [self.create_agent_passage(p, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_agent_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple agent passages."""
|
||||
agent_passages = []
|
||||
for p in passages:
|
||||
if not p.agent_id:
|
||||
raise ValueError("Agent passage must have agent_id")
|
||||
if p.source_id:
|
||||
raise ValueError("Agent passage cannot have source_id")
|
||||
|
||||
data = p.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
agent_fields = {"agent_id": data["agent_id"]}
|
||||
agent_passages.append(AgentPassage(**common_fields, **agent_fields))
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
agent_created = await AgentPassage.batch_create_async(items=agent_passages, db_session=session, actor=actor)
|
||||
return [p.to_pydantic() for p in agent_created]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_many_source_passages(
|
||||
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
|
||||
) -> List[PydanticPassage]:
|
||||
"""Create multiple source passages."""
|
||||
return [self.create_source_passage(p, file_metadata, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_source_passages_async(
|
||||
self, passages: List[PydanticPassage], file_metadata: PydanticFileMetadata, actor: PydanticUser
|
||||
) -> List[PydanticPassage]:
|
||||
"""Create multiple source passages."""
|
||||
source_passages = []
|
||||
for p in passages:
|
||||
if not p.source_id:
|
||||
raise ValueError("Source passage must have source_id")
|
||||
if p.agent_id:
|
||||
raise ValueError("Source passage cannot have agent_id")
|
||||
|
||||
data = p.model_dump(to_orm=True)
|
||||
common_fields = {
|
||||
"id": data.get("id"),
|
||||
"text": data["text"],
|
||||
"embedding": data["embedding"],
|
||||
"embedding_config": data["embedding_config"],
|
||||
"organization_id": data["organization_id"],
|
||||
"metadata_": data.get("metadata", {}),
|
||||
"is_deleted": data.get("is_deleted", False),
|
||||
"created_at": data.get("created_at", datetime.now(timezone.utc)),
|
||||
}
|
||||
source_fields = {
|
||||
"source_id": data["source_id"],
|
||||
"file_id": data.get("file_id"),
|
||||
"file_name": file_metadata.file_name,
|
||||
}
|
||||
source_passages.append(SourcePassage(**common_fields, **source_fields))
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
source_created = await SourcePassage.batch_create_async(items=source_passages, db_session=session, actor=actor)
|
||||
return [p.to_pydantic() for p in source_created]
|
||||
|
||||
# DEPRECATED - Use specific methods above
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_many_passages(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple passages."""
|
||||
"""DEPRECATED: Use create_many_agent_passages() or create_many_source_passages() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"create_many_passages is deprecated. Use create_many_agent_passages() or create_many_source_passages() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
return [self.create_passage(p, actor) for p in passages]
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_many_passages_async(self, passages: List[PydanticPassage], actor: PydanticUser) -> List[PydanticPassage]:
|
||||
"""Create multiple passages."""
|
||||
"""DEPRECATED: Use create_many_agent_passages_async() or create_many_source_passages_async() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"create_many_passages_async is deprecated. Use create_many_agent_passages_async() or create_many_source_passages_async() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
agent_passages = []
|
||||
source_passages = []
|
||||
@@ -203,7 +496,7 @@ class PassageManager:
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passage = self.create_passage(
|
||||
passage = self.create_agent_passage(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
@@ -251,7 +544,7 @@ class PassageManager:
|
||||
for chunk_text, embedding in zip(text_chunks, embeddings)
|
||||
]
|
||||
|
||||
passages = await self.create_many_passages_async(passages=passages, actor=actor)
|
||||
passages = await self.create_many_agent_passages_async(passages=passages, actor=actor)
|
||||
|
||||
return passages
|
||||
|
||||
@@ -292,10 +585,191 @@ class PassageManager:
|
||||
|
||||
return processed_embeddings
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_agent_passage_by_id(
|
||||
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
|
||||
) -> Optional[PydanticPassage]:
|
||||
"""Update an agent passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
curr_passage = AgentPassage.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Agent passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
# Commit changes
|
||||
curr_passage.update(session, actor=actor)
|
||||
return curr_passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_agent_passage_by_id_async(
|
||||
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
|
||||
) -> Optional[PydanticPassage]:
|
||||
"""Update an agent passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
curr_passage = await AgentPassage.read_async(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Agent passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
# Commit changes
|
||||
await curr_passage.update_async(session, actor=actor)
|
||||
return curr_passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_source_passage_by_id(
|
||||
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
|
||||
) -> Optional[PydanticPassage]:
|
||||
"""Update a source passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
curr_passage = SourcePassage.read(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Source passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
# Commit changes
|
||||
curr_passage.update(session, actor=actor)
|
||||
return curr_passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_source_passage_by_id_async(
|
||||
self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs
|
||||
) -> Optional[PydanticPassage]:
|
||||
"""Update a source passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
curr_passage = await SourcePassage.read_async(
|
||||
db_session=session,
|
||||
identifier=passage_id,
|
||||
actor=actor,
|
||||
)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Source passage with id {passage_id} does not exist.")
|
||||
|
||||
# Update the database record with values from the provided record
|
||||
update_data = passage.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(curr_passage, key, value)
|
||||
|
||||
# Commit changes
|
||||
await curr_passage.update_async(session, actor=actor)
|
||||
return curr_passage.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_agent_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete an agent passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
passage = AgentPassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage.hard_delete(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete an agent passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
passage = await AgentPassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
||||
await passage.hard_delete_async(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Agent passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_source_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a source passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
passage = SourcePassage.read(db_session=session, identifier=passage_id, actor=actor)
|
||||
passage.hard_delete(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Source passage with id {passage_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_source_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a source passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
passage = await SourcePassage.read_async(db_session=session, identifier=passage_id, actor=actor)
|
||||
await passage.hard_delete_async(session, actor=actor)
|
||||
return True
|
||||
except NoResultFound:
|
||||
raise NoResultFound(f"Source passage with id {passage_id} not found.")
|
||||
|
||||
# DEPRECATED - Use specific methods above
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_passage_by_id(self, passage_id: str, passage: PydanticPassage, actor: PydanticUser, **kwargs) -> Optional[PydanticPassage]:
|
||||
"""Update a passage."""
|
||||
"""DEPRECATED: Use update_agent_passage_by_id() or update_source_passage_by_id() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"update_passage_by_id is deprecated. Use update_agent_passage_by_id() or update_source_passage_by_id() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
@@ -330,7 +804,15 @@ class PassageManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_passage_by_id(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a passage from either source or archival passages."""
|
||||
"""DEPRECATED: Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"delete_passage_by_id is deprecated. Use delete_agent_passage_by_id() or delete_source_passage_by_id() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
@@ -352,7 +834,15 @@ class PassageManager:
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
"""Delete a passage from either source or archival passages."""
|
||||
"""DEPRECATED: Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"delete_passage_by_id_async is deprecated. Use delete_agent_passage_by_id_async() or delete_source_passage_by_id_async() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
|
||||
@@ -373,15 +863,42 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_passages(
|
||||
def delete_agent_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
passages: List[PydanticPassage],
|
||||
) -> bool:
|
||||
"""Delete multiple agent passages."""
|
||||
# TODO: This is very inefficient
|
||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
self.delete_agent_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_agent_passages_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
passages: List[PydanticPassage],
|
||||
) -> bool:
|
||||
"""Delete multiple agent passages."""
|
||||
async with db_registry.async_session() as session:
|
||||
await AgentPassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def delete_source_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
passages: List[PydanticPassage],
|
||||
) -> bool:
|
||||
"""Delete multiple source passages."""
|
||||
# TODO: This is very inefficient
|
||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||
for passage in passages:
|
||||
self.delete_source_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
@@ -395,14 +912,36 @@ class PassageManager:
|
||||
await SourcePassage.bulk_hard_delete_async(db_session=session, identifiers=[p.id for p in passages], actor=actor)
|
||||
return True
|
||||
|
||||
# DEPRECATED - Use specific methods above
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def size(
|
||||
def delete_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
passages: List[PydanticPassage],
|
||||
) -> bool:
|
||||
"""DEPRECATED: Use delete_agent_passages() or delete_source_passages() instead."""
|
||||
import warnings
|
||||
|
||||
warnings.warn(
|
||||
"delete_passages is deprecated. Use delete_agent_passages() or delete_source_passages() instead.",
|
||||
DeprecationWarning,
|
||||
stacklevel=2,
|
||||
)
|
||||
# TODO: This is very inefficient
|
||||
# TODO: We should have a base `delete_all_matching_filters`-esque function
|
||||
for passage in passages:
|
||||
self.delete_passage_by_id(passage_id=passage.id, actor=actor)
|
||||
return True
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def agent_passage_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
"""Get the total count of agent passages with optional filters.
|
||||
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
@@ -411,14 +950,29 @@ class PassageManager:
|
||||
with db_registry.session() as session:
|
||||
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
# DEPRECATED - Use agent_passage_size() instead since this only counted agent passages anyway
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def size_async(
|
||||
def size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of messages with optional filters.
|
||||
"""DEPRECATED: Use agent_passage_size() instead (this only counted agent passages anyway)."""
|
||||
import warnings
|
||||
|
||||
warnings.warn("size is deprecated. Use agent_passage_size() instead.", DeprecationWarning, stacklevel=2)
|
||||
with db_registry.session() as session:
|
||||
return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def agent_passage_size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of agent passages with optional filters.
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
agent_id: The agent ID of the messages
|
||||
@@ -426,6 +980,37 @@ class PassageManager:
|
||||
async with db_registry.async_session() as session:
|
||||
return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def source_passage_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
source_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of source passages with optional filters.
|
||||
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
source_id: The source ID of the passages
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
return SourcePassage.size(db_session=session, actor=actor, source_id=source_id)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def source_passage_size_async(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
source_id: Optional[str] = None,
|
||||
) -> int:
|
||||
"""Get the total count of source passages with optional filters.
|
||||
Args:
|
||||
actor: The user requesting the count
|
||||
source_id: The source ID of the passages
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
return await SourcePassage.size_async(db_session=session, actor=actor, source_id=source_id)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def estimate_embeddings_size_async(
|
||||
@@ -448,7 +1033,7 @@ class PassageManager:
|
||||
raise ValueError(f"Invalid storage unit: {storage_unit}. Must be one of {list(BYTES_PER_STORAGE_UNIT.keys())}.")
|
||||
BYTES_PER_EMBEDDING_DIM = 4
|
||||
GB_PER_EMBEDDING = BYTES_PER_EMBEDDING_DIM / BYTES_PER_STORAGE_UNIT[storage_unit] * MAX_EMBEDDING_DIM
|
||||
return await self.size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
|
||||
return await self.agent_passage_size_async(actor=actor, agent_id=agent_id) * GB_PER_EMBEDDING
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -126,6 +126,13 @@ class LettaFileToolExecutor(ToolExecutor):
|
||||
|
||||
# TODO: Make this paginated?
|
||||
async def search_files(self, agent_state: AgentState, query: str) -> List[str]:
|
||||
"""Stub for search_files tool."""
|
||||
"""Search for text within attached files and return passages with their source filenames."""
|
||||
passages = await self.agent_manager.list_source_passages_async(actor=self.actor, agent_id=agent_state.id, query_text=query)
|
||||
return [p.text for p in passages]
|
||||
formatted_results = []
|
||||
for p in passages:
|
||||
if p.file_name:
|
||||
formatted_result = f"[{p.file_name}]:\n{p.text}"
|
||||
else:
|
||||
formatted_result = p.text
|
||||
formatted_results.append(formatted_result)
|
||||
return formatted_results
|
||||
|
||||
@@ -12,6 +12,7 @@ import httpx
|
||||
|
||||
# tests/test_file_content_flow.py
|
||||
import pytest
|
||||
from _pytest.python_api import approx
|
||||
from anthropic.types.beta import BetaMessage
|
||||
from anthropic.types.beta.messages import BetaMessageBatchIndividualResponse, BetaMessageBatchSucceededResult
|
||||
from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall
|
||||
@@ -280,7 +281,7 @@ async def default_run(server: SyncServer, default_user):
|
||||
@pytest.fixture
|
||||
def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
"""Fixture to create an agent passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Hello, I am an agent passage",
|
||||
agent_id=sarah_agent.id,
|
||||
@@ -297,7 +298,7 @@ def agent_passage_fixture(server: SyncServer, default_user, sarah_agent):
|
||||
@pytest.fixture
|
||||
def source_passage_fixture(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Fixture to create a source passage."""
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Hello, I am a source passage",
|
||||
source_id=default_source.id,
|
||||
@@ -307,6 +308,7 @@ def source_passage_fixture(server: SyncServer, default_user, default_file, defau
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
yield passage
|
||||
@@ -318,7 +320,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
# Create agent passages
|
||||
passages = []
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text=f"Agent passage {i}",
|
||||
agent_id=sarah_agent.id,
|
||||
@@ -335,7 +337,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
|
||||
# Create source passages
|
||||
for i in range(5):
|
||||
passage = server.passage_manager.create_passage(
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text=f"Source passage {i}",
|
||||
source_id=default_source.id,
|
||||
@@ -345,6 +347,7 @@ def create_test_passages(server: SyncServer, default_file, default_user, sarah_a
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test"},
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
passages.append(passage)
|
||||
@@ -525,7 +528,7 @@ def server():
|
||||
|
||||
@pytest.fixture
|
||||
@pytest.mark.asyncio
|
||||
async def agent_passages_setup(server, default_source, default_user, sarah_agent, event_loop):
|
||||
async def agent_passages_setup(server, default_source, default_file, default_user, sarah_agent, event_loop):
|
||||
"""Setup fixture for agent passages tests"""
|
||||
agent_id = sarah_agent.id
|
||||
actor = default_user
|
||||
@@ -535,14 +538,16 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
# Create some source passages
|
||||
source_passages = []
|
||||
for i in range(3):
|
||||
passage = await server.passage_manager.create_passage_async(
|
||||
passage = await server.passage_manager.create_source_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
text=f"Source passage {i}",
|
||||
embedding=[0.1], # Default OpenAI embedding size
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=actor,
|
||||
)
|
||||
source_passages.append(passage)
|
||||
@@ -550,7 +555,7 @@ async def agent_passages_setup(server, default_source, default_user, sarah_agent
|
||||
# Create some agent passages
|
||||
agent_passages = []
|
||||
for i in range(2):
|
||||
passage = await server.passage_manager.create_passage_async(
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
agent_id=agent_id,
|
||||
@@ -2022,7 +2027,7 @@ async def test_agent_list_passages_filtering(server, default_user, sarah_agent,
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, event_loop):
|
||||
async def test_agent_list_passages_vector_search(server, default_user, sarah_agent, default_source, default_file, event_loop):
|
||||
"""Test vector search functionality of agent passages"""
|
||||
embed_model = embedding_model(DEFAULT_EMBEDDING_CONFIG)
|
||||
|
||||
@@ -2041,6 +2046,7 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
|
||||
for i, text in enumerate(test_passages):
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
if i % 2 == 0:
|
||||
# Create agent passage
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
@@ -2048,15 +2054,18 @@ async def test_agent_list_passages_vector_search(server, default_user, sarah_age
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = await server.passage_manager.create_agent_passage_async(passage, default_user)
|
||||
else:
|
||||
# Create source passage
|
||||
passage = PydanticPassage(
|
||||
text=text,
|
||||
organization_id=default_user.organization_id,
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
embedding=embedding,
|
||||
)
|
||||
created_passage = await server.passage_manager.create_passage_async(passage, default_user)
|
||||
created_passage = await server.passage_manager.create_source_passage_async(passage, default_file, default_user)
|
||||
passages.append(created_passage)
|
||||
|
||||
# Query vector similar to "red" embedding
|
||||
@@ -2261,6 +2270,416 @@ async def test_passage_cascade_deletion(
|
||||
server.passage_manager.get_passage_by_id(source_passage_fixture.id, default_user)
|
||||
|
||||
|
||||
def test_create_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test creating an agent passage using the new agent-specific method."""
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Test agent passage via specific method",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test_specific"},
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert passage.id is not None
|
||||
assert passage.text == "Test agent passage via specific method"
|
||||
assert passage.agent_id == sarah_agent.id
|
||||
assert passage.source_id is None
|
||||
|
||||
|
||||
def test_create_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Test creating a source passage using the new source-specific method."""
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Test source passage via specific method",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
metadata={"type": "test_specific"},
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert passage.id is not None
|
||||
assert passage.text == "Test source passage via specific method"
|
||||
assert passage.source_id == default_source.id
|
||||
assert passage.agent_id is None
|
||||
|
||||
|
||||
def test_create_agent_passage_validation(server: SyncServer, default_user, default_source, sarah_agent):
|
||||
"""Test that agent passage creation validates inputs correctly."""
|
||||
# Should fail if agent_id is missing
|
||||
with pytest.raises(ValueError, match="Agent passage must have agent_id"):
|
||||
server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid agent passage",
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should fail if source_id is present
|
||||
with pytest.raises(ValueError, match="Agent passage cannot have source_id"):
|
||||
server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid agent passage",
|
||||
agent_id=sarah_agent.id,
|
||||
source_id=default_source.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
def test_create_source_passage_validation(server: SyncServer, default_user, default_file, default_source, sarah_agent):
|
||||
"""Test that source passage creation validates inputs correctly."""
|
||||
# Should fail if source_id is missing
|
||||
with pytest.raises(ValueError, match="Source passage must have source_id"):
|
||||
server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid source passage",
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Should fail if agent_id is present
|
||||
with pytest.raises(ValueError, match="Source passage cannot have agent_id"):
|
||||
server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Invalid source passage",
|
||||
source_id=default_source.id,
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
|
||||
def test_get_agent_passage_by_id_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test retrieving an agent passage using the new agent-specific method."""
|
||||
# Create an agent passage
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Agent passage for retrieval test",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Retrieve it using the specific method
|
||||
retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == passage.id
|
||||
assert retrieved.text == passage.text
|
||||
assert retrieved.agent_id == sarah_agent.id
|
||||
|
||||
|
||||
def test_get_source_passage_by_id_specific(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Test retrieving a source passage using the new source-specific method."""
|
||||
# Create a source passage
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Source passage for retrieval test",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Retrieve it using the specific method
|
||||
retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
assert retrieved.id == passage.id
|
||||
assert retrieved.text == passage.text
|
||||
assert retrieved.source_id == default_source.id
|
||||
|
||||
|
||||
def test_get_wrong_passage_type_fails(server: SyncServer, default_user, sarah_agent, default_file, default_source):
|
||||
"""Test that trying to get the wrong passage type with specific methods fails."""
|
||||
# Create an agent passage
|
||||
agent_passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Agent passage",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Create a source passage
|
||||
source_passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Source passage",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Trying to get agent passage with source method should fail
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_source_passage_by_id(agent_passage.id, actor=default_user)
|
||||
|
||||
# Trying to get source passage with agent method should fail
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_agent_passage_by_id(source_passage.id, actor=default_user)
|
||||
|
||||
|
||||
def test_update_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test updating an agent passage using the new agent-specific method."""
|
||||
# Create an agent passage
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Original agent passage text",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update it
|
||||
updated_passage = server.passage_manager.update_agent_passage_by_id(
|
||||
passage.id,
|
||||
PydanticPassage(
|
||||
text="Updated agent passage text",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.2],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated_passage.text == "Updated agent passage text"
|
||||
assert updated_passage.embedding[0] == approx(0.2)
|
||||
assert updated_passage.id == passage.id
|
||||
|
||||
|
||||
def test_update_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Test updating a source passage using the new source-specific method."""
|
||||
# Create a source passage
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Original source passage text",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Update it
|
||||
updated_passage = server.passage_manager.update_source_passage_by_id(
|
||||
passage.id,
|
||||
PydanticPassage(
|
||||
text="Updated source passage text",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.2],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
assert updated_passage.text == "Updated source passage text"
|
||||
assert updated_passage.embedding[0] == approx(0.2)
|
||||
assert updated_passage.id == passage.id
|
||||
|
||||
|
||||
def test_delete_agent_passage_specific(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test deleting an agent passage using the new agent-specific method."""
|
||||
# Create an agent passage
|
||||
passage = server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text="Agent passage to delete",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
retrieved = server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
|
||||
# Delete it
|
||||
result = server.passage_manager.delete_agent_passage_by_id(passage.id, actor=default_user)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_agent_passage_by_id(passage.id, actor=default_user)
|
||||
|
||||
|
||||
def test_delete_source_passage_specific(server: SyncServer, default_user, default_file, default_source):
|
||||
"""Test deleting a source passage using the new source-specific method."""
|
||||
# Create a source passage
|
||||
passage = server.passage_manager.create_source_passage(
|
||||
PydanticPassage(
|
||||
text="Source passage to delete",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
file_metadata=default_file,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify it exists
|
||||
retrieved = server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
|
||||
assert retrieved is not None
|
||||
|
||||
# Delete it
|
||||
result = server.passage_manager.delete_source_passage_by_id(passage.id, actor=default_user)
|
||||
assert result is True
|
||||
|
||||
# Verify it's gone
|
||||
with pytest.raises(NoResultFound):
|
||||
server.passage_manager.get_source_passage_by_id(passage.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_agent_passages_async(server: SyncServer, default_user, sarah_agent, event_loop):
|
||||
"""Test creating multiple agent passages using the new batch method."""
|
||||
passages = [
|
||||
PydanticPassage(
|
||||
text=f"Batch agent passage {i}",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1 * i],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
created_passages = await server.passage_manager.create_many_agent_passages_async(passages, actor=default_user)
|
||||
|
||||
assert len(created_passages) == 3
|
||||
for i, passage in enumerate(created_passages):
|
||||
assert passage.text == f"Batch agent passage {i}"
|
||||
assert passage.agent_id == sarah_agent.id
|
||||
assert passage.source_id is None
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_many_source_passages_async(server: SyncServer, default_user, default_file, default_source, event_loop):
|
||||
"""Test creating multiple source passages using the new batch method."""
|
||||
passages = [
|
||||
PydanticPassage(
|
||||
text=f"Batch source passage {i}",
|
||||
source_id=default_source.id,
|
||||
file_id=default_file.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1 * i],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
)
|
||||
for i in range(3)
|
||||
]
|
||||
|
||||
created_passages = await server.passage_manager.create_many_source_passages_async(
|
||||
passages, file_metadata=default_file, actor=default_user
|
||||
)
|
||||
|
||||
assert len(created_passages) == 3
|
||||
for i, passage in enumerate(created_passages):
|
||||
assert passage.text == f"Batch source passage {i}"
|
||||
assert passage.source_id == default_source.id
|
||||
assert passage.agent_id is None
|
||||
|
||||
|
||||
def test_agent_passage_size(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test counting agent passages using the new agent-specific size method."""
|
||||
initial_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id)
|
||||
|
||||
# Create some agent passages
|
||||
for i in range(3):
|
||||
server.passage_manager.create_agent_passage(
|
||||
PydanticPassage(
|
||||
text=f"Agent passage {i} for size test",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
final_size = server.passage_manager.agent_passage_size(actor=default_user, agent_id=sarah_agent.id)
|
||||
assert final_size == initial_size + 3
|
||||
|
||||
|
||||
def test_deprecated_methods_show_warnings(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test that deprecated methods show deprecation warnings."""
|
||||
import warnings
|
||||
|
||||
with warnings.catch_warnings(record=True) as w:
|
||||
warnings.simplefilter("always")
|
||||
|
||||
# Test deprecated create_passage
|
||||
passage = server.passage_manager.create_passage(
|
||||
PydanticPassage(
|
||||
text="Test deprecated method",
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Test deprecated get_passage_by_id
|
||||
server.passage_manager.get_passage_by_id(passage.id, actor=default_user)
|
||||
|
||||
# Test deprecated size
|
||||
server.passage_manager.size(actor=default_user, agent_id=sarah_agent.id)
|
||||
|
||||
# Check that deprecation warnings were issued
|
||||
assert len(w) >= 3
|
||||
assert any("create_passage is deprecated" in str(warning.message) for warning in w)
|
||||
assert any("get_passage_by_id is deprecated" in str(warning.message) for warning in w)
|
||||
assert any("size is deprecated" in str(warning.message) for warning in w)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user