feat: add organization endpoints and schemas (#1762)

This commit is contained in:
Sarah Wooders
2024-09-20 11:34:57 -07:00
committed by GitHub
parent da9bbeada0
commit a9d8445e4a
9 changed files with 220 additions and 6 deletions

View File

@@ -6,6 +6,7 @@ from requests import HTTPError
from memgpt.functions.functions import parse_source_code from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema from memgpt.functions.schema_generator import generate_schema
from memgpt.schemas.api_key import APIKey, APIKeyCreate from memgpt.schemas.api_key import APIKey, APIKeyCreate
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.schemas.user import User, UserCreate from memgpt.schemas.user import User, UserCreate
@@ -59,8 +60,8 @@ class Admin:
raise HTTPError(response.json()) raise HTTPError(response.json())
return APIKey(**response.json()) return APIKey(**response.json())
def create_user(self, name: Optional[str] = None) -> User: def create_user(self, name: Optional[str] = None, org_id: Optional[str] = None) -> User:
request = UserCreate(name=name) request = UserCreate(name=name, org_id=org_id)
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump()) response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
if response.status_code != 200: if response.status_code != 200:
raise HTTPError(response.json()) raise HTTPError(response.json())
@@ -74,6 +75,32 @@ class Admin:
raise HTTPError(response.json()) raise HTTPError(response.json())
return User(**response.json()) return User(**response.json())
def create_organization(self, name: Optional[str] = None) -> Organization:
request = OrganizationCreate(name=name)
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/orgs", headers=self.headers, json=request.model_dump())
if response.status_code != 200:
raise HTTPError(response.json())
response_json = response.json()
return Organization(**response_json)
def delete_organization(self, org_id: str) -> Organization:
params = {"org_id": str(org_id)}
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return Organization(**response.json())
def get_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
params = {}
if cursor:
params["cursor"] = str(cursor)
if limit:
params["limit"] = limit
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
return [Organization(**org) for org in response.json()]
def _reset_server(self): def _reset_server(self):
# DANGER: this will delete all users and keys # DANGER: this will delete all users and keys
# clear all state associated with users # clear all state associated with users

View File

@@ -29,6 +29,7 @@ from memgpt.schemas.job import Job
from memgpt.schemas.llm_config import LLMConfig from memgpt.schemas.llm_config import LLMConfig
from memgpt.schemas.memory import Memory from memgpt.schemas.memory import Memory
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
from memgpt.schemas.organization import Organization
from memgpt.schemas.source import Source from memgpt.schemas.source import Source
from memgpt.schemas.tool import Tool from memgpt.schemas.tool import Tool
from memgpt.schemas.user import User from memgpt.schemas.user import User
@@ -121,6 +122,7 @@ class UserModel(Base):
__table_args__ = {"extend_existing": True} __table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True) id = Column(String, primary_key=True)
org_id = Column(String)
name = Column(String, nullable=False) name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True)) created_at = Column(DateTime(timezone=True))
@@ -131,7 +133,22 @@ class UserModel(Base):
return f"<User(id='{self.id}' name='{self.name}')>" return f"<User(id='{self.id}' name='{self.name}')>"
def to_record(self) -> User: def to_record(self) -> User:
return User(id=self.id, name=self.name, created_at=self.created_at) return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
class OrganizationModel(Base):
__tablename__ = "organizations"
__table_args__ = {"extend_existing": True}
id = Column(String, primary_key=True)
name = Column(String, nullable=False)
created_at = Column(DateTime(timezone=True))
def __repr__(self) -> str:
return f"<Organization(id='{self.id}' name='{self.name}')>"
def to_record(self) -> Organization:
return Organization(id=self.id, name=self.name, created_at=self.created_at)
class APIKeyModel(Base): class APIKeyModel(Base):
@@ -515,6 +532,14 @@ class MetadataStore:
session.add(UserModel(**vars(user))) session.add(UserModel(**vars(user)))
session.commit() session.commit()
@enforce_types
def create_organization(self, organization: Organization):
with self.session_maker() as session:
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
raise ValueError(f"Organization with id {organization.id} already exists")
session.add(OrganizationModel(**vars(organization)))
session.commit()
@enforce_types @enforce_types
def create_block(self, block: Block): def create_block(self, block: Block):
with self.session_maker() as session: with self.session_maker() as session:
@@ -638,6 +663,16 @@ class MetadataStore:
session.commit() session.commit()
@enforce_types
def delete_organization(self, org_id: str):
with self.session_maker() as session:
# delete from organizations table
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
# TODO: delete associated data
session.commit()
@enforce_types @enforce_types
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools # def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]: def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
@@ -685,6 +720,30 @@ class MetadataStore:
assert len(results) == 1, f"Expected 1 result, got {len(results)}" assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record() return results[0].to_record()
@enforce_types
def get_organization(self, org_id: str) -> Optional[Organization]:
with self.session_maker() as session:
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
if len(results) == 0:
return None
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
return results[0].to_record()
@enforce_types
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session:
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
if cursor:
query = query.filter(OrganizationModel.id < cursor)
results = query.limit(limit).all()
if not results:
return None, []
organization_records = [r.to_record() for r in results]
next_cursor = organization_records[-1].id
assert isinstance(next_cursor, str)
return next_cursor, organization_records
@enforce_types @enforce_types
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50): def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
with self.session_maker() as session: with self.session_maker() as session:

View File

