fix: Fix create organization bug (#1956)
This commit is contained in:
@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
|
||||
import requests
|
||||
|
||||
import letta.utils
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from letta.data_sources.connectors import DataConnector
|
||||
from letta.functions.functions import parse_source_code
|
||||
from letta.memory import get_memory_functions
|
||||
@@ -39,6 +39,7 @@ from letta.schemas.memory import (
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, UpdateMessage
|
||||
from letta.schemas.openai.chat_completions import ToolCall
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
|
||||
@@ -282,6 +283,15 @@ class AbstractClient(object):
|
||||
def list_embedding_configs(self) -> List[EmbeddingConfig]:
|
||||
raise NotImplementedError
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
raise NotImplementedError
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
raise NotImplementedError
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class RESTClient(AbstractClient):
|
||||
"""
|
||||
@@ -1464,6 +1474,54 @@ class RESTClient(AbstractClient):
|
||||
raise ValueError(f"Failed to list embedding configs: {response.text}")
|
||||
return [EmbeddingConfig(**config) for config in response.json()]
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
"""
|
||||
Retrieves a list of all organizations in the database, with optional pagination.
|
||||
|
||||
@param cursor: the pagination cursor, if any
|
||||
@param limit: the maximum number of organizations to retrieve
|
||||
@return: a list of Organization objects
|
||||
"""
|
||||
params = {"cursor": cursor, "limit": limit}
|
||||
response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to retrieve organizations: {response.text}")
|
||||
return [Organization(**org_data) for org_data in response.json()]
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
"""
|
||||
Creates an organization with the given name. If not provided, we generate a random one.
|
||||
|
||||
@param name: the name of the organization
|
||||
@return: the created Organization
|
||||
"""
|
||||
payload = {"name": name}
|
||||
response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create org: {response.text}")
|
||||
return Organization(**response.json())
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
"""
|
||||
Deletes an organization by its ID.
|
||||
|
||||
@param org_id: the ID of the organization to delete
|
||||
@return: the deleted Organization object
|
||||
"""
|
||||
# Define query parameters with org_id
|
||||
params = {"org_id": org_id}
|
||||
|
||||
# Make the DELETE request with query parameters
|
||||
response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
|
||||
|
||||
if response.status_code == 404:
|
||||
raise ValueError(f"Organization with ID '{org_id}' does not exist")
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to delete organization: {response.text}")
|
||||
|
||||
# Parse and return the deleted organization
|
||||
return Organization(**response.json())
|
||||
|
||||
|
||||
class LocalClient(AbstractClient):
|
||||
"""
|
||||
@@ -2648,3 +2706,12 @@ class LocalClient(AbstractClient):
|
||||
configs (List[EmbeddingConfig]): List of embedding configurations
|
||||
"""
|
||||
return self.server.list_embedding_models()
|
||||
|
||||
def create_org(self, name: Optional[str] = None) -> Organization:
|
||||
return self.server.organization_manager.create_organization(name=name)
|
||||
|
||||
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
|
||||
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
|
||||
def delete_org(self, org_id: str) -> Organization:
|
||||
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)
|
||||
|
||||
@@ -13,6 +13,7 @@ from letta.schemas.letta_message import (
|
||||
InternalMonologue,
|
||||
)
|
||||
from letta.schemas.letta_response import LettaStreamingResponse
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
|
||||
def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
|
||||
@@ -58,6 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
yield FunctionCallMessage(**chunk_data)
|
||||
elif "function_return" in chunk_data:
|
||||
yield FunctionReturn(**chunk_data)
|
||||
elif "usage" in chunk_data:
|
||||
yield LettaUsageStatistics(**chunk_data["usage"])
|
||||
else:
|
||||
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
|
||||
|
||||
|
||||
@@ -3,6 +3,9 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
|
||||
|
||||
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
# String in the error message for when the context window is too large
|
||||
# Example full message:
|
||||
|
||||
@@ -36,4 +36,4 @@ class LettaResponse(BaseModel):
|
||||
|
||||
|
||||
# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
|
||||
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus]
|
||||
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics]
|
||||
|
||||
@@ -8,6 +8,7 @@ import uvicorn
|
||||
from fastapi import FastAPI
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.server.constants import REST_DEFAULT_PORT
|
||||
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
@@ -54,11 +55,6 @@ password = None
|
||||
# #typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)
|
||||
|
||||
|
||||
ADMIN_PREFIX = "/v1/admin"
|
||||
API_PREFIX = "/v1"
|
||||
OPENAI_API_PREFIX = "/openai"
|
||||
|
||||
|
||||
def create_application() -> "FastAPI":
|
||||
"""the application start routine"""
|
||||
# global server
|
||||
|
||||
@@ -22,7 +22,7 @@ def get_all_orgs(
|
||||
Get a list of all orgs in the database
|
||||
"""
|
||||
try:
|
||||
next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -38,7 +38,7 @@ def create_org(
|
||||
"""
|
||||
Create a new org in the database
|
||||
"""
|
||||
org = server.organization_manager.create_organization(request)
|
||||
org = server.organization_manager.create_organization(name=request.name)
|
||||
return org
|
||||
|
||||
|
||||
|
||||
@@ -1793,43 +1793,6 @@ class SyncServer(Server):
|
||||
letta_agent = self._get_or_load_agent(agent_id=agent_id)
|
||||
return letta_agent.update_message(request=request)
|
||||
|
||||
# TODO decide whether this should be done in the server.py or agent.py
|
||||
# Reason to put it in agent.py:
|
||||
# - we use the agent object's persistence_manager to update the message
|
||||
# - it makes it easy to do things like `retry`, `rethink`, etc.
|
||||
# Reason to put it in server.py:
|
||||
# - fundamentally, we should be able to edit a message (without agent id)
|
||||
# in the server by directly accessing the DB / message store
|
||||
"""
|
||||
message = letta_agent.persistence_manager.recall_memory.storage.get(id=request.id)
|
||||
if message is None:
|
||||
raise ValueError(f"Message with id {request.id} not found")
|
||||
|
||||
# Override fields
|
||||
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
|
||||
if request.role:
|
||||
message.role = request.role
|
||||
if request.text:
|
||||
message.text = request.text
|
||||
if request.name:
|
||||
message.name = request.name
|
||||
if request.tool_calls:
|
||||
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
|
||||
message.tool_calls = request.tool_calls
|
||||
if request.tool_call_id:
|
||||
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
|
||||
message.tool_call_id = request.tool_call_id
|
||||
|
||||
# Save the updated message
|
||||
letta_agent.persistence_manager.recall_memory.storage.update(record=message)
|
||||
|
||||
# Return the updated message
|
||||
updated_message = letta_agent.persistence_manager.recall_memory.storage.get(id=message.id)
|
||||
if updated_message is None:
|
||||
raise ValueError(f"Error persisting message - message with id {request.id} not found")
|
||||
return updated_message
|
||||
"""
|
||||
|
||||
def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message:
|
||||
|
||||
# Get the current message
|
||||
|
||||
@@ -229,6 +229,12 @@ def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: A
|
||||
elif chunk == MessageStreamStatus.done_generation:
|
||||
assert not done_gen, "Message stream already done generation"
|
||||
done_gen = True
|
||||
if isinstance(chunk, LettaUsageStatistics):
|
||||
# Some rough metrics for a reasonable usage pattern
|
||||
assert chunk.step_count == 1
|
||||
assert chunk.completion_tokens > 10
|
||||
assert chunk.prompt_tokens > 1000
|
||||
assert chunk.total_tokens > 1000
|
||||
|
||||
assert inner_thoughts_exist, "No inner thoughts found"
|
||||
assert send_message_ran, "send_message function call not found"
|
||||
@@ -488,6 +494,21 @@ def test_organization(client: RESTClient):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
|
||||
|
||||
# create an organization
|
||||
org_name = "test-org"
|
||||
org = client.create_org(org_name)
|
||||
|
||||
# assert the id appears
|
||||
orgs = client.list_orgs()
|
||||
assert org.id in [o.id for o in orgs]
|
||||
|
||||
org = client.delete_org(org.id)
|
||||
assert org.name == org_name
|
||||
|
||||
# assert the id is gone
|
||||
orgs = client.list_orgs()
|
||||
assert not (org.id in [o.id for o in orgs])
|
||||
|
||||
|
||||
def test_list_llm_models(client: RESTClient):
|
||||
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""
|
||||
|
||||
Reference in New Issue
Block a user