From 70a2b1de7993f424a9839d769a78d5e4e95b0440 Mon Sep 17 00:00:00 2001 From: Andy Li <55300002+cliandy@users.noreply.github.com> Date: Thu, 29 May 2025 10:29:05 -0700 Subject: [PATCH] feat: insert heartbeat dynamically and remove from stored json (#2501) --- letta/agent.py | 5 ++- letta/agents/letta_agent.py | 7 +++- letta/constants.py | 1 + letta/functions/functions.py | 2 + letta/functions/schema_generator.py | 30 +++++---------- letta/schemas/tool.py | 38 +++++++++---------- letta/services/helpers/tool_parser_helper.py | 39 +++++++++++++++++++- tests/test_tool_schema_parsing.py | 2 +- 8 files changed, 80 insertions(+), 44 deletions(-) diff --git a/letta/agent.py b/letta/agent.py index 4c8b2553..c4565491 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -59,6 +59,7 @@ from letta.schemas.usage import LettaUsageStatistics from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block +from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema from letta.services.job_manager import JobManager from letta.services.mcp.base_client import AsyncBaseMCPClient from letta.services.message_manager import MessageManager @@ -327,7 +328,9 @@ class Agent(BaseAgent): return None allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names] - allowed_functions = self._runtime_override_tool_json_schema(allowed_functions) + allowed_functions = runtime_override_tool_json_schema( + tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True + ) # For the first message, force the initial tool if one is specified force_tool_call = None diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index b6a48330..04dca491 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -33,6 +33,7 @@ from letta.schemas.user import User from letta.server.rest_api.utils import create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager +from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.services.step_manager import NoopStepManager, StepManager @@ -478,7 +479,7 @@ class LettaAgent(BaseAgent): in_context_messages: List[Message], agent_state: AgentState, tool_rules_solver: ToolRulesSolver, - ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: + ) -> dict: self.num_messages, self.num_archival_memories = await asyncio.gather( ( self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) @@ -510,7 +511,6 @@ class LettaAgent(BaseAgent): ToolType.EXTERNAL_COMPOSIO, ToolType.EXTERNAL_MCP, } - or (t.tool_type == ToolType.EXTERNAL_COMPOSIO) ] # Mirror the sync agent loop: get allowed tools or allow all if none are allowed @@ -528,6 +528,9 @@ class LettaAgent(BaseAgent): force_tool_call = valid_tool_names[0] allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] + allowed_tools = runtime_override_tool_json_schema( + tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True + ) return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call) diff --git a/letta/constants.py b/letta/constants.py index 88bb7f53..334128e0 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -139,6 +139,7 @@ DEFAULT_MESSAGE_TOOL_KWARG = "message" PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg" REQUEST_HEARTBEAT_PARAM = "request_heartbeat" +REQUEST_HEARTBEAT_DESCRIPTION = "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function." # Structured output models diff --git a/letta/functions/functions.py b/letta/functions/functions.py index b0c41a86..39a25381 100644 --- a/letta/functions/functions.py +++ b/letta/functions/functions.py @@ -12,6 +12,8 @@ from letta.functions.schema_generator import generate_schema def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict: """Derives the OpenAI JSON schema for a given function source code. + # TODO (cliandy): I don't think we need to or should execute here + # TODO (cliandy): CONFIRM THIS BEFORE MERGING. First, attempts to execute the source code in a custom environment with only the necessary imports. Then, it generates the schema from the function's docstring and signature. """ diff --git a/letta/functions/schema_generator.py b/letta/functions/schema_generator.py index 6f1c003c..db8ae2a0 100644 --- a/letta/functions/schema_generator.py +++ b/letta/functions/schema_generator.py @@ -7,6 +7,7 @@ from docstring_parser import parse from pydantic import BaseModel from typing_extensions import Literal +from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM from letta.functions.mcp_client.types import MCPTool @@ -422,17 +423,6 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[ # TODO is this not duplicating the other append directly above? if param.annotation == inspect.Parameter.empty: schema["parameters"]["required"].append(param.name) - - # append the heartbeat - # TODO: don't hard-code - # TODO: if terminal, don't include this - # if function.__name__ not in ["send_message"]: - schema["parameters"]["properties"]["request_heartbeat"] = { - "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", - } - schema["parameters"]["required"].append("request_heartbeat") - return schema @@ -455,11 +445,11 @@ def generate_schema_from_args_schema_v2( } if append_heartbeat: - function_call_json["parameters"]["properties"]["request_heartbeat"] = { + function_call_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = { "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", + "description": REQUEST_HEARTBEAT_DESCRIPTION, } - function_call_json["parameters"]["required"].append("request_heartbeat") + function_call_json["parameters"]["required"].append(REQUEST_HEARTBEAT_PARAM) return function_call_json @@ -486,11 +476,11 @@ def generate_tool_schema_for_mcp( # Add the optional heartbeat parameter if append_heartbeat: - parameters_schema["properties"]["request_heartbeat"] = { + parameters_schema["properties"][REQUEST_HEARTBEAT_PARAM] = { "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", + "description": REQUEST_HEARTBEAT_DESCRIPTION, } - parameters_schema["required"].append("request_heartbeat") + parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM) # Return the final schema if strict: @@ -548,11 +538,11 @@ def generate_tool_schema_for_composio( # Add the optional heartbeat parameter if append_heartbeat: - properties_json["request_heartbeat"] = { + properties_json[REQUEST_HEARTBEAT_PARAM] = { "type": "boolean", - "description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.", + "description": REQUEST_HEARTBEAT_DESCRIPTION, } - required_fields.append("request_heartbeat") + required_fields.append(REQUEST_HEARTBEAT_PARAM) # Return the final schema if strict: diff --git a/letta/schemas/tool.py b/letta/schemas/tool.py index af20bd2a..d805f6ec 100644 --- a/letta/schemas/tool.py +++ b/letta/schemas/tool.py @@ -72,31 +72,28 @@ class Tool(BaseTool): """ from letta.functions.helpers import generate_model_from_args_json_schema - if self.tool_type == ToolType.CUSTOM: - # If it's a custom tool, we need to ensure source_code is present + if self.tool_type is ToolType.CUSTOM: if not self.source_code: error_msg = f"Custom tool with id={self.id} is missing source_code field." logger.error(error_msg) raise ValueError(error_msg) # Always derive json_schema for freshest possible json_schema - # TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType - # TODO: We skip this for Composio bc composio json schemas are derived differently - if not (COMPOSIO_TOOL_TAG_NAME in self.tags): - if self.args_json_schema is not None: - name, description = get_function_name_and_docstring(self.source_code, self.name) - args_schema = generate_model_from_args_json_schema(self.args_json_schema) - self.json_schema = generate_schema_from_args_schema_v2( - args_schema=args_schema, - name=name, - description=description, - ) - else: - try: - self.json_schema = derive_openai_json_schema(source_code=self.source_code) - except Exception as e: - error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}" - logger.error(error_msg) + if self.args_json_schema is not None: + name, description = get_function_name_and_docstring(self.source_code, self.name) + args_schema = generate_model_from_args_json_schema(self.args_json_schema) + self.json_schema = generate_schema_from_args_schema_v2( + args_schema=args_schema, + name=name, + description=description, + append_heartbeat=False, + ) + else: + try: + self.json_schema = derive_openai_json_schema(source_code=self.source_code) + except Exception as e: + error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}" + logger.error(error_msg) elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}: # If it's letta core tool, we generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name) @@ -109,6 +106,9 @@ class Tool(BaseTool): elif self.tool_type in {ToolType.LETTA_BUILTIN}: # If it's letta voice tool, we generate the json_schema on the fly here self.json_schema = get_json_schema_from_module(module_name=LETTA_BUILTIN_TOOL_MODULE_NAME, function_name=self.name) + elif self.tool_type in {ToolType.EXTERNAL_COMPOSIO}: + # Composio schemas handled separately + pass # At this point, we need to validate that at least json_schema is populated if not self.json_schema: diff --git a/letta/services/helpers/tool_parser_helper.py b/letta/services/helpers/tool_parser_helper.py index 36798a91..145eed52 100644 --- a/letta/services/helpers/tool_parser_helper.py +++ b/letta/services/helpers/tool_parser_helper.py @@ -3,8 +3,10 @@ import base64 import pickle from typing import Any +from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME from letta.schemas.agent import AgentState -from letta.types import JsonValue +from letta.schemas.response_format import ResponseFormat, ResponseFormatType, ResponseFormatUnion +from letta.types import JsonDict, JsonValue def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]: @@ -61,3 +63,38 @@ def convert_param_to_str_value(param_type: str, raw_value: JsonValue) -> str: # raise ValueError(f'Invalid array value: "{raw_value}"') # return raw_value.strip() return str(raw_value) + + +def runtime_override_tool_json_schema( + tool_list: list[JsonDict], + response_format: ResponseFormatUnion | None, + request_heartbeat: bool = True, +) -> list[JsonDict]: + """Override the tool JSON schemas at runtime if certain conditions are met. + + Cases: + 1. We will inject `send_message` tool calls with `response_format` if provided + 2. Tools will have an additional `request_heartbeat` parameter added. + """ + for tool_json in tool_list: + if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text: + if response_format.type == ResponseFormatType.json_schema: + tool_json["parameters"]["properties"]["message"] = response_format.json_schema["schema"] + if response_format.type == ResponseFormatType.json_object: + tool_json["parameters"]["properties"]["message"] = { + "type": "object", + "description": "Message contents. All unicode (including emojis) are supported.", + "additionalProperties": True, + "properties": {}, + } + if request_heartbeat: + # TODO (cliandy): see support for tool control loop parameters + if tool_json["name"] != SEND_MESSAGE_TOOL_NAME: + tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = { + "type": "boolean", + "description": REQUEST_HEARTBEAT_DESCRIPTION, + } + if REQUEST_HEARTBEAT_PARAM not in tool_json["parameters"]["required"]: + tool_json["parameters"]["required"].append(REQUEST_HEARTBEAT_PARAM) + + return tool_list diff --git a/tests/test_tool_schema_parsing.py b/tests/test_tool_schema_parsing.py index 08ad8052..89b2ed73 100644 --- a/tests/test_tool_schema_parsing.py +++ b/tests/test_tool_schema_parsing.py @@ -67,7 +67,7 @@ def _run_schema_test(schema_name: str, desired_function_name: str, expect_struct with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{schema_name}.json"), "r") as file: expected_schema = json.load(file) - _compare_schemas(schema, expected_schema) + _compare_schemas(schema, expected_schema, False) # Convert to structured output and compare if expect_structured_output_fail: