Files
letta-server/letta/services/user_manager.py
2025-06-10 14:27:01 -07:00

258 lines
10 KiB
Python

from typing import List, Optional
from async_lru import alru_cache
from sqlalchemy import select, text
from letta.constants import DEFAULT_ORG_ID
from letta.orm.errors import NoResultFound
from letta.orm.organization import Organization as OrganizationModel
from letta.orm.user import User as UserModel
from letta.otel.tracing import trace_method
from letta.schemas.user import User as PydanticUser
from letta.schemas.user import UserUpdate
from letta.server.db import db_registry
from letta.utils import enforce_types
class UserManager:
"""Manager class to handle business logic related to Users."""
DEFAULT_USER_NAME = "default_user"
DEFAULT_USER_ID = "user-00000000-0000-4000-8000-000000000000"
@alru_cache(maxsize=16384) # =16384 cached actor IDs per process
async def _cached_actor_by_id(self, actor_id: str) -> "PydanticUser":
"""LOW-level fetch that the cache wraps around."""
return await self.get_actor_by_id_async(actor_id)
@trace_method
async def _get_actor_cached(self, actor_id: str) -> "PydanticUser":
"""Public helper that returns cached actor."""
return await self._cached_actor_by_id(actor_id)
@trace_method
def _invalidate_actor_cache(self, actor_id: str | None = None) -> None:
"""Call whenever you create/update/delete an actor."""
if actor_id:
try:
self._cached_actor_by_id.cache_invalidate(actor_id)
except KeyError:
pass
else:
self._cached_actor_by_id.cache_clear()
@enforce_types
@trace_method
def create_default_user(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
with db_registry.session() as session:
# Make sure the org id exists
try:
OrganizationModel.read(db_session=session, identifier=org_id)
except NoResultFound:
raise ValueError(f"No organization with {org_id} exists in the organization table.")
# Try to retrieve the user
try:
user = UserModel.read(db_session=session, identifier=self.DEFAULT_USER_ID)
except NoResultFound:
# If it doesn't exist, make it
user = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
user.create(session)
self._invalidate_actor_cache(self.DEFAULT_USER_ID)
return user.to_pydantic()
@enforce_types
@trace_method
async def create_default_actor_async(self, org_id: str = DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
async with db_registry.async_session() as session:
# Make sure the org id exists
try:
await OrganizationModel.read_async(db_session=session, identifier=org_id)
except NoResultFound:
raise ValueError(f"No organization with {org_id} exists in the organization table.")
# Try to retrieve the user
try:
actor = await UserModel.read_async(db_session=session, identifier=self.DEFAULT_USER_ID)
except NoResultFound:
# If it doesn't exist, make it
actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
await actor.create_async(session)
self._invalidate_actor_cache(self.DEFAULT_USER_ID)
return actor.to_pydantic()
@enforce_types
@trace_method
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist."""
with db_registry.session() as session:
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
new_user.create(session)
self._invalidate_actor_cache(new_user.id)
return new_user.to_pydantic()
@enforce_types
@trace_method
async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser:
"""Create a new user if it doesn't already exist (async version)."""
async with db_registry.async_session() as session:
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
await new_user.create_async(session)
self._invalidate_actor_cache(new_user.id)
return new_user.to_pydantic()
@enforce_types
@trace_method
def update_user(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details."""
with db_registry.session() as session:
# Retrieve the existing user by ID
existing_user = UserModel.read(db_session=session, identifier=user_update.id)
# Update only the fields that are provided in UserUpdate
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_user, key, value)
# Commit the updated user
existing_user.update(session)
self._invalidate_actor_cache(user_update.id)
return existing_user.to_pydantic()
@enforce_types
@trace_method
async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser:
"""Update user details (async version)."""
async with db_registry.async_session() as session:
# Retrieve the existing user by ID
existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id)
# Update only the fields that are provided in UserUpdate
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
for key, value in update_data.items():
setattr(existing_user, key, value)
# Commit the updated user
await existing_user.update_async(session)
self._invalidate_actor_cache(user_update.id)
return existing_user.to_pydantic()
@enforce_types
@trace_method
def delete_user_by_id(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings)."""
with db_registry.session() as session:
# Delete from user table
user = UserModel.read(db_session=session, identifier=user_id)
user.hard_delete(session)
self._invalidate_actor_cache(user_id)
session.commit()
@enforce_types
@trace_method
async def delete_actor_by_id_async(self, user_id: str):
"""Delete a user and their associated records (agents, sources, mappings) asynchronously."""
async with db_registry.async_session() as session:
# Delete from user table
user = await UserModel.read_async(db_session=session, identifier=user_id)
await user.hard_delete_async(session)
self._invalidate_actor_cache(user_id)
@enforce_types
@trace_method
def get_user_by_id(self, user_id: str) -> PydanticUser:
"""Fetch a user by ID."""
with db_registry.session() as session:
user = UserModel.read(db_session=session, identifier=user_id)
return user.to_pydantic()
@enforce_types
@trace_method
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
"""Fetch a user by ID asynchronously."""
async with db_registry.async_session() as session:
# Turn off seqscan to force use pk index
await session.execute(text("SET LOCAL enable_seqscan = OFF"))
try:
stmt = select(UserModel).where(UserModel.id == actor_id)
result = await session.execute(stmt)
user = result.scalar_one_or_none()
finally:
await session.execute(text("SET LOCAL enable_seqscan = ON"))
if not user:
raise NoResultFound(f"User not found with id={actor_id}")
return user.to_pydantic()
@enforce_types
@trace_method
def get_default_user(self) -> PydanticUser:
"""Fetch the default user. If it doesn't exist, create it."""
try:
return self.get_user_by_id(self.DEFAULT_USER_ID)
except NoResultFound:
return self.create_default_user()
@enforce_types
@trace_method
def get_user_or_default(self, user_id: Optional[str] = None):
"""Fetch the user or default user."""
if not user_id:
return self.get_default_user()
try:
return self.get_user_by_id(user_id=user_id)
except NoResultFound:
return self.get_default_user()
@enforce_types
@trace_method
async def get_default_actor_async(self) -> PydanticUser:
"""Fetch the default user asynchronously. If it doesn't exist, create it."""
try:
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
except NoResultFound:
return await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
@enforce_types
@trace_method
async def get_actor_or_default_async(self, actor_id: Optional[str] = None):
"""Fetch the user or default user asynchronously."""
target_id = actor_id or self.DEFAULT_USER_ID
try:
return await self._get_actor_cached(target_id)
except NoResultFound:
user = await self.create_default_actor_async(org_id=DEFAULT_ORG_ID)
return user
@enforce_types
@trace_method
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination."""
with db_registry.session() as session:
users = UserModel.list(
db_session=session,
after=after,
limit=limit,
)
return [user.to_pydantic() for user in users]
@enforce_types
@trace_method
async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
"""List all users with optional pagination (async version)."""
async with db_registry.async_session() as session:
users = await UserModel.list_async(
db_session=session,
after=after,
limit=limit,
)
return [user.to_pydantic() for user in users]