feat: add organization endpoints and schemas (#1762)
This commit is contained in:
@@ -6,6 +6,7 @@ from requests import HTTPError
|
||||
from memgpt.functions.functions import parse_source_code
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
||||
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||
from memgpt.schemas.user import User, UserCreate
|
||||
|
||||
|
||||
@@ -59,8 +60,8 @@ class Admin:
|
||||
raise HTTPError(response.json())
|
||||
return APIKey(**response.json())
|
||||
|
||||
def create_user(self, name: Optional[str] = None) -> User:
|
||||
request = UserCreate(name=name)
|
||||
def create_user(self, name: Optional[str] = None, org_id: Optional[str] = None) -> User:
|
||||
request = UserCreate(name=name, org_id=org_id)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
@@ -74,6 +75,32 @@ class Admin:
|
||||
raise HTTPError(response.json())
|
||||
return User(**response.json())
|
||||
|
||||
def create_organization(self, name: Optional[str] = None) -> Organization:
|
||||
request = OrganizationCreate(name=name)
|
||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/orgs", headers=self.headers, json=request.model_dump())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
response_json = response.json()
|
||||
return Organization(**response_json)
|
||||
|
||||
def delete_organization(self, org_id: str) -> Organization:
|
||||
params = {"org_id": str(org_id)}
|
||||
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return Organization(**response.json())
|
||||
|
||||
def get_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
params = {}
|
||||
if cursor:
|
||||
params["cursor"] = str(cursor)
|
||||
if limit:
|
||||
params["limit"] = limit
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return [Organization(**org) for org in response.json()]
|
||||
|
||||
def _reset_server(self):
|
||||
# DANGER: this will delete all users and keys
|
||||
# clear all state associated with users
|
||||
|
||||
@@ -29,6 +29,7 @@ from memgpt.schemas.job import Job
|
||||
from memgpt.schemas.llm_config import LLMConfig
|
||||
from memgpt.schemas.memory import Memory
|
||||
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||
from memgpt.schemas.organization import Organization
|
||||
from memgpt.schemas.source import Source
|
||||
from memgpt.schemas.tool import Tool
|
||||
from memgpt.schemas.user import User
|
||||
@@ -121,6 +122,7 @@ class UserModel(Base):
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
org_id = Column(String)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
@@ -131,7 +133,22 @@ class UserModel(Base):
|
||||
return f"<User(id='{self.id}' name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> User:
|
||||
return User(id=self.id, name=self.name, created_at=self.created_at)
|
||||
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
|
||||
|
||||
|
||||
class OrganizationModel(Base):
|
||||
__tablename__ = "organizations"
|
||||
__table_args__ = {"extend_existing": True}
|
||||
|
||||
id = Column(String, primary_key=True)
|
||||
name = Column(String, nullable=False)
|
||||
created_at = Column(DateTime(timezone=True))
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"<Organization(id='{self.id}' name='{self.name}')>"
|
||||
|
||||
def to_record(self) -> Organization:
|
||||
return Organization(id=self.id, name=self.name, created_at=self.created_at)
|
||||
|
||||
|
||||
class APIKeyModel(Base):
|
||||
@@ -515,6 +532,14 @@ class MetadataStore:
|
||||
session.add(UserModel(**vars(user)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_organization(self, organization: Organization):
|
||||
with self.session_maker() as session:
|
||||
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
|
||||
raise ValueError(f"Organization with id {organization.id} already exists")
|
||||
session.add(OrganizationModel(**vars(organization)))
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def create_block(self, block: Block):
|
||||
with self.session_maker() as session:
|
||||
@@ -638,6 +663,16 @@ class MetadataStore:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_organization(self, org_id: str):
|
||||
with self.session_maker() as session:
|
||||
# delete from organizations table
|
||||
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
|
||||
|
||||
# TODO: delete associated data
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||
@@ -685,6 +720,30 @@ class MetadataStore:
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def get_organization(self, org_id: str) -> Optional[Organization]:
|
||||
with self.session_maker() as session:
|
||||
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
|
||||
if len(results) == 0:
|
||||
return None
|
||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||
return results[0].to_record()
|
||||
|
||||
@enforce_types
|
||||
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||
with self.session_maker() as session:
|
||||
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
|
||||
if cursor:
|
||||
query = query.filter(OrganizationModel.id < cursor)
|
||||
results = query.limit(limit).all()
|
||||
if not results:
|
||||
return None, []
|
||||
organization_records = [r.to_record() for r in results]
|
||||
next_cursor = organization_records[-1].id
|
||||
assert isinstance(next_cursor, str)
|
||||
|
||||
return next_cursor, organization_records
|
||||
|
||||
@enforce_types
|
||||
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||
with self.session_maker() as session:
|
||||
|
||||
20
memgpt/schemas/organization.py
Normal file
20
memgpt/schemas/organization.py
Normal file
@@ -0,0 +1,20 @@
|
||||
from datetime import datetime
|
||||
from typing import Optional
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||
|
||||
|
||||
class OrganizationBase(MemGPTBase):
|
||||
__id_prefix__ = "org"
|
||||
|
||||
|
||||
class Organization(OrganizationBase):
|
||||
id: str = OrganizationBase.generate_id_field()
|
||||
name: str = Field(..., description="The name of the organization.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
|
||||
|
||||
class OrganizationCreate(OrganizationBase):
|
||||
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||
@@ -21,9 +21,13 @@ class User(UserBase):
|
||||
"""
|
||||
|
||||
id: str = UserBase.generate_id_field()
|
||||
org_id: Optional[str] = Field(
|
||||
..., description="The organization id of the user"
|
||||
) # TODO: dont make optional, and pass in default org ID
|
||||
name: str = Field(..., description="The name of the user.")
|
||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||
|
||||
|
||||
class UserCreate(UserBase):
|
||||
name: Optional[str] = Field(None, description="The name of the user.")
|
||||
org_id: Optional[str] = Field(None, description="The organization id of the user.")
|
||||
|
||||
@@ -29,6 +29,9 @@ from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions imp
|
||||
|
||||
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
||||
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
from memgpt.server.rest_api.routers.v1.organizations import (
|
||||
router as organizations_router,
|
||||
)
|
||||
from memgpt.server.rest_api.routers.v1.users import (
|
||||
router as users_router, # TODO: decide on admin
|
||||
)
|
||||
@@ -103,6 +106,7 @@ def create_application() -> "FastAPI":
|
||||
|
||||
# admin/users
|
||||
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
||||
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
|
||||
|
||||
# openai
|
||||
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
||||
|
||||
61
memgpt/server/rest_api/routers/v1/organizations.py
Normal file
61
memgpt/server/rest_api/routers/v1/organizations.py
Normal file
@@ -0,0 +1,61 @@
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||
|
||||
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
|
||||
router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
|
||||
def get_all_orgs(
|
||||
cursor: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Get a list of all orgs in the database
|
||||
"""
|
||||
try:
|
||||
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return orgs
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
|
||||
def create_org(
|
||||
request: OrganizationCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
"""
|
||||
Create a new org in the database
|
||||
"""
|
||||
|
||||
org = server.create_organization(request)
|
||||
return org
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization")
|
||||
def delete_org(
|
||||
org_id: str = Query(..., description="The org_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_memgpt_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
org = server.ms.get_organization(org_id=org_id)
|
||||
if org is None:
|
||||
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||
server.ms.delete_organization(org_id=org_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
return org
|
||||
@@ -58,6 +58,7 @@ from memgpt.schemas.memgpt_message import MemGPTMessage
|
||||
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||
from memgpt.schemas.message import Message, UpdateMessage
|
||||
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||
from memgpt.schemas.passage import Passage
|
||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
@@ -146,6 +147,7 @@ from memgpt.metadata import (
|
||||
APIKeyModel,
|
||||
BlockModel,
|
||||
JobModel,
|
||||
OrganizationModel,
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
UserModel,
|
||||
@@ -177,6 +179,7 @@ Base.metadata.create_all(
|
||||
JobModel.__table__,
|
||||
PassageModel.__table__,
|
||||
MessageModel.__table__,
|
||||
OrganizationModel.__table__,
|
||||
],
|
||||
)
|
||||
|
||||
@@ -689,17 +692,32 @@ class SyncServer(Server):
|
||||
if not request.name:
|
||||
# auto-generate a name
|
||||
request.name = create_random_username()
|
||||
user = User(name=request.name)
|
||||
user = User(name=request.name, org_id=request.org_id)
|
||||
self.ms.create_user(user)
|
||||
logger.info(f"Created new user from config: {user}")
|
||||
|
||||
# add default for the user
|
||||
# TODO: move to org
|
||||
assert user.id is not None, f"User id is None: {user}"
|
||||
self.add_default_blocks(user.id)
|
||||
self.add_default_tools(module_name="base", user_id=user.id)
|
||||
|
||||
return user
|
||||
|
||||
def create_organization(self, request: OrganizationCreate) -> Organization:
|
||||
"""Create a new org using a config"""
|
||||
if not request.name:
|
||||
# auto-generate a name
|
||||
request.name = create_random_username()
|
||||
org = Organization(name=request.name)
|
||||
self.ms.create_organization(org)
|
||||
logger.info(f"Created new org from config: {org}")
|
||||
|
||||
# add default for the org
|
||||
# TODO: add default data
|
||||
|
||||
return org
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
|
||||
@@ -38,11 +38,26 @@ def admin_client():
|
||||
yield admin
|
||||
|
||||
|
||||
def test_admin_client(admin_client):
|
||||
@pytest.fixture(scope="module")
|
||||
def organization(admin_client):
|
||||
# create an organization
|
||||
org_name = "test_org"
|
||||
org = admin_client.create_organization(org_name)
|
||||
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
|
||||
|
||||
# test listing
|
||||
orgs = admin_client.get_organizations()
|
||||
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
|
||||
|
||||
yield org
|
||||
admin_client.delete_organization(org.id)
|
||||
|
||||
|
||||
def test_admin_client(admin_client, organization):
|
||||
|
||||
# create a user
|
||||
user_name = "test_user"
|
||||
user1 = admin_client.create_user(user_name)
|
||||
user1 = admin_client.create_user(user_name, organization.id)
|
||||
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
|
||||
|
||||
# create another user
|
||||
|
||||
@@ -410,3 +410,9 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
|
||||
new_text = "This exact string would never show up in the message???"
|
||||
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
||||
assert new_message.text == new_text
|
||||
|
||||
|
||||
def test_organization(client: RESTClient):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
client.base_url
|
||||
|
||||
Reference in New Issue
Block a user