fix: Fix create organization bug (#1956)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
import requests
|
||||
|
||||
import letta.utils
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.memory import get_memory_functions
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.memory import (
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
@@ -282,6 +283,15 @@ class AbstractClient(object):
|
||||
def list_embedding_configs(self) -> List[EmbeddingConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RESTClient(AbstractClient):
|
||||
"""
|
||||
@@ -1464,6 +1474,54 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to list embedding configs: {response.text}")
|
||||
return [EmbeddingConfig(**config) for config in response.json()]
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
"""
|
||||
Retrieves a list of all organizations in the database, with optional pagination.
|
||||
|
||||
@param cursor: the pagination cursor, if any
|
||||
@param limit: the maximum number of organizations to retrieve
|
||||
@return: a list of Organization objects
|
||||
"""
|
||||
params = {"cursor": cursor, "limit": limit}
|
||||
response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to retrieve organizations: {response.text}")
|
||||
return [Organization(**org_data) for org_data in response.json()]
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
"""
|
||||
Creates an organization with the given name. If not provided, we generate a random one.
|
||||
|
||||
@param name: the name of the organization
|
||||
@return: the created Organization
|
||||
"""
|
||||
payload = {"name": name}
|
||||
response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create org: {response.text}")
|
||||
return Organization(**response.json())
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
"""
|
||||
Deletes an organization by its ID.
|
||||
|
||||
@param org_id: the ID of the organization to delete
|
||||
@return: the deleted Organization object
|
||||
"""
|
||||
# Define query parameters with org_id
|
||||
params = {"org_id": org_id}
|
||||
|
||||
# Make the DELETE request with query parameters
|
||||
response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise ValueError(f"Organization with ID '{org_id}' does not exist")
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete organization: {response.text}")
|
||||
|
||||
# Parse and return the deleted organization
|
||||
return Organization(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
"""
|
||||
@@ -2648,3 +2706,12 @@ class LocalClient(AbstractClient):
|
||||
configs (List[EmbeddingConfig]): List of embedding configurations
|
||||
"""
|
||||
return self.server.list_embedding_models()
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
return self.server.organization_manager.create_organization(name=name)
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
|
||||
|
||||
Reference in New Issue
Block a user