fix: fix client injection code (#6421)
This commit is contained in:
committed by
Caren Thomas
parent
4d5be22d14
commit
0a0cf391fc
@@ -465,3 +465,7 @@ MODAL_DEFAULT_PYTHON_VERSION = "3.12"
|
||||
MODAL_SAFE_IMPORT_MODULES = {"typing", "pydantic", "datetime", "uuid"} # decimal, enum
|
||||
# Default handle for model used to generate tools
|
||||
DEFAULT_GENERATE_TOOL_MODEL_HANDLE = "openai/gpt-4.1"
|
||||
|
||||
# Reserved keyword arguments that are injected by the system into tool functions, not provided by the LLM
|
||||
# These parameters are excluded from tool schema generation
|
||||
TOOL_RESERVED_KWARGS = ["self", "agent_state", "client"]
|
||||
|
||||
@@ -5,7 +5,7 @@ from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, TOOL_RESERVED_KWARGS
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.log import get_logger
|
||||
|
||||
@@ -34,7 +34,7 @@ def validate_google_style_docstring(function):
|
||||
# 3. Args and Returns sections should be properly formatted
|
||||
|
||||
sig = inspect.signature(function)
|
||||
has_params = any(param.name not in ["self", "agent_state"] for param in sig.parameters.values())
|
||||
has_params = any(param.name not in TOOL_RESERVED_KWARGS for param in sig.parameters.values())
|
||||
|
||||
# Check for Args section if function has parameters
|
||||
if has_params and "Args:" not in docstring:
|
||||
@@ -51,7 +51,7 @@ def validate_google_style_docstring(function):
|
||||
|
||||
# Check that each parameter is documented
|
||||
for param in sig.parameters.values():
|
||||
if param.name in ["self", "agent_state"]:
|
||||
if param.name in TOOL_RESERVED_KWARGS:
|
||||
continue
|
||||
if f"{param.name} (" not in args_section and f"{param.name}:" not in args_section:
|
||||
raise ValueError(
|
||||
@@ -448,12 +448,9 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
"parameters": {"type": "object", "properties": {}, "required": []},
|
||||
}
|
||||
|
||||
# TODO: ensure that 'agent' keyword is reserved for `Agent` class
|
||||
|
||||
for param in sig.parameters.values():
|
||||
# Exclude 'self' parameter
|
||||
# TODO: eventually remove this (only applies to BASE_TOOLS)
|
||||
if param.name in ["self", "agent_state"]: # Add agent_manager to excluded
|
||||
# Exclude reserved parameters that are injected by the system
|
||||
if param.name in TOOL_RESERVED_KWARGS:
|
||||
continue
|
||||
|
||||
# Assert that the parameter has a type annotation
|
||||
|
||||
@@ -67,7 +67,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
self.inject_agent_state = False
|
||||
|
||||
# Check for Letta client and agent_id injection
|
||||
self.inject_letta_client = "letta_client" in tool_arguments or "client" in tool_arguments
|
||||
self.inject_letta_client = "client" in tool_arguments
|
||||
self.inject_agent_id = "agent_id" in tool_arguments
|
||||
|
||||
self.is_async_function = self._detect_async_function()
|
||||
@@ -195,20 +195,18 @@ class AsyncToolSandboxBase(ABC):
|
||||
[
|
||||
"# Initialize Letta client for tool execution",
|
||||
"import os",
|
||||
"letta_client = None",
|
||||
"client = None",
|
||||
"if os.getenv('LETTA_API_KEY'):",
|
||||
" # Check letta_client version to use correct parameter name",
|
||||
" from packaging import version as pkg_version",
|
||||
" import letta_client as lc_module",
|
||||
" lc_version = pkg_version.parse(lc_module.__version__)",
|
||||
" if lc_version < pkg_version.parse('1.0.0'):",
|
||||
" letta_client = Letta(",
|
||||
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||
" client = Letta(",
|
||||
" token=os.getenv('LETTA_API_KEY')",
|
||||
" )",
|
||||
" else:",
|
||||
" letta_client = Letta(",
|
||||
" base_url=os.getenv('LETTA_BASE_URL', 'http://localhost:8283'),",
|
||||
" client = Letta(",
|
||||
" api_key=os.getenv('LETTA_API_KEY')",
|
||||
" )",
|
||||
]
|
||||
@@ -346,9 +344,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
# Check if the function expects 'client' or 'letta_client'
|
||||
tool_arguments = parse_function_arguments(self.tool.source_code, self.tool.name)
|
||||
if "client" in tool_arguments:
|
||||
param_list.append("client=letta_client")
|
||||
elif "letta_client" in tool_arguments:
|
||||
param_list.append("letta_client=letta_client")
|
||||
param_list.append("client=client")
|
||||
|
||||
if self.inject_agent_id:
|
||||
param_list.append("agent_id=agent_id")
|
||||
|
||||
@@ -1287,7 +1287,7 @@ async def test_local_sandbox_with_client_injection(disable_e2b_api_key, list_too
|
||||
# Verify the script contains Letta client initialization
|
||||
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||
assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client"
|
||||
|
||||
# Run the tool and verify it works
|
||||
result = await sandbox.run(agent_state=None)
|
||||
@@ -1345,6 +1345,6 @@ async def test_e2b_sandbox_with_client_injection(check_e2b_key_is_set, list_tool
|
||||
# Verify the script contains Letta client initialization
|
||||
assert "from letta_client import Letta" in script, "Script should import Letta client"
|
||||
assert "LETTA_API_KEY" in script, "Script should check for LETTA_API_KEY"
|
||||
assert "letta_client = Letta(" in script or "letta_client = None" in script, "Script should initialize Letta client"
|
||||
assert "client = Letta(" in script or "client = None" in script, "Script should initialize Letta client"
|
||||
|
||||
# Cannot run the tool since E2B is remote
|
||||
|
||||
@@ -401,6 +401,30 @@ def agent_state_ok(agent_state, value: int) -> str:
|
||||
return "ok"
|
||||
|
||||
|
||||
def client_ok(client, value: int) -> str:
|
||||
"""Ignores client param (injected Letta client).
|
||||
|
||||
Args:
|
||||
value (int): Some value.
|
||||
|
||||
Returns:
|
||||
str: Status.
|
||||
"""
|
||||
return "ok"
|
||||
|
||||
|
||||
def all_reserved_params_ok(agent_state, client, value: int) -> str:
|
||||
"""Ignores all reserved params.
|
||||
|
||||
Args:
|
||||
value (int): Some value.
|
||||
|
||||
Returns:
|
||||
str: Status.
|
||||
"""
|
||||
return "ok"
|
||||
|
||||
|
||||
class Dummy:
|
||||
def method(self, bar: int) -> str: # keeps an explicit self
|
||||
"""Bound-method example.
|
||||
@@ -446,6 +470,8 @@ def missing_param_doc(x: int, y: int) -> str:
|
||||
[
|
||||
(good_function, None),
|
||||
(agent_state_ok, None),
|
||||
(client_ok, None), # client is a reserved param (injected Letta client)
|
||||
(all_reserved_params_ok, None), # all reserved params together
|
||||
(Dummy.method, None), # unbound method keeps `self`
|
||||
(good_function_no_return, None),
|
||||
(no_doc, "has no docstring"),
|
||||
@@ -457,6 +483,28 @@ def test_google_style_docstring_validation(fn, regex):
|
||||
_check(fn, regex)
|
||||
|
||||
|
||||
def test_reserved_params_excluded_from_schema():
|
||||
"""Test that reserved params (agent_state, client) are excluded from generated schema."""
|
||||
from letta.functions.schema_generator import generate_schema
|
||||
|
||||
# Test with client param
|
||||
schema = generate_schema(client_ok)
|
||||
assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema"
|
||||
assert "value" in schema["parameters"]["properties"], "value should be in schema"
|
||||
|
||||
# Test with agent_state param
|
||||
schema = generate_schema(agent_state_ok)
|
||||
assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema"
|
||||
assert "value" in schema["parameters"]["properties"], "value should be in schema"
|
||||
|
||||
# Test with all reserved params
|
||||
schema = generate_schema(all_reserved_params_ok)
|
||||
assert "agent_state" not in schema["parameters"]["properties"], "agent_state should be excluded from schema"
|
||||
assert "client" not in schema["parameters"]["properties"], "client should be excluded from schema"
|
||||
assert "value" in schema["parameters"]["properties"], "value should be in schema"
|
||||
assert schema["parameters"]["required"] == ["value"], "only value should be required"
|
||||
|
||||
|
||||
def test_complex_nested_anyof_schema_to_structured_output():
|
||||
"""Test that complex nested anyOf schemas with inlined $refs can be converted to structured outputs.
|
||||
|
||||
|
||||
Reference in New Issue
Block a user