@@ -0,0 +1,20 @@
from datetime import datetime
from typing import Optional
from pydantic import Field
from memgpt.schemas.memgpt_base import MemGPTBase
class OrganizationBase(MemGPTBase):
__id_prefix__ = "org"
class Organization(OrganizationBase):
id: str = OrganizationBase.generate_id_field()
name: str = Field(..., description="The name of the organization.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class OrganizationCreate(OrganizationBase):
name: Optional[str] = Field(None, description="The name of the organization.")

View File

@@ -21,9 +21,13 @@ class User(UserBase):
""" """
id: str = UserBase.generate_id_field() id: str = UserBase.generate_id_field()
org_id: Optional[str] = Field(
..., description="The organization id of the user"
) # TODO: dont make optional, and pass in default org ID
name: str = Field(..., description="The name of the user.") name: str = Field(..., description="The name of the user.")
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.") created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
class UserCreate(UserBase): class UserCreate(UserBase):
name: Optional[str] = Field(None, description="The name of the user.") name: Optional[str] = Field(None, description="The name of the user.")
org_id: Optional[str] = Field(None, description="The organization id of the user.")

View File

@@ -29,6 +29,9 @@ from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions imp
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM # from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
from memgpt.server.rest_api.routers.v1.organizations import (
router as organizations_router,
)
from memgpt.server.rest_api.routers.v1.users import ( from memgpt.server.rest_api.routers.v1.users import (
router as users_router, # TODO: decide on admin router as users_router, # TODO: decide on admin
) )
@@ -103,6 +106,7 @@ def create_application() -> "FastAPI":
# admin/users # admin/users
app.include_router(users_router, prefix=ADMIN_PREFIX) app.include_router(users_router, prefix=ADMIN_PREFIX)
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
# openai # openai
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX) app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)

View File

@@ -0,0 +1,61 @@
from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, HTTPException, Query
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.server.rest_api.utils import get_memgpt_server
if TYPE_CHECKING:
from memgpt.server.server import SyncServer
router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
def get_all_orgs(
cursor: Optional[str] = Query(None),
limit: Optional[int] = Query(50),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Get a list of all orgs in the database
"""
try:
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return orgs
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
def create_org(
request: OrganizationCreate = Body(...),
server: "SyncServer" = Depends(get_memgpt_server),
):
"""
Create a new org in the database
"""
org = server.create_organization(request)
return org
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization")
def delete_org(
org_id: str = Query(..., description="The org_id key to be deleted."),
server: "SyncServer" = Depends(get_memgpt_server),
):
# TODO make a soft deletion, instead of a hard deletion
try:
org = server.ms.get_organization(org_id=org_id)
if org is None:
raise HTTPException(status_code=404, detail=f"Organization does not exist")
server.ms.delete_organization(org_id=org_id)
except HTTPException:
raise
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
return org

View File

@@ -58,6 +58,7 @@ from memgpt.schemas.memgpt_message import MemGPTMessage
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
from memgpt.schemas.message import Message, UpdateMessage from memgpt.schemas.message import Message, UpdateMessage
from memgpt.schemas.openai.chat_completion_response import UsageStatistics from memgpt.schemas.openai.chat_completion_response import UsageStatistics
from memgpt.schemas.organization import Organization, OrganizationCreate
from memgpt.schemas.passage import Passage from memgpt.schemas.passage import Passage
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
@@ -146,6 +147,7 @@ from memgpt.metadata import (
APIKeyModel, APIKeyModel,
BlockModel, BlockModel,
JobModel, JobModel,
OrganizationModel,
SourceModel, SourceModel,
ToolModel, ToolModel,
UserModel, UserModel,
@@ -177,6 +179,7 @@ Base.metadata.create_all(
JobModel.__table__, JobModel.__table__,
PassageModel.__table__, PassageModel.__table__,
MessageModel.__table__, MessageModel.__table__,
OrganizationModel.__table__,
], ],
) )
@@ -689,17 +692,32 @@ class SyncServer(Server):
if not request.name: if not request.name:
# auto-generate a name # auto-generate a name
request.name = create_random_username() request.name = create_random_username()
user = User(name=request.name) user = User(name=request.name, org_id=request.org_id)
self.ms.create_user(user) self.ms.create_user(user)
logger.info(f"Created new user from config: {user}") logger.info(f"Created new user from config: {user}")
# add default for the user # add default for the user
# TODO: move to org
assert user.id is not None, f"User id is None: {user}" assert user.id is not None, f"User id is None: {user}"
self.add_default_blocks(user.id) self.add_default_blocks(user.id)
self.add_default_tools(module_name="base", user_id=user.id) self.add_default_tools(module_name="base", user_id=user.id)
return user return user
def create_organization(self, request: OrganizationCreate) -> Organization:
"""Create a new org using a config"""
if not request.name:
# auto-generate a name
request.name = create_random_username()
org = Organization(name=request.name)
self.ms.create_organization(org)
logger.info(f"Created new org from config: {org}")
# add default for the org
# TODO: add default data
return org
def create_agent( def create_agent(
self, self,
request: CreateAgent, request: CreateAgent,

View File

@@ -38,11 +38,26 @@ def admin_client():
yield admin yield admin
def test_admin_client(admin_client): @pytest.fixture(scope="module")
def organization(admin_client):
# create an organization
org_name = "test_org"
org = admin_client.create_organization(org_name)
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
# test listing
orgs = admin_client.get_organizations()
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
yield org
admin_client.delete_organization(org.id)
def test_admin_client(admin_client, organization):
# create a user # create a user
user_name = "test_user" user_name = "test_user"
user1 = admin_client.create_user(user_name) user1 = admin_client.create_user(user_name, organization.id)
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}" assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
# create another user # create another user

View File

@@ -410,3 +410,9 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
new_text = "This exact string would never show up in the message???" new_text = "This exact string would never show up in the message???"
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id) new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
assert new_message.text == new_text assert new_message.text == new_text
def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
client.base_url