fix: fix client injection code (#6421)

This commit is contained in:
Sarah Wooders
2025-11-26 17:16:02 -08:00
committed by Caren Thomas
parent 4d5be22d14
commit 0a0cf391fc
5 changed files with 64 additions and 19 deletions

View File

@@ -465,3 +465,7 @@ MODAL_DEFAULT_PYTHON_VERSION = "3.12"
MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "uuid"} # decimal, enum
# Default handle for model used to generate tools
DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1"
# Reserved keyword arguments that are injected by the system into tool functions, not provided by the LLM
# These parameters are excluded from tool schema generation
TOOL_RESERVED_KWARGS = ["self", "agent_state", "client"]

View File

@@ -5,7 +5,7 @@ from docstring_parser import parse
from pydantic import BaseModel
from typing_extensions import Literal
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, TOOL_RESERVED_KWARGS
from letta.functions.mcp_client.types import MCPTool
from letta.log import get_logger
@@ -34,7 +34,7 @@ def validate_google_style_docstring(function):
# 3. Args and Returns sections should be properly formatted
sig = inspect.signature(function)
has_params = any(param.name not in ["self", "agent_state"] for param in sig.parameters.values())
has_params = any(param.name not in TOOL_RESERVED_KWARGS for param in sig.parameters.values())
# Check for Args section if function has parameters
if has_params and "Args:" not in docstring:
@@ -51,7 +51,7 @@ def validate_google_style_docstring(function):
# Check that each parameter is documented
for param in sig.parameters.values():
if param.name in ["self", "agent_state"]:
if param.name in TOOL_RESERVED_KWARGS:
continue
if f"{param.name} (" not in args_section and f"{param.name}:" not in args_section:
raise ValueError(
@@ -448,12 +448,9 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
"parameters": {"type": "object", "properties": {}, "required": []},
}
# TODO: ensure that 'agent' keyword is reserved for `Agent` class
for param in sig.parameters.values():
# Exclude 'self' parameter
# TODO: eventually remove this (only applies to BASE_TOOLS)
if param.name in ["self", "agent_state"]: # Add agent_manager to excluded
# Exclude reserved parameters that are injected by the system
if param.name in TOOL_RESERVED_KWARGS:
continue
# Assert that the parameter has a type annotation

View File

@@ -67,7 +67,7 @@ class AsyncToolSandboxBase(ABC):
self.inject_agent_state = False
# Check for Letta client and agent_id injection
self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments
self.inject_letta_client = "client" in tool_arguments
self.inject_agent_id = "agent_id" in tool_arguments
self.is_async_function = self._detect_async_function()
@@ -195,20 +195,18 @@ class AsyncToolSandboxBase(ABC):
[
"# Initialize Letta client for tool execution",
"import os",
"letta_client = None",
"client = None",
"if os.getenv('LETTA_API_KEY'):",
" # Check letta_client version to use correct parameter name",
" from packaging import version as pkg_version",
" import letta_client as lc_module",
" lc_version = pkg_version.parse(lc_module.__version__)",
" if lc_version < pkg_version.parse('1.0.0'):",
" letta_client = Letta(",
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
" client = Letta(",
" token=os.getenv('LETTA_API_KEY')",
" )",
" else:",
" letta_client = Letta(",
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
" client = Letta(",
" api_key=os.getenv('LETTA_API_KEY')",
" )",
]
@@ -346,9 +344,7 @@ class AsyncToolSandboxBase(ABC):
# Check if the function expects 'client' or 'letta_client'
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
if "client" in tool_arguments:
param_list.append("client=letta_client")
elif "letta_client" in tool_arguments:
param_list.append("letta_client=letta_client")
param_list.append("client=client")
if self.inject_agent_id:
param_list.append("agent_id=agent_id")

View File

@@ -1287,7 +1287,7 @@ async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_too
# Verify the script contains Letta client initialization
assert "from letta_client import Letta" in script, "Script should import Letta client"
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client"
# Run the tool and verify it works
result = await sandbox.run(agent_state=None)
@@ -1345,6 +1345,6 @@ async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tool
# Verify the script contains Letta client initialization
assert "from letta_client import Letta" in script, "Script should import Letta client"
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client"
# Cannot run the tool since E2B is remote

View File

@@ -401,6 +401,30 @@ def agent_state_ok(agent_state, value: int) -> str:
return "ok"
def client_ok(client, value: int) -> str:
"""Ignores client param (injected Letta client).
Args:
value (int): Some value.
Returns:
str: Status.
"""
return "ok"
def all_reserved_params_ok(agent_state, client, value: int) -> str:
"""Ignores all reserved params.
Args:
value (int): Some value.
Returns:
str: Status.
"""
return "ok"
class Dummy:
def method(self, bar: int) -> str: # keeps an explicit self
"""Bound-method example.
@@ -446,6 +470,8 @@ def missing_param_doc(x: int, y: int) -> str:
[
(good_function, None),
(agent_state_ok, None),
(client_ok, None), # client is a reserved param (injected Letta client)
(all_reserved_params_ok, None), # all reserved params together
(Dummy.method, None), # unbound method keeps `self`
(good_function_no_return, None),
(no_doc, "has no docstring"),
@@ -457,6 +483,28 @@ def test_google_style_docstring_validation(fn, regex):
_check(fn, regex)
def test_reserved_params_excluded_from_schema():
"""Test that reserved params (agent_state, client) are excluded from generated schema."""
from letta.functions.schema_generator import generate_schema
# Test with client param
schema = generate_schema(client_ok)
assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema"
assert "value" in schema["parameters"]["properties"], "value should be in schema"
# Test with agent_state param
schema = generate_schema(agent_state_ok)
assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema"
assert "value" in schema["parameters"]["properties"], "value should be in schema"
# Test with all reserved params
schema = generate_schema(all_reserved_params_ok)
assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema"
assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema"
assert "value" in schema["parameters"]["properties"], "value should be in schema"
assert schema["parameters"]["required"] == ["value"], "only value should be required"
def test_complex_nested_anyof_schema_to_structured_output():
"""Test that complex nested anyOf schemas with inlined $refs can be converted to structured outputs.