feat: separate Passages tables (#2245)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-16 15:24:20 -08:00
committed by GitHub
parent 10e610bb95
commit e2d916148e
19 changed files with 1026 additions and 546 deletions

View File

@@ -1,17 +1,26 @@
from typing import Dict, List, Optional
from datetime import datetime
import numpy as np
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS
from sqlalchemy import select, union_all, literal, func, Select
from letta.constants import BASE_MEMORY_TOOLS, BASE_TOOLS, MAX_EMBEDDING_DIM
from letta.embeddings import embedding_model
from letta.log import get_logger
from letta.orm import Agent as AgentModel
from letta.orm import Block as BlockModel
from letta.orm import Source as SourceModel
from letta.orm import Tool as ToolModel
from letta.orm import AgentPassage, SourcePassage
from letta.orm import SourcesAgents
from letta.orm.errors import NoResultFound
from letta.orm.sqlite_functions import adapt_array
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source as PydanticSource
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
from letta.schemas.user import User as PydanticUser
@@ -21,9 +30,9 @@ from letta.services.helpers.agent_manager_helper import (
_process_tags,
derive_system_message,
)
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
from letta.settings import settings
from letta.utils import enforce_types
logger = get_logger(__name__)
@@ -229,13 +238,6 @@ class AgentManager:
with self.session_maker() as session:
# Retrieve the agent
agent = AgentModel.read(db_session=session, identifier=agent_id, actor=actor)
# TODO: @mindy delete this piece when we have a proper passages/sources implementation
# TODO: This is done very hacky on purpose
# TODO: 1000 limit is also wack
passage_manager = PassageManager()
passage_manager.delete_passages(actor=actor, agent_id=agent_id, limit=1000)
agent_state = agent.to_pydantic()
agent.hard_delete(session)
return agent_state
@@ -407,6 +409,262 @@ class AgentManager:
agent.update(session, actor=actor)
return agent.to_pydantic()
# ======================================================================================================================
# Passage Management
# ======================================================================================================================
def _build_passage_query(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False,
) -> Select:
"""Helper function to build the base passage query with all filters applied.
Returns the query before any limit or count operations are applied.
"""
embedded_text = None
if embed_query:
assert embedding_config is not None, "embedding_config must be specified for vector search"
assert query_text is not None, "query_text must be specified for vector search"
embedded_text = embedding_model(embedding_config).get_text_embedding(query_text)
embedded_text = np.array(embedded_text)
embedded_text = np.pad(embedded_text, (0, MAX_EMBEDDING_DIM - embedded_text.shape[0]), mode="constant").tolist()
with self.session_maker() as session:
# Start with base query for source passages
source_passages = None
if not agent_only: # Include source passages
if agent_id is not None:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.join(SourcesAgents, SourcesAgents.source_id == SourcePassage.source_id)
.where(SourcesAgents.agent_id == agent_id)
.where(SourcePassage.organization_id == actor.organization_id)
)
else:
source_passages = (
select(
SourcePassage,
literal(None).label('agent_id')
)
.where(SourcePassage.organization_id == actor.organization_id)
)
if source_id:
source_passages = source_passages.where(SourcePassage.source_id == source_id)
if file_id:
source_passages = source_passages.where(SourcePassage.file_id == file_id)
# Add agent passages query
agent_passages = None
if agent_id is not None:
agent_passages = (
select(
AgentPassage.id,
AgentPassage.text,
AgentPassage.embedding_config,
AgentPassage.metadata_,
AgentPassage.embedding,
AgentPassage.created_at,
AgentPassage.updated_at,
AgentPassage.is_deleted,
AgentPassage._created_by_id,
AgentPassage._last_updated_by_id,
AgentPassage.organization_id,
literal(None).label('file_id'),
literal(None).label('source_id'),
AgentPassage.agent_id
)
.where(AgentPassage.agent_id == agent_id)
.where(AgentPassage.organization_id == actor.organization_id)
)
# Combine queries
if source_passages is not None and agent_passages is not None:
combined_query = union_all(source_passages, agent_passages).cte('combined_passages')
elif agent_passages is not None:
combined_query = agent_passages.cte('combined_passages')
elif source_passages is not None:
combined_query = source_passages.cte('combined_passages')
else:
raise ValueError("No passages found")
# Build main query from combined CTE
main_query = select(combined_query)
# Apply filters
if start_date:
main_query = main_query.where(combined_query.c.created_at >= start_date)
if end_date:
main_query = main_query.where(combined_query.c.created_at <= end_date)
if source_id:
main_query = main_query.where(combined_query.c.source_id == source_id)
if file_id:
main_query = main_query.where(combined_query.c.file_id == file_id)
# Vector search
if embedded_text:
if settings.letta_pg_uri_no_default:
# PostgreSQL with pgvector
main_query = main_query.order_by(
combined_query.c.embedding.cosine_distance(embedded_text).asc()
)
else:
# SQLite with custom vector type
query_embedding_binary = adapt_array(embedded_text)
if ascending:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.asc(),
combined_query.c.id.asc()
)
else:
main_query = main_query.order_by(
func.cosine_distance(combined_query.c.embedding, query_embedding_binary).asc(),
combined_query.c.created_at.desc(),
combined_query.c.id.asc()
)
else:
if query_text:
main_query = main_query.where(func.lower(combined_query.c.text).contains(func.lower(query_text)))
# Handle cursor-based pagination
if cursor:
cursor_query = select(combined_query.c.created_at).where(
combined_query.c.id == cursor
).scalar_subquery()
if ascending:
main_query = main_query.where(
combined_query.c.created_at > cursor_query
)
else:
main_query = main_query.where(
combined_query.c.created_at < cursor_query
)
# Add ordering if not already ordered by similarity
if not embed_query:
if ascending:
main_query = main_query.order_by(
combined_query.c.created_at.asc(),
combined_query.c.id.asc(),
)
else:
main_query = main_query.order_by(
combined_query.c.created_at.desc(),
combined_query.c.id.asc(),
)
return main_query
@enforce_types
def list_passages(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
) -> List[PydanticPassage]:
"""Lists all passages attached to an agent."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Add limit
if limit:
main_query = main_query.limit(limit)
# Execute query
results = list(session.execute(main_query))
passages = []
for row in results:
data = dict(row._mapping)
if data['agent_id'] is not None:
# This is an AgentPassage - remove source fields
data.pop('source_id', None)
data.pop('file_id', None)
passage = AgentPassage(**data)
else:
# This is a SourcePassage - remove agent field
data.pop('agent_id', None)
passage = SourcePassage(**data)
passages.append(passage)
return [p.to_pydantic() for p in passages]
@enforce_types
def passage_size(
self,
actor: PydanticUser,
agent_id: Optional[str] = None,
file_id: Optional[str] = None,
query_text: Optional[str] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
cursor: Optional[str] = None,
source_id: Optional[str] = None,
embed_query: bool = False,
ascending: bool = True,
embedding_config: Optional[EmbeddingConfig] = None,
agent_only: bool = False
) -> int:
"""Returns the count of passages matching the given criteria."""
with self.session_maker() as session:
main_query = self._build_passage_query(
actor=actor,
agent_id=agent_id,
file_id=file_id,
query_text=query_text,
start_date=start_date,
end_date=end_date,
cursor=cursor,
source_id=source_id,
embed_query=embed_query,
ascending=ascending,
embedding_config=embedding_config,
agent_only=agent_only,
)
# Convert to count query
count_query = select(func.count()).select_from(main_query.subquery())
return session.scalar(count_query) or 0
# ======================================================================================================================
# Tool Management
# ======================================================================================================================