fix: Fix create organization bug (#1956)

This commit is contained in:
Matthew Zhou
2024-10-30 13:55:48 -07:00
committed by GitHub
parent 7fa632aa94
commit 41e868c6cc
8 changed files with 99 additions and 46 deletions

View File

@@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union
import requests
import letta.utils
from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.data_sources.connectors import DataConnector
from letta.functions.functions import parse_source_code
from letta.memory import get_memory_functions
@@ -39,6 +39,7 @@ from letta.schemas.memory import (
)
from letta.schemas.message import Message, MessageCreate, UpdateMessage
from letta.schemas.openai.chat_completions import ToolCall
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.source import Source, SourceCreate, SourceUpdate
from letta.schemas.tool import Tool, ToolCreate, ToolUpdate
@@ -282,6 +283,15 @@ class AbstractClient(object):
def list_embedding_configs(self) -> List[EmbeddingConfig]:
raise NotImplementedError
def create_org(self, name: Optional[str] = None) -> Organization:
raise NotImplementedError
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
raise NotImplementedError
def delete_org(self, org_id: str) -> Organization:
raise NotImplementedError
class RESTClient(AbstractClient):
"""
@@ -1464,6 +1474,54 @@ class RESTClient(AbstractClient):
raise ValueError(f"Failed to list embedding configs: {response.text}")
return [EmbeddingConfig(**config) for config in response.json()]
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
"""
Retrieves a list of all organizations in the database, with optional pagination.
@param cursor: the pagination cursor, if any
@param limit: the maximum number of organizations to retrieve
@return: a list of Organization objects
"""
params = {"cursor": cursor, "limit": limit}
response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
if response.status_code != 200:
raise ValueError(f"Failed to retrieve organizations: {response.text}")
return [Organization(**org_data) for org_data in response.json()]
def create_org(self, name: Optional[str] = None) -> Organization:
"""
Creates an organization with the given name. If not provided, we generate a random one.
@param name: the name of the organization
@return: the created Organization
"""
payload = {"name": name}
response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload)
if response.status_code != 200:
raise ValueError(f"Failed to create org: {response.text}")
return Organization(**response.json())
def delete_org(self, org_id: str) -> Organization:
"""
Deletes an organization by its ID.
@param org_id: the ID of the organization to delete
@return: the deleted Organization object
"""
# Define query parameters with org_id
params = {"org_id": org_id}
# Make the DELETE request with query parameters
response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params)
if response.status_code == 404:
raise ValueError(f"Organization with ID '{org_id}' does not exist")
elif response.status_code != 200:
raise ValueError(f"Failed to delete organization: {response.text}")
# Parse and return the deleted organization
return Organization(**response.json())
class LocalClient(AbstractClient):
"""
@@ -2648,3 +2706,12 @@ class LocalClient(AbstractClient):
configs (List[EmbeddingConfig]): List of embedding configurations
"""
return self.server.list_embedding_models()
def create_org(self, name: Optional[str] = None) -> Organization:
return self.server.organization_manager.create_organization(name=name)
def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]:
return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit)
def delete_org(self, org_id: str) -> Organization:
return self.server.organization_manager.delete_organization_by_id(org_id=org_id)

View File

@@ -13,6 +13,7 @@ from letta.schemas.letta_message import (
InternalMonologue,
)
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics
def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
@@ -58,6 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield FunctionCallMessage(**chunk_data)
elif "function_return" in chunk_data:
yield FunctionReturn(**chunk_data)
elif "usage" in chunk_data:
yield LettaUsageStatistics(**chunk_data["usage"])
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")

View File

@@ -3,6 +3,9 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING
LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta")
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
# String in the error message for when the context window is too large
# Example full message:

View File

@@ -36,4 +36,4 @@ class LettaResponse(BaseModel):
# The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus]
LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics]

View File

@@ -8,6 +8,7 @@ import uvicorn
from fastapi import FastAPI
from starlette.middleware.cors import CORSMiddleware
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
@@ -54,11 +55,6 @@ password = None
# #typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN)
ADMIN_PREFIX = "/v1/admin"
API_PREFIX = "/v1"
OPENAI_API_PREFIX = "/openai"
def create_application() -> "FastAPI":
"""the application start routine"""
# global server

View File

@@ -22,7 +22,7 @@ def get_all_orgs(
Get a list of all orgs in the database
"""
try:
next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit)
except HTTPException:
raise
except Exception as e:
@@ -38,7 +38,7 @@ def create_org(
"""
Create a new org in the database
"""
org = server.organization_manager.create_organization(request)
org = server.organization_manager.create_organization(name=request.name)
return org

View File

@@ -1793,43 +1793,6 @@ class SyncServer(Server):
letta_agent = self._get_or_load_agent(agent_id=agent_id)
return letta_agent.update_message(request=request)
# TODO decide whether this should be done in the server.py or agent.py
# Reason to put it in agent.py:
# - we use the agent object's persistence_manager to update the message
# - it makes it easy to do things like `retry`, `rethink`, etc.
# Reason to put it in server.py:
# - fundamentally, we should be able to edit a message (without agent id)
# in the server by directly accessing the DB / message store
"""
message = letta_agent.persistence_manager.recall_memory.storage.get(id=request.id)
if message is None:
raise ValueError(f"Message with id {request.id} not found")
# Override fields
# NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof
if request.role:
message.role = request.role
if request.text:
message.text = request.text
if request.name:
message.name = request.name
if request.tool_calls:
assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages"
message.tool_calls = request.tool_calls
if request.tool_call_id:
assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages"
message.tool_call_id = request.tool_call_id
# Save the updated message
letta_agent.persistence_manager.recall_memory.storage.update(record=message)
# Return the updated message
updated_message = letta_agent.persistence_manager.recall_memory.storage.get(id=message.id)
if updated_message is None:
raise ValueError(f"Error persisting message - message with id {request.id} not found")
return updated_message
"""
def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message:
# Get the current message

View File

@@ -229,6 +229,12 @@ def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: A
elif chunk == MessageStreamStatus.done_generation:
assert not done_gen, "Message stream already done generation"
done_gen = True
if isinstance(chunk, LettaUsageStatistics):
# Some rough metrics for a reasonable usage pattern
assert chunk.step_count == 1
assert chunk.completion_tokens > 10
assert chunk.prompt_tokens > 1000
assert chunk.total_tokens > 1000
assert inner_thoughts_exist, "No inner thoughts found"
assert send_message_ran, "send_message function call not found"
@@ -488,6 +494,21 @@ def test_organization(client: RESTClient):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_organization because LocalClient does not support organizations")
# create an organization
org_name = "test-org"
org = client.create_org(org_name)
# assert the id appears
orgs = client.list_orgs()
assert org.id in [o.id for o in orgs]
org = client.delete_org(org.id)
assert org.name == org_name
# assert the id is gone
orgs = client.list_orgs()
assert not (org.id in [o.id for o in orgs])
def test_list_llm_models(client: RESTClient):
"""Test that if the user's env has the right api keys set, at least one model appears in the model list"""