Files
letta-server/letta/services/archive_manager.py
Matthew Zhou d797296032 feat: Support basic upload/querying on tpuf [LET-3465] (#4255)
* wip implementing turbopuffer

* Move imports up

* Add type of archive

* Integrate turbopuffer functionality

* Debug turbopuffer tests failing

* Fix turbopuffer

* Run fern

* Fix multiple heads
2025-08-28 10:39:16 -07:00

280 lines
9.8 KiB
Python

from typing import List, Optional
from sqlalchemy import select
from letta.helpers.tpuf_client import should_use_tpuf
from letta.log import get_logger
from letta.orm import ArchivalPassage
from letta.orm import Archive as ArchiveModel
from letta.orm import ArchivesAgents
from letta.schemas.archive import Archive as PydanticArchive
from letta.schemas.enums import VectorDBProvider
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types
logger = get_logger(__name__)
class ArchiveManager:
"""Manager class to handle business logic related to Archives."""
@enforce_types
def create_archive(
self,
name: str,
description: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
"""Create a new archive."""
try:
with db_registry.session() as session:
# determine vector db provider based on settings
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
archive = ArchiveModel(
name=name,
description=description,
organization_id=actor.organization_id,
vector_db_provider=vector_db_provider,
)
archive.create(session, actor=actor)
return archive.to_pydantic()
except Exception as e:
logger.exception(f"Failed to create archive {name}. error={e}")
raise
@enforce_types
async def create_archive_async(
self,
name: str,
description: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
"""Create a new archive."""
try:
async with db_registry.async_session() as session:
# determine vector db provider based on settings
vector_db_provider = VectorDBProvider.TPUF if should_use_tpuf() else VectorDBProvider.NATIVE
archive = ArchiveModel(
name=name,
description=description,
organization_id=actor.organization_id,
vector_db_provider=vector_db_provider,
)
await archive.create_async(session, actor=actor)
return archive.to_pydantic()
except Exception as e:
logger.exception(f"Failed to create archive {name}. error={e}")
raise
@enforce_types
async def get_archive_by_id_async(
self,
archive_id: str,
actor: PydanticUser,
) -> PydanticArchive:
"""Get an archive by ID."""
async with db_registry.async_session() as session:
archive = await ArchiveModel.read_async(
db_session=session,
identifier=archive_id,
actor=actor,
)
return archive.to_pydantic()
@enforce_types
def attach_agent_to_archive(
self,
agent_id: str,
archive_id: str,
is_owner: bool,
actor: PydanticUser,
) -> None:
"""Attach an agent to an archive."""
with db_registry.session() as session:
# Check if already attached
existing = session.query(ArchivesAgents).filter_by(agent_id=agent_id, archive_id=archive_id).first()
if existing:
# Update ownership if needed
if existing.is_owner != is_owner:
existing.is_owner = is_owner
session.commit()
return
# Create new relationship
archives_agents = ArchivesAgents(
agent_id=agent_id,
archive_id=archive_id,
is_owner=is_owner,
)
session.add(archives_agents)
session.commit()
@enforce_types
async def attach_agent_to_archive_async(
self,
agent_id: str,
archive_id: str,
is_owner: bool = False,
actor: PydanticUser = None,
) -> None:
"""Attach an agent to an archive."""
async with db_registry.async_session() as session:
# Check if relationship already exists
existing = await session.execute(
select(ArchivesAgents).where(
ArchivesAgents.agent_id == agent_id,
ArchivesAgents.archive_id == archive_id,
)
)
existing_record = existing.scalar_one_or_none()
if existing_record:
# Update ownership if needed
if existing_record.is_owner != is_owner:
existing_record.is_owner = is_owner
await session.commit()
return
# Create the relationship
archives_agents = ArchivesAgents(
agent_id=agent_id,
archive_id=archive_id,
is_owner=is_owner,
)
session.add(archives_agents)
await session.commit()
@enforce_types
async def get_or_create_default_archive_for_agent_async(
self,
agent_id: str,
agent_name: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
"""Get the agent's default archive, creating one if it doesn't exist."""
# First check if agent has any archives
from letta.services.agent_manager import AgentManager
agent_manager = AgentManager()
archive_ids = await agent_manager.get_agent_archive_ids_async(
agent_id=agent_id,
actor=actor,
)
if archive_ids:
# TODO: Remove this check once we support multiple archives per agent
if len(archive_ids) > 1:
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
# Get the archive
archive = await self.get_archive_by_id_async(
archive_id=archive_ids[0],
actor=actor,
)
return archive
# Create a default archive for this agent
archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive"
archive = await self.create_archive_async(
name=archive_name,
description="Default archive created automatically",
actor=actor,
)
# Attach the agent to the archive as owner
await self.attach_agent_to_archive_async(
agent_id=agent_id,
archive_id=archive.id,
is_owner=True,
actor=actor,
)
return archive
@enforce_types
def get_or_create_default_archive_for_agent(
self,
agent_id: str,
agent_name: Optional[str] = None,
actor: PydanticUser = None,
) -> PydanticArchive:
"""Get the agent's default archive, creating one if it doesn't exist."""
with db_registry.session() as session:
# First check if agent has any archives
query = select(ArchivesAgents.archive_id).where(ArchivesAgents.agent_id == agent_id)
result = session.execute(query)
archive_ids = [row[0] for row in result.fetchall()]
if archive_ids:
# TODO: Remove this check once we support multiple archives per agent
if len(archive_ids) > 1:
raise ValueError(f"Agent {agent_id} has multiple archives, which is not yet supported")
# Get the archive
archive = ArchiveModel.read(db_session=session, identifier=archive_ids[0], actor=actor)
return archive.to_pydantic()
# Create a default archive for this agent
archive_name = f"{agent_name or f'Agent {agent_id}'}'s Archive"
# Create the archive
archive_model = ArchiveModel(
name=archive_name,
description="Default archive created automatically",
organization_id=actor.organization_id,
)
archive_model.create(session, actor=actor)
# Attach the agent to the archive as owner
self.attach_agent_to_archive(
agent_id=agent_id,
archive_id=archive_model.id,
is_owner=True,
actor=actor,
)
return archive_model.to_pydantic()
@enforce_types
async def get_agents_for_archive_async(
self,
archive_id: str,
actor: PydanticUser,
) -> List[str]:
"""Get all agent IDs that have access to an archive."""
async with db_registry.async_session() as session:
result = await session.execute(select(ArchivesAgents.agent_id).where(ArchivesAgents.archive_id == archive_id))
return [row[0] for row in result.fetchall()]
@enforce_types
async def get_agent_from_passage_async(
self,
passage_id: str,
actor: PydanticUser,
) -> Optional[str]:
"""Get the agent ID that owns a passage (through its archive).
Returns the first agent found (for backwards compatibility).
Returns None if no agent found.
"""
async with db_registry.async_session() as session:
# First get the passage to find its archive_id
passage = await ArchivalPassage.read_async(
db_session=session,
identifier=passage_id,
actor=actor,
)
# Then find agents connected to that archive
result = await session.execute(select(ArchivesAgents.agent_id).where(ArchivesAgents.archive_id == passage.archive_id))
agent_ids = [row[0] for row in result.fetchall()]
if not agent_ids:
return None
# For now, return the first agent (backwards compatibility)
return agent_ids[0]