diff --git a/letta/constants.py b/letta/constants.py index ba105e8f..06274b9a 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -465,3 +465,7 @@ MODAL_DEFAULT_PYTHON_VERSION = "3.12" MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "uuid"} # decimal, enum # Default handle for model used to generate tools DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1" + +# Reserved keyword arguments that are injected by the system into tool functions, not provided by the LLM +# These parameters are excluded from tool schema generation +TOOL_RESERVED_KWARGS = ["self", "agent_state", "client"] diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 9fe0e37b..3f549069 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -5,7 +5,7 @@ from docstring_parser import parse from pydantic import BaseModel from typing_extensions import Literal -from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM +from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, TOOL_RESERVED_KWARGS from letta.functions.mcp_client.types import MCPTool from letta.log import get_logger @@ -34,7 +34,7 @@ def validate_google_style_docstring(function): # 3. Args and Returns sections should be properly formatted sig = inspect.signature(function) - has_params = any(param.name not in ["self", "agent_state"] for param in sig.parameters.values()) + has_params = any(param.name not in TOOL_RESERVED_KWARGS for param in sig.parameters.values()) # Check for Args section if function has parameters if has_params and "Args:" not in docstring: @@ -51,7 +51,7 @@ def validate_google_style_docstring(function): # Check that each parameter is documented for param in sig.parameters.values(): - if param.name in ["self", "agent_state"]: + if param.name in TOOL_RESERVED_KWARGS: continue if f"{param.name} (" not in args_section and f"{param.name}:" not in args_section: raise ValueError( @@ -448,12 +448,9 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ "parameters": {"type": "object", "properties": {}, "required": []}, } - # TODO: ensure that 'agent' keyword is reserved for `Agent` class - for param in sig.parameters.values(): - # Exclude 'self' parameter - # TODO: eventually remove this (only applies to BASE_TOOLS) - if param.name in ["self", "agent_state"]: # Add agent_manager to excluded + # Exclude reserved parameters that are injected by the system + if param.name in TOOL_RESERVED_KWARGS: continue # Assert that the parameter has a type annotation diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index cbc23c9a..c12e5fa5 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -67,7 +67,7 @@ class AsyncToolSandboxBase(ABC): self.inject_agent_state = False # Check for Letta client and agent_id injection - self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments + self.inject_letta_client = "client" in tool_arguments self.inject_agent_id = "agent_id" in tool_arguments self.is_async_function = self._detect_async_function() @@ -195,20 +195,18 @@ class AsyncToolSandboxBase(ABC): [ "# Initialize Letta client for tool execution", "import os", - "letta_client = None", + "client = None", "if os.getenv('LETTA_API_KEY'):", " # Check letta_client version to use correct parameter name", " from packaging import version as pkg_version", " import letta_client as lc_module", " lc_version = pkg_version.parse(lc_module.__version__)", " if lc_version < pkg_version.parse('1.0.0'):", - " letta_client = Letta(", - " base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),", + " client = Letta(", " token=os.getenv('LETTA_API_KEY')", " )", " else:", - " letta_client = Letta(", - " base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),", + " client = Letta(", " api_key=os.getenv('LETTA_API_KEY')", " )", ] @@ -346,9 +344,7 @@ class AsyncToolSandboxBase(ABC): # Check if the function expects 'client' or 'letta_client' tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name) if "client" in tool_arguments: - param_list.append("client=letta_client") - elif "letta_client" in tool_arguments: - param_list.append("letta_client=letta_client") + param_list.append("client=client") if self.inject_agent_id: param_list.append("agent_id=agent_id") diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index 1a101cdf..4aeeb03a 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -1287,7 +1287,7 @@ async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_too # Verify the script contains Letta client initialization assert "from letta_client import Letta" in script, "Script should import Letta client" assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY" - assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client" + assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client" # Run the tool and verify it works result = await sandbox.run(agent_state=None) @@ -1345,6 +1345,6 @@ async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tool # Verify the script contains Letta client initialization assert "from letta_client import Letta" in script, "Script should import Letta client" assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY" - assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client" + assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client" # Cannot run the tool since E2B is remote diff --git a/tests/test_tool_schema_parsing.py b/tests/test_tool_schema_parsing.py index e870c782..bdaa4817 100644 --- a/tests/test_tool_schema_parsing.py +++ b/tests/test_tool_schema_parsing.py @@ -401,6 +401,30 @@ def agent_state_ok(agent_state, value: int) -> str: return "ok" +def client_ok(client, value: int) -> str: + """Ignores client param (injected Letta client). + + Args: + value (int): Some value. + + Returns: + str: Status. + """ + return "ok" + + +def all_reserved_params_ok(agent_state, client, value: int) -> str: + """Ignores all reserved params. + + Args: + value (int): Some value. + + Returns: + str: Status. + """ + return "ok" + + class Dummy: def method(self, bar: int) -> str: # keeps an explicit self """Bound-method example. @@ -446,6 +470,8 @@ def missing_param_doc(x: int, y: int) -> str: [ (good_function, None), (agent_state_ok, None), + (client_ok, None), # client is a reserved param (injected Letta client) + (all_reserved_params_ok, None), # all reserved params together (Dummy.method, None), # unbound method keeps `self` (good_function_no_return, None), (no_doc, "has no docstring"), @@ -457,6 +483,28 @@ def test_google_style_docstring_validation(fn, regex): _check(fn, regex) +def test_reserved_params_excluded_from_schema(): + """Test that reserved params (agent_state, client) are excluded from generated schema.""" + from letta.functions.schema_generator import generate_schema + + # Test with client param + schema = generate_schema(client_ok) + assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema" + assert "value" in schema["parameters"]["properties"], "value should be in schema" + + # Test with agent_state param + schema = generate_schema(agent_state_ok) + assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema" + assert "value" in schema["parameters"]["properties"], "value should be in schema" + + # Test with all reserved params + schema = generate_schema(all_reserved_params_ok) + assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema" + assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema" + assert "value" in schema["parameters"]["properties"], "value should be in schema" + assert schema["parameters"]["required"] == ["value"], "only value should be required" + + def test_complex_nested_anyof_schema_to_structured_output(): """Test that complex nested anyOf schemas with inlined $refs can be converted to structured outputs.