feat: add organization endpoints and schemas (#1762)
This commit is contained in:
@@ -6,6 +6,7 @@ from requests import HTTPError
|
|||||||
from memgpt.functions.functions import parse_source_code
|
from memgpt.functions.functions import parse_source_code
|
||||||
from memgpt.functions.schema_generator import generate_schema
|
from memgpt.functions.schema_generator import generate_schema
|
||||||
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
from memgpt.schemas.api_key import APIKey, APIKeyCreate
|
||||||
|
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||||
from memgpt.schemas.user import User, UserCreate
|
from memgpt.schemas.user import User, UserCreate
|
||||||
|
|
||||||
|
|
||||||
@@ -59,8 +60,8 @@ class Admin:
|
|||||||
raise HTTPError(response.json())
|
raise HTTPError(response.json())
|
||||||
return APIKey(**response.json())
|
return APIKey(**response.json())
|
||||||
|
|
||||||
def create_user(self, name: Optional[str] = None) -> User:
|
def create_user(self, name: Optional[str] = None, org_id: Optional[str] = None) -> User:
|
||||||
request = UserCreate(name=name)
|
request = UserCreate(name=name, org_id=org_id)
|
||||||
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
|
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/users", headers=self.headers, json=request.model_dump())
|
||||||
if response.status_code != 200:
|
if response.status_code != 200:
|
||||||
raise HTTPError(response.json())
|
raise HTTPError(response.json())
|
||||||
@@ -74,6 +75,32 @@ class Admin:
|
|||||||
raise HTTPError(response.json())
|
raise HTTPError(response.json())
|
||||||
return User(**response.json())
|
return User(**response.json())
|
||||||
|
|
||||||
|
def create_organization(self, name: Optional[str] = None) -> Organization:
|
||||||
|
request = OrganizationCreate(name=name)
|
||||||
|
response = requests.post(f"{self.base_url}/{self.api_prefix}/admin/orgs", headers=self.headers, json=request.model_dump())
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPError(response.json())
|
||||||
|
response_json = response.json()
|
||||||
|
return Organization(**response_json)
|
||||||
|
|
||||||
|
def delete_organization(self, org_id: str) -> Organization:
|
||||||
|
params = {"org_id": str(org_id)}
|
||||||
|
response = requests.delete(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPError(response.json())
|
||||||
|
return Organization(**response.json())
|
||||||
|
|
||||||
|
def get_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||||
|
params = {}
|
||||||
|
if cursor:
|
||||||
|
params["cursor"] = str(cursor)
|
||||||
|
if limit:
|
||||||
|
params["limit"] = limit
|
||||||
|
response = requests.get(f"{self.base_url}/{self.api_prefix}/admin/orgs", params=params, headers=self.headers)
|
||||||
|
if response.status_code != 200:
|
||||||
|
raise HTTPError(response.json())
|
||||||
|
return [Organization(**org) for org in response.json()]
|
||||||
|
|
||||||
def _reset_server(self):
|
def _reset_server(self):
|
||||||
# DANGER: this will delete all users and keys
|
# DANGER: this will delete all users and keys
|
||||||
# clear all state associated with users
|
# clear all state associated with users
|
||||||
|
|||||||
@@ -29,6 +29,7 @@ from memgpt.schemas.job import Job
|
|||||||
from memgpt.schemas.llm_config import LLMConfig
|
from memgpt.schemas.llm_config import LLMConfig
|
||||||
from memgpt.schemas.memory import Memory
|
from memgpt.schemas.memory import Memory
|
||||||
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
from memgpt.schemas.openai.chat_completions import ToolCall, ToolCallFunction
|
||||||
|
from memgpt.schemas.organization import Organization
|
||||||
from memgpt.schemas.source import Source
|
from memgpt.schemas.source import Source
|
||||||
from memgpt.schemas.tool import Tool
|
from memgpt.schemas.tool import Tool
|
||||||
from memgpt.schemas.user import User
|
from memgpt.schemas.user import User
|
||||||
@@ -121,6 +122,7 @@ class UserModel(Base):
|
|||||||
__table_args__ = {"extend_existing": True}
|
__table_args__ = {"extend_existing": True}
|
||||||
|
|
||||||
id = Column(String, primary_key=True)
|
id = Column(String, primary_key=True)
|
||||||
|
org_id = Column(String)
|
||||||
name = Column(String, nullable=False)
|
name = Column(String, nullable=False)
|
||||||
created_at = Column(DateTime(timezone=True))
|
created_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
@@ -131,7 +133,22 @@ class UserModel(Base):
|
|||||||
return f"<User(id='{self.id}' name='{self.name}')>"
|
return f"<User(id='{self.id}' name='{self.name}')>"
|
||||||
|
|
||||||
def to_record(self) -> User:
|
def to_record(self) -> User:
|
||||||
return User(id=self.id, name=self.name, created_at=self.created_at)
|
return User(id=self.id, name=self.name, created_at=self.created_at, org_id=self.org_id)
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationModel(Base):
|
||||||
|
__tablename__ = "organizations"
|
||||||
|
__table_args__ = {"extend_existing": True}
|
||||||
|
|
||||||
|
id = Column(String, primary_key=True)
|
||||||
|
name = Column(String, nullable=False)
|
||||||
|
created_at = Column(DateTime(timezone=True))
|
||||||
|
|
||||||
|
def __repr__(self) -> str:
|
||||||
|
return f"<Organization(id='{self.id}' name='{self.name}')>"
|
||||||
|
|
||||||
|
def to_record(self) -> Organization:
|
||||||
|
return Organization(id=self.id, name=self.name, created_at=self.created_at)
|
||||||
|
|
||||||
|
|
||||||
class APIKeyModel(Base):
|
class APIKeyModel(Base):
|
||||||
@@ -515,6 +532,14 @@ class MetadataStore:
|
|||||||
session.add(UserModel(**vars(user)))
|
session.add(UserModel(**vars(user)))
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def create_organization(self, organization: Organization):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
if session.query(OrganizationModel).filter(OrganizationModel.id == organization.id).count() > 0:
|
||||||
|
raise ValueError(f"Organization with id {organization.id} already exists")
|
||||||
|
session.add(OrganizationModel(**vars(organization)))
|
||||||
|
session.commit()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def create_block(self, block: Block):
|
def create_block(self, block: Block):
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
@@ -638,6 +663,16 @@ class MetadataStore:
|
|||||||
|
|
||||||
session.commit()
|
session.commit()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def delete_organization(self, org_id: str):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
# delete from organizations table
|
||||||
|
session.query(OrganizationModel).filter(OrganizationModel.id == org_id).delete()
|
||||||
|
|
||||||
|
# TODO: delete associated data
|
||||||
|
|
||||||
|
session.commit()
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
# def list_tools(self, user_id: str) -> List[ToolModel]: # TODO: add when users can creat tools
|
||||||
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
def list_tools(self, user_id: Optional[str] = None) -> List[ToolModel]:
|
||||||
@@ -685,6 +720,30 @@ class MetadataStore:
|
|||||||
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
return results[0].to_record()
|
return results[0].to_record()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def get_organization(self, org_id: str) -> Optional[Organization]:
|
||||||
|
with self.session_maker() as session:
|
||||||
|
results = session.query(OrganizationModel).filter(OrganizationModel.id == org_id).all()
|
||||||
|
if len(results) == 0:
|
||||||
|
return None
|
||||||
|
assert len(results) == 1, f"Expected 1 result, got {len(results)}"
|
||||||
|
return results[0].to_record()
|
||||||
|
|
||||||
|
@enforce_types
|
||||||
|
def list_organizations(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||||
|
with self.session_maker() as session:
|
||||||
|
query = session.query(OrganizationModel).order_by(desc(OrganizationModel.id))
|
||||||
|
if cursor:
|
||||||
|
query = query.filter(OrganizationModel.id < cursor)
|
||||||
|
results = query.limit(limit).all()
|
||||||
|
if not results:
|
||||||
|
return None, []
|
||||||
|
organization_records = [r.to_record() for r in results]
|
||||||
|
next_cursor = organization_records[-1].id
|
||||||
|
assert isinstance(next_cursor, str)
|
||||||
|
|
||||||
|
return next_cursor, organization_records
|
||||||
|
|
||||||
@enforce_types
|
@enforce_types
|
||||||
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
def get_all_users(self, cursor: Optional[str] = None, limit: Optional[int] = 50):
|
||||||
with self.session_maker() as session:
|
with self.session_maker() as session:
|
||||||
|
|||||||
20
memgpt/schemas/organization.py
Normal file
20
memgpt/schemas/organization.py
Normal file
@@ -0,0 +1,20 @@
|
|||||||
|
from datetime import datetime
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from pydantic import Field
|
||||||
|
|
||||||
|
from memgpt.schemas.memgpt_base import MemGPTBase
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationBase(MemGPTBase):
|
||||||
|
__id_prefix__ = "org"
|
||||||
|
|
||||||
|
|
||||||
|
class Organization(OrganizationBase):
|
||||||
|
id: str = OrganizationBase.generate_id_field()
|
||||||
|
name: str = Field(..., description="The name of the organization.")
|
||||||
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||||
|
|
||||||
|
|
||||||
|
class OrganizationCreate(OrganizationBase):
|
||||||
|
name: Optional[str] = Field(None, description="The name of the organization.")
|
||||||
@@ -21,9 +21,13 @@ class User(UserBase):
|
|||||||
"""
|
"""
|
||||||
|
|
||||||
id: str = UserBase.generate_id_field()
|
id: str = UserBase.generate_id_field()
|
||||||
|
org_id: Optional[str] = Field(
|
||||||
|
..., description="The organization id of the user"
|
||||||
|
) # TODO: dont make optional, and pass in default org ID
|
||||||
name: str = Field(..., description="The name of the user.")
|
name: str = Field(..., description="The name of the user.")
|
||||||
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
created_at: datetime = Field(default_factory=datetime.utcnow, description="The creation date of the user.")
|
||||||
|
|
||||||
|
|
||||||
class UserCreate(UserBase):
|
class UserCreate(UserBase):
|
||||||
name: Optional[str] = Field(None, description="The name of the user.")
|
name: Optional[str] = Field(None, description="The name of the user.")
|
||||||
|
org_id: Optional[str] = Field(None, description="The organization id of the user.")
|
||||||
|
|||||||
@@ -29,6 +29,9 @@ from memgpt.server.rest_api.routers.openai.chat_completions.chat_completions imp
|
|||||||
|
|
||||||
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
# from memgpt.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
||||||
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
from memgpt.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||||
|
from memgpt.server.rest_api.routers.v1.organizations import (
|
||||||
|
router as organizations_router,
|
||||||
|
)
|
||||||
from memgpt.server.rest_api.routers.v1.users import (
|
from memgpt.server.rest_api.routers.v1.users import (
|
||||||
router as users_router, # TODO: decide on admin
|
router as users_router, # TODO: decide on admin
|
||||||
)
|
)
|
||||||
@@ -103,6 +106,7 @@ def create_application() -> "FastAPI":
|
|||||||
|
|
||||||
# admin/users
|
# admin/users
|
||||||
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
||||||
|
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
|
||||||
|
|
||||||
# openai
|
# openai
|
||||||
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
app.include_router(openai_assistants_router, prefix=OPENAI_API_PREFIX)
|
||||||
|
|||||||
61
memgpt/server/rest_api/routers/v1/organizations.py
Normal file
61
memgpt/server/rest_api/routers/v1/organizations.py
Normal file
@@ -0,0 +1,61 @@
|
|||||||
|
from typing import TYPE_CHECKING, List, Optional
|
||||||
|
|
||||||
|
from fastapi import APIRouter, Body, Depends, HTTPException, Query
|
||||||
|
|
||||||
|
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||||
|
from memgpt.server.rest_api.utils import get_memgpt_server
|
||||||
|
|
||||||
|
if TYPE_CHECKING:
|
||||||
|
from memgpt.server.server import SyncServer
|
||||||
|
|
||||||
|
|
||||||
|
router = APIRouter(prefix="/orgs", tags=["organization", "admin"])
|
||||||
|
|
||||||
|
|
||||||
|
@router.get("/", tags=["admin"], response_model=List[Organization], operation_id="list_orgs")
|
||||||
|
def get_all_orgs(
|
||||||
|
cursor: Optional[str] = Query(None),
|
||||||
|
limit: Optional[int] = Query(50),
|
||||||
|
server: "SyncServer" = Depends(get_memgpt_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Get a list of all orgs in the database
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
next_cursor, orgs = server.ms.list_organizations(cursor=cursor, limit=limit)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
return orgs
|
||||||
|
|
||||||
|
|
||||||
|
@router.post("/", tags=["admin"], response_model=Organization, operation_id="create_organization")
|
||||||
|
def create_org(
|
||||||
|
request: OrganizationCreate = Body(...),
|
||||||
|
server: "SyncServer" = Depends(get_memgpt_server),
|
||||||
|
):
|
||||||
|
"""
|
||||||
|
Create a new org in the database
|
||||||
|
"""
|
||||||
|
|
||||||
|
org = server.create_organization(request)
|
||||||
|
return org
|
||||||
|
|
||||||
|
|
||||||
|
@router.delete("/", tags=["admin"], response_model=Organization, operation_id="delete_organization")
|
||||||
|
def delete_org(
|
||||||
|
org_id: str = Query(..., description="The org_id key to be deleted."),
|
||||||
|
server: "SyncServer" = Depends(get_memgpt_server),
|
||||||
|
):
|
||||||
|
# TODO make a soft deletion, instead of a hard deletion
|
||||||
|
try:
|
||||||
|
org = server.ms.get_organization(org_id=org_id)
|
||||||
|
if org is None:
|
||||||
|
raise HTTPException(status_code=404, detail=f"Organization does not exist")
|
||||||
|
server.ms.delete_organization(org_id=org_id)
|
||||||
|
except HTTPException:
|
||||||
|
raise
|
||||||
|
except Exception as e:
|
||||||
|
raise HTTPException(status_code=500, detail=f"{e}")
|
||||||
|
return org
|
||||||
@@ -58,6 +58,7 @@ from memgpt.schemas.memgpt_message import MemGPTMessage
|
|||||||
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
from memgpt.schemas.memory import ArchivalMemorySummary, Memory, RecallMemorySummary
|
||||||
from memgpt.schemas.message import Message, UpdateMessage
|
from memgpt.schemas.message import Message, UpdateMessage
|
||||||
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
|
from memgpt.schemas.openai.chat_completion_response import UsageStatistics
|
||||||
|
from memgpt.schemas.organization import Organization, OrganizationCreate
|
||||||
from memgpt.schemas.passage import Passage
|
from memgpt.schemas.passage import Passage
|
||||||
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
from memgpt.schemas.source import Source, SourceCreate, SourceUpdate
|
||||||
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
from memgpt.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||||
@@ -146,6 +147,7 @@ from memgpt.metadata import (
|
|||||||
APIKeyModel,
|
APIKeyModel,
|
||||||
BlockModel,
|
BlockModel,
|
||||||
JobModel,
|
JobModel,
|
||||||
|
OrganizationModel,
|
||||||
SourceModel,
|
SourceModel,
|
||||||
ToolModel,
|
ToolModel,
|
||||||
UserModel,
|
UserModel,
|
||||||
@@ -177,6 +179,7 @@ Base.metadata.create_all(
|
|||||||
JobModel.__table__,
|
JobModel.__table__,
|
||||||
PassageModel.__table__,
|
PassageModel.__table__,
|
||||||
MessageModel.__table__,
|
MessageModel.__table__,
|
||||||
|
OrganizationModel.__table__,
|
||||||
],
|
],
|
||||||
)
|
)
|
||||||
|
|
||||||
@@ -689,17 +692,32 @@ class SyncServer(Server):
|
|||||||
if not request.name:
|
if not request.name:
|
||||||
# auto-generate a name
|
# auto-generate a name
|
||||||
request.name = create_random_username()
|
request.name = create_random_username()
|
||||||
user = User(name=request.name)
|
user = User(name=request.name, org_id=request.org_id)
|
||||||
self.ms.create_user(user)
|
self.ms.create_user(user)
|
||||||
logger.info(f"Created new user from config: {user}")
|
logger.info(f"Created new user from config: {user}")
|
||||||
|
|
||||||
# add default for the user
|
# add default for the user
|
||||||
|
# TODO: move to org
|
||||||
assert user.id is not None, f"User id is None: {user}"
|
assert user.id is not None, f"User id is None: {user}"
|
||||||
self.add_default_blocks(user.id)
|
self.add_default_blocks(user.id)
|
||||||
self.add_default_tools(module_name="base", user_id=user.id)
|
self.add_default_tools(module_name="base", user_id=user.id)
|
||||||
|
|
||||||
return user
|
return user
|
||||||
|
|
||||||
|
def create_organization(self, request: OrganizationCreate) -> Organization:
|
||||||
|
"""Create a new org using a config"""
|
||||||
|
if not request.name:
|
||||||
|
# auto-generate a name
|
||||||
|
request.name = create_random_username()
|
||||||
|
org = Organization(name=request.name)
|
||||||
|
self.ms.create_organization(org)
|
||||||
|
logger.info(f"Created new org from config: {org}")
|
||||||
|
|
||||||
|
# add default for the org
|
||||||
|
# TODO: add default data
|
||||||
|
|
||||||
|
return org
|
||||||
|
|
||||||
def create_agent(
|
def create_agent(
|
||||||
self,
|
self,
|
||||||
request: CreateAgent,
|
request: CreateAgent,
|
||||||
|
|||||||
@@ -38,11 +38,26 @@ def admin_client():
|
|||||||
yield admin
|
yield admin
|
||||||
|
|
||||||
|
|
||||||
def test_admin_client(admin_client):
|
@pytest.fixture(scope="module")
|
||||||
|
def organization(admin_client):
|
||||||
|
# create an organization
|
||||||
|
org_name = "test_org"
|
||||||
|
org = admin_client.create_organization(org_name)
|
||||||
|
assert org_name == org.name, f"Expected {org_name}, got {org.name}"
|
||||||
|
|
||||||
|
# test listing
|
||||||
|
orgs = admin_client.get_organizations()
|
||||||
|
assert len(orgs) > 0, f"Expected 1 org, got {orgs}"
|
||||||
|
|
||||||
|
yield org
|
||||||
|
admin_client.delete_organization(org.id)
|
||||||
|
|
||||||
|
|
||||||
|
def test_admin_client(admin_client, organization):
|
||||||
|
|
||||||
# create a user
|
# create a user
|
||||||
user_name = "test_user"
|
user_name = "test_user"
|
||||||
user1 = admin_client.create_user(user_name)
|
user1 = admin_client.create_user(user_name, organization.id)
|
||||||
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
|
assert user_name == user1.name, f"Expected {user_name}, got {user1.name}"
|
||||||
|
|
||||||
# create another user
|
# create another user
|
||||||
|
|||||||
@@ -410,3 +410,9 @@ def test_message_update(client: Union[LocalClient, RESTClient], agent: AgentStat
|
|||||||
new_text = "This exact string would never show up in the message???"
|
new_text = "This exact string would never show up in the message???"
|
||||||
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
new_message = client.update_message(message_id=message.id, text=new_text, agent_id=agent.id)
|
||||||
assert new_message.text == new_text
|
assert new_message.text == new_text
|
||||||
|
|
||||||
|
|
||||||
|
def test_organization(client: RESTClient):
|
||||||
|
if isinstance(client, LocalClient):
|
||||||
|
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||||
|
client.base_url
|
||||||
|
|||||||
Reference in New Issue
Block a user