feat: separate Passages tables (#2245)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -1,17 +1,26 @@
|
||||
from typing import Dict, List, Optional
|
||||
from datetime import datetime
|
||||
import numpy as np
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
|
||||
from sqlalchemy import select, union_all, literal, func, Select
|
||||
|
||||
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.log import get_logger
|
||||
from letta.orm import Agent as AgentModel
|
||||
from letta.orm import Block as BlockModel
|
||||
from letta.orm import Source as SourceModel
|
||||
from letta.orm import Tool as ToolModel
|
||||
from letta.orm import AgentPassage, SourcePassage
|
||||
from letta.orm import SourcesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.sqlite_functions import adapt_array
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
@@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
_process_tags,
|
||||
derive_system_message,
|
||||
)
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -229,13 +238,6 @@ class AgentManager:
|
||||
with self.session_maker() as session:
|
||||
# Retrieve the agent
|
||||
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
|
||||
|
||||
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
|
||||
# TODO: This is done very hacky on purpose
|
||||
# TODO: 1000 limit is also wack
|
||||
passage_manager = PassageManager()
|
||||
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
|
||||
|
||||
agent_state = agent.to_pydantic()
|
||||
agent.hard_delete(session)
|
||||
return agent_state
|
||||
@@ -407,6 +409,262 @@ class AgentManager:
|
||||
agent.update(session, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Passage Management
|
||||
# ======================================================================================================================
|
||||
def _build_passage_query(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False,
|
||||
) -> Select:
|
||||
"""Helper function to build the base passage query with all filters applied.
|
||||
|
||||
Returns the query before any limit or count operations are applied.
|
||||
"""
|
||||
embedded_text = None
|
||||
if embed_query:
|
||||
assert embedding_config is not None, "embedding_config must be specified for vector search"
|
||||
assert query_text is not None, "query_text must be specified for vector search"
|
||||
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
|
||||
embedded_text = np.array(embedded_text)
|
||||
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
|
||||
|
||||
with self.session_maker() as session:
|
||||
# Start with base query for source passages
|
||||
source_passages = None
|
||||
if not agent_only: # Include source passages
|
||||
if agent_id is not None:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
|
||||
.where(SourcesAgents.agent_id == agent_id)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
else:
|
||||
source_passages = (
|
||||
select(
|
||||
SourcePassage,
|
||||
literal(None).label('agent_id')
|
||||
)
|
||||
.where(SourcePassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
if source_id:
|
||||
source_passages = source_passages.where(SourcePassage.source_id == source_id)
|
||||
if file_id:
|
||||
source_passages = source_passages.where(SourcePassage.file_id == file_id)
|
||||
|
||||
# Add agent passages query
|
||||
agent_passages = None
|
||||
if agent_id is not None:
|
||||
agent_passages = (
|
||||
select(
|
||||
AgentPassage.id,
|
||||
AgentPassage.text,
|
||||
AgentPassage.embedding_config,
|
||||
AgentPassage.metadata_,
|
||||
AgentPassage.embedding,
|
||||
AgentPassage.created_at,
|
||||
AgentPassage.updated_at,
|
||||
AgentPassage.is_deleted,
|
||||
AgentPassage._created_by_id,
|
||||
AgentPassage._last_updated_by_id,
|
||||
AgentPassage.organization_id,
|
||||
literal(None).label('file_id'),
|
||||
literal(None).label('source_id'),
|
||||
AgentPassage.agent_id
|
||||
)
|
||||
.where(AgentPassage.agent_id == agent_id)
|
||||
.where(AgentPassage.organization_id == actor.organization_id)
|
||||
)
|
||||
|
||||
# Combine queries
|
||||
if source_passages is not None and agent_passages is not None:
|
||||
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
|
||||
elif agent_passages is not None:
|
||||
combined_query = agent_passages.cte('combined_passages')
|
||||
elif source_passages is not None:
|
||||
combined_query = source_passages.cte('combined_passages')
|
||||
else:
|
||||
raise ValueError("No passages found")
|
||||
|
||||
# Build main query from combined CTE
|
||||
main_query = select(combined_query)
|
||||
|
||||
# Apply filters
|
||||
if start_date:
|
||||
main_query = main_query.where(combined_query.c.created_at >= start_date)
|
||||
if end_date:
|
||||
main_query = main_query.where(combined_query.c.created_at <= end_date)
|
||||
if source_id:
|
||||
main_query = main_query.where(combined_query.c.source_id == source_id)
|
||||
if file_id:
|
||||
main_query = main_query.where(combined_query.c.file_id == file_id)
|
||||
|
||||
# Vector search
|
||||
if embedded_text:
|
||||
if settings.letta_pg_uri_no_default:
|
||||
# PostgreSQL with pgvector
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.embedding.cosine_distance(embedded_text).asc()
|
||||
)
|
||||
else:
|
||||
# SQLite with custom vector type
|
||||
query_embedding_binary = adapt_array(embedded_text)
|
||||
if ascending:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc()
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc()
|
||||
)
|
||||
else:
|
||||
if query_text:
|
||||
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
|
||||
|
||||
# Handle cursor-based pagination
|
||||
if cursor:
|
||||
cursor_query = select(combined_query.c.created_at).where(
|
||||
combined_query.c.id == cursor
|
||||
).scalar_subquery()
|
||||
|
||||
if ascending:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at > cursor_query
|
||||
)
|
||||
else:
|
||||
main_query = main_query.where(
|
||||
combined_query.c.created_at < cursor_query
|
||||
)
|
||||
|
||||
# Add ordering if not already ordered by similarity
|
||||
if not embed_query:
|
||||
if ascending:
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.created_at.asc(),
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
else:
|
||||
main_query = main_query.order_by(
|
||||
combined_query.c.created_at.desc(),
|
||||
combined_query.c.id.asc(),
|
||||
)
|
||||
|
||||
return main_query
|
||||
|
||||
@enforce_types
|
||||
def list_passages(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
) -> List[PydanticPassage]:
|
||||
"""Lists all passages attached to an agent."""
|
||||
with self.session_maker() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
cursor=cursor,
|
||||
source_id=source_id,
|
||||
embed_query=embed_query,
|
||||
ascending=ascending,
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
# Add limit
|
||||
if limit:
|
||||
main_query = main_query.limit(limit)
|
||||
|
||||
# Execute query
|
||||
results = list(session.execute(main_query))
|
||||
|
||||
passages = []
|
||||
for row in results:
|
||||
data = dict(row._mapping)
|
||||
if data['agent_id'] is not None:
|
||||
# This is an AgentPassage - remove source fields
|
||||
data.pop('source_id', None)
|
||||
data.pop('file_id', None)
|
||||
passage = AgentPassage(**data)
|
||||
else:
|
||||
# This is a SourcePassage - remove agent field
|
||||
data.pop('agent_id', None)
|
||||
passage = SourcePassage(**data)
|
||||
passages.append(passage)
|
||||
|
||||
return [p.to_pydantic() for p in passages]
|
||||
|
||||
|
||||
@enforce_types
|
||||
def passage_size(
|
||||
self,
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
file_id: Optional[str] = None,
|
||||
query_text: Optional[str] = None,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
cursor: Optional[str] = None,
|
||||
source_id: Optional[str] = None,
|
||||
embed_query: bool = False,
|
||||
ascending: bool = True,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
agent_only: bool = False
|
||||
) -> int:
|
||||
"""Returns the count of passages matching the given criteria."""
|
||||
with self.session_maker() as session:
|
||||
main_query = self._build_passage_query(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
file_id=file_id,
|
||||
query_text=query_text,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
cursor=cursor,
|
||||
source_id=source_id,
|
||||
embed_query=embed_query,
|
||||
ascending=ascending,
|
||||
embedding_config=embedding_config,
|
||||
agent_only=agent_only,
|
||||
)
|
||||
|
||||
# Convert to count query
|
||||
count_query = select(func.count()).select_from(main_query.subquery())
|
||||
return session.scalar(count_query) or 0
|
||||
|
||||
# ======================================================================================================================
|
||||
# Tool Management
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user