From f0db8598434fe706c84171c6067d9483f1c4073e Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Fri, 21 Feb 2025 16:58:12 -0800 Subject: [PATCH 1/8] fix: Don't refresh Composio schemas (#1099) --- letta/schemas/tool.py | 14 +------------- 1 file changed, 1 insertion(+), 13 deletions(-) diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index f17498c1..209664b0 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -9,7 +9,7 @@ from letta.constants import ( LETTA_MULTI_AGENT_TOOL_MODULE_NAME, ) from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module -from letta.functions.helpers import generate_composio_action_from_func_name, generate_composio_tool_wrapper, generate_langchain_tool_wrapper +from letta.functions.helpers import generate_composio_tool_wrapper, generate_langchain_tool_wrapper from letta.functions.schema_generator import generate_schema_from_args_schema_v2, generate_tool_schema_for_composio from letta.log import get_logger from letta.orm.enums import ToolType @@ -77,18 +77,6 @@ class Tool(BaseTool): elif self.tool_type in {ToolType.LETTA_MULTI_AGENT_CORE}: # If it's letta multi-agent tool, we also generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name=self.name) - elif self.tool_type == ToolType.EXTERNAL_COMPOSIO: - # If it is a composio tool, we generate both the source code and json schema on the fly here - # TODO: Deriving the composio action name is brittle, need to think long term about how to improve this - try: - composio_action = generate_composio_action_from_func_name(self.name) - tool_create = ToolCreate.from_composio(composio_action) - self.source_code = tool_create.source_code - self.json_schema = tool_create.json_schema - self.description = tool_create.description - self.tags = tool_create.tags - except Exception as e: - logger.error(f"Encountered exception while attempting to refresh source_code and json_schema for composio_tool: {e}") # At this point, we need to validate that at least json_schema is populated if not self.json_schema: From b132903a217ad974809ecc13c0dab05ac0617bff Mon Sep 17 00:00:00 2001 From: cthomas Date: Mon, 24 Feb 2025 11:20:22 -0800 Subject: [PATCH 2/8] fix: agent deletion bug from source cascade (#1101) --- letta/orm/source.py | 4 ++-- tests/test_managers.py | 24 ++++++++++++++++++++++++ 2 files changed, 26 insertions(+), 2 deletions(-) diff --git a/letta/orm/source.py b/letta/orm/source.py index 055f140e..b71e3989 100644 --- a/letta/orm/source.py +++ b/letta/orm/source.py @@ -42,6 +42,6 @@ class Source(SqlalchemyBase, OrganizationMixin): secondary="sources_agents", back_populates="sources", lazy="selectin", - cascade="all, delete", # Ensures rows in sources_agents are deleted when the source is deleted - passive_deletes=True, # Allows the database to handle deletion of orphaned rows + cascade="save-update", # Only propagate save and update operations + passive_deletes=True, # Let the database handle deletions ) diff --git a/tests/test_managers.py b/tests/test_managers.py index 8d078acf..4d207d8a 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2150,6 +2150,30 @@ def test_delete_source(server: SyncServer, default_user): assert len(sources) == 0 +def test_delete_attached_source(server: SyncServer, sarah_agent, default_user): + """Test deleting a source.""" + source_pydantic = PydanticSource( + name="To Delete", description="This source will be deleted.", embedding_config=DEFAULT_EMBEDDING_CONFIG + ) + source = server.source_manager.create_source(source=source_pydantic, actor=default_user) + + server.agent_manager.attach_source(agent_id=sarah_agent.id, source_id=source.id, actor=default_user) + + # Delete the source + deleted_source = server.source_manager.delete_source(source_id=source.id, actor=default_user) + + # Assertions to verify deletion + assert deleted_source.id == source.id + + # Verify that the source no longer appears in list_sources + sources = server.source_manager.list_sources(actor=default_user) + assert len(sources) == 0 + + # Verify that agent is not deleted + agent = server.agent_manager.get_agent_by_id(sarah_agent.id, actor=default_user) + assert agent is not None + + def test_list_sources(server: SyncServer, default_user): """Test listing sources with pagination.""" # Create multiple sources From 0aee823e68fba2429798a3b7b8c4eb7264a84f86 Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Mon, 24 Feb 2025 18:15:31 -0800 Subject: [PATCH 3/8] feat: add in `claude-3-7-sonnet-20250219` (#1102) Co-authored-by: Charles Packer --- letta/llm_api/anthropic.py | 38 +++++++++++++++++++++++++++++- letta/schemas/providers.py | 47 ++++++++++++++++++++++++++++++++++---- 2 files changed, 80 insertions(+), 5 deletions(-) diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index b8151e7a..205f4cb7 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -47,14 +47,39 @@ BASE_URL = "https://api.anthropic.com/v1" # https://docs.anthropic.com/claude/docs/models-overview # Sadly hardcoded MODEL_LIST = [ + ## Opus { "name": "claude-3-opus-20240229", "context_window": 200000, }, + ## Sonnet + # 3.0 + { + "name": "claude-3-sonnet-20240229", + "context_window": 200000, + }, + # 3.5 + { + "name": "claude-3-5-sonnet-20240620", + "context_window": 200000, + }, + # 3.5 new { "name": "claude-3-5-sonnet-20241022", "context_window": 200000, }, + # 3.7 + { + "name": "claude-3-7-sonnet-20250219", + "context_window": 200000, + }, + ## Haiku + # 3.0 + { + "name": "claude-3-haiku-20240307", + "context_window": 200000, + }, + # 3.5 { "name": "claude-3-5-haiku-20241022", "context_window": 200000, @@ -75,7 +100,18 @@ def anthropic_get_model_list(url: str, api_key: Union[str, None]) -> dict: """https://docs.anthropic.com/claude/docs/models-overview""" # NOTE: currently there is no GET /models, so we need to hardcode - return MODEL_LIST + # return MODEL_LIST + + anthropic_override_key = ProviderManager().get_anthropic_override_key() + if anthropic_override_key: + anthropic_client = anthropic.Anthropic(api_key=anthropic_override_key) + elif model_settings.anthropic_api_key: + anthropic_client = anthropic.Anthropic() + + models = anthropic_client.models.list() + models_json = models.model_dump() + assert "data" in models_json, f"Anthropic model query response missing 'data' field: {models_json}" + return models_json["data"] def convert_tools_to_anthropic_format(tools: List[Tool]) -> List[dict]: diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index bc11909d..1de05b0e 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -410,28 +410,67 @@ class AnthropicProvider(Provider): base_url: str = "https://api.anthropic.com/v1" def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.anthropic import anthropic_get_model_list + from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list models = anthropic_get_model_list(self.base_url, api_key=self.api_key) + """ + Example response: + { + "data": [ + { + "type": "model", + "id": "claude-3-5-sonnet-20241022", + "display_name": "Claude 3.5 Sonnet (New)", + "created_at": "2024-10-22T00:00:00Z" + } + ], + "has_more": true, + "first_id": "", + "last_id": "" + } + """ + configs = [] for model in models: + if model["type"] != "model": + continue + + if "id" not in model: + continue + + # Don't support 2.0 and 2.1 + if model["id"].startswith("claude-2"): + continue + + # Anthropic doesn't return the context window in their API + if "context_window" not in model: + # Remap list to name: context_window + model_library = {m["name"]: m["context_window"] for m in MODEL_LIST} + # Attempt to look it up in a hardcoded list + if model["id"] in model_library: + model["context_window"] = model_library[model["id"]] + else: + # On fallback, we can set 200k (generally safe), but we should warn the user + warnings.warn(f"Couldn't find context window size for model {model['id']}, defaulting to 200,000") + model["context_window"] = 200000 + # We set this to false by default, because Anthropic can # natively support tags inside of content fields # However, putting COT inside of tool calls can make it more # reliable for tool calling (no chance of a non-tool call step) # Since tool_choice_type 'any' doesn't work with in-content COT # NOTE For Haiku, it can be flaky if we don't enable this by default - inner_thoughts_in_kwargs = True if "haiku" in model["name"] else False + inner_thoughts_in_kwargs = True if "haiku" in model["id"] else False configs.append( LLMConfig( - model=model["name"], + model=model["id"], model_endpoint_type="anthropic", model_endpoint=self.base_url, context_window=model["context_window"], - handle=self.get_handle(model["name"]), + handle=self.get_handle(model["id"]), put_inner_thoughts_in_kwargs=inner_thoughts_in_kwargs, ) ) From 4548555d6c4fffc20cd8af55cc1607ce598dbef6 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 25 Feb 2025 11:35:19 -0800 Subject: [PATCH 4/8] chore: change the name of user id to actor (#1098) --- letta/server/rest_api/routers/v1/agents.py | 110 +++++++++--------- letta/server/rest_api/routers/v1/blocks.py | 24 ++-- .../server/rest_api/routers/v1/identities.py | 24 ++-- letta/server/rest_api/routers/v1/jobs.py | 16 +-- letta/server/rest_api/routers/v1/providers.py | 16 ++- letta/server/rest_api/routers/v1/runs.py | 24 ++-- .../rest_api/routers/v1/sandbox_configs.py | 48 ++++---- letta/server/rest_api/routers/v1/sources.py | 40 +++---- letta/server/rest_api/routers/v1/steps.py | 15 +-- letta/server/rest_api/routers/v1/tags.py | 4 +- letta/server/rest_api/routers/v1/tools.py | 40 +++---- letta/services/provider_manager.py | 19 +-- letta/services/step_manager.py | 4 +- tests/test_server.py | 6 +- 14 files changed, 199 insertions(+), 191 deletions(-) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 4c9d93f6..70664d6b 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -44,7 +44,7 @@ def list_agents( description="If True, only returns agents that match ALL given tags. Otherwise, return agents that have ANY of the passed in tags.", ), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), before: Optional[str] = Query(None, description="Cursor for pagination"), after: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(None, description="Limit for pagination"), @@ -58,7 +58,7 @@ def list_agents( List all agents associated with a given user. This endpoint retrieves a list of all agents and their configurations associated with the specified user ID. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # Use dictionary comprehension to build kwargs dynamically kwargs = { @@ -91,12 +91,12 @@ def list_agents( def retrieve_agent_context_window( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the context window of a specific agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.get_agent_context_window(agent_id=agent_id, actor=actor) @@ -107,21 +107,21 @@ class CreateAgentRequest(CreateAgent): """ # Override the user_id field to exclude it from the request body validation - user_id: Optional[str] = Field(None, exclude=True) + actor_id: Optional[str] = Field(None, exclude=True) @router.post("/", response_model=AgentState, operation_id="create_agent") def create_agent( agent: CreateAgentRequest = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware ): """ Create a new agent with the specified configuration. """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.create_agent(agent, actor=actor) except Exception as e: traceback.print_exc() @@ -133,10 +133,10 @@ def modify_agent( agent_id: str, update_agent: UpdateAgent = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an existing agent""" - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.update_agent(agent_id=agent_id, agent_update=update_agent, actor=actor) @@ -144,10 +144,10 @@ def modify_agent( def list_agent_tools( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Get tools from an existing agent""" - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.list_attached_tools(agent_id=agent_id, actor=actor) @@ -156,12 +156,12 @@ def attach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Attach a tool to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) @@ -170,12 +170,12 @@ def detach_tool( agent_id: str, tool_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Detach a tool from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor) @@ -184,12 +184,12 @@ def attach_source( agent_id: str, source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Attach a source to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.attach_source(agent_id=agent_id, source_id=source_id, actor=actor) @@ -198,12 +198,12 @@ def detach_source( agent_id: str, source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Detach a source from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.detach_source(agent_id=agent_id, source_id=source_id, actor=actor) @@ -211,12 +211,12 @@ def detach_source( def retrieve_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the state of the agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) @@ -228,12 +228,12 @@ def retrieve_agent( def delete_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: server.agent_manager.delete_agent(agent_id=agent_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Agent id={agent_id} successfully deleted"}) @@ -245,12 +245,12 @@ def delete_agent( def list_agent_sources( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get the sources associated with an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.list_attached_sources(agent_id=agent_id, actor=actor) @@ -259,13 +259,13 @@ def list_agent_sources( def retrieve_agent_memory( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memory state of a specific agent. This endpoint fetches the current memory state of the agent identified by the user ID and agent ID. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.get_agent_memory(agent_id=agent_id, actor=actor) @@ -275,12 +275,12 @@ def retrieve_core_memory_block( agent_id: str, block_label: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve a memory block from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: return server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) @@ -292,12 +292,12 @@ def retrieve_core_memory_block( def list_core_memory_blocks( agent_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memory blocks of a specific agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: agent = server.agent_manager.get_agent_by_id(agent_id, actor=actor) return agent.memory.blocks @@ -311,12 +311,12 @@ def modify_core_memory_block( block_label: str, block_update: BlockUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Updates a memory block of an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) block = server.agent_manager.get_block_with_label(agent_id=agent_id, block_label=block_label, actor=actor) block = server.block_manager.update_block(block.id, block_update=block_update, actor=actor) @@ -332,12 +332,12 @@ def attach_core_memory_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Attach a block to an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.attach_block(agent_id=agent_id, block_id=block_id, actor=actor) @@ -346,12 +346,12 @@ def detach_core_memory_block( agent_id: str, block_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Detach a block from an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.detach_block(agent_id=agent_id, block_id=block_id, actor=actor) @@ -362,12 +362,12 @@ def list_archival_memory( after: Optional[int] = Query(None, description="Unique ID of the memory to start the query range at."), before: Optional[int] = Query(None, description="Unique ID of the memory to end the query range at."), limit: Optional[int] = Query(None, description="How many results to include in the response."), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve the memories in an agent's archival memory store (paginated query). """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.get_agent_archival( user_id=actor.id, @@ -383,12 +383,12 @@ def create_archival_memory( agent_id: str, request: CreateArchivalMemory = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Insert a memory into an agent's archival memory store. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.insert_archival_memory(agent_id=agent_id, memory_contents=request.text, actor=actor) @@ -401,12 +401,12 @@ def delete_archival_memory( memory_id: str, # memory_id: str = Query(..., description="Unique ID of the memory to be deleted."), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a memory from an agent's archival memory store. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.delete_archival_memory(memory_id=memory_id, actor=actor) return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Memory id={memory_id} successfully deleted"}) @@ -427,12 +427,12 @@ def list_messages( use_assistant_message: bool = Query(True, description="Whether to use assistant messages"), assistant_message_tool_name: str = Query(DEFAULT_MESSAGE_TOOL, description="The name of the designated message tool."), assistant_message_tool_kwarg: str = Query(DEFAULT_MESSAGE_TOOL_KWARG, description="The name of the message argument."), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Retrieve message history for an agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.get_agent_recall( user_id=actor.id, @@ -454,13 +454,13 @@ def modify_message( message_id: str, request: MessageUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the details of a message associated with an agent. """ # TODO: Get rid of agent_id here, it's not really relevant - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor) @@ -474,13 +474,13 @@ async def send_message( agent_id: str, server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) result = await server.send_message_to_agent( agent_id=agent_id, actor=actor, @@ -513,14 +513,14 @@ async def send_message_streaming( agent_id: str, server: SyncServer = Depends(get_letta_server), request: LettaStreamingRequest = Body(...), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) result = await server.send_message_to_agent( agent_id=agent_id, actor=actor, @@ -590,13 +590,13 @@ async def send_message_async( background_tasks: BackgroundTasks, server: SyncServer = Depends(get_letta_server), request: LettaRequest = Body(...), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Asynchronously process a user message and return a run object. The actual processing happens in the background, and the status can be checked using the run ID. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # Create a new job run = Run( @@ -635,8 +635,8 @@ def reset_messages( agent_id: str, add_default_initial_messages: bool = Query(default=False, description="If true, adds the default initial messages after resetting."), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Resets the messages for an agent""" - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.agent_manager.reset_messages(agent_id=agent_id, actor=actor, add_default_initial_messages=add_default_initial_messages) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 8c5297d0..322f323d 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -21,9 +21,9 @@ def list_blocks( templates_only: bool = Query(True, description="Whether to include only templates"), name: Optional[str] = Query(None, description="Name of the block"), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.block_manager.get_blocks(actor=actor, label=label, is_template=templates_only, template_name=name) @@ -31,9 +31,9 @@ def list_blocks( def create_block( create_block: CreateBlock = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) block = Block(**create_block.model_dump()) return server.block_manager.create_or_update_block(actor=actor, block=block) @@ -43,9 +43,9 @@ def modify_block( block_id: str, block_update: BlockUpdate = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.block_manager.update_block(block_id=block_id, block_update=block_update, actor=actor) @@ -53,9 +53,9 @@ def modify_block( def delete_block( block_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.block_manager.delete_block(block_id=block_id, actor=actor) @@ -63,10 +63,10 @@ def delete_block( def retrieve_block( block_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): print("call get block", block_id) - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: block = server.block_manager.get_block_by_id(block_id=block_id, actor=actor) if block is None: @@ -80,13 +80,13 @@ def retrieve_block( def list_agents_for_block( block_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Retrieves all agents associated with the specified block. Raises a 404 if the block does not exist. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor) return agents diff --git a/letta/server/rest_api/routers/v1/identities.py b/letta/server/rest_api/routers/v1/identities.py index a4311b12..7b4156a9 100644 --- a/letta/server/rest_api/routers/v1/identities.py +++ b/letta/server/rest_api/routers/v1/identities.py @@ -22,13 +22,13 @@ def list_identities( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a list of all identities in the database """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) identities = server.identity_manager.list_identities( name=name, @@ -51,10 +51,10 @@ def list_identities( def retrieve_identity( identity_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.identity_manager.get_identity(identity_id=identity_id, actor=actor) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -64,11 +64,11 @@ def retrieve_identity( def create_identity( identity: IdentityCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware ): try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.identity_manager.create_identity(identity=identity, actor=actor) except HTTPException: raise @@ -80,11 +80,11 @@ def create_identity( def upsert_identity( identity: IdentityCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present x_project: Optional[str] = Header(None, alias="X-Project"), # Only handled by next js middleware ): try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.identity_manager.upsert_identity(identity=identity, actor=actor) except HTTPException: raise @@ -97,10 +97,10 @@ def modify_identity( identity_id: str, identity: IdentityUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.identity_manager.update_identity(identity_id=identity_id, identity=identity, actor=actor) except HTTPException: raise @@ -112,10 +112,10 @@ def modify_identity( def delete_identity( identity_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete an identity by its identifier key """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.identity_manager.delete_identity(identity_id=identity_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/jobs.py b/letta/server/rest_api/routers/v1/jobs.py index 4e41490b..8adbdd2d 100644 --- a/letta/server/rest_api/routers/v1/jobs.py +++ b/letta/server/rest_api/routers/v1/jobs.py @@ -15,12 +15,12 @@ router = APIRouter(prefix="/jobs", tags=["jobs"]) def list_jobs( server: "SyncServer" = Depends(get_letta_server), source_id: Optional[str] = Query(None, description="Only list jobs associated with the source."), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all jobs. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # TODO: add filtering by status jobs = server.job_manager.list_jobs(actor=actor) @@ -35,12 +35,12 @@ def list_jobs( @router.get("/active", response_model=List[Job], operation_id="list_active_jobs") def list_active_jobs( server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all active jobs. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running]) @@ -48,13 +48,13 @@ def list_active_jobs( @router.get("/{job_id}", response_model=Job, operation_id="retrieve_job") def retrieve_job( job_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Get the status of a job. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: return server.job_manager.get_job_by_id(job_id=job_id, actor=actor) @@ -65,13 +65,13 @@ def retrieve_job( @router.delete("/{job_id}", response_model=Job, operation_id="delete_job") def delete_job( job_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Delete a job by its job_id. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: job = server.job_manager.delete_job_by_id(job_id=job_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 7feb1674..c26101d5 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -15,13 +15,15 @@ router = APIRouter(prefix="/providers", tags=["providers"]) def list_providers( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Get a list of all custom providers in the database """ try: - providers = server.provider_manager.list_providers(after=after, limit=limit) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + providers = server.provider_manager.list_providers(after=after, limit=limit, actor=actor) except HTTPException: raise except Exception as e: @@ -32,13 +34,13 @@ def list_providers( @router.post("/", tags=["providers"], response_model=Provider, operation_id="create_provider") def create_provider( request: ProviderCreate = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new custom provider """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) provider = Provider(**request.model_dump()) provider = server.provider_manager.create_provider(provider, actor=actor) @@ -48,25 +50,29 @@ def create_provider( @router.patch("/", tags=["providers"], response_model=Provider, operation_id="modify_provider") def modify_provider( request: ProviderUpdate = Body(...), + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present server: "SyncServer" = Depends(get_letta_server), ): """ Update an existing custom provider """ - provider = server.provider_manager.update_provider(request) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + provider = server.provider_manager.update_provider(request, actor=actor) return provider @router.delete("/", tags=["providers"], response_model=None, operation_id="delete_provider") def delete_provider( provider_id: str = Query(..., description="The provider_id key to be deleted."), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Delete an existing custom provider """ try: - server.provider_manager.delete_provider_by_id(provider_id=provider_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + server.provider_manager.delete_provider_by_id(provider_id=provider_id, actor=actor) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index d0abd3c3..0e5dff98 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -18,12 +18,12 @@ router = APIRouter(prefix="/runs", tags=["runs"]) @router.get("/", response_model=List[Run], operation_id="list_runs") def list_runs( server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all runs. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return [Run.from_job(job) for job in server.job_manager.list_jobs(actor=actor, job_type=JobType.RUN)] @@ -31,12 +31,12 @@ def list_runs( @router.get("/active", response_model=List[Run], operation_id="list_active_runs") def list_active_runs( server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all active runs. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) active_runs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.RUN) @@ -46,13 +46,13 @@ def list_active_runs( @router.get("/{run_id}", response_model=Run, operation_id="retrieve_run") def retrieve_run( run_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Get the status of a run. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: job = server.job_manager.get_job_by_id(job_id=run_id, actor=actor) @@ -74,7 +74,7 @@ RunMessagesResponse = Annotated[ async def list_run_messages( run_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), before: Optional[str] = Query(None, description="Cursor for pagination"), after: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(100, description="Maximum number of messages to return"), @@ -102,7 +102,7 @@ async def list_run_messages( if order not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: messages = server.job_manager.get_run_messages( @@ -122,13 +122,13 @@ async def list_run_messages( @router.get("/{run_id}/usage", response_model=UsageStatistics, operation_id="retrieve_run_usage") def retrieve_run_usage( run_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Get usage statistics for a run. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: usage = server.job_manager.get_job_usage(job_id=run_id, actor=actor) @@ -140,13 +140,13 @@ def retrieve_run_usage( @router.delete("/{run_id}", response_model=Run, operation_id="delete_run") def delete_run( run_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: "SyncServer" = Depends(get_letta_server), ): """ Delete a run by its run_id. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: job = server.job_manager.delete_job_by_id(job_id=run_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/sandbox_configs.py b/letta/server/rest_api/routers/v1/sandbox_configs.py index e32acbe0..6ef76a5b 100644 --- a/letta/server/rest_api/routers/v1/sandbox_configs.py +++ b/letta/server/rest_api/routers/v1/sandbox_configs.py @@ -25,9 +25,9 @@ logger = get_logger(__name__) def create_sandbox_config( config_create: SandboxConfigCreate, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.create_or_update_sandbox_config(config_create, actor) @@ -35,18 +35,18 @@ def create_sandbox_config( @router.post("/e2b/default", response_model=PydanticSandboxConfig) def create_default_e2b_sandbox_config( server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.E2B, actor=actor) @router.post("/local/default", response_model=PydanticSandboxConfig) def create_default_local_sandbox_config( server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor) @@ -54,7 +54,7 @@ def create_default_local_sandbox_config( def create_custom_local_sandbox_config( local_sandbox_config: LocalSandboxConfig, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): """ Create or update a custom LocalSandboxConfig, including pip_requirements. @@ -67,7 +67,7 @@ def create_custom_local_sandbox_config( ) # Retrieve the user (actor) - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # Wrap the LocalSandboxConfig into a SandboxConfigCreate sandbox_config_create = SandboxConfigCreate(config=local_sandbox_config) @@ -83,9 +83,9 @@ def update_sandbox_config( sandbox_config_id: str, config_update: SandboxConfigUpdate, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.update_sandbox_config(sandbox_config_id, config_update, actor) @@ -93,9 +93,9 @@ def update_sandbox_config( def delete_sandbox_config( sandbox_config_id: str, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.sandbox_config_manager.delete_sandbox_config(sandbox_config_id, actor) @@ -105,22 +105,22 @@ def list_sandbox_configs( after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), sandbox_type: Optional[SandboxType] = Query(None, description="Filter for this specific sandbox type"), server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.list_sandbox_configs(actor, limit=limit, after=after, sandbox_type=sandbox_type) @router.post("/local/recreate-venv", response_model=PydanticSandboxConfig) def force_recreate_local_sandbox_venv( server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): """ Forcefully recreate the virtual environment for the local sandbox. Deletes and recreates the venv, then reinstalls required dependencies. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # Retrieve the local sandbox config sbx_config = server.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=SandboxType.LOCAL, actor=actor) @@ -162,9 +162,9 @@ def create_sandbox_env_var( sandbox_config_id: str, env_var_create: SandboxEnvironmentVariableCreate, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.create_sandbox_env_var(env_var_create, sandbox_config_id, actor) @@ -173,9 +173,9 @@ def update_sandbox_env_var( env_var_id: str, env_var_update: SandboxEnvironmentVariableUpdate, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.update_sandbox_env_var(env_var_id, env_var_update, actor) @@ -183,9 +183,9 @@ def update_sandbox_env_var( def delete_sandbox_env_var( env_var_id: str, server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.sandbox_config_manager.delete_sandbox_env_var(env_var_id, actor) @@ -195,7 +195,7 @@ def list_sandbox_env_vars( limit: int = Query(1000, description="Number of results to return"), after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), server: SyncServer = Depends(get_letta_server), - user_id: str = Depends(get_user_id), + actor_id: str = Depends(get_user_id), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.sandbox_config_manager.list_sandbox_env_vars(sandbox_config_id, actor, limit=limit, after=after) diff --git a/letta/server/rest_api/routers/v1/sources.py b/letta/server/rest_api/routers/v1/sources.py index 0be1b8c7..08051d63 100644 --- a/letta/server/rest_api/routers/v1/sources.py +++ b/letta/server/rest_api/routers/v1/sources.py @@ -23,12 +23,12 @@ router = APIRouter(prefix="/sources", tags=["sources"]) def retrieve_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get all sources """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) if not source: @@ -40,12 +40,12 @@ def retrieve_source( def get_source_id_by_name( source_name: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a source by name """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) source = server.source_manager.get_source_by_name(source_name=source_name, actor=actor) if not source: @@ -56,12 +56,12 @@ def get_source_id_by_name( @router.get("/", response_model=List[Source], operation_id="list_sources") def list_sources( server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all data sources created by a user. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.list_all_sources(actor=actor) @@ -70,12 +70,12 @@ def list_sources( def create_source( source_create: SourceCreate, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) source = Source(**source_create.model_dump()) return server.source_manager.create_source(source=source, actor=actor) @@ -86,12 +86,12 @@ def modify_source( source_id: str, source: SourceUpdate, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update the name or documentation of an existing data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) if not server.source_manager.get_source_by_id(source_id=source_id, actor=actor): raise HTTPException(status_code=404, detail=f"Source with id={source_id} does not exist.") return server.source_manager.update_source(source_id=source_id, source_update=source, actor=actor) @@ -101,12 +101,12 @@ def modify_source( def delete_source( source_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.delete_source(source_id=source_id, actor=actor) @@ -117,12 +117,12 @@ def upload_file_to_source( source_id: str, background_tasks: BackgroundTasks, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Upload a file to a data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) source = server.source_manager.get_source_by_id(source_id=source_id, actor=actor) assert source is not None, f"Source with id={source_id} not found." @@ -151,12 +151,12 @@ def upload_file_to_source( def list_source_passages( source_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List all passages associated with a data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) passages = server.list_data_source_passages(user_id=actor.id, source_id=source_id) return passages @@ -167,12 +167,12 @@ def list_source_files( limit: int = Query(1000, description="Number of files to return"), after: Optional[str] = Query(None, description="Pagination cursor to fetch the next set of results"), server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ List paginated files associated with a data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.source_manager.list_files(source_id=source_id, limit=limit, after=after, actor=actor) @@ -183,12 +183,12 @@ def delete_file_from_source( source_id: str, file_id: str, server: "SyncServer" = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a data source. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) deleted_file = server.source_manager.delete_file(file_id=file_id, actor=actor) if deleted_file is None: diff --git a/letta/server/rest_api/routers/v1/steps.py b/letta/server/rest_api/routers/v1/steps.py index cb82bf59..7c67de9c 100644 --- a/letta/server/rest_api/routers/v1/steps.py +++ b/letta/server/rest_api/routers/v1/steps.py @@ -21,13 +21,13 @@ def list_steps( end_date: Optional[str] = Query(None, description='Return steps before this ISO datetime (e.g. "2025-01-29T15:01:19-08:00")'), model: Optional[str] = Query(None, description="Filter by the name of the model used for the step"), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ List steps with optional pagination and date filters. Dates should be provided in ISO 8601 format (e.g. 2025-01-29T15:01:19-08:00) """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) # Convert ISO strings to datetime objects if provided start_dt = datetime.fromisoformat(start_date) if start_date else None @@ -48,14 +48,15 @@ def list_steps( @router.get("/{step_id}", response_model=Step, operation_id="retrieve_step") def retrieve_step( step_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: SyncServer = Depends(get_letta_server), ): """ Get a step by ID. """ try: - return server.step_manager.get_step(step_id=step_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) + return server.step_manager.get_step(step_id=step_id, actor=actor) except NoResultFound: raise HTTPException(status_code=404, detail="Step not found") @@ -64,15 +65,15 @@ def retrieve_step( def update_step_transaction_id( step_id: str, transaction_id: str, - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), server: SyncServer = Depends(get_letta_server), ): """ Update the transaction ID for a step. """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - return server.step_manager.update_step_transaction_id(actor, step_id=step_id, transaction_id=transaction_id) + return server.step_manager.update_step_transaction_id(actor=actor, step_id=step_id, transaction_id=transaction_id) except NoResultFound: raise HTTPException(status_code=404, detail="Step not found") diff --git a/letta/server/rest_api/routers/v1/tags.py b/letta/server/rest_api/routers/v1/tags.py index fd889af6..dab01771 100644 --- a/letta/server/rest_api/routers/v1/tags.py +++ b/letta/server/rest_api/routers/v1/tags.py @@ -17,11 +17,11 @@ def list_tags( limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), query_text: Optional[str] = Query(None), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Get a list of all tags in the database """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) tags = server.agent_manager.list_tags(actor=actor, after=after, limit=limit, query_text=query_text) return tags diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index cf8254bf..f198592b 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -29,12 +29,12 @@ logger = get_logger(__name__) def delete_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Delete a tool by name """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) server.tool_manager.delete_tool_by_id(tool_id=tool_id, actor=actor) @@ -42,12 +42,12 @@ def delete_tool( def retrieve_tool( tool_id: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a tool by ID """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) tool = server.tool_manager.get_tool_by_id(tool_id=tool_id, actor=actor) if tool is None: # return 404 error @@ -61,13 +61,13 @@ def list_tools( limit: Optional[int] = 50, name: Optional[str] = None, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Get a list of all tools available to agents belonging to the org of the user """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) if name is not None: tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor) return [tool] if tool else [] @@ -82,13 +82,13 @@ def list_tools( def create_tool( request: ToolCreate = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Create a new tool """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) tool = Tool(**request.model_dump()) return server.tool_manager.create_tool(pydantic_tool=tool, actor=actor) except UniqueConstraintViolationError as e: @@ -114,13 +114,13 @@ def create_tool( def upsert_tool( request: ToolCreate = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Create or update a tool """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) tool = server.tool_manager.create_or_update_tool(pydantic_tool=Tool(**request.model_dump()), actor=actor) return tool except UniqueConstraintViolationError as e: @@ -142,13 +142,13 @@ def modify_tool( tool_id: str, request: ToolUpdate = Body(...), server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Update an existing tool """ try: - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.tool_manager.update_tool_by_id(tool_id=tool_id, tool_update=request, actor=actor) except LettaToolCreateError as e: # HTTP 400 == Bad Request @@ -163,12 +163,12 @@ def modify_tool( @router.post("/add-base-tools", response_model=List[Tool], operation_id="add_base_tools") def upsert_base_tools( server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Upsert base tools """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) return server.tool_manager.upsert_base_tools(actor=actor) @@ -176,12 +176,12 @@ def upsert_base_tools( def run_tool_from_source( server: SyncServer = Depends(get_letta_server), request: ToolRunFromSource = Body(...), - user_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present + actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """ Attempt to build a tool from source, then run it on the provided arguments """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: return server.run_tool_from_source( @@ -227,12 +227,12 @@ def list_composio_apps(server: SyncServer = Depends(get_letta_server), user_id: def list_composio_actions_by_app( composio_app_name: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Get a list of all Composio actions for a specific app """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) composio_api_key = get_composio_api_key(actor=actor, logger=logger) if not composio_api_key: raise HTTPException( @@ -246,12 +246,12 @@ def list_composio_actions_by_app( def add_composio_tool( composio_action_name: str, server: SyncServer = Depends(get_letta_server), - user_id: Optional[str] = Header(None, alias="user_id"), + actor_id: Optional[str] = Header(None, alias="user_id"), ): """ Add a new Composio tool by action name (Composio refers to each tool as an `Action`) """ - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = server.user_manager.get_user_or_default(user_id=actor_id) try: tool_create = ToolCreate.from_composio(action_name=composio_action_name) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 20f7c2ad..c81a6234 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -25,15 +25,15 @@ class ProviderManager: provider.resolve_identifier() new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) - new_provider.create(session) + new_provider.create(session, actor=actor) return new_provider.to_pydantic() @enforce_types - def update_provider(self, provider_update: ProviderUpdate) -> PydanticProvider: + def update_provider(self, provider_update: ProviderUpdate, actor: PydanticUser) -> PydanticProvider: """Update provider details.""" with self.session_maker() as session: # Retrieve the existing provider by ID - existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id) + existing_provider = ProviderModel.read(db_session=session, identifier=provider_update.id, actor=actor) # Update only the fields that are provided in ProviderUpdate update_data = provider_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) @@ -41,31 +41,32 @@ class ProviderManager: setattr(existing_provider, key, value) # Commit the updated provider - existing_provider.update(session) + existing_provider.update(session, actor=actor) return existing_provider.to_pydantic() @enforce_types - def delete_provider_by_id(self, provider_id: str): + def delete_provider_by_id(self, provider_id: str, actor: PydanticUser): """Delete a provider.""" with self.session_maker() as session: # Clear api key field - existing_provider = ProviderModel.read(db_session=session, identifier=provider_id) + existing_provider = ProviderModel.read(db_session=session, identifier=provider_id, actor=actor) existing_provider.api_key = None - existing_provider.update(session) + existing_provider.update(session, actor=actor) # Soft delete in provider table - existing_provider.delete(session) + existing_provider.delete(session, actor=actor) session.commit() @enforce_types - def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticProvider]: + def list_providers(self, after: Optional[str] = None, limit: Optional[int] = 50, actor: PydanticUser = None) -> List[PydanticProvider]: """List all providers with optional pagination.""" with self.session_maker() as session: providers = ProviderModel.list( db_session=session, after=after, limit=limit, + actor=actor, ) return [provider.to_pydantic() for provider in providers] diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 612c8bf2..278dd292 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -84,9 +84,9 @@ class StepManager: return new_step.to_pydantic() @enforce_types - def get_step(self, step_id: str) -> PydanticStep: + def get_step(self, step_id: str, actor: PydanticUser) -> PydanticStep: with self.session_maker() as session: - step = StepModel.read(db_session=session, identifier=step_id) + step = StepModel.read(db_session=session, identifier=step_id, actor=actor) return step.to_pydantic() @enforce_types diff --git a/tests/test_server.py b/tests/test_server.py index ed5a33f5..e5972724 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1194,7 +1194,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): step_ids = set([msg.step_id for msg in get_messages_response]) completion_tokens, prompt_tokens, total_tokens = 0, 0, 0 for step_id in step_ids: - step = server.step_manager.get_step(step_id=step_id) + step = server.step_manager.get_step(step_id=step_id, actor=actor) assert step, "Step was not logged correctly" assert step.provider_id == provider.id assert step.provider_name == agent.llm_config.model_endpoint_type @@ -1208,7 +1208,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): assert prompt_tokens == usage.prompt_tokens assert total_tokens == usage.total_tokens - server.provider_manager.delete_provider_by_id(provider.id) + server.provider_manager.delete_provider_by_id(provider.id, actor=actor) existing_messages = server.message_manager.list_messages_for_agent(agent_id=agent.id, actor=actor) @@ -1221,7 +1221,7 @@ def test_messages_with_provider_override(server: SyncServer, user_id: str): step_ids = set([msg.step_id for msg in get_messages_response]) completion_tokens, prompt_tokens, total_tokens = 0, 0, 0 for step_id in step_ids: - step = server.step_manager.get_step(step_id=step_id) + step = server.step_manager.get_step(step_id=step_id, actor=actor) assert step, "Step was not logged correctly" assert step.provider_id == None assert step.provider_name == agent.llm_config.model_endpoint_type From 6a20838d85db4818d0fb86ff712bdaca8a1c2f38 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 25 Feb 2025 12:17:19 -0800 Subject: [PATCH 5/8] fix: persist properties on identities update (#1110) --- letta/services/identity_manager.py | 6 ++++++ 1 file changed, 6 insertions(+) diff --git a/letta/services/identity_manager.py b/letta/services/identity_manager.py index 53058960..8dcb83ad 100644 --- a/letta/services/identity_manager.py +++ b/letta/services/identity_manager.py @@ -111,6 +111,12 @@ class IdentityManager: existing_identity.name = identity.name if identity.identity_type is not None: existing_identity.identity_type = identity.identity_type + if identity.properties is not None: + if replace: + existing_identity.properties = [prop.model_dump() for prop in identity.properties] + else: + new_properties = existing_identity.properties + identity.properties + existing_identity.properties = [prop.model_dump() for prop in new_properties] self._process_agent_relationship( session=session, identity=existing_identity, agent_ids=identity.agent_ids, allow_partial=False, replace=replace From 39c17a11de30c6113bb24c086c9746cbf9c35e6f Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 25 Feb 2025 14:23:22 -0800 Subject: [PATCH 6/8] feat: enable listing agents by identity id (#1112) --- letta/server/rest_api/routers/v1/agents.py | 2 ++ 1 file changed, 2 insertions(+) diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 70664d6b..dd75d1cc 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -52,6 +52,7 @@ def list_agents( project_id: Optional[str] = Query(None, description="Search agents by project id"), template_id: Optional[str] = Query(None, description="Search agents by template id"), base_template_id: Optional[str] = Query(None, description="Search agents by base template id"), + identifier_id: Optional[str] = Query(None, description="Search agents by identifier id"), identifier_keys: Optional[List[str]] = Query(None, description="Search agents by identifier keys"), ): """ @@ -68,6 +69,7 @@ def list_agents( "project_id": project_id, "template_id": template_id, "base_template_id": base_template_id, + "identifier_id": identifier_id, }.items() if value is not None } From 997019afe82e1537c646029b2387699141de5415 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 25 Feb 2025 15:13:35 -0800 Subject: [PATCH 7/8] feat: Add tool calling to fast chat completions (#1109) --- letta/constants.py | 7 + letta/helpers/composio_helpers.py | 5 +- letta/helpers/tool_execution_helper.py | 171 ++++++++++++ .../schemas/openai/chat_completion_request.py | 2 +- .../chat_completions/chat_completions.py | 254 +++++++++++++++--- letta/server/rest_api/utils.py | 5 +- letta/services/agent_manager.py | 36 +++ letta/services/block_manager.py | 6 +- letta/services/message_manager.py | 5 +- tests/integration_test_chat_completions.py | 59 +++- 10 files changed, 507 insertions(+), 43 deletions(-) create mode 100644 letta/helpers/tool_execution_helper.py diff --git a/letta/constants.py b/letta/constants.py index 35ab7cb4..468afa4c 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -52,6 +52,8 @@ BASE_TOOLS = ["send_message", "conversation_search", "archival_memory_insert", " BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] # Multi agent tools MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] +# Set of all built-in Letta tools +LETTA_TOOL_SET = set(BASE_TOOLS + BASE_MEMORY_TOOLS + MULTI_AGENT_TOOLS) # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) @@ -59,6 +61,11 @@ MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_t DEFAULT_MESSAGE_TOOL = "send_message" DEFAULT_MESSAGE_TOOL_KWARG = "message" +PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" + +REQUEST_HEARTBEAT_PARAM = "request_heartbeat" + + # Structured output models STRUCTURED_OUTPUT_MODELS = {"gpt-4o", "gpt-4o-mini"} diff --git a/letta/helpers/composio_helpers.py b/letta/helpers/composio_helpers.py index 8a8c3249..a3c518ec 100644 --- a/letta/helpers/composio_helpers.py +++ b/letta/helpers/composio_helpers.py @@ -6,10 +6,11 @@ from letta.services.sandbox_config_manager import SandboxConfigManager from letta.settings import tool_settings -def get_composio_api_key(actor: User, logger: Logger) -> Optional[str]: +def get_composio_api_key(actor: User, logger: Optional[Logger] = None) -> Optional[str]: api_keys = SandboxConfigManager().list_sandbox_env_vars_by_key(key="COMPOSIO_API_KEY", actor=actor) if not api_keys: - logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...") + if logger: + logger.warning(f"No API keys found for Composio. Defaulting to the environment variable...") if tool_settings.composio_api_key: return tool_settings.composio_api_key else: diff --git a/letta/helpers/tool_execution_helper.py b/letta/helpers/tool_execution_helper.py new file mode 100644 index 00000000..948772ee --- /dev/null +++ b/letta/helpers/tool_execution_helper.py @@ -0,0 +1,171 @@ +from collections import OrderedDict +from typing import Any, Dict, Optional + +from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, PRE_EXECUTION_MESSAGE_ARG +from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source +from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name +from letta.helpers.composio_helpers import get_composio_api_key +from letta.orm.enums import ToolType +from letta.schemas.agent import AgentState +from letta.schemas.sandbox_config import SandboxRunResult +from letta.schemas.tool import Tool +from letta.schemas.user import User +from letta.services.tool_execution_sandbox import ToolExecutionSandbox +from letta.utils import get_friendly_error_msg + + +def enable_strict_mode(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Enables strict mode for a tool schema by setting 'strict' to True and + disallowing additional properties in the parameters. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema with strict mode enabled. + """ + schema = tool_schema.copy() + + # Enable strict mode + schema["strict"] = True + + # Ensure parameters is a valid dictionary + parameters = schema.get("parameters", {}) + + if isinstance(parameters, dict) and parameters.get("type") == "object": + # Set additionalProperties to False + parameters["additionalProperties"] = False + schema["parameters"] = parameters + + return schema + + +def add_pre_execution_message(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Adds a `pre_execution_message` parameter to a tool schema to prompt a natural, human-like message before executing the tool. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema with the `pre_execution_message` field added at the beginning. + """ + schema = tool_schema.copy() + parameters = schema.get("parameters", {}) + + if not isinstance(parameters, dict) or parameters.get("type") != "object": + return schema # Do not modify if schema is not valid + + properties = parameters.get("properties", {}) + required = parameters.get("required", []) + + # Define the new `pre_execution_message` field with a refined description + pre_execution_message_field = { + "type": "string", + "description": ( + "A concise message to be uttered before executing this tool. " + "This should sound natural, as if a person is casually announcing their next action." + "You MUST also include punctuation at the end of this message." + ), + } + + # Ensure the pre-execution message is the first field in properties + updated_properties = OrderedDict() + updated_properties[PRE_EXECUTION_MESSAGE_ARG] = pre_execution_message_field + updated_properties.update(properties) # Retain all existing properties + + # Ensure pre-execution message is the first required field + if PRE_EXECUTION_MESSAGE_ARG not in required: + required = [PRE_EXECUTION_MESSAGE_ARG] + required + + # Update the schema with ordered properties and required list + schema["parameters"] = { + **parameters, + "properties": dict(updated_properties), # Convert OrderedDict back to dict + "required": required, + } + + return schema + + +def remove_request_heartbeat(tool_schema: Dict[str, Any]) -> Dict[str, Any]: + """Removes the `request_heartbeat` parameter from a tool schema if it exists. + + Args: + tool_schema (Dict[str, Any]): The original tool schema. + + Returns: + Dict[str, Any]: A new tool schema without `request_heartbeat`. + """ + schema = tool_schema.copy() + parameters = schema.get("parameters", {}) + + if isinstance(parameters, dict): + properties = parameters.get("properties", {}) + required = parameters.get("required", []) + + # Remove the `request_heartbeat` property if it exists + if "request_heartbeat" in properties: + properties.pop("request_heartbeat") + + # Remove `request_heartbeat` from required fields if present + if "request_heartbeat" in required: + required = [r for r in required if r != "request_heartbeat"] + + # Update parameters with modified properties and required list + schema["parameters"] = {**parameters, "properties": properties, "required": required} + + return schema + + +# TODO: Deprecate the `execute_external_tool` function on the agent body +def execute_external_tool( + agent_state: AgentState, + function_name: str, + function_args: dict, + target_letta_tool: Tool, + actor: User, + allow_agent_state_modifications: bool = False, +) -> tuple[Any, Optional[SandboxRunResult]]: + # TODO: need to have an AgentState object that actually has full access to the block data + # this is because the sandbox tools need to be able to access block.value to edit this data + try: + if target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO: + action_name = generate_composio_action_from_func_name(target_letta_tool.name) + # Get entity ID from the agent_state + entity_id = None + for env_var in agent_state.tool_exec_environment_variables: + if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY: + entity_id = env_var.value + # Get composio_api_key + composio_api_key = get_composio_api_key(actor=actor) + function_response = execute_composio_action( + action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id + ) + return function_response, None + elif target_letta_tool.tool_type == ToolType.CUSTOM: + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + function_args = coerce_dict_args_by_annotations(function_args, annotations) + + # execute tool in a sandbox + # TODO: allow agent_state to specify which sandbox to execute tools in + # TODO: This is only temporary, can remove after we publish a pip package with this object + if allow_agent_state_modifications: + agent_state_copy = agent_state.__deepcopy__() + agent_state_copy.tools = [] + agent_state_copy.tool_rules = [] + else: + agent_state_copy = None + + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor).run(agent_state=agent_state_copy) + function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state + # TODO: Bring this back + # if allow_agent_state_modifications and updated_agent_state is not None: + # self.update_memory_if_changed(updated_agent_state.memory) + return function_response, sandbox_run_result + except Exception as e: + # Need to catch error here, or else trunction wont happen + # TODO: modify to function execution error + function_response = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)) + return function_response, None diff --git a/letta/schemas/openai/chat_completion_request.py b/letta/schemas/openai/chat_completion_request.py index 5b7b2743..12486bca 100644 --- a/letta/schemas/openai/chat_completion_request.py +++ b/letta/schemas/openai/chat_completion_request.py @@ -99,7 +99,7 @@ class ChatCompletionRequest(BaseModel): """https://platform.openai.com/docs/api-reference/chat/create""" model: str - messages: List[ChatMessage] + messages: List[Union[ChatMessage, Dict]] frequency_penalty: Optional[float] = 0 logit_bias: Optional[Dict[str, int]] = None logprobs: Optional[bool] = False diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 13fd2347..428dbbd9 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -1,19 +1,39 @@ import asyncio +import json +import uuid from typing import TYPE_CHECKING, List, Optional, Union import httpx import openai from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi.responses import StreamingResponse +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta from openai.types.chat.completion_create_params import CompletionCreateParams from starlette.concurrency import run_in_threadpool from letta.agent import Agent -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG +from letta.helpers.tool_execution_helper import ( + add_pre_execution_message, + enable_strict_mode, + execute_external_tool, + remove_request_heartbeat, +) from letta.log import get_logger +from letta.orm.enums import ToolType from letta.schemas.message import Message, MessageCreate +from letta.schemas.openai.chat_completion_request import ( + AssistantMessage, + ChatCompletionRequest, + Tool, + ToolCall, + ToolCallFunction, + ToolMessage, + UserMessage, +) from letta.schemas.user import User from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface +from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser # TODO this belongs in a controller! from letta.server.rest_api.utils import ( @@ -52,20 +72,53 @@ async def create_fast_chat_completions( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - # TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior + actor = server.user_manager.get_user_or_default(user_id=user_id) + agent_id = str(completion_request.get("user", None)) if agent_id is None: - error_msg = "Must pass agent_id in the 'user' field" - logger.error(error_msg) - raise HTTPException(status_code=400, detail=error_msg) - model = completion_request.get("model") + raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") - actor = server.user_manager.get_user_or_default(user_id=user_id) + agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state.llm_config.model_endpoint_type != "openai": + raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.") + + # Convert Letta messages to OpenAI messages + in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor) + openai_messages = convert_letta_messages_to_openai(in_context_messages) + + # Also parse user input from completion_request and append + input_message = get_messages_from_completion_request(completion_request)[-1] + openai_messages.append(input_message) + + # Tools we allow this agent to call + tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}] + + # Initial request + openai_request = ChatCompletionRequest( + model=agent_state.llm_config.model, + messages=openai_messages, + # TODO: This nested thing here is so ugly, need to refactor + tools=( + [ + Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema)))) + for t in tools + ] + if tools + else None + ), + tool_choice="auto", + user=user_id, + max_completion_tokens=agent_state.llm_config.max_tokens, + temperature=agent_state.llm_config.temperature, + stream=True, + ) + + # Create the OpenAI async client client = openai.AsyncClient( api_key=model_settings.openai_api_key, max_retries=0, http_client=httpx.AsyncClient( - timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), + timeout=httpx.Timeout(connect=15.0, read=30.0, write=15.0, pool=15.0), follow_redirects=True, limits=httpx.Limits( max_connections=50, @@ -75,38 +128,175 @@ async def create_fast_chat_completions( ), ) - # Magic message manipulating - input_message = get_messages_from_completion_request(completion_request)[-1] - completion_request.pop("messages") - - # Get in context messages - in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor) - openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages) - openai_dict_in_context_messages.append(input_message) + # The messages we want to persist to the Letta agent + user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) + message_db_queue = [user_message] async def event_stream(): - # TODO: Factor this out into separate interface - response_accumulator = [] + """ + A function-calling loop: + - We stream partial tokens. + - If we detect a tool call (finish_reason="tool_calls"), we parse it, + add two messages to the conversation: + (a) assistant message with tool_calls referencing the same ID + (b) a tool message referencing that ID, containing the tool result. + - Re-invoke the OpenAI request with updated conversation, streaming again. + - End when finish_reason="stop" or no more tool calls. + """ - stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages) + # We'll keep updating this conversation in a loop + conversation = openai_messages[:] - async with stream: - async for chunk in stream: - if chunk.choices and chunk.choices[0].delta.content: - # TODO: This does not support tool calling right now - response_accumulator.append(chunk.choices[0].delta.content) - yield f"data: {chunk.model_dump_json()}\n\n" + while True: + # Make the streaming request to OpenAI + stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True)) - # Construct messages - user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) - assistant_message = create_assistant_message_from_openai_response( - response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor - ) + content_buffer = [] + tool_call_name = None + tool_call_args_str = "" + tool_call_id = None + tool_call_happened = False + finish_reason_stop = False + optimistic_json_parser = OptimisticJSONParser(strict=True) + current_parsed_json_result = {} + + async with stream: + async for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + finish_reason = choice.finish_reason # "tool_calls", "stop", or None + + if delta.content: + content_buffer.append(delta.content) + yield f"data: {chunk.model_dump_json()}\n\n" + + # CASE B: Partial tool call info + if delta.tool_calls: + # Typically there's only one in delta.tool_calls + tc = delta.tool_calls[0] + if tc.function.name: + tool_call_name = tc.function.name + if tc.function.arguments: + tool_call_args_str += tc.function.arguments + + # See if we can stream out the pre-execution message + parsed_args = optimistic_json_parser.parse(tool_call_args_str) + if parsed_args.get( + PRE_EXECUTION_MESSAGE_ARG + ) and current_parsed_json_result.get( # Ensure key exists and is not None/empty + PRE_EXECUTION_MESSAGE_ARG + ) != parsed_args.get( + PRE_EXECUTION_MESSAGE_ARG + ): + # Only stream if there's something new to stream + # We do this way to avoid hanging JSON at the end of the stream, e.g. '}' + if parsed_args != current_parsed_json_result: + current_parsed_json_result = parsed_args + synthetic_chunk = ChatCompletionChunk( + id=chunk.id, + object=chunk.object, + created=chunk.created, + model=chunk.model, + choices=[ + Choice( + index=choice.index, + delta=ChoiceDelta(content=tc.function.arguments, role="assistant"), + finish_reason=None, + ) + ], + ) + + yield f"data: {synthetic_chunk.model_dump_json()}\n\n" + + # We might generate a unique ID for the tool call + if tc.id: + tool_call_id = tc.id + + # Check finish_reason + if finish_reason == "tool_calls": + tool_call_happened = True + break + elif finish_reason == "stop": + finish_reason_stop = True + break + + if content_buffer: + # We treat that partial text as an assistant message + content = "".join(content_buffer) + conversation.append({"role": "assistant", "content": content}) + + # Create an assistant message here to persist later + assistant_message = create_assistant_message_from_openai_response( + response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor + ) + message_db_queue.append(assistant_message) + + if tool_call_happened: + # Parse the tool call arguments + try: + tool_args = json.loads(tool_call_args_str) + except json.JSONDecodeError: + tool_args = {} + + if not tool_call_id: + # If no tool_call_id given by the model, generate one + tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + + # 1) Insert the "assistant" message with the tool_calls field + # referencing the same tool_call_id + assistant_tool_call_msg = AssistantMessage( + content=None, + tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))], + ) + + conversation.append(assistant_tool_call_msg.model_dump()) + + # 2) Execute the tool + target_tool = next((x for x in tools if x.name == tool_call_name), None) + if not target_tool: + # Tool not found, handle error + yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n" + break + + try: + tool_result, _ = execute_external_tool( + agent_state=agent_state, + function_name=tool_call_name, + function_args=tool_args, + target_letta_tool=target_tool, + actor=actor, + allow_agent_state_modifications=False, + ) + except Exception as e: + tool_result = f"Failed to call tool. Error: {e}" + + # 3) Insert the "tool" message referencing the same tool_call_id + tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id) + + conversation.append(tool_message.model_dump()) + + # 4) Add a user message prompting the tool call result summarization + heartbeat_user_message = UserMessage( + content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.", + ) + conversation.append(heartbeat_user_message.model_dump()) + + # Now, re-invoke OpenAI with the updated conversation + openai_request.messages = conversation + + continue # Start the while loop again + + if finish_reason_stop: + # Model is done, no more calls + break + + # If we reach here, no tool call, no "stop", but we've ended streaming + # Possibly a model error or some other finish reason. We'll just end. + break - # Persist both in one synchronous DB call, done in a threadpool await run_in_threadpool( server.agent_manager.append_to_in_context_messages, - [user_message, assistant_message], + message_db_queue, agent_id=agent_id, actor=actor, ) diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index d5bf4520..8008d056 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -7,7 +7,6 @@ from datetime import datetime, timezone from enum import Enum from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast -import pytz from fastapi import Header, HTTPException from openai.types.chat import ChatCompletionMessageParam from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall @@ -145,7 +144,7 @@ def create_user_message(input_message: dict, agent_id: str, actor: User) -> Mess Converts a user input message into the internal structured format. """ # Generate timestamp in the correct format - now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + now = datetime.now(timezone.utc).isoformat() # Format message as structured JSON structured_message = {"type": "user_message", "message": input_message["content"], "time": now} @@ -197,7 +196,7 @@ def create_assistant_message_from_openai_response( agent_id=agent_id, model=model, tool_calls=[tool_call], - tool_call_id=None, + tool_call_id=tool_call_id, created_at=datetime.now(timezone.utc), ) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index fb450d3d..8d9743ea 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -21,8 +21,10 @@ from letta.orm.sqlite_functions import adapt_array from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent from letta.schemas.block import Block as PydanticBlock +from letta.schemas.block import BlockUpdate from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.llm_config import LLMConfig +from letta.schemas.memory import Memory from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate from letta.schemas.passage import Passage as PydanticPassage @@ -613,6 +615,40 @@ class AgentManager: ) return self.append_to_in_context_messages([system_message], agent_id=agent_state.id, actor=actor) + # TODO: I moved this from agent.py - replace all mentions of this with the agent_manager version + @enforce_types + def update_memory_if_changed(self, agent_id: str, new_memory: Memory, actor: PydanticUser) -> PydanticAgentState: + """ + Update internal memory object and system prompt if there have been modifications. + + Args: + new_memory (Memory): the new memory object to compare to the current memory object + + Returns: + modified (bool): whether the memory was updated + """ + agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor) + if agent_state.memory.compile() != new_memory.compile(): + # update the blocks (LRW) in the DB + for label in agent_state.memory.list_block_labels(): + updated_value = new_memory.get_block(label).value + if updated_value != agent_state.memory.get_block(label).value: + # update the block if it's changed + block_id = agent_state.memory.get_block(label).id + block = self.block_manager.update_block(block_id=block_id, block_update=BlockUpdate(value=updated_value), actor=actor) + + # refresh memory from DB (using block ids) + agent_state.memory = Memory( + blocks=[self.block_manager.get_block_by_id(block.id, actor=actor) for block in agent_state.memory.get_blocks()] + ) + + # NOTE: don't do this since re-buildin the memory is handled at the start of the step + # rebuild memory - this records the last edited timestamp of the memory + # TODO: pass in update timestamp from block edit time + agent_state = self.rebuild_system_prompt(agent_id=agent_id, actor=actor) + + return agent_state + # ====================================================================================================================== # Source Management # ====================================================================================================================== diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 7ae743a7..fe10671d 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -107,12 +107,14 @@ class BlockManager: @enforce_types def add_default_blocks(self, actor: PydanticUser): for persona_file in list_persona_files(): - text = open(persona_file, "r", encoding="utf-8").read() + with open(persona_file, "r", encoding="utf-8") as f: + text = f.read() name = os.path.basename(persona_file).replace(".txt", "") self.create_or_update_block(Persona(template_name=name, value=text, is_template=True), actor=actor) for human_file in list_human_files(): - text = open(human_file, "r", encoding="utf-8").read() + with open(human_file, "r", encoding="utf-8") as f: + text = f.read() name = os.path.basename(human_file).replace(".txt", "") self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index ed2881b3..26f0bee5 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -2,6 +2,7 @@ from typing import List, Optional from sqlalchemy import and_, or_ +from letta.log import get_logger from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel @@ -11,6 +12,8 @@ from letta.schemas.message import MessageUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types +logger = get_logger(__name__) + class MessageManager: """Manager class to handle business logic related to Messages.""" @@ -37,7 +40,7 @@ class MessageManager: results = MessageModel.list(db_session=session, id=message_ids, organization_id=actor.organization_id, limit=len(message_ids)) if len(results) != len(message_ids): - raise NoResultFound( + logger.warning( f"Expected {len(message_ids)} messages, but found {len(results)}. Missing ids={set(message_ids) - set([r.id for r in results])}" ) diff --git a/tests/integration_test_chat_completions.py b/tests/integration_test_chat_completions.py index 4ab3b1d8..4b501de3 100644 --- a/tests/integration_test_chat_completions.py +++ b/tests/integration_test_chat_completions.py @@ -14,7 +14,9 @@ from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageStreamStatus from letta.schemas.llm_config import LLMConfig from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage +from letta.schemas.tool import ToolCreate from letta.schemas.usage import LettaUsageStatistics +from letta.services.tool_manager import ToolManager # --- Server Management --- # @@ -69,9 +71,49 @@ def roll_dice_tool(client): @pytest.fixture(scope="function") -def agent(client, roll_dice_tool): +def weather_tool(client): + def get_weather(location: str) -> str: + """ + Fetches the current weather for a given location. + + Parameters: + location (str): The location to get the weather for. + + Returns: + str: A formatted string describing the weather in the given location. + + Raises: + RuntimeError: If the request to fetch weather data fails. + """ + import requests + + url = f"https://wttr.in/{location}?format=%C+%t" + + response = requests.get(url) + if response.status_code == 200: + weather_data = response.text + return f"The weather in {location} is {weather_data}." + else: + raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") + + tool = client.create_or_update_tool(func=get_weather) + # Yield the created tool + yield tool + + +@pytest.fixture(scope="function") +def composio_gmail_get_profile_tool(default_user): + tool_create = ToolCreate.from_composio(action_name="GMAIL_GET_PROFILE") + tool = ToolManager().create_or_update_composio_tool(tool_create=tool_create, actor=default_user) + yield tool + + +@pytest.fixture(scope="function") +def agent(client, roll_dice_tool, weather_tool, composio_gmail_get_profile_tool): """Creates an agent and ensures cleanup after tests.""" - agent_state = client.create_agent(name=f"test_client_{uuid.uuid4()}", tool_ids=[roll_dice_tool.id]) + agent_state = client.create_agent( + name=f"test_compl_{str(uuid.uuid4())[5:]}", tool_ids=[roll_dice_tool.id, weather_tool.id, composio_gmail_get_profile_tool.id] + ) yield agent_state client.delete_agent(agent_state.id) @@ -111,6 +153,19 @@ def _assert_valid_chunk(chunk, idx, chunks): # --- Test Cases --- # +@pytest.mark.parametrize("message", ["What's the weather in SF?"]) +@pytest.mark.parametrize("endpoint", ["fast/chat/completions"]) +def test_tool_usage_fast_chat_completions(mock_e2b_api_key_none, client, agent, message, endpoint): + """Tests chat completion streaming via SSE.""" + request = _get_chat_request(agent.id, message) + + response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers) + + for chunk in response: + if isinstance(chunk, ChatCompletionChunk) and chunk.choices: + print(chunk.choices[0].delta.content) + + @pytest.mark.parametrize("message", ["Tell me something interesting about bananas."]) @pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"]) def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint): From 4145e1652377e7dc1791adf60ec5273976aafd51 Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 25 Feb 2025 15:39:19 -0800 Subject: [PATCH 8/8] fix: unsupported annotation error for pydantic tool args (#1117) Co-authored-by: Sarah Wooders --- letta/agent.py | 15 ++++++++++----- 1 file changed, 10 insertions(+), 5 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 61d69343..e7d435b4 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -245,10 +245,13 @@ class Agent(BaseAgent): action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id ) else: - # Parse the source code to extract function annotations - annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) - # Coerce the function arguments to the correct types based on the annotations - function_args = coerce_dict_args_by_annotations(function_args, annotations) + try: + # Parse the source code to extract function annotations + annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name) + # Coerce the function arguments to the correct types based on the annotations + function_args = coerce_dict_args_by_annotations(function_args, annotations) + except ValueError as e: + self.logger.debug(f"Error coercing function arguments: {e}") # execute tool in a sandbox # TODO: allow agent_state to specify which sandbox to execute tools in @@ -257,7 +260,9 @@ class Agent(BaseAgent): agent_state_copy.tools = [] agent_state_copy.tool_rules = [] - sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user).run(agent_state=agent_state_copy) + sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run( + agent_state=agent_state_copy + ) function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" if updated_agent_state is not None: