diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index dec21187..04b69e94 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -13,7 +13,7 @@ router = APIRouter(prefix="/orgs", tags=["organization", "admin"]) @router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs") -def get_all_orgs( +async def get_all_orgs( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), @@ -22,7 +22,7 @@ def get_all_orgs( Get a list of all orgs in the database """ try: - orgs = server.organization_manager.list_organizations(after=after, limit=limit) + orgs = await server.organization_manager.list_organizations_async(after=after, limit=limit) except HTTPException: raise except Exception as e: @@ -31,7 +31,7 @@ def get_all_orgs( @router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization") -def create_org( +async def create_org( request: OrganizationCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): @@ -39,21 +39,21 @@ def create_org( Create a new org in the database """ org = Organization(**request.model_dump()) - org = server.organization_manager.create_organization(pydantic_org=org) + org = await server.organization_manager.create_organization_async(pydantic_org=org) return org @router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization_by_id") -def delete_org( +async def delete_org( org_id: str = Query(..., description="The org_id key to be deleted."), server: "SyncServer" = Depends(get_letta_server), ): # TODO make a soft deletion, instead of a hard deletion try: - org = server.organization_manager.get_organization_by_id(org_id=org_id) + org = await server.organization_manager.get_organization_by_id_async(org_id=org_id) if org is None: raise HTTPException(status_code=404, detail=f"Organization does not exist") - server.organization_manager.delete_organization_by_id(org_id=org_id) + await server.organization_manager.delete_organization_by_id_async(org_id=org_id) except HTTPException: raise except Exception as e: @@ -62,16 +62,16 @@ def delete_org( @router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization") -def update_org( +async def update_org( org_id: str = Query(..., description="The org_id key to be updated."), request: OrganizationUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): try: - org = server.organization_manager.get_organization_by_id(org_id=org_id) + org = await server.organization_manager.get_organization_by_id_async(org_id=org_id) if org is None: raise HTTPException(status_code=404, detail=f"Organization does not exist") - org = server.organization_manager.update_organization(org_id=org_id, name=request.name) + org = await server.organization_manager.update_organization_async(org_id=org_id, name=request.name) except HTTPException: raise except Exception as e: diff --git a/letta/server/server.py b/letta/server/server.py index 3d543fad..1916e6aa 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -49,7 +49,6 @@ from letta.schemas.letta_response import LettaResponse from letta.schemas.llm_config import LLMConfig from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from letta.schemas.message import Message, MessageCreate, MessageUpdate -from letta.schemas.organization import Organization from letta.schemas.passage import Passage, PassageUpdate from letta.schemas.providers import ( AnthropicBedrockProvider, @@ -1424,16 +1423,6 @@ class SyncServer(Server): # Get the current message return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) - def get_organization_or_default(self, org_id: Optional[str]) -> Organization: - """Get the organization object for org_id if it exists, otherwise return the default organization object""" - if org_id is None: - org_id = self.organization_manager.DEFAULT_ORG_ID - - try: - return self.organization_manager.get_organization_by_id(org_id=org_id) - except NoResultFound: - raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") - def list_llm_models( self, actor: User, diff --git a/letta/services/organization_manager.py b/letta/services/organization_manager.py index 715f57aa..08f8f70a 100644 --- a/letta/services/organization_manager.py +++ b/letta/services/organization_manager.py @@ -17,9 +17,9 @@ class OrganizationManager: @enforce_types @trace_method - def get_default_organization(self) -> PydanticOrganization: + async def get_default_organization_async(self) -> PydanticOrganization: """Fetch the default organization.""" - return self.get_organization_by_id(self.DEFAULT_ORG_ID) + return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID) @enforce_types @trace_method @@ -31,50 +31,78 @@ class OrganizationManager: @enforce_types @trace_method - def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: - """Create a new organization.""" - try: - org = self.get_organization_by_id(pydantic_org.id) - return org - except NoResultFound: - return self._create_organization(pydantic_org=pydantic_org) + async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]: + """Fetch an organization by ID.""" + async with db_registry.async_session() as session: + organization = await OrganizationModel.read_async(db_session=session, identifier=org_id) + return organization.to_pydantic() @enforce_types @trace_method - def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: + def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: + """Create the default organization.""" with db_registry.session() as session: + try: + organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id) + return organization.to_pydantic() + except: + organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) + organization = organization.create(session) + return organization.to_pydantic() + + @enforce_types + @trace_method + async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: + """Create a new organization.""" + try: + org = await self.get_organization_by_id_async(pydantic_org.id) + return org + except NoResultFound: + return await self._create_organization_async(pydantic_org=pydantic_org) + + @enforce_types + @trace_method + async def _create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization: + async with db_registry.async_session() as session: org = OrganizationModel(**pydantic_org.model_dump(to_orm=True)) - org.create(session) + await org.create_async(session) return org.to_pydantic() @enforce_types @trace_method def create_default_organization(self) -> PydanticOrganization: """Create the default organization.""" - return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)) + pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID) + return self.create_organization(pydantic_org) @enforce_types @trace_method - def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: + async def create_default_organization_async(self) -> PydanticOrganization: + """Create the default organization.""" + return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)) + + @enforce_types + @trace_method + async def update_organization_name_using_id_async(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization: """Update an organization.""" - with db_registry.session() as session: - org = OrganizationModel.read(db_session=session, identifier=org_id) + async with db_registry.async_session() as session: + org = await OrganizationModel.read_async(db_session=session, identifier=org_id) if name: org.name = name - org.update(session) + await org.update_async(session) return org.to_pydantic() @enforce_types @trace_method - def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization: + async def update_organization_async(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization: """Update an organization.""" - with db_registry.session() as session: - org = OrganizationModel.read(db_session=session, identifier=org_id) + async with db_registry.async_session() as session: + org = await OrganizationModel.read_async(db_session=session, identifier=org_id) if org_update.name: org.name = org_update.name if org_update.privileged_tools: org.privileged_tools = org_update.privileged_tools - org.update(session) + await org.update_async(session) return org.to_pydantic() @enforce_types @@ -87,10 +115,18 @@ class OrganizationManager: @enforce_types @trace_method - def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: + async def delete_organization_by_id_async(self, org_id: str): + """Delete an organization by marking it as deleted.""" + async with db_registry.async_session() as session: + organization = await OrganizationModel.read_async(db_session=session, identifier=org_id) + await organization.hard_delete_async(session) + + @enforce_types + @trace_method + async def list_organizations_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]: """List all organizations with optional pagination.""" - with db_registry.session() as session: - organizations = OrganizationModel.list( + async with db_registry.async_session() as session: + organizations = await OrganizationModel.list_async( db_session=session, after=after, limit=limit, diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 55c493be..9601484e 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -38,6 +38,27 @@ class UserManager: return user.to_pydantic() + @enforce_types + @trace_method + async def create_default_actor_async(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser: + """Create the default user.""" + async with db_registry.async_session() as session: + # Make sure the org id exists + try: + await OrganizationModel.read_async(db_session=session, identifier=org_id) + except NoResultFound: + raise ValueError(f"No organization with {org_id} exists in the organization table.") + + # Try to retrieve the user + try: + actor = await UserModel.read_async(db_session=session, identifier=self.DEFAULT_USER_ID) + except NoResultFound: + # If it doesn't exist, make it + actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id) + await actor.create_async(session) + + return actor.to_pydantic() + @enforce_types @trace_method def create_user(self, pydantic_user: PydanticUser) -> PydanticUser: @@ -154,8 +175,7 @@ class UserManager: try: return await self.get_actor_by_id_async(self.DEFAULT_USER_ID) except NoResultFound: - # Fall back to synchronous version since create_default_user isn't async yet - return self.create_default_user(org_id=self.DEFAULT_ORG_ID) + return await self.create_default_actor_async(org_id=self.DEFAULT_ORG_ID) @enforce_types @trace_method diff --git a/tests/test_managers.py b/tests/test_managers.py index c767355c..2449d4cc 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -105,14 +105,14 @@ async def _clear_tables(): @pytest.fixture -def default_organization(server: SyncServer): +async def default_organization(server: SyncServer): """Fixture to create and return the default organization.""" org = server.organization_manager.create_default_organization() yield org @pytest.fixture -def other_organization(server: SyncServer): +async def other_organization(server: SyncServer): """Fixture to create and return the default organization.""" org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta")) yield org @@ -2002,55 +2002,61 @@ async def test_list_source_passages_only(server: SyncServer, default_user, defau # ====================================================================================================================== # Organization Manager Tests # ====================================================================================================================== -def test_list_organizations(server: SyncServer): +@pytest.mark.asyncio +async def test_list_organizations(server: SyncServer, event_loop): # Create a new org and confirm that it is created correctly org_name = "test" - org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name)) + org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name)) - orgs = server.organization_manager.list_organizations() + orgs = await server.organization_manager.list_organizations_async() assert len(orgs) == 1 assert orgs[0].name == org_name # Delete it after - server.organization_manager.delete_organization_by_id(org.id) - assert len(server.organization_manager.list_organizations()) == 0 + await server.organization_manager.delete_organization_by_id_async(org.id) + orgs = await server.organization_manager.list_organizations_async() + assert len(orgs) == 0 -def test_create_default_organization(server: SyncServer): - server.organization_manager.create_default_organization() - retrieved = server.organization_manager.get_default_organization() +@pytest.mark.asyncio +async def test_create_default_organization(server: SyncServer, event_loop): + await server.organization_manager.create_default_organization_async() + retrieved = await server.organization_manager.get_default_organization_async() assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME -def test_update_organization_name(server: SyncServer): +@pytest.mark.asyncio +async def test_update_organization_name(server: SyncServer, event_loop): org_name_a = "a" org_name_b = "b" - org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a)) + org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name_a)) assert org.name == org_name_a - org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b) + org = await server.organization_manager.update_organization_name_using_id_async(org_id=org.id, name=org_name_b) assert org.name == org_name_b -def test_update_organization_privileged_tools(server: SyncServer): +@pytest.mark.asyncio +async def test_update_organization_privileged_tools(server: SyncServer, event_loop): org_name = "test" - org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name)) + org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name)) assert org.privileged_tools == False - org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True)) + org = await server.organization_manager.update_organization_async(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True)) assert org.privileged_tools == True -def test_list_organizations_pagination(server: SyncServer): - server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a")) - server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b")) +@pytest.mark.asyncio +async def test_list_organizations_pagination(server: SyncServer, event_loop): + await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="a")) + await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="b")) - orgs_x = server.organization_manager.list_organizations(limit=1) + orgs_x = await server.organization_manager.list_organizations_async(limit=1) assert len(orgs_x) == 1 - orgs_y = server.organization_manager.list_organizations(after=orgs_x[0].id, limit=1) + orgs_y = await server.organization_manager.list_organizations_async(after=orgs_x[0].id, limit=1) assert len(orgs_y) == 1 assert orgs_y[0].name != orgs_x[0].name - orgs = server.organization_manager.list_organizations(after=orgs_y[0].id, limit=1) + orgs = await server.organization_manager.list_organizations_async(after=orgs_y[0].id, limit=1) assert len(orgs) == 0 @@ -2150,7 +2156,7 @@ async def test_passage_cascade_deletion( @pytest.mark.asyncio async def test_list_users(server: SyncServer, event_loop): # Create default organization - org = server.organization_manager.create_default_organization() + org = await server.organization_manager.create_default_organization_async() user_name = "user" user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id)) @@ -2164,10 +2170,11 @@ async def test_list_users(server: SyncServer, event_loop): assert len(await server.user_manager.list_actors_async()) == 0 -def test_create_default_user(server: SyncServer): - org = server.organization_manager.create_default_organization() - server.user_manager.create_default_user(org_id=org.id) - retrieved = server.user_manager.get_default_user() +@pytest.mark.asyncio +async def test_create_default_user(server: SyncServer, event_loop): + org = await server.organization_manager.create_default_organization_async() + await server.user_manager.create_default_actor_async(org_id=org.id) + retrieved = await server.user_manager.get_default_actor_async() assert retrieved.name == server.user_manager.DEFAULT_USER_NAME