From a9d8445e4aba0bd2c2d547aeb543141d3d553f41 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 20 Sep 2024 11:34:57 -0700 Subject: [PATCH] feat: add organization endpoints and schemas (#1762) --- memgpt/client/admin.py | 31 +++++++++- memgpt/metadata.py | 61 ++++++++++++++++++- memgpt/schemas/organization.py | 20 ++++++ memgpt/schemas/user.py | 4 ++ memgpt/server/rest_api/app.py | 4 ++ .../rest_api/routers/v1/organizations.py | 61 +++++++++++++++++++ memgpt/server/server.py | 20 +++++- tests/test_admin_client.py | 19 +++++- tests/test_client.py | 6 ++ 9 files changed, 220 insertions(+), 6 deletions(-) create mode 100644 memgpt/schemas/organization.py create mode 100644 memgpt/server/rest_api/routers/v1/organizations.py diff --git a/memgpt/client/admin.py b/memgpt/client/admin.py index cb59063a..f1f49f03 100644 --- a/memgpt/client/admin.py +++ b/memgpt/client/admin.py @@ -6,6 +6,7 @@ from requests import HTTPError from memgpt.functions.functions import parse_source_code from memgpt.functions.schema_generator import generate_schema from memgpt.schemas.api_key import APIKey, APIKeyCreate +from memgpt.schemas.organization import Organization, OrganizationCreate from memgpt.schemas.user import User, UserCreate @@ -59,8 +60,8 @@ class Admin: raise HTTPError(response.json()) return APIKey(**response.json()) - def create_user(self, name: Optional[str] = None) -> User: - request = UserCreate(name=name) + def create_user(self, name: Optional[str] = None, org_id: Optional[str] = None) -> User: + request = UserCreate(name=name, org_id=org_id) response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump()) if response.status_code != 200: raise HTTPError(response.json()) @@ -74,6 +75,32 @@ class Admin: raise HTTPError(response.json()) return User(**response.json()) + def create_organization(self, name: Optional[str] = None) -> Organization: + request = OrganizationCreate(name=name) + response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/orgs", headers=self.headers, json=request.model_dump()) + if response.status_code != 200: + raise HTTPError(response.json()) + response_json = response.json() + return Organization(**response_json) + + def delete_organization(self, org_id: str) -> Organization: + params = {"org_id": str(org_id)} + response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + return Organization(**response.json()) + + def get_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: + params = {} + if cursor: + params["cursor"] = str(cursor) + if limit: + params["limit"] = limit + response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers) + if response.status_code != 200: + raise HTTPError(response.json()) + return [Organization(**org) for org in response.json()] + def _reset_server(self): # DANGER: this will delete all users and keys # clear all state associated with users diff --git a/memgpt/metadata.py b/memgpt/metadata.py index 1eafd0fe..2496a9f8 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -29,6 +29,7 @@ from memgpt.schemas.job import Job from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.memory import Memory from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction +from memgpt.schemas.organization import Organization from memgpt.schemas.source import Source from memgpt.schemas.tool import Tool from memgpt.schemas.user import User @@ -121,6 +122,7 @@ class UserModel(Base): __table_args__ = {"extend_existing": True} id = Column(String, primary_key=True) + org_id = Column(String) name = Column(String, nullable=False) created_at = Column(DateTime(timezone=True)) @@ -131,7 +133,22 @@ class UserModel(Base): return f"" def to_record(self) -> User: - return User(id=self.id, name=self.name, created_at=self.created_at) + return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id) + + +class OrganizationModel(Base): + __tablename__ = "organizations" + __table_args__ = {"extend_existing": True} + + id = Column(String, primary_key=True) + name = Column(String, nullable=False) + created_at = Column(DateTime(timezone=True)) + + def __repr__(self) -> str: + return f"" + + def to_record(self) -> Organization: + return Organization(id=self.id, name=self.name, created_at=self.created_at) class APIKeyModel(Base): @@ -515,6 +532,14 @@ class MetadataStore: session.add(UserModel(**vars(user))) session.commit() + @enforce_types + def create_organization(self, organization: Organization): + with self.session_maker() as session: + if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0: + raise ValueError(f"Organization with id {organization.id} already exists") + session.add(OrganizationModel(**vars(organization))) + session.commit() + @enforce_types def create_block(self, block: Block): with self.session_maker() as session: @@ -638,6 +663,16 @@ class MetadataStore: session.commit() + @enforce_types + def delete_organization(self, org_id: str): + with self.session_maker() as session: + # delete from organizations table + session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete() + + # TODO: delete associated data + + session.commit() + @enforce_types # def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]: @@ -685,6 +720,30 @@ class MetadataStore: assert len(results) == 1, f"Expected 1 result, got {len(results)}" return results[0].to_record() + @enforce_types + def get_organization(self, org_id: str) -> Optional[Organization]: + with self.session_maker() as session: + results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all() + if len(results) == 0: + return None + assert len(results) == 1, f"Expected 1 result, got {len(results)}" + return results[0].to_record() + + @enforce_types + def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50): + with self.session_maker() as session: + query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id)) + if cursor: + query = query.filter(OrganizationModel.id < cursor) + results = query.limit(limit).all() + if not results: + return None, [] + organization_records = [r.to_record() for r in results] + next_cursor = organization_records[-1].id + assert isinstance(next_cursor, str) + + return next_cursor, organization_records + @enforce_types def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): with self.session_maker() as session: diff --git a/memgpt/schemas/organization.py b/memgpt/schemas/organization.py new file mode 100644 index 00000000..c9313107 --- /dev/null +++ b/memgpt/schemas/organization.py @@ -0,0 +1,20 @@ +from datetime import datetime +from typing import Optional + +from pydantic import Field + +from memgpt.schemas.memgpt_base import MemGPTBase + + +class OrganizationBase(MemGPTBase): + __id_prefix__ = "org" + + +class Organization(OrganizationBase): + id: str = OrganizationBase.generate_id_field() + name: str = Field(..., description="The name of the organization.") + created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") + + +class OrganizationCreate(OrganizationBase): + name: Optional[str] = Field(None, description="The name of the organization.") diff --git a/memgpt/schemas/user.py b/memgpt/schemas/user.py index cac5b5da..d2b5007d 100644 --- a/memgpt/schemas/user.py +++ b/memgpt/schemas/user.py @@ -21,9 +21,13 @@ class User(UserBase): """ id: str = UserBase.generate_id_field() + org_id: Optional[str] = Field( + ..., description="The organization id of the user" + ) # TODO: dont make optional, and pass in default org ID name: str = Field(..., description="The name of the user.") created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") class UserCreate(UserBase): name: Optional[str] = Field(None, description="The name of the user.") + org_id: Optional[str] = Field(None, description="The organization id of the user.") diff --git a/memgpt/server/rest_api/app.py b/memgpt/server/rest_api/app.py index f53ad27a..7e977119 100644 --- a/memgpt/server/rest_api/app.py +++ b/memgpt/server/rest_api/app.py @@ -29,6 +29,9 @@ from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions imp # from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes +from memgpt.server.rest_api.routers.v1.organizations import ( + router as organizations_router, +) from memgpt.server.rest_api.routers.v1.users import ( router as users_router, # TODO: decide on admin ) @@ -103,6 +106,7 @@ def create_application() -> "FastAPI": # admin/users app.include_router(users_router, prefix=ADMIN_PREFIX) + app.include_router(organizations_router, prefix=ADMIN_PREFIX) # openai app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) diff --git a/memgpt/server/rest_api/routers/v1/organizations.py b/memgpt/server/rest_api/routers/v1/organizations.py new file mode 100644 index 00000000..2b0873c2 --- /dev/null +++ b/memgpt/server/rest_api/routers/v1/organizations.py @@ -0,0 +1,61 @@ +from typing import TYPE_CHECKING, List, Optional + +from fastapi import APIRouter, Body, Depends, HTTPException, Query + +from memgpt.schemas.organization import Organization, OrganizationCreate +from memgpt.server.rest_api.utils import get_memgpt_server + +if TYPE_CHECKING: + from memgpt.server.server import SyncServer + + +router = APIRouter(prefix="/orgs", tags=["organization", "admin"]) + + +@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs") +def get_all_orgs( + cursor: Optional[str] = Query(None), + limit: Optional[int] = Query(50), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Get a list of all orgs in the database + """ + try: + next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return orgs + + +@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization") +def create_org( + request: OrganizationCreate = Body(...), + server: "SyncServer" = Depends(get_memgpt_server), +): + """ + Create a new org in the database + """ + + org = server.create_organization(request) + return org + + +@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization") +def delete_org( + org_id: str = Query(..., description="The org_id key to be deleted."), + server: "SyncServer" = Depends(get_memgpt_server), +): + # TODO make a soft deletion, instead of a hard deletion + try: + org = server.ms.get_organization(org_id=org_id) + if org is None: + raise HTTPException(status_code=404, detail=f"Organization does not exist") + server.ms.delete_organization(org_id=org_id) + except HTTPException: + raise + except Exception as e: + raise HTTPException(status_code=500, detail=f"{e}") + return org diff --git a/memgpt/server/server.py b/memgpt/server/server.py index e6747b48..b9658fbf 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -58,6 +58,7 @@ from memgpt.schemas.memgpt_message import MemGPTMessage from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from memgpt.schemas.message import Message, UpdateMessage from memgpt.schemas.openai.chat_completion_response import UsageStatistics +from memgpt.schemas.organization import Organization, OrganizationCreate from memgpt.schemas.passage import Passage from memgpt.schemas.source import Source, SourceCreate, SourceUpdate from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate @@ -146,6 +147,7 @@ from memgpt.metadata import ( APIKeyModel, BlockModel, JobModel, + OrganizationModel, SourceModel, ToolModel, UserModel, @@ -177,6 +179,7 @@ Base.metadata.create_all( JobModel.__table__, PassageModel.__table__, MessageModel.__table__, + OrganizationModel.__table__, ], ) @@ -689,17 +692,32 @@ class SyncServer(Server): if not request.name: # auto-generate a name request.name = create_random_username() - user = User(name=request.name) + user = User(name=request.name, org_id=request.org_id) self.ms.create_user(user) logger.info(f"Created new user from config: {user}") # add default for the user + # TODO: move to org assert user.id is not None, f"User id is None: {user}" self.add_default_blocks(user.id) self.add_default_tools(module_name="base", user_id=user.id) return user + def create_organization(self, request: OrganizationCreate) -> Organization: + """Create a new org using a config""" + if not request.name: + # auto-generate a name + request.name = create_random_username() + org = Organization(name=request.name) + self.ms.create_organization(org) + logger.info(f"Created new org from config: {org}") + + # add default for the org + # TODO: add default data + + return org + def create_agent( self, request: CreateAgent, diff --git a/tests/test_admin_client.py b/tests/test_admin_client.py index 8df80f61..54c6feae 100644 --- a/tests/test_admin_client.py +++ b/tests/test_admin_client.py @@ -38,11 +38,26 @@ def admin_client(): yield admin -def test_admin_client(admin_client): +@pytest.fixture(scope="module") +def organization(admin_client): + # create an organization + org_name = "test_org" + org = admin_client.create_organization(org_name) + assert org_name == org.name, f"Expected {org_name}, got {org.name}" + + # test listing + orgs = admin_client.get_organizations() + assert len(orgs) > 0, f"Expected 1 org, got {orgs}" + + yield org + admin_client.delete_organization(org.id) + + +def test_admin_client(admin_client, organization): # create a user user_name = "test_user" - user1 = admin_client.create_user(user_name) + user1 = admin_client.create_user(user_name, organization.id) assert user_name == user1.name, f"Expected {user_name}, got {user1.name}" # create another user diff --git a/tests/test_client.py b/tests/test_client.py index dbf24429..b8ae1abe 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -410,3 +410,9 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat new_text = "This exact string would never show up in the message???" new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) assert new_message.text == new_text + + +def test_organization(client: RESTClient): + if isinstance(client, LocalClient): + pytest.skip("Skipping test_organization because LocalClient does not support organizations") + client.base_url