Files
letta-server/letta/services/source_manager.py
cthomas 33eaabb04a chore: bump version 0.8.14 (#2720)
Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
Co-authored-by: Sarah Wooders <sarahwooders@gmail.com>
Co-authored-by: Matthew Zhou <mattzh1314@gmail.com>
Co-authored-by: Andy Li <55300002+cliandy@users.noreply.github.com>
Co-authored-by: jnjpng <jin@letta.com>
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
Co-authored-by: cpacker <packercharles@gmail.com>
Co-authored-by: Shubham Naik <shub@letta.com>
Co-authored-by: Shubham Naik <shub@memgpt.ai>
Co-authored-by: Kevin Lin <klin5061@gmail.com>
2025-07-14 11:03:15 -07:00

155 lines
6.5 KiB
Python

import asyncio
from typing import List, Optional
from sqlalchemy import select
from letta.orm import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.source import Source as SourceModel
from letta.orm.sources_agents import SourcesAgents
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.source import Source as PydanticSource
from letta.schemas.source import SourceUpdate
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.utils import enforce_types, printd
class SourceManager:
"""Manager class to handle business logic related to Sources."""
@enforce_types
@trace_method
async def create_source(self, source: PydanticSource, actor: PydanticUser) -> PydanticSource:
"""Create a new source based on the PydanticSource schema."""
db_source = await self.get_source_by_id(source.id, actor=actor)
if db_source:
return db_source
else:
async with db_registry.async_session() as session:
# Provide default embedding config if not given
source.organization_id = actor.organization_id
source = SourceModel(**source.model_dump(to_orm=True, exclude_none=True))
await source.create_async(session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
async def update_source(self, source_id: str, source_update: SourceUpdate, actor: PydanticUser) -> PydanticSource:
"""Update a source by its ID with the given SourceUpdate object."""
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# get update dictionary
update_data = source_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Remove redundant update fields
update_data = {key: value for key, value in update_data.items() if getattr(source, key) != value}
if update_data:
for key, value in update_data.items():
setattr(source, key, value)
await source.update_async(db_session=session, actor=actor)
else:
printd(
f"`update_source` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={source.name}, but found existing source with nothing to update."
)
return source.to_pydantic()
@enforce_types
@trace_method
async def delete_source(self, source_id: str, actor: PydanticUser) -> PydanticSource:
"""Delete a source by its ID."""
async with db_registry.async_session() as session:
source = await SourceModel.read_async(db_session=session, identifier=source_id)
await source.hard_delete_async(db_session=session, actor=actor)
return source.to_pydantic()
@enforce_types
@trace_method
async def list_sources(
self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50, **kwargs
) -> List[PydanticSource]:
"""List all sources with optional pagination."""
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
after=after,
limit=limit,
organization_id=actor.organization_id,
**kwargs,
)
return [source.to_pydantic() for source in sources]
@enforce_types
@trace_method
async def size_async(self, actor: PydanticUser) -> int:
"""
Get the total count of sources for the given user.
"""
async with db_registry.async_session() as session:
return await SourceModel.size_async(db_session=session, actor=actor)
@enforce_types
@trace_method
async def list_attached_agents(self, source_id: str, actor: Optional[PydanticUser] = None) -> List[PydanticAgentState]:
"""
Lists all agents that have the specified source attached.
Args:
source_id: ID of the source to find attached agents for
actor: User performing the action (optional for now, following existing pattern)
Returns:
List[PydanticAgentState]: List of agents that have this source attached
"""
async with db_registry.async_session() as session:
# Verify source exists and user has permission to access it
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
# Use junction table query instead of relationship to avoid performance issues
query = (
select(AgentModel)
.join(SourcesAgents, AgentModel.id == SourcesAgents.agent_id)
.where(
SourcesAgents.source_id == source_id,
AgentModel.organization_id == actor.organization_id if actor else True,
AgentModel.is_deleted == False,
)
.order_by(AgentModel.created_at.desc(), AgentModel.id)
)
result = await session.execute(query)
agents_orm = result.scalars().all()
return await asyncio.gather(*[agent.to_pydantic_async() for agent in agents_orm])
# TODO: We make actor optional for now, but should most likely be enforced due to security reasons
@enforce_types
@trace_method
async def get_source_by_id(self, source_id: str, actor: Optional[PydanticUser] = None) -> Optional[PydanticSource]:
"""Retrieve a source by its ID."""
async with db_registry.async_session() as session:
try:
source = await SourceModel.read_async(db_session=session, identifier=source_id, actor=actor)
return source.to_pydantic()
except NoResultFound:
return None
@enforce_types
@trace_method
async def get_source_by_name(self, source_name: str, actor: PydanticUser) -> Optional[PydanticSource]:
"""Retrieve a source by its name."""
async with db_registry.async_session() as session:
sources = await SourceModel.list_async(
db_session=session,
name=source_name,
organization_id=actor.organization_id,
limit=1,
)
if not sources:
return None
else:
return sources[0].to_pydantic()