feat: insert heartbeat dynamically and remove from stored json (#2501)
This commit is contained in:
@@ -59,6 +59,7 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
from letta.services.message_manager import MessageManager
|
||||
@@ -327,7 +328,9 @@ class Agent(BaseAgent):
|
||||
return None
|
||||
|
||||
allowed_functions = [func for func in agent_state_tool_jsons if func["name"] in allowed_tool_names]
|
||||
allowed_functions = self._runtime_override_tool_json_schema(allowed_functions)
|
||||
allowed_functions = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_functions, response_format=self.agent_state.response_format, request_heartbeat=True
|
||||
)
|
||||
|
||||
# For the first message, force the initial tool if one is specified
|
||||
force_tool_call = None
|
||||
|
||||
@@ -33,6 +33,7 @@ from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.step_manager import NoopStepManager, StepManager
|
||||
@@ -478,7 +479,7 @@ class LettaAgent(BaseAgent):
|
||||
in_context_messages: List[Message],
|
||||
agent_state: AgentState,
|
||||
tool_rules_solver: ToolRulesSolver,
|
||||
) -> ChatCompletion | AsyncStream[ChatCompletionChunk]:
|
||||
) -> dict:
|
||||
self.num_messages, self.num_archival_memories = await asyncio.gather(
|
||||
(
|
||||
self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
@@ -510,7 +511,6 @@ class LettaAgent(BaseAgent):
|
||||
ToolType.EXTERNAL_COMPOSIO,
|
||||
ToolType.EXTERNAL_MCP,
|
||||
}
|
||||
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
|
||||
]
|
||||
|
||||
# Mirror the sync agent loop: get allowed tools or allow all if none are allowed
|
||||
@@ -528,6 +528,9 @@ class LettaAgent(BaseAgent):
|
||||
force_tool_call = valid_tool_names[0]
|
||||
|
||||
allowed_tools = [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)]
|
||||
allowed_tools = runtime_override_tool_json_schema(
|
||||
tool_list=allowed_tools, response_format=agent_state.response_format, request_heartbeat=True
|
||||
)
|
||||
|
||||
return llm_client.build_request_data(in_context_messages, agent_state.llm_config, allowed_tools, force_tool_call)
|
||||
|
||||
|
||||
@@ -139,6 +139,7 @@ DEFAULT_MESSAGE_TOOL_KWARG = "message"
|
||||
PRE_EXECUTION_MESSAGE_ARG = "pre_exec_msg"
|
||||
|
||||
REQUEST_HEARTBEAT_PARAM = "request_heartbeat"
|
||||
REQUEST_HEARTBEAT_DESCRIPTION = "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function."
|
||||
|
||||
|
||||
# Structured output models
|
||||
|
||||
@@ -12,6 +12,8 @@ from letta.functions.schema_generator import generate_schema
|
||||
def derive_openai_json_schema(source_code: str, name: Optional[str] = None) -> dict:
|
||||
"""Derives the OpenAI JSON schema for a given function source code.
|
||||
|
||||
# TODO (cliandy): I don't think we need to or should execute here
|
||||
# TODO (cliandy): CONFIRM THIS BEFORE MERGING.
|
||||
First, attempts to execute the source code in a custom environment with only the necessary imports.
|
||||
Then, it generates the schema from the function's docstring and signature.
|
||||
"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from docstring_parser import parse
|
||||
from pydantic import BaseModel
|
||||
from typing_extensions import Literal
|
||||
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
|
||||
|
||||
@@ -422,17 +423,6 @@ def generate_schema(function, name: Optional[str] = None, description: Optional[
|
||||
# TODO is this not duplicating the other append directly above?
|
||||
if param.annotation == inspect.Parameter.empty:
|
||||
schema["parameters"]["required"].append(param.name)
|
||||
|
||||
# append the heartbeat
|
||||
# TODO: don't hard-code
|
||||
# TODO: if terminal, don't include this
|
||||
# if function.__name__ not in ["send_message"]:
|
||||
schema["parameters"]["properties"]["request_heartbeat"] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
}
|
||||
schema["parameters"]["required"].append("request_heartbeat")
|
||||
|
||||
return schema
|
||||
|
||||
|
||||
@@ -455,11 +445,11 @@ def generate_schema_from_args_schema_v2(
|
||||
}
|
||||
|
||||
if append_heartbeat:
|
||||
function_call_json["parameters"]["properties"]["request_heartbeat"] = {
|
||||
function_call_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
}
|
||||
function_call_json["parameters"]["required"].append("request_heartbeat")
|
||||
function_call_json["parameters"]["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
return function_call_json
|
||||
|
||||
@@ -486,11 +476,11 @@ def generate_tool_schema_for_mcp(
|
||||
|
||||
# Add the optional heartbeat parameter
|
||||
if append_heartbeat:
|
||||
parameters_schema["properties"]["request_heartbeat"] = {
|
||||
parameters_schema["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
}
|
||||
parameters_schema["required"].append("request_heartbeat")
|
||||
parameters_schema["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
# Return the final schema
|
||||
if strict:
|
||||
@@ -548,11 +538,11 @@ def generate_tool_schema_for_composio(
|
||||
|
||||
# Add the optional heartbeat parameter
|
||||
if append_heartbeat:
|
||||
properties_json["request_heartbeat"] = {
|
||||
properties_json[REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": "Request an immediate heartbeat after function execution. Set to `True` if you want to send a follow-up message or run a follow-up function.",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
}
|
||||
required_fields.append("request_heartbeat")
|
||||
required_fields.append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
# Return the final schema
|
||||
if strict:
|
||||
|
||||
@@ -72,31 +72,28 @@ class Tool(BaseTool):
|
||||
"""
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
# If it's a custom tool, we need to ensure source_code is present
|
||||
if self.tool_type is ToolType.CUSTOM:
|
||||
if not self.source_code:
|
||||
error_msg = f"Custom tool with id={self.id} is missing source_code field."
|
||||
logger.error(error_msg)
|
||||
raise ValueError(error_msg)
|
||||
|
||||
# Always derive json_schema for freshest possible json_schema
|
||||
# TODO: Instead of checking the tag, we should having `COMPOSIO` as a specific ToolType
|
||||
# TODO: We skip this for Composio bc composio json schemas are derived differently
|
||||
if not (COMPOSIO_TOOL_TAG_NAME in self.tags):
|
||||
if self.args_json_schema is not None:
|
||||
name, description = get_function_name_and_docstring(self.source_code, self.name)
|
||||
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
|
||||
self.json_schema = generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
if self.args_json_schema is not None:
|
||||
name, description = get_function_name_and_docstring(self.source_code, self.name)
|
||||
args_schema = generate_model_from_args_json_schema(self.args_json_schema)
|
||||
self.json_schema = generate_schema_from_args_schema_v2(
|
||||
args_schema=args_schema,
|
||||
name=name,
|
||||
description=description,
|
||||
append_heartbeat=False,
|
||||
)
|
||||
else:
|
||||
try:
|
||||
self.json_schema = derive_openai_json_schema(source_code=self.source_code)
|
||||
except Exception as e:
|
||||
error_msg = f"Failed to derive json schema for tool with id={self.id} name={self.name}. Error: {str(e)}"
|
||||
logger.error(error_msg)
|
||||
elif self.tool_type in {ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE, ToolType.LETTA_SLEEPTIME_CORE}:
|
||||
# If it's letta core tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_CORE_TOOL_MODULE_NAME, function_name=self.name)
|
||||
@@ -109,6 +106,9 @@ class Tool(BaseTool):
|
||||
elif self.tool_type in {ToolType.LETTA_BUILTIN}:
|
||||
# If it's letta voice tool, we generate the json_schema on the fly here
|
||||
self.json_schema = get_json_schema_from_module(module_name=LETTA_BUILTIN_TOOL_MODULE_NAME, function_name=self.name)
|
||||
elif self.tool_type in {ToolType.EXTERNAL_COMPOSIO}:
|
||||
# Composio schemas handled separately
|
||||
pass
|
||||
|
||||
# At this point, we need to validate that at least json_schema is populated
|
||||
if not self.json_schema:
|
||||
|
||||
@@ -3,8 +3,10 @@ import base64
|
||||
import pickle
|
||||
from typing import Any
|
||||
|
||||
from letta.constants import REQUEST_HEARTBEAT_DESCRIPTION, REQUEST_HEARTBEAT_PARAM, SEND_MESSAGE_TOOL_NAME
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.types import JsonValue
|
||||
from letta.schemas.response_format import ResponseFormat, ResponseFormatType, ResponseFormatUnion
|
||||
from letta.types import JsonDict, JsonValue
|
||||
|
||||
|
||||
def parse_stdout_best_effort(text: str | bytes) -> tuple[Any, AgentState | None]:
|
||||
@@ -61,3 +63,38 @@ def convert_param_to_str_value(param_type: str, raw_value: JsonValue) -> str:
|
||||
# raise ValueError(f'Invalid array value: "{raw_value}"')
|
||||
# return raw_value.strip()
|
||||
return str(raw_value)
|
||||
|
||||
|
||||
def runtime_override_tool_json_schema(
|
||||
tool_list: list[JsonDict],
|
||||
response_format: ResponseFormatUnion | None,
|
||||
request_heartbeat: bool = True,
|
||||
) -> list[JsonDict]:
|
||||
"""Override the tool JSON schemas at runtime if certain conditions are met.
|
||||
|
||||
Cases:
|
||||
1. We will inject `send_message` tool calls with `response_format` if provided
|
||||
2. Tools will have an additional `request_heartbeat` parameter added.
|
||||
"""
|
||||
for tool_json in tool_list:
|
||||
if tool_json["name"] == SEND_MESSAGE_TOOL_NAME and response_format and response_format.type != ResponseFormatType.text:
|
||||
if response_format.type == ResponseFormatType.json_schema:
|
||||
tool_json["parameters"]["properties"]["message"] = response_format.json_schema["schema"]
|
||||
if response_format.type == ResponseFormatType.json_object:
|
||||
tool_json["parameters"]["properties"]["message"] = {
|
||||
"type": "object",
|
||||
"description": "Message contents. All unicode (including emojis) are supported.",
|
||||
"additionalProperties": True,
|
||||
"properties": {},
|
||||
}
|
||||
if request_heartbeat:
|
||||
# TODO (cliandy): see support for tool control loop parameters
|
||||
if tool_json["name"] != SEND_MESSAGE_TOOL_NAME:
|
||||
tool_json["parameters"]["properties"][REQUEST_HEARTBEAT_PARAM] = {
|
||||
"type": "boolean",
|
||||
"description": REQUEST_HEARTBEAT_DESCRIPTION,
|
||||
}
|
||||
if REQUEST_HEARTBEAT_PARAM not in tool_json["parameters"]["required"]:
|
||||
tool_json["parameters"]["required"].append(REQUEST_HEARTBEAT_PARAM)
|
||||
|
||||
return tool_list
|
||||
|
||||
@@ -67,7 +67,7 @@ def _run_schema_test(schema_name: str, desired_function_name: str, expect_struct
|
||||
with open(os.path.join(os.path.dirname(__file__), f"test_tool_schema_parsing_files/{schema_name}.json"), "r") as file:
|
||||
expected_schema = json.load(file)
|
||||
|
||||
_compare_schemas(schema, expected_schema)
|
||||
_compare_schemas(schema, expected_schema, False)
|
||||
|
||||
# Convert to structured output and compare
|
||||
if expect_structured_output_fail:
|
||||
|
||||
Reference in New Issue
Block a user