Swap the order of @trace_method and @raise_on_invalid_id decorators across all service managers so that @trace_method is always the first wrapper applied to the function (positioned directly above the method). This ensures the ID validation happens before tracing begins, which is the intended execution order. Files modified: - agent_manager.py (23 occurrences) - archive_manager.py (11 occurrences) - block_manager.py (7 occurrences) - file_manager.py (6 occurrences) - group_manager.py (9 occurrences) - identity_manager.py (10 occurrences) - job_manager.py (7 occurrences) - message_manager.py (2 occurrences) - provider_manager.py (3 occurrences) - sandbox_config_manager.py (7 occurrences) - source_manager.py (5 occurrences) - step_manager.py (13 occurrences)
443 lines
19 KiB
Python
443 lines
19 KiB
Python
import asyncio
|
|
from typing import List, Optional
|
|
|
|
from fastapi import HTTPException
|
|
from sqlalchemy import select
|
|
from sqlalchemy.exc import NoResultFound
|
|
|
|
from letta.orm.agent import Agent as AgentModel
|
|
from letta.orm.block import Block as BlockModel
|
|
from letta.orm.errors import UniqueConstraintViolationError
|
|
from letta.orm.identity import Identity as IdentityModel
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.agent import AgentState
|
|
from letta.schemas.block import Block
|
|
from letta.schemas.enums import PrimitiveType
|
|
from letta.schemas.identity import (
|
|
Identity as PydanticIdentity,
|
|
IdentityCreate,
|
|
IdentityProperty,
|
|
IdentityType,
|
|
IdentityUpdate,
|
|
IdentityUpsert,
|
|
)
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.settings import DatabaseChoice, settings
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
|
|
class IdentityManager:
|
|
@enforce_types
|
|
@trace_method
|
|
async def list_identities_async(
|
|
self,
|
|
name: Optional[str] = None,
|
|
project_id: Optional[str] = None,
|
|
identifier_key: Optional[str] = None,
|
|
identity_type: Optional[IdentityType] = None,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
actor: PydanticUser = None,
|
|
) -> tuple[list[PydanticIdentity], Optional[str], bool]:
|
|
"""
|
|
List identities with pagination metadata.
|
|
|
|
Returns:
|
|
Tuple of (identities, next_cursor, has_more)
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
filters = {"organization_id": actor.organization_id}
|
|
if project_id:
|
|
filters["project_id"] = project_id
|
|
if identifier_key:
|
|
filters["identifier_key"] = identifier_key
|
|
if identity_type:
|
|
filters["identity_type"] = identity_type
|
|
|
|
# Request one more than limit to check if there are more pages
|
|
query_limit = limit + 1 if limit else None
|
|
|
|
identities = await IdentityModel.list_async(
|
|
db_session=session,
|
|
query_text=name,
|
|
before=before,
|
|
after=after,
|
|
limit=query_limit,
|
|
ascending=ascending,
|
|
**filters,
|
|
)
|
|
|
|
# Check if we got more records than requested (meaning there are more pages)
|
|
has_more = len(identities) > limit if limit else False
|
|
if has_more:
|
|
# Trim back to the requested limit
|
|
identities = identities[:limit]
|
|
|
|
# Get cursor for next page (ID of last item in current page)
|
|
next_cursor = identities[-1].id if identities else None
|
|
|
|
return [identity.to_pydantic() for identity in identities], next_cursor, has_more
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def get_identity_async(self, identity_id: str, actor: PydanticUser) -> PydanticIdentity:
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
return identity.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def create_identity_async(self, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
|
async with db_registry.async_session() as session:
|
|
return await self._create_identity_async(db_session=session, identity=identity, actor=actor)
|
|
|
|
async def _create_identity_async(self, db_session, identity: IdentityCreate, actor: PydanticUser) -> PydanticIdentity:
|
|
new_identity = IdentityModel(**identity.model_dump(exclude={"agent_ids", "block_ids"}, exclude_unset=True))
|
|
new_identity.organization_id = actor.organization_id
|
|
|
|
# For SQLite compatibility: check for unique constraint violation manually
|
|
# since SQLite doesn't support postgresql_nulls_not_distinct=True
|
|
if settings.database_engine is DatabaseChoice.SQLITE:
|
|
# Check if an identity with the same identifier_key, project_id, and organization_id exists
|
|
query = select(IdentityModel).where(
|
|
IdentityModel.identifier_key == new_identity.identifier_key,
|
|
IdentityModel.project_id == new_identity.project_id,
|
|
IdentityModel.organization_id == new_identity.organization_id,
|
|
)
|
|
result = await db_session.execute(query)
|
|
existing_identity = result.scalar_one_or_none()
|
|
if existing_identity is not None:
|
|
raise UniqueConstraintViolationError(
|
|
f"A unique constraint was violated for Identity. "
|
|
f"An identity with identifier_key='{new_identity.identifier_key}', "
|
|
f"project_id='{new_identity.project_id}', and "
|
|
f"organization_id='{new_identity.organization_id}' already exists."
|
|
)
|
|
|
|
await self._process_relationship_async(
|
|
db_session=db_session,
|
|
identity=new_identity,
|
|
relationship_name="agents",
|
|
model_class=AgentModel,
|
|
item_ids=identity.agent_ids,
|
|
allow_partial=False,
|
|
)
|
|
await self._process_relationship_async(
|
|
db_session=db_session,
|
|
identity=new_identity,
|
|
relationship_name="blocks",
|
|
model_class=BlockModel,
|
|
item_ids=identity.block_ids,
|
|
allow_partial=False,
|
|
)
|
|
await new_identity.create_async(db_session=db_session, actor=actor)
|
|
return new_identity.to_pydantic()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def upsert_identity_async(self, identity: IdentityUpsert, actor: PydanticUser) -> PydanticIdentity:
|
|
async with db_registry.async_session() as session:
|
|
existing_identity = await IdentityModel.read_async(
|
|
db_session=session,
|
|
identifier_key=identity.identifier_key,
|
|
project_id=identity.project_id,
|
|
organization_id=actor.organization_id,
|
|
actor=actor,
|
|
)
|
|
|
|
if existing_identity is None:
|
|
return await self._create_identity_async(db_session=session, identity=IdentityCreate(**identity.model_dump()), actor=actor)
|
|
else:
|
|
identity_update = IdentityUpdate(
|
|
name=identity.name,
|
|
identifier_key=identity.identifier_key,
|
|
identity_type=identity.identity_type,
|
|
agent_ids=identity.agent_ids,
|
|
properties=identity.properties,
|
|
)
|
|
return await self._update_identity_async(
|
|
db_session=session, existing_identity=existing_identity, identity=identity_update, actor=actor, replace=True
|
|
)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def update_identity_async(
|
|
self, identity_id: str, identity: IdentityUpdate, actor: PydanticUser, replace: bool = False
|
|
) -> PydanticIdentity:
|
|
async with db_registry.async_session() as session:
|
|
try:
|
|
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
except NoResultFound:
|
|
raise HTTPException(status_code=404, detail="Identity not found")
|
|
if existing_identity.organization_id != actor.organization_id:
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
|
|
return await self._update_identity_async(
|
|
db_session=session, existing_identity=existing_identity, identity=identity, actor=actor, replace=replace
|
|
)
|
|
|
|
async def _update_identity_async(
|
|
self,
|
|
db_session,
|
|
existing_identity: IdentityModel,
|
|
identity: IdentityUpdate,
|
|
actor: PydanticUser,
|
|
replace: bool = False,
|
|
) -> PydanticIdentity:
|
|
if identity.identifier_key is not None:
|
|
existing_identity.identifier_key = identity.identifier_key
|
|
if identity.name is not None:
|
|
existing_identity.name = identity.name
|
|
if identity.identity_type is not None:
|
|
existing_identity.identity_type = identity.identity_type
|
|
if identity.properties is not None:
|
|
if replace:
|
|
existing_identity.properties = [prop.model_dump() for prop in identity.properties]
|
|
else:
|
|
new_properties = {old_prop["key"]: old_prop for old_prop in existing_identity.properties} | {
|
|
new_prop.key: new_prop.model_dump() for new_prop in identity.properties
|
|
}
|
|
existing_identity.properties = list(new_properties.values())
|
|
|
|
if identity.agent_ids is not None:
|
|
await self._process_relationship_async(
|
|
db_session=db_session,
|
|
identity=existing_identity,
|
|
relationship_name="agents",
|
|
model_class=AgentModel,
|
|
item_ids=identity.agent_ids,
|
|
allow_partial=False,
|
|
replace=replace,
|
|
)
|
|
if identity.block_ids is not None:
|
|
await self._process_relationship_async(
|
|
db_session=db_session,
|
|
identity=existing_identity,
|
|
relationship_name="blocks",
|
|
model_class=BlockModel,
|
|
item_ids=identity.block_ids,
|
|
allow_partial=False,
|
|
replace=replace,
|
|
)
|
|
await existing_identity.update_async(db_session=db_session, actor=actor)
|
|
return existing_identity.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def upsert_identity_properties_async(
|
|
self, identity_id: str, properties: List[IdentityProperty], actor: PydanticUser
|
|
) -> PydanticIdentity:
|
|
async with db_registry.async_session() as session:
|
|
existing_identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
if existing_identity is None:
|
|
raise HTTPException(status_code=404, detail="Identity not found")
|
|
return await self._update_identity_async(
|
|
db_session=session,
|
|
existing_identity=existing_identity,
|
|
identity=IdentityUpdate(properties=properties),
|
|
actor=actor,
|
|
replace=True,
|
|
)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def delete_identity_async(self, identity_id: str, actor: PydanticUser) -> None:
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
if identity is None:
|
|
raise HTTPException(status_code=404, detail="Identity not found")
|
|
if identity.organization_id != actor.organization_id:
|
|
raise HTTPException(status_code=403, detail="Forbidden")
|
|
await session.delete(identity)
|
|
await session.commit()
|
|
|
|
@enforce_types
|
|
@trace_method
|
|
async def size_async(
|
|
self,
|
|
actor: PydanticUser,
|
|
) -> int:
|
|
"""
|
|
Get the total count of identities for the given user.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
return await IdentityModel.size_async(db_session=session, actor=actor)
|
|
|
|
async def _process_relationship_async(
|
|
self,
|
|
db_session,
|
|
identity: PydanticIdentity,
|
|
relationship_name: str,
|
|
model_class,
|
|
item_ids: List[str],
|
|
allow_partial=False,
|
|
replace=True,
|
|
):
|
|
current_relationship = getattr(identity, relationship_name, [])
|
|
if not item_ids:
|
|
if replace:
|
|
setattr(identity, relationship_name, [])
|
|
return
|
|
|
|
# Retrieve models for the provided IDs
|
|
found_items = (await db_session.execute(select(model_class).where(model_class.id.in_(item_ids)))).scalars().all()
|
|
|
|
# Validate all items are found if allow_partial is False
|
|
if not allow_partial and len(found_items) != len(item_ids):
|
|
missing = set(item_ids) - {item.id for item in found_items}
|
|
raise NoResultFound(f"Items not found in agents: {missing}")
|
|
|
|
if replace:
|
|
# Replace the relationship
|
|
setattr(identity, relationship_name, found_items)
|
|
else:
|
|
# Extend the relationship (only add new items)
|
|
current_ids = {item.id for item in current_relationship}
|
|
new_items = [item for item in found_items if item.id not in current_ids]
|
|
current_relationship.extend(new_items)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def list_agents_for_identity_async(
|
|
self,
|
|
identity_id: str,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
include: List[str] = [],
|
|
actor: PydanticUser = None,
|
|
) -> List[AgentState]:
|
|
"""
|
|
Get all agents associated with the specified identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# First verify the identity exists and belongs to the user
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
if identity is None:
|
|
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
|
|
|
|
# Get agents associated with this identity with pagination
|
|
agents = await AgentModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
identity_id=identity.id,
|
|
)
|
|
return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=[], include=include) for agent in agents])
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@trace_method
|
|
async def list_blocks_for_identity_async(
|
|
self,
|
|
identity_id: str,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
limit: Optional[int] = 50,
|
|
ascending: bool = False,
|
|
actor: PydanticUser = None,
|
|
) -> List[Block]:
|
|
"""
|
|
Get all blocks associated with the specified identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
# First verify the identity exists and belongs to the user
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
if identity is None:
|
|
raise HTTPException(status_code=404, detail=f"Identity with id={identity_id} not found")
|
|
|
|
# Get blocks associated with this identity with pagination
|
|
blocks = await BlockModel.list_async(
|
|
db_session=session,
|
|
before=before,
|
|
after=after,
|
|
limit=limit,
|
|
ascending=ascending,
|
|
identity_id=identity.id,
|
|
)
|
|
return [block.to_pydantic() for block in blocks]
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
|
@trace_method
|
|
async def attach_agent_async(self, identity_id: str, agent_id: str, actor: PydanticUser) -> None:
|
|
"""
|
|
Attach an agent to an identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
|
|
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
|
|
|
# Add agent to identity if not already attached
|
|
if agent not in identity.agents:
|
|
identity.agents.append(agent)
|
|
await identity.update_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@raise_on_invalid_id(param_name="agent_id", expected_prefix=PrimitiveType.AGENT)
|
|
@trace_method
|
|
async def detach_agent_async(self, identity_id: str, agent_id: str, actor: PydanticUser) -> None:
|
|
"""
|
|
Detach an agent from an identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
|
|
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
|
|
|
# Remove agent from identity if attached
|
|
if agent in identity.agents:
|
|
identity.agents.remove(agent)
|
|
await identity.update_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
|
@trace_method
|
|
async def attach_block_async(self, identity_id: str, block_id: str, actor: PydanticUser) -> None:
|
|
"""
|
|
Attach a block to an identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
|
|
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
|
|
|
# Add block to identity if not already attached
|
|
if block not in identity.blocks:
|
|
identity.blocks.append(block)
|
|
await identity.update_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="identity_id", expected_prefix=PrimitiveType.IDENTITY)
|
|
@raise_on_invalid_id(param_name="block_id", expected_prefix=PrimitiveType.BLOCK)
|
|
@trace_method
|
|
async def detach_block_async(self, identity_id: str, block_id: str, actor: PydanticUser) -> None:
|
|
"""
|
|
Detach a block from an identity.
|
|
"""
|
|
async with db_registry.async_session() as session:
|
|
identity = await IdentityModel.read_async(db_session=session, identifier=identity_id, actor=actor)
|
|
|
|
block = await BlockModel.read_async(db_session=session, identifier=block_id, actor=actor)
|
|
|
|
# Remove block from identity if attached
|
|
if block in identity.blocks:
|
|
identity.blocks.remove(block)
|
|
await identity.update_async(db_session=session, actor=actor)
|