feat(asyncify): convert org manager (#2426)
This commit is contained in:
@@ -13,7 +13,7 @@ router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
|
||||
def get_all_orgs(
|
||||
async def get_all_orgs(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -22,7 +22,7 @@ def get_all_orgs(
|
||||
Get a list of all orgs in the database
|
||||
"""
|
||||
try:
|
||||
orgs = server.organization_manager.list_organizations(after=after, limit=limit)
|
||||
orgs = await server.organization_manager.list_organizations_async(after=after, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -31,7 +31,7 @@ def get_all_orgs(
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
|
||||
def create_org(
|
||||
async def create_org(
|
||||
request: OrganizationCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
@@ -39,21 +39,21 @@ def create_org(
|
||||
Create a new org in the database
|
||||
"""
|
||||
org = Organization(**request.model_dump())
|
||||
org = server.organization_manager.create_organization(pydantic_org=org)
|
||||
org = await server.organization_manager.create_organization_async(pydantic_org=org)
|
||||
return org
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization_by_id")
|
||||
def delete_org(
|
||||
async def delete_org(
|
||||
org_id: str = Query(..., description="The org_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
org = server.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
org = await server.organization_manager.get_organization_by_id_async(org_id=org_id)
|
||||
if org is None:
|
||||
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||
server.organization_manager.delete_organization_by_id(org_id=org_id)
|
||||
await server.organization_manager.delete_organization_by_id_async(org_id=org_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -62,16 +62,16 @@ def delete_org(
|
||||
|
||||
|
||||
@router.patch("/", tags=["admin"], response_model=Organization, operation_id="update_organization")
|
||||
def update_org(
|
||||
async def update_org(
|
||||
org_id: str = Query(..., description="The org_id key to be updated."),
|
||||
request: OrganizationUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
try:
|
||||
org = server.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
org = await server.organization_manager.get_organization_by_id_async(org_id=org_id)
|
||||
if org is None:
|
||||
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||
org = server.organization_manager.update_organization(org_id=org_id, name=request.name)
|
||||
org = await server.organization_manager.update_organization_async(org_id=org_id, name=request.name)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -49,7 +49,6 @@ from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.providers import (
|
||||
AnthropicBedrockProvider,
|
||||
@@ -1424,16 +1423,6 @@ class SyncServer(Server):
|
||||
# Get the current message
|
||||
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
|
||||
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
|
||||
if org_id is None:
|
||||
org_id = self.organization_manager.DEFAULT_ORG_ID
|
||||
|
||||
try:
|
||||
return self.organization_manager.get_organization_by_id(org_id=org_id)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(
|
||||
self,
|
||||
actor: User,
|
||||
|
||||
@@ -17,9 +17,9 @@ class OrganizationManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def get_default_organization(self) -> PydanticOrganization:
|
||||
async def get_default_organization_async(self) -> PydanticOrganization:
|
||||
"""Fetch the default organization."""
|
||||
return self.get_organization_by_id(self.DEFAULT_ORG_ID)
|
||||
return await self.get_organization_by_id_async(self.DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
@@ -31,50 +31,78 @@ class OrganizationManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
org = self.get_organization_by_id(pydantic_org.id)
|
||||
return org
|
||||
except NoResultFound:
|
||||
return self._create_organization(pydantic_org=pydantic_org)
|
||||
async def get_organization_by_id_async(self, org_id: str) -> Optional[PydanticOrganization]:
|
||||
"""Fetch an organization by ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
||||
return organization.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def _create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
def create_organization(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
organization = OrganizationModel.read(db_session=session, identifier=pydantic_org.id)
|
||||
return organization.to_pydantic()
|
||||
except:
|
||||
organization = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
||||
organization = organization.create(session)
|
||||
return organization.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
"""Create a new organization."""
|
||||
try:
|
||||
org = await self.get_organization_by_id_async(pydantic_org.id)
|
||||
return org
|
||||
except NoResultFound:
|
||||
return await self._create_organization_async(pydantic_org=pydantic_org)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def _create_organization_async(self, pydantic_org: PydanticOrganization) -> PydanticOrganization:
|
||||
async with db_registry.async_session() as session:
|
||||
org = OrganizationModel(**pydantic_org.model_dump(to_orm=True))
|
||||
org.create(session)
|
||||
await org.create_async(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_default_organization(self) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
return self.create_organization(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
|
||||
pydantic_org = PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID)
|
||||
return self.create_organization(pydantic_org)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_organization_name_using_id(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
||||
async def create_default_organization_async(self) -> PydanticOrganization:
|
||||
"""Create the default organization."""
|
||||
return await self.create_organization_async(PydanticOrganization(name=self.DEFAULT_ORG_NAME, id=self.DEFAULT_ORG_ID))
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_organization_name_using_id_async(self, org_id: str, name: Optional[str] = None) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with db_registry.session() as session:
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
async with db_registry.async_session() as session:
|
||||
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
||||
if name:
|
||||
org.name = name
|
||||
org.update(session)
|
||||
await org.update_async(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_organization(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
||||
async def update_organization_async(self, org_id: str, org_update: OrganizationUpdate) -> PydanticOrganization:
|
||||
"""Update an organization."""
|
||||
with db_registry.session() as session:
|
||||
org = OrganizationModel.read(db_session=session, identifier=org_id)
|
||||
async with db_registry.async_session() as session:
|
||||
org = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
||||
if org_update.name:
|
||||
org.name = org_update.name
|
||||
if org_update.privileged_tools:
|
||||
org.privileged_tools = org_update.privileged_tools
|
||||
org.update(session)
|
||||
await org.update_async(session)
|
||||
return org.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@@ -87,10 +115,18 @@ class OrganizationManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def list_organizations(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
||||
async def delete_organization_by_id_async(self, org_id: str):
|
||||
"""Delete an organization by marking it as deleted."""
|
||||
async with db_registry.async_session() as session:
|
||||
organization = await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
||||
await organization.hard_delete_async(session)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def list_organizations_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticOrganization]:
|
||||
"""List all organizations with optional pagination."""
|
||||
with db_registry.session() as session:
|
||||
organizations = OrganizationModel.list(
|
||||
async with db_registry.async_session() as session:
|
||||
organizations = await OrganizationModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
|
||||
@@ -38,6 +38,27 @@ class UserManager:
|
||||
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def create_default_actor_async(self, org_id: str = OrganizationManager.DEFAULT_ORG_ID) -> PydanticUser:
|
||||
"""Create the default user."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Make sure the org id exists
|
||||
try:
|
||||
await OrganizationModel.read_async(db_session=session, identifier=org_id)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"No organization with {org_id} exists in the organization table.")
|
||||
|
||||
# Try to retrieve the user
|
||||
try:
|
||||
actor = await UserModel.read_async(db_session=session, identifier=self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
# If it doesn't exist, make it
|
||||
actor = UserModel(id=self.DEFAULT_USER_ID, name=self.DEFAULT_USER_NAME, organization_id=org_id)
|
||||
await actor.create_async(session)
|
||||
|
||||
return actor.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_user(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
@@ -154,8 +175,7 @@ class UserManager:
|
||||
try:
|
||||
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
# Fall back to synchronous version since create_default_user isn't async yet
|
||||
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
|
||||
return await self.create_default_actor_async(org_id=self.DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
|
||||
@@ -105,14 +105,14 @@ async def _clear_tables():
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def default_organization(server: SyncServer):
|
||||
async def default_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_default_organization()
|
||||
yield org
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_organization(server: SyncServer):
|
||||
async def other_organization(server: SyncServer):
|
||||
"""Fixture to create and return the default organization."""
|
||||
org = server.organization_manager.create_organization(pydantic_org=Organization(name="letta"))
|
||||
yield org
|
||||
@@ -2002,55 +2002,61 @@ async def test_list_source_passages_only(server: SyncServer, default_user, defau
|
||||
# ======================================================================================================================
|
||||
# Organization Manager Tests
|
||||
# ======================================================================================================================
|
||||
def test_list_organizations(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_organizations(server: SyncServer, event_loop):
|
||||
# Create a new org and confirm that it is created correctly
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
|
||||
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name))
|
||||
|
||||
orgs = server.organization_manager.list_organizations()
|
||||
orgs = await server.organization_manager.list_organizations_async()
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == org_name
|
||||
|
||||
# Delete it after
|
||||
server.organization_manager.delete_organization_by_id(org.id)
|
||||
assert len(server.organization_manager.list_organizations()) == 0
|
||||
await server.organization_manager.delete_organization_by_id_async(org.id)
|
||||
orgs = await server.organization_manager.list_organizations_async()
|
||||
assert len(orgs) == 0
|
||||
|
||||
|
||||
def test_create_default_organization(server: SyncServer):
|
||||
server.organization_manager.create_default_organization()
|
||||
retrieved = server.organization_manager.get_default_organization()
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_organization(server: SyncServer, event_loop):
|
||||
await server.organization_manager.create_default_organization_async()
|
||||
retrieved = await server.organization_manager.get_default_organization_async()
|
||||
assert retrieved.name == server.organization_manager.DEFAULT_ORG_NAME
|
||||
|
||||
|
||||
def test_update_organization_name(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_name(server: SyncServer, event_loop):
|
||||
org_name_a = "a"
|
||||
org_name_b = "b"
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name_a))
|
||||
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name_a))
|
||||
assert org.name == org_name_a
|
||||
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
|
||||
org = await server.organization_manager.update_organization_name_using_id_async(org_id=org.id, name=org_name_b)
|
||||
assert org.name == org_name_b
|
||||
|
||||
|
||||
def test_update_organization_privileged_tools(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_organization_privileged_tools(server: SyncServer, event_loop):
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name=org_name))
|
||||
org = await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name=org_name))
|
||||
assert org.privileged_tools == False
|
||||
org = server.organization_manager.update_organization(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
|
||||
org = await server.organization_manager.update_organization_async(org_id=org.id, org_update=OrganizationUpdate(privileged_tools=True))
|
||||
assert org.privileged_tools == True
|
||||
|
||||
|
||||
def test_list_organizations_pagination(server: SyncServer):
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="a"))
|
||||
server.organization_manager.create_organization(pydantic_org=PydanticOrganization(name="b"))
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_organizations_pagination(server: SyncServer, event_loop):
|
||||
await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="a"))
|
||||
await server.organization_manager.create_organization_async(pydantic_org=PydanticOrganization(name="b"))
|
||||
|
||||
orgs_x = server.organization_manager.list_organizations(limit=1)
|
||||
orgs_x = await server.organization_manager.list_organizations_async(limit=1)
|
||||
assert len(orgs_x) == 1
|
||||
|
||||
orgs_y = server.organization_manager.list_organizations(after=orgs_x[0].id, limit=1)
|
||||
orgs_y = await server.organization_manager.list_organizations_async(after=orgs_x[0].id, limit=1)
|
||||
assert len(orgs_y) == 1
|
||||
assert orgs_y[0].name != orgs_x[0].name
|
||||
|
||||
orgs = server.organization_manager.list_organizations(after=orgs_y[0].id, limit=1)
|
||||
orgs = await server.organization_manager.list_organizations_async(after=orgs_y[0].id, limit=1)
|
||||
assert len(orgs) == 0
|
||||
|
||||
|
||||
@@ -2150,7 +2156,7 @@ async def test_passage_cascade_deletion(
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(server: SyncServer, event_loop):
|
||||
# Create default organization
|
||||
org = server.organization_manager.create_default_organization()
|
||||
org = await server.organization_manager.create_default_organization_async()
|
||||
|
||||
user_name = "user"
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id))
|
||||
@@ -2164,10 +2170,11 @@ async def test_list_users(server: SyncServer, event_loop):
|
||||
assert len(await server.user_manager.list_actors_async()) == 0
|
||||
|
||||
|
||||
def test_create_default_user(server: SyncServer):
|
||||
org = server.organization_manager.create_default_organization()
|
||||
server.user_manager.create_default_user(org_id=org.id)
|
||||
retrieved = server.user_manager.get_default_user()
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_default_user(server: SyncServer, event_loop):
|
||||
org = await server.organization_manager.create_default_organization_async()
|
||||
await server.user_manager.create_default_actor_async(org_id=org.id)
|
||||
retrieved = await server.user_manager.get_default_actor_async()
|
||||
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user