Files
letta-server/letta/services/source_manager.py

425 lines
17 KiB
Python

import asyncio
from typing import List, Optional, Union
from sqlalchemy import and_, exists, select
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
class SourceManager:
"""Manager class to handle business logic related to Sources."""
@trace_method
async def _validate_source_exists_async(self, session, source_id: str, actor: PydanticUser) -> None:
"""
Validate that a source exists and user has access to it using raw SQL for efficiency.
Args:
session: Database session
source_id: ID of the source to validate
actor: User performing the action
Raises:
NoResultFound: If source doesn't exist or user doesn't have access
"""
source_exists_query = select(
exists().where(
and_(SourceModel.id == source_id, SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False)
)
)
result = await session.execute(source_exists_query)
if not result.scalar():
raise NoResultFound(f"Source with ID {source_id} not found")
@enforce_types
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source
else:
async with db_registry.async_session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
await source.create_async(session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
async def bulk_upsert_sources_async(self, pydantic_sources: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""
Bulk create or update multiple sources in a single database transaction.
Uses optimized PostgreSQL bulk upsert when available, falls back to individual
upserts for SQLite. This is much more efficient than calling create_source
in a loop.
IMPORTANT BEHAVIOR NOTES:
- Sources are matched by (name, organization_id) unique constraint, NOT by ID
- If a source with the same name already exists for the organization, it will be updated
regardless of any ID provided in the input source
- The existing source's ID is preserved during updates
- If you provide a source with an explicit ID but a name that matches an existing source,
the existing source will be updated and the provided ID will be ignored
- This matches the behavior of create_source which also checks by ID first
PostgreSQL optimization:
- Uses native ON CONFLICT (name, organization_id) DO UPDATE for atomic upserts
- All sources are processed in a single SQL statement for maximum efficiency
SQLite fallback:
- Falls back to individual create_source calls
- Still benefits from batched transaction handling
Args:
pydantic_sources: List of sources to create or update
actor: User performing the action
Returns:
List of created/updated sources
"""
if not pydantic_sources:
return []
from letta.settings import settings
if settings.letta_pg_uri_no_default:
# use optimized postgresql bulk upsert
async with db_registry.async_session() as session:
return await self._bulk_upsert_postgresql(session, pydantic_sources, actor)
else:
# fallback to individual upserts for sqlite
return await self._upsert_sources_individually(pydantic_sources, actor)
@trace_method
async def _bulk_upsert_postgresql(self, session, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""Hyper-optimized PostgreSQL bulk upsert using ON CONFLICT DO UPDATE."""
from sqlalchemy import func, select
from sqlalchemy.dialects.postgresql import insert
# prepare data for bulk insert
table = SourceModel.__table__
valid_columns = {col.name for col in table.columns}
insert_data = []
for source in source_data_list:
source_dict = source.model_dump(to_orm=True)
# set created/updated by fields
if actor:
source_dict["_created_by_id"] = actor.id
source_dict["_last_updated_by_id"] = actor.id
source_dict["organization_id"] = actor.organization_id
# filter to only include columns that exist in the table
filtered_dict = {k: v for k, v in source_dict.items() if k in valid_columns}
insert_data.append(filtered_dict)
# use postgresql's native bulk upsert
stmt = insert(table).values(insert_data)
# on conflict, update all columns except id, created_at, and _created_by_id
excluded = stmt.excluded
update_dict = {}
for col in table.columns:
if col.name not in ("id", "created_at", "_created_by_id"):
if col.name == "updated_at":
update_dict[col.name] = func.now()
else:
update_dict[col.name] = excluded[col.name]
upsert_stmt = stmt.on_conflict_do_update(index_elements=["name", "organization_id"], set_=update_dict)
await session.execute(upsert_stmt)
await session.commit()
# fetch results
source_names = [source.name for source in source_data_list]
result_query = select(SourceModel).where(
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(result_query)
return [source.to_pydantic() for source in result.scalars()]
@trace_method
async def _upsert_sources_individually(self, source_data_list: List[PydanticSource], actor: PydanticUser) -> List[PydanticSource]:
"""Fallback to individual upserts for SQLite."""
sources = []
for source in source_data_list:
# try to get existing source by name
existing_source = await self.get_source_by_name(source.name, actor)
if existing_source:
# update existing source
from letta.schemas.source import SourceUpdate
update_data = source.model_dump(exclude={"id"}, exclude_none=True)
updated_source = await self.update_source(existing_source.id, SourceUpdate(**update_data), actor)
sources.append(updated_source)
else:
# create new source
created_source = await self.create_source(source, actor)
sources.append(created_source)
return sources
@enforce_types
@trace_method
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value}
if update_data:
for key, value in update_data.items():
setattr(source, key, value)
await source.update_async(db_session=session, actor=actor)
else:
printd(
f"`update_source` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={source.name}, but found existing source with nothing to update."
)
return source.to_pydantic()
@enforce_types
@trace_method
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id)
await source.hard_delete_async(db_session=session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
async def list_sources(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[PydanticSource]:
"""List all sources with optional pagination."""
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
after=after,
limit=limit,
organization_id=actor.organization_id,
**kwargs,
)
return [source.to_pydantic() for source in sources]
@enforce_types
@trace_method
async def size_async(self, actor: PydanticUser) -> int:
"""
Get the total count of sources for the given user.
"""
async with db_registry.async_session() as session:
return await SourceModel.size_async(db_session=session, actor=actor)
@enforce_types
@trace_method
async def list_attached_agents(
self, source_id: str, actor: PydanticUser, ids_only: bool = False
) -> Union[List[PydanticAgentState], List[str]]:
"""
Lists all agents that have the specified source attached.
Args:
source_id: ID of the source to find attached agents for
actor: User performing the action
ids_only: If True, return only agent IDs instead of full agent states
Returns:
List[PydanticAgentState] | List[str]: List of agents or agent IDs that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
await self._validate_source_exists_async(session, source_id, actor)
if ids_only:
# Query only agent IDs for performance
query = (
select(AgentModel.id)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
return list(result.scalars().all())
else:
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
agents_orm = result.scalars().all()
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
@enforce_types
@trace_method
async def get_agents_for_source_id(self, source_id: str, actor: PydanticUser) -> List[str]:
"""
Get all agent IDs associated with a given source ID.
Args:
source_id: ID of the source to find agents for
actor: User performing the action
Returns:
List[str]: List of agent IDs that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
await self._validate_source_exists_async(session, source_id, actor)
# Query the junction table directly for performance
query = select(SourcesAgents.agent_id).where(SourcesAgents.source_id == source_id)
result = await session.execute(query)
agent_ids = result.scalars().all()
return list(agent_ids)
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
async with db_registry.async_session() as session:
try:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
return source.to_pydantic()
except NoResultFound:
return None
@enforce_types
@trace_method
async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its name."""
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
name=source_name,
organization_id=actor.organization_id,
limit=1,
)
if not sources:
return None
else:
return sources[0].to_pydantic()
@enforce_types
@trace_method
async def get_sources_by_ids_async(self, source_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
"""
Get multiple sources by their IDs in a single query.
Args:
source_ids: List of source IDs to retrieve
actor: User performing the action
Returns:
List[PydanticSource]: List of sources (may be fewer than requested if some don't exist)
"""
if not source_ids:
return []
async with db_registry.async_session() as session:
query = select(SourceModel).where(
SourceModel.id.in_(source_ids), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(query)
sources_orm = result.scalars().all()
return [source.to_pydantic() for source in sources_orm]
@enforce_types
@trace_method
async def get_sources_for_agents_async(self, agent_ids: List[str], actor: PydanticUser) -> List[PydanticSource]:
"""
Get all sources associated with the given agents via sources-agents relationships.
Args:
agent_ids: List of agent IDs to find sources for
actor: User performing the action
Returns:
List[PydanticSource]: List of unique sources associated with these agents
"""
if not agent_ids:
return []
async with db_registry.async_session() as session:
# Join through sources-agents junction table
query = (
select(SourceModel)
.join(SourcesAgents, SourceModel.id == SourcesAgents.source_id)
.where(
SourcesAgents.agent_id.in_(agent_ids),
SourceModel.organization_id == actor.organization_id,
SourceModel.is_deleted == False,
)
.distinct() # Ensure we don't get duplicate sources
)
result = await session.execute(query)
sources_orm = result.scalars().all()
return [source.to_pydantic() for source in sources_orm]
@enforce_types
@trace_method
async def get_existing_source_names(self, source_names: List[str], actor: PydanticUser) -> set[str]:
"""
Fast batch check to see which source names already exist for the organization.
Args:
source_names: List of source names to check
actor: User performing the action
Returns:
Set of source names that already exist
"""
if not source_names:
return set()
async with db_registry.async_session() as session:
query = select(SourceModel.name).where(
SourceModel.name.in_(source_names), SourceModel.organization_id == actor.organization_id, SourceModel.is_deleted == False
)
result = await session.execute(query)
existing_names = result.scalars().all()
return set(existing_names)