fix: Fix create organization bug (#1956)

This commit is contained in:
Matthew Zhou
2024-10-30 13:55:48 -07:00
committed by GitHub
parent 7fa632aa94
commit 41e868c6cc
8 changed files with 99 additions and 46 deletions

View File

@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
import requests
import letta.utils
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.memory import get_memory_functions
@@ -39,6 +39,7 @@ from letta.schemas.memory import (
)
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
@@ -282,6 +283,15 @@ class AbstractClient(object):
def list_embedding_configs(self) -> List[EmbeddingConfig]:
raise NotImplementedError
def create_org(self, name: Optional[str] = None) -> Organization:
raise NotImplementedError
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
raise NotImplementedError
def delete_org(self, org_id: str) -> Organization:
raise NotImplementedError
class RESTClient(AbstractClient):
"""
@@ -1464,6 +1474,54 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to list embedding configs: {response.text}")
return [EmbeddingConfig(**config) for config in response.json()]
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
"""
Retrieves a list of all organizations in the database, with optional pagination.
@param cursor: the pagination cursor, if any
@param limit: the maximum number of organizations to retrieve
@return: a list of Organization objects
"""
params = {"cursor": cursor, "limit": limit}
response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
if response.status_code != 200:
raise ValueError(f"Failed to retrieve organizations: {response.text}")
return [Organization(**org_data) for org_data in response.json()]
def create_org(self, name: Optional[str] = None) -> Organization:
"""
Creates an organization with the given name. If not provided, we generate a random one.
@param name: the name of the organization
@return: the created Organization
"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
if response.status_code != 200:
raise ValueError(f"Failed to create org: {response.text}")
return Organization(**response.json())
def delete_org(self, org_id: str) -> Organization:
"""
Deletes an organization by its ID.
@param org_id: the ID of the organization to delete
@return: the deleted Organization object
"""
# Define query parameters with org_id
params = {"org_id": org_id}
# Make the DELETE request with query parameters
response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
if response.status_code == 404:
raise ValueError(f"Organization with ID '{org_id}' does not exist")
elif response.status_code != 200:
raise ValueError(f"Failed to delete organization: {response.text}")
# Parse and return the deleted organization
return Organization(**response.json())
class LocalClient(AbstractClient):
"""
@@ -2648,3 +2706,12 @@ class LocalClient(AbstractClient):
configs (List[EmbeddingConfig]): List of embedding configurations
"""
return self.server.list_embedding_models()
def create_org(self, name: Optional[str] = None) -> Organization:
return self.server.organization_manager.create_organization(name=name)
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
def delete_org(self, org_id: str) -> Organization:
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)