feat(asyncify): convert org manager (#2426)

This commit is contained in:
cthomas
2025-05-25 21:23:06 -07:00
committed by GitHub
parent 82659a0915
commit e8dfa67920
5 changed files with 125 additions and 73 deletions

View File

@@ -13,7 +13,7 @@ router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
def get_all_orgs(
async def get_all_orgs(
after: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_letta_server),
@@ -22,7 +22,7 @@ def get_all_orgs(
Get a list of all orgs in the database
"""
try:
orgs = server.organization_manager.list_organizations(after=after, limit=limit)
orgs = await server.organization_manager.list_organizations_async(after=after, limit=limit)
except HTTPException:
raise
except Exception as e:
@@ -31,7 +31,7 @@ def get_all_orgs(
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
def create_org(
async def create_org(
request: OrganizationCreate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
@@ -39,21 +39,21 @@ def create_org(
Create a new org in the database
"""
org = Organization(**request.model_dump())
org = server.organization_manager.create_organization(pydantic_org=org)
org = await server.organization_manager.create_organization_async(pydantic_org=org)
return org
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization_by_id")
def delete_org(
async def delete_org(
org_id: str = Query(..., description="The org_id key to be deleted."),
server: "SyncServer" = Depends(get_letta_server),
):
# TODO make a soft deletion, instead of a hard deletion
try:
org = server.organization_manager.get_organization_by_id(org_id=org_id)
org = await server.organization_manager.get_organization_by_id_async(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
server.organization_manager.delete_organization_by_id(org_id=org_id)
await server.organization_manager.delete_organization_by_id_async(org_id=org_id)
except HTTPException:
raise
except Exception as e:
@@ -62,16 +62,16 @@ def delete_org(
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
def update_org(
async def update_org(
org_id: str = Query(..., description="The org_id key to be updated."),
request: OrganizationUpdate = Body(...),
server: "SyncServer" = Depends(get_letta_server),
):
try:
org = server.organization_manager.get_organization_by_id(org_id=org_id)
org = await server.organization_manager.get_organization_by_id_async(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
org = await server.organization_manager.update_organization_async(org_id=org_id, name=request.name)
except HTTPException:
raise
except Exception as e:

View File

@@ -49,7 +49,6 @@ from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage, PassageUpdate
from letta.schemas.providers import (
AnthropicBedrockProvider,
@@ -1424,16 +1423,6 @@ class SyncServer(Server):
# Get the current message
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
if org_id is None:
org_id = self.organization_manager.DEFAULT_ORG_ID
try:
return self.organization_manager.get_organization_by_id(org_id=org_id)
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(
self,
actor: User,

View File

@@ -17,9 +17,9 @@ class OrganizationManager:
@enforce_types
@trace_method
def get_default_organization(self) -> PydanticOrganization:
async def get_default_organization_async(self) -> PydanticOrganization:
"""Fetch the default organization."""
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
@enforce_types
@trace_method
@@ -31,50 +31,78 @@ class OrganizationManager:
@enforce_types
@trace_method
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create a new organization."""
try:
org = self.get_organization_by_id(pydantic_org.id)
return org
except NoResultFound:
return self._create_organization(pydantic_org=pydantic_org)
async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]:
"""Fetch an organization by ID."""
async with db_registry.async_session() as session:
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
return organization.to_pydantic()
@enforce_types
@trace_method
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create the default organization."""
with db_registry.session() as session:
try:
organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id)
return organization.to_pydantic()
except:
organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
organization = organization.create(session)
return organization.to_pydantic()
@enforce_types
@trace_method
async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
"""Create a new organization."""
try:
org = await self.get_organization_by_id_async(pydantic_org.id)
return org
except NoResultFound:
return await self._create_organization_async(pydantic_org=pydantic_org)
@enforce_types
@trace_method
async def _create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
async with db_registry.async_session() as session:
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
org.create(session)
await org.create_async(session)
return org.to_pydantic()
@enforce_types
@trace_method
def create_default_organization(self) -> PydanticOrganization:
"""Create the default organization."""
return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
return self.create_organization(pydantic_org)
@enforce_types
@trace_method
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
async def create_default_organization_async(self) -> PydanticOrganization:
"""Create the default organization."""
return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
@enforce_types
@trace_method
async def update_organization_name_using_id_async(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
"""Update an organization."""
with db_registry.session() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
async with db_registry.async_session() as session:
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
if name:
org.name = name
org.update(session)
await org.update_async(session)
return org.to_pydantic()
@enforce_types
@trace_method
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
async def update_organization_async(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
"""Update an organization."""
with db_registry.session() as session:
org = OrganizationModel.read(db_session=session, identifier=org_id)
async with db_registry.async_session() as session:
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
if org_update.name:
org.name = org_update.name
if org_update.privileged_tools:
org.privileged_tools = org_update.privileged_tools
org.update(session)
await org.update_async(session)
return org.to_pydantic()
@enforce_types
@@ -87,10 +115,18 @@ class OrganizationManager:
@enforce_types
@trace_method
def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
async def delete_organization_by_id_async(self, org_id: str):
"""Delete an organization by marking it as deleted."""
async with db_registry.async_session() as session:
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
await organization.hard_delete_async(session)
@enforce_types
@trace_method
async def list_organizations_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
"""List all organizations with optional pagination."""
with db_registry.session() as session:
organizations = OrganizationModel.list(
async with db_registry.async_session() as session:
organizations = await OrganizationModel.list_async(
db_session=session,
after=after,
limit=limit,

View File

@@ -38,6 +38,27 @@ class UserManager:
return user.to_pydantic()
@enforce_types
@trace_method
async def create_default_actor_async(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
"""Create the default user."""
async with db_registry.async_session() as session:
# Make sure the org id exists
try:
await OrganizationModel.read_async(db_session=session, identifier=org_id)
except NoResultFound:
raise ValueError(f"No organization with {org_id} exists in the organization table.")
# Try to retrieve the user
try:
actor = await UserModel.read_async(db_session=session, identifier=self.DEFAULT_USER_ID)
except NoResultFound:
# If it doesn't exist, make it
actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
await actor.create_async(session)
return actor.to_pydantic()
@enforce_types
@trace_method
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
@@ -154,8 +175,7 @@ class UserManager:
try:
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
except NoResultFound:
# Fall back to synchronous version since create_default_user isn't async yet
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
return await self.create_default_actor_async(org_id=self.DEFAULT_ORG_ID)
@enforce_types
@trace_method

View File

@@ -105,14 +105,14 @@ async def _clear_tables():
@pytest.fixture
def default_organization(server: SyncServer):
async def default_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_default_organization()
yield org
@pytest.fixture
def other_organization(server: SyncServer):
async def other_organization(server: SyncServer):
"""Fixture to create and return the default organization."""
org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta"))
yield org
@@ -2002,55 +2002,61 @@ async def test_list_source_passages_only(server: SyncServer, default_user, defau
# ======================================================================================================================
# Organization Manager Tests
# ======================================================================================================================
def test_list_organizations(server: SyncServer):
@pytest.mark.asyncio
async def test_list_organizations(server: SyncServer, event_loop):
# Create a new org and confirm that it is created correctly
org_name = "test"
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name))
orgs = server.organization_manager.list_organizations()
orgs = await server.organization_manager.list_organizations_async()
assert len(orgs) == 1
assert orgs[0].name == org_name
# Delete it after
server.organization_manager.delete_organization_by_id(org.id)
assert len(server.organization_manager.list_organizations()) == 0
await server.organization_manager.delete_organization_by_id_async(org.id)
orgs = await server.organization_manager.list_organizations_async()
assert len(orgs) == 0
def test_create_default_organization(server: SyncServer):
server.organization_manager.create_default_organization()
retrieved = server.organization_manager.get_default_organization()
@pytest.mark.asyncio
async def test_create_default_organization(server: SyncServer, event_loop):
await server.organization_manager.create_default_organization_async()
retrieved = await server.organization_manager.get_default_organization_async()
assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME
def test_update_organization_name(server: SyncServer):
@pytest.mark.asyncio
async def test_update_organization_name(server: SyncServer, event_loop):
org_name_a = "a"
org_name_b = "b"
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a))
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name_a))
assert org.name == org_name_a
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
org = await server.organization_manager.update_organization_name_using_id_async(org_id=org.id, name=org_name_b)
assert org.name == org_name_b
def test_update_organization_privileged_tools(server: SyncServer):
@pytest.mark.asyncio
async def test_update_organization_privileged_tools(server: SyncServer, event_loop):
org_name = "test"
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name))
assert org.privileged_tools == False
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
org = await server.organization_manager.update_organization_async(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
assert org.privileged_tools == True
def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
@pytest.mark.asyncio
async def test_list_organizations_pagination(server: SyncServer, event_loop):
await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="a"))
await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="b"))
orgs_x = server.organization_manager.list_organizations(limit=1)
orgs_x = await server.organization_manager.list_organizations_async(limit=1)
assert len(orgs_x) == 1
orgs_y = server.organization_manager.list_organizations(after=orgs_x[0].id, limit=1)
orgs_y = await server.organization_manager.list_organizations_async(after=orgs_x[0].id, limit=1)
assert len(orgs_y) == 1
assert orgs_y[0].name != orgs_x[0].name
orgs = server.organization_manager.list_organizations(after=orgs_y[0].id, limit=1)
orgs = await server.organization_manager.list_organizations_async(after=orgs_y[0].id, limit=1)
assert len(orgs) == 0
@@ -2150,7 +2156,7 @@ async def test_passage_cascade_deletion(
@pytest.mark.asyncio
async def test_list_users(server: SyncServer, event_loop):
# Create default organization
org = server.organization_manager.create_default_organization()
org = await server.organization_manager.create_default_organization_async()
user_name = "user"
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id))
@@ -2164,10 +2170,11 @@ async def test_list_users(server: SyncServer, event_loop):
assert len(await server.user_manager.list_actors_async()) == 0
def test_create_default_user(server: SyncServer):
org = server.organization_manager.create_default_organization()
server.user_manager.create_default_user(org_id=org.id)
retrieved = server.user_manager.get_default_user()
@pytest.mark.asyncio
async def test_create_default_user(server: SyncServer, event_loop):
org = await server.organization_manager.create_default_organization_async()
await server.user_manager.create_default_actor_async(org_id=org.id)
retrieved = await server.user_manager.get_default_actor_async()
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME