feat: add organization endpoints and schemas (#1762)

This commit is contained in:
Sarah Wooders
2024-09-20 11:34:57 -07:00
committed by GitHub
parent da9bbeada0
commit a9d8445e4a
9 changed files with 220 additions and 6 deletions

View File

@@ -6,6 +6,7 @@ from requests import HTTPError
from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema
from memgpt.schemas.api_key import APIKey, APIKeyCreate
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.schemas.user import User, UserCreate
@@ -59,8 +60,8 @@ class Admin:
raise HTTPError(response.json())
return APIKey(**response.json())
def create_user(self, name: Optional[str] = None) -> User:
request = UserCreate(name=name)
def create_user(self, name: Optional[str] = None, org_id: Optional[str] = None) -> User:
request = UserCreate(name=name, org_id=org_id)
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
if response.status_code != 200:
raise HTTPError(response.json())
@@ -74,6 +75,32 @@ class Admin:
raise HTTPError(response.json())
return User(**response.json())
def create_organization(self, name: Optional[str] = None) -> Organization:
request = OrganizationCreate(name=name)
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/orgs", headers=self.headers, json=request.model_dump())
if response.status_code != 200:
raise HTTPError(response.json())
response_json = response.json()
return Organization(**response_json)
def delete_organization(self, org_id: str) -> Organization:
params = {"org_id": str(org_id)}
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return Organization(**response.json())
def get_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
params = {}
if cursor:
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [Organization(**org) for org in response.json()]
def _reset_server(self):
# DANGER: this will delete all users and keys
# clear all state associated with users

View File

@@ -29,6 +29,7 @@ from memgpt.schemas.job import Job
from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import Memory
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from memgpt.schemas.organization import Organization
from memgpt.schemas.source import Source
from memgpt.schemas.tool import Tool
from memgpt.schemas.user import User
@@ -121,6 +122,7 @@ class UserModel(Base):
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
org_id = Column(String)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))
@@ -131,7 +133,22 @@ class UserModel(Base):
return f"<User(id='{self.id}' name='{self.name}')>"
def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at)
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
class OrganizationModel(Base):
__tablename__ = "organizations"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))
def __repr__(self) -> str:
return f"<Organization(id='{self.id}' name='{self.name}')>"
def to_record(self) -> Organization:
return Organization(id=self.id, name=self.name, created_at=self.created_at)
class APIKeyModel(Base):
@@ -515,6 +532,14 @@ class MetadataStore:
session.add(UserModel(**vars(user)))
session.commit()
@enforce_types
def create_organization(self, organization: Organization):
with self.session_maker() as session:
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
raise ValueError(f"Organization with id {organization.id} already exists")
session.add(OrganizationModel(**vars(organization)))
session.commit()
@enforce_types
def create_block(self, block: Block):
with self.session_maker() as session:
@@ -638,6 +663,16 @@ class MetadataStore:
session.commit()
@enforce_types
def delete_organization(self, org_id: str):
with self.session_maker() as session:
# delete from organizations table
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
# TODO: delete associated data
session.commit()
@enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
@@ -685,6 +720,30 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def get_organization(self, org_id: str) -> Optional[Organization]:
with self.session_maker() as session:
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
if cursor:
query = query.filter(OrganizationModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
organization_records = [r.to_record() for r in results]
next_cursor = organization_records[-1].id
assert isinstance(next_cursor, str)
return next_cursor, organization_records
@enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:

View File

@@ -0,0 +1,20 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class OrganizationBase(MemGPTBase):
__id_prefix__ = "org"
class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field()
name: str = Field(..., description="The name of the organization.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class OrganizationCreate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.")

View File

@@ -21,9 +21,13 @@ class User(UserBase):
"""
id: str = UserBase.generate_id_field()
org_id: Optional[str] = Field(
..., description="The organization id of the user"
) # TODO: dont make optional, and pass in default org ID
name: str = Field(..., description="The name of the user.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class UserCreate(UserBase):
name: Optional[str] = Field(None, description="The name of the user.")
org_id: Optional[str] = Field(None, description="The organization id of the user.")

View File

@@ -29,6 +29,9 @@ from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions imp
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.organizations import (
router as organizations_router,
)
from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin
)
@@ -103,6 +106,7 @@ def create_application() -> "FastAPI":
# admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX)
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)

View File

@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.server.rest_api.utils import get_memgpt_server
if TYPE_CHECKING:
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
def get_all_orgs(
cursor: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get a list of all orgs in the database
"""
try:
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return orgs
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
def create_org(
request: OrganizationCreate = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new org in the database
"""
org = server.create_organization(request)
return org
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization")
def delete_org(
org_id: str = Query(..., description="The org_id key to be deleted."),
server: "SyncServer" = Depends(get_memgpt_server),
):
# TODO make a soft deletion, instead of a hard deletion
try:
org = server.ms.get_organization(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
server.ms.delete_organization(org_id=org_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return org

View File

@@ -58,6 +58,7 @@ from memgpt.schemas.memgpt_message import MemGPTMessage
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message, UpdateMessage
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
@@ -146,6 +147,7 @@ from memgpt.metadata import (
APIKeyModel,
BlockModel,
JobModel,
OrganizationModel,
SourceModel,
ToolModel,
UserModel,
@@ -177,6 +179,7 @@ Base.metadata.create_all(
JobModel.__table__,
PassageModel.__table__,
MessageModel.__table__,
OrganizationModel.__table__,
],
)
@@ -689,17 +692,32 @@ class SyncServer(Server):
if not request.name:
# auto-generate a name
request.name = create_random_username()
user = User(name=request.name)
user = User(name=request.name, org_id=request.org_id)
self.ms.create_user(user)
logger.info(f"Created new user from config: {user}")
# add default for the user
# TODO: move to org
assert user.id is not None, f"User id is None: {user}"
self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id)
return user
def create_organization(self, request: OrganizationCreate) -> Organization:
"""Create a new org using a config"""
if not request.name:
# auto-generate a name
request.name = create_random_username()
org = Organization(name=request.name)
self.ms.create_organization(org)
logger.info(f"Created new org from config: {org}")
# add default for the org
# TODO: add default data
return org
def create_agent(
self,
request: CreateAgent,

View File

@@ -38,11 +38,26 @@ def admin_client():
yield admin
def test_admin_client(admin_client):
@pytest.fixture(scope="module")
def organization(admin_client):
# create an organization
org_name = "test_org"
org = admin_client.create_organization(org_name)
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
# test listing
orgs = admin_client.get_organizations()
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
yield org
admin_client.delete_organization(org.id)
def test_admin_client(admin_client, organization):
# create a user
user_name = "test_user"
user1 = admin_client.create_user(user_name)
user1 = admin_client.create_user(user_name, organization.id)
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
# create another user

View File

@@ -410,3 +410,9 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
new_text = "This exact string would never show up in the message???"
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
assert new_message.text == new_text
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
client.base_url