From 41e868c6ccfa21909d64090da40fb2c08dc784c7 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Wed, 30 Oct 2024 13:55:48 -0700 Subject: [PATCH] fix: Fix create organization bug (#1956) --- letta/client/client.py | 69 ++++++++++++++++++- letta/client/streaming.py | 3 + letta/constants.py | 3 + letta/schemas/letta_response.py | 2 +- letta/server/rest_api/app.py | 6 +- .../rest_api/routers/v1/organizations.py | 4 +- letta/server/server.py | 37 ---------- tests/test_client.py | 21 ++++++ 8 files changed, 99 insertions(+), 46 deletions(-) diff --git a/letta/client/client.py b/letta/client/client.py index 30945530..f0d5203f 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -5,7 +5,7 @@ from typing import Callable, Dict, Generator, List, Optional, Union import requests import letta.utils -from letta.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA +from letta.constants import ADMIN_PREFIX, BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA from letta.data_sources.connectors import DataConnector from letta.functions.functions import parse_source_code from letta.memory import get_memory_functions @@ -39,6 +39,7 @@ from letta.schemas.memory import ( ) from letta.schemas.message import Message, MessageCreate, UpdateMessage from letta.schemas.openai.chat_completions import ToolCall +from letta.schemas.organization import Organization from letta.schemas.passage import Passage from letta.schemas.source import Source, SourceCreate, SourceUpdate from letta.schemas.tool import Tool, ToolCreate, ToolUpdate @@ -282,6 +283,15 @@ class AbstractClient(object): def list_embedding_configs(self) -> List[EmbeddingConfig]: raise NotImplementedError + def create_org(self, name: Optional[str] = None) -> Organization: + raise NotImplementedError + + def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: + raise NotImplementedError + + def delete_org(self, org_id: str) -> Organization: + raise NotImplementedError + class RESTClient(AbstractClient): """ @@ -1464,6 +1474,54 @@ class RESTClient(AbstractClient): raise ValueError(f"Failed to list embedding configs: {response.text}") return [EmbeddingConfig(**config) for config in response.json()] + def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: + """ + Retrieves a list of all organizations in the database, with optional pagination. + + @param cursor: the pagination cursor, if any + @param limit: the maximum number of organizations to retrieve + @return: a list of Organization objects + """ + params = {"cursor": cursor, "limit": limit} + response = requests.get(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params) + if response.status_code != 200: + raise ValueError(f"Failed to retrieve organizations: {response.text}") + return [Organization(**org_data) for org_data in response.json()] + + def create_org(self, name: Optional[str] = None) -> Organization: + """ + Creates an organization with the given name. If not provided, we generate a random one. + + @param name: the name of the organization + @return: the created Organization + """ + payload = {"name": name} + response = requests.post(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, json=payload) + if response.status_code != 200: + raise ValueError(f"Failed to create org: {response.text}") + return Organization(**response.json()) + + def delete_org(self, org_id: str) -> Organization: + """ + Deletes an organization by its ID. + + @param org_id: the ID of the organization to delete + @return: the deleted Organization object + """ + # Define query parameters with org_id + params = {"org_id": org_id} + + # Make the DELETE request with query parameters + response = requests.delete(f"{self.base_url}/{ADMIN_PREFIX}/orgs", headers=self.headers, params=params) + + if response.status_code == 404: + raise ValueError(f"Organization with ID '{org_id}' does not exist") + elif response.status_code != 200: + raise ValueError(f"Failed to delete organization: {response.text}") + + # Parse and return the deleted organization + return Organization(**response.json()) + class LocalClient(AbstractClient): """ @@ -2648,3 +2706,12 @@ class LocalClient(AbstractClient): configs (List[EmbeddingConfig]): List of embedding configurations """ return self.server.list_embedding_models() + + def create_org(self, name: Optional[str] = None) -> Organization: + return self.server.organization_manager.create_organization(name=name) + + def list_orgs(self, cursor: Optional[str] = None, limit: Optional[int] = 50) -> List[Organization]: + return self.server.organization_manager.list_organizations(cursor=cursor, limit=limit) + + def delete_org(self, org_id: str) -> Organization: + return self.server.organization_manager.delete_organization_by_id(org_id=org_id) diff --git a/letta/client/streaming.py b/letta/client/streaming.py index 48780d6d..80a8a814 100644 --- a/letta/client/streaming.py +++ b/letta/client/streaming.py @@ -13,6 +13,7 @@ from letta.schemas.letta_message import ( InternalMonologue, ) from letta.schemas.letta_response import LettaStreamingResponse +from letta.schemas.usage import LettaUsageStatistics def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]: @@ -58,6 +59,8 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe yield FunctionCallMessage(**chunk_data) elif "function_return" in chunk_data: yield FunctionReturn(**chunk_data) + elif "usage" in chunk_data: + yield LettaUsageStatistics(**chunk_data["usage"]) else: raise ValueError(f"Unknown message type in chunk_data: {chunk_data}") diff --git a/letta/constants.py b/letta/constants.py index 39319c5a..ccbd4fb0 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -3,6 +3,9 @@ from logging import CRITICAL, DEBUG, ERROR, INFO, NOTSET, WARN, WARNING LETTA_DIR = os.path.join(os.path.expanduser("~"), ".letta") +ADMIN_PREFIX = "/v1/admin" +API_PREFIX = "/v1" +OPENAI_API_PREFIX = "/openai" # String in the error message for when the context window is too large # Example full message: diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 818e0306..21cc881d 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -36,4 +36,4 @@ class LettaResponse(BaseModel): # The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage -LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus] +LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics] diff --git a/letta/server/rest_api/app.py b/letta/server/rest_api/app.py index 303ed0ad..dfcd567d 100644 --- a/letta/server/rest_api/app.py +++ b/letta/server/rest_api/app.py @@ -8,6 +8,7 @@ import uvicorn from fastapi import FastAPI from starlette.middleware.cors import CORSMiddleware +from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX from letta.server.constants import REST_DEFAULT_PORT # NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests @@ -54,11 +55,6 @@ password = None # #typer.secho(f"Generated admin server password for this session: {password}", fg=typer.colors.GREEN) -ADMIN_PREFIX = "/v1/admin" -API_PREFIX = "/v1" -OPENAI_API_PREFIX = "/openai" - - def create_application() -> "FastAPI": """the application start routine""" # global server diff --git a/letta/server/rest_api/routers/v1/organizations.py b/letta/server/rest_api/routers/v1/organizations.py index c4ac9f2c..a52d81c7 100644 --- a/letta/server/rest_api/routers/v1/organizations.py +++ b/letta/server/rest_api/routers/v1/organizations.py @@ -22,7 +22,7 @@ def get_all_orgs( Get a list of all orgs in the database """ try: - next_cursor, orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit) + orgs = server.organization_manager.list_organizations(cursor=cursor, limit=limit) except HTTPException: raise except Exception as e: @@ -38,7 +38,7 @@ def create_org( """ Create a new org in the database """ - org = server.organization_manager.create_organization(request) + org = server.organization_manager.create_organization(name=request.name) return org diff --git a/letta/server/server.py b/letta/server/server.py index 5ebd77b3..ba1b8d75 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1793,43 +1793,6 @@ class SyncServer(Server): letta_agent = self._get_or_load_agent(agent_id=agent_id) return letta_agent.update_message(request=request) - # TODO decide whether this should be done in the server.py or agent.py - # Reason to put it in agent.py: - # - we use the agent object's persistence_manager to update the message - # - it makes it easy to do things like `retry`, `rethink`, etc. - # Reason to put it in server.py: - # - fundamentally, we should be able to edit a message (without agent id) - # in the server by directly accessing the DB / message store - """ - message = letta_agent.persistence_manager.recall_memory.storage.get(id=request.id) - if message is None: - raise ValueError(f"Message with id {request.id} not found") - - # Override fields - # NOTE: we try to do some sanity checking here (see asserts), but it's not foolproof - if request.role: - message.role = request.role - if request.text: - message.text = request.text - if request.name: - message.name = request.name - if request.tool_calls: - assert message.role == MessageRole.assistant, "Tool calls can only be added to assistant messages" - message.tool_calls = request.tool_calls - if request.tool_call_id: - assert message.role == MessageRole.tool, "tool_call_id can only be added to tool messages" - message.tool_call_id = request.tool_call_id - - # Save the updated message - letta_agent.persistence_manager.recall_memory.storage.update(record=message) - - # Return the updated message - updated_message = letta_agent.persistence_manager.recall_memory.storage.get(id=message.id) - if updated_message is None: - raise ValueError(f"Error persisting message - message with id {request.id} not found") - return updated_message - """ - def rewrite_agent_message(self, agent_id: str, new_text: str) -> Message: # Get the current message diff --git a/tests/test_client.py b/tests/test_client.py index 3fd015c5..807cd7ab 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -229,6 +229,12 @@ def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: A elif chunk == MessageStreamStatus.done_generation: assert not done_gen, "Message stream already done generation" done_gen = True + if isinstance(chunk, LettaUsageStatistics): + # Some rough metrics for a reasonable usage pattern + assert chunk.step_count == 1 + assert chunk.completion_tokens > 10 + assert chunk.prompt_tokens > 1000 + assert chunk.total_tokens > 1000 assert inner_thoughts_exist, "No inner thoughts found" assert send_message_ran, "send_message function call not found" @@ -488,6 +494,21 @@ def test_organization(client: RESTClient): if isinstance(client, LocalClient): pytest.skip("Skipping test_organization because LocalClient does not support organizations") + # create an organization + org_name = "test-org" + org = client.create_org(org_name) + + # assert the id appears + orgs = client.list_orgs() + assert org.id in [o.id for o in orgs] + + org = client.delete_org(org.id) + assert org.name == org_name + + # assert the id is gone + orgs = client.list_orgs() + assert not (org.id in [o.id for o in orgs]) + def test_list_llm_models(client: RESTClient): """Test that if the user's env has the right api keys set, at least one model appears in the model list"""