Files
letta-server/letta/agents/helpers.py
Kian Jones 25d54dd896 chore: enable F821, F401, W293 (#9503)
* auto fixes

* auto fix pt2 and transitive deps and undefined var checking locals()

* manual fixes (ignored or letta-code fixed)

* fix circular import
2026-02-24 10:55:08 -08:00

544 lines
24 KiB
Python

import json
import xml.etree.ElementTree as ET
from typing import Any, Dict, List, Optional, Tuple
from uuid import UUID, uuid4
from letta.errors import LettaError, PendingApprovalError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message import MessageType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import ApprovalCreate, Message, MessageCreate, MessageCreateBase
from letta.schemas.tool_execution_result import ToolExecutionResult
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.utils import create_approval_response_message_from_input, create_input_messages
from letta.services.message_manager import MessageManager
logger = get_logger(__name__)
def _create_letta_response(
new_in_context_messages: list[Message],
use_assistant_message: bool,
usage: LettaUsageStatistics,
stop_reason: Optional[LettaStopReason] = None,
include_return_message_types: Optional[List[MessageType]] = None,
) -> LettaResponse:
"""
Converts the newly created/persisted messages into a LettaResponse.
"""
# NOTE: hacky solution to avoid returning heartbeat messages and the original user message
filter_user_messages = [m for m in new_in_context_messages if m.role != "user"]
# Convert to Letta messages first
response_messages = Message.to_letta_messages_from_list(
messages=filter_user_messages, use_assistant_message=use_assistant_message, reverse=False
)
# Filter approval response messages
response_messages = [m for m in response_messages if m.message_type != "approval_response_message"]
# Apply message type filtering if specified
if include_return_message_types is not None:
response_messages = [msg for msg in response_messages if msg.message_type in include_return_message_types]
if stop_reason is None:
stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
return LettaResponse(messages=response_messages, stop_reason=stop_reason, usage=usage)
async def _prepare_in_context_messages_async(
input_messages: List[MessageCreate],
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
run_id: str,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
Async version of _prepare_in_context_messages.
Args:
input_messages (List[MessageCreate]): The new user input messages to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
- The current in-context messages (existing context for the agent).
- The new in-context messages (messages created from the new input).
"""
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Create a new user message from the input and store it
input_msgs = await create_input_messages(
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
new_in_context_messages = await message_manager.create_many_messages_async(
input_msgs,
actor=actor,
project_id=agent_state.project_id,
)
return current_in_context_messages, new_in_context_messages
@trace_method
def validate_persisted_tool_call_ids(tool_return_message: Message, approval_response_message: ApprovalCreate) -> bool:
persisted_tool_returns = tool_return_message.tool_returns
if not persisted_tool_returns:
return False
persisted_tool_call_ids = [tool_return.tool_call_id for tool_return in persisted_tool_returns]
approval_responses = approval_response_message.approvals
if not approval_responses:
return False
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
request_response_diff = set(persisted_tool_call_ids).symmetric_difference(set(approval_response_tool_call_ids))
if request_response_diff:
return False
return True
@trace_method
def validate_approval_tool_call_ids(approval_request_message: Message, approval_response_message: ApprovalCreate):
approval_requests = approval_request_message.tool_calls
if approval_requests:
approval_request_tool_call_ids = [approval_request.id for approval_request in approval_requests]
elif approval_request_message.tool_call_id:
approval_request_tool_call_ids = [approval_request_message.tool_call_id]
else:
raise ValueError(
f"Invalid tool call IDs. Approval request message '{approval_request_message.id}' does not contain any tool calls."
)
approval_responses = approval_response_message.approvals
if not approval_responses:
raise ValueError("Invalid approval response. Approval response message does not contain any approvals.")
approval_response_tool_call_ids = [approval_response.tool_call_id for approval_response in approval_responses]
request_response_diff = set(approval_request_tool_call_ids).symmetric_difference(set(approval_response_tool_call_ids))
if request_response_diff:
if len(approval_request_tool_call_ids) == 1 and approval_response_tool_call_ids[0] == approval_request_message.id:
# legacy case where we used to use message id instead of tool call id
return
raise ValueError(
f"Invalid tool call IDs. Expected '{approval_request_tool_call_ids}', but received '{approval_response_tool_call_ids}'."
)
@trace_method
async def _prepare_in_context_messages_no_persist_async(
input_messages: List[MessageCreateBase],
agent_state: AgentState,
message_manager: MessageManager,
actor: User,
run_id: Optional[str] = None,
conversation_id: Optional[str] = None,
) -> Tuple[List[Message], List[Message]]:
"""
Prepares in-context messages for an agent, based on the current state and a new user input.
When conversation_id is provided, messages are loaded from the conversation_messages
table instead of agent_state.message_ids.
Args:
input_messages (List[MessageCreate]): The new user input messages to process.
agent_state (AgentState): The current state of the agent, including message buffer config.
message_manager (MessageManager): The manager used to retrieve and create messages.
actor (User): The user performing the action, used for access control and attribution.
run_id (str): The run ID associated with this message processing.
conversation_id (str): Optional conversation ID to load messages from.
Returns:
Tuple[List[Message], List[Message]]: A tuple containing:
- The current in-context messages (existing context for the agent).
- The new in-context messages (messages created from the new input).
"""
if conversation_id:
# Conversation mode: load messages from conversation_messages table
from letta.services.conversation_manager import ConversationManager
conversation_manager = ConversationManager()
message_ids = await conversation_manager.get_message_ids_for_conversation(
conversation_id=conversation_id,
actor=actor,
)
if agent_state.message_buffer_autoclear and message_ids:
# If autoclear is enabled, only include the system message
current_in_context_messages = [await message_manager.get_message_by_id_async(message_id=message_ids[0], actor=actor)]
elif message_ids:
# Otherwise, include the full list of messages from the conversation
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
else:
# No messages in conversation yet - compile a new system message for this conversation
# Each conversation gets its own system message (captures memory state at conversation start)
from letta.prompts.prompt_generator import PromptGenerator
from letta.services.passage_manager import PassageManager
num_messages = await message_manager.size_async(actor=actor, agent_id=agent_state.id)
passage_manager = PassageManager()
num_archival_memories = await passage_manager.agent_passage_size_async(actor=actor, agent_id=agent_state.id)
system_message_str = await PromptGenerator.compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=get_utc_time(),
timezone=agent_state.timezone,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=num_messages,
archival_memory_size=num_archival_memories,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
)
system_message = Message.dict_to_message(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
openai_message_dict={"role": "system", "content": system_message_str},
)
# Persist the new system message
persisted_messages = await message_manager.create_many_messages_async([system_message], actor=actor)
system_message = persisted_messages[0]
# Add it to the conversation tracking
await conversation_manager.add_messages_to_conversation(
conversation_id=conversation_id,
agent_id=agent_state.id,
message_ids=[system_message.id],
actor=actor,
starting_position=0,
)
current_in_context_messages = [system_message]
else:
# Default mode: load messages from agent_state.message_ids
if not agent_state.message_ids:
raise LettaError(
message=f"Agent {agent_state.id} has no in-context messages. "
"This typically means the agent's system message was not initialized correctly.",
)
if agent_state.message_buffer_autoclear:
# If autoclear is enabled, only include the most recent system message (usually at index 0)
current_in_context_messages = [
await message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
]
else:
# Otherwise, include the full list of messages by ID for context
current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
# Convert ToolReturnCreate to ApprovalCreate for unified processing
if input_messages[0].type == "tool_return":
tool_return_msg = input_messages[0]
input_messages = [
ApprovalCreate(approvals=tool_return_msg.tool_returns),
*input_messages[1:],
]
# Check for approval-related message validation
if input_messages[0].type == "approval":
# User is trying to send an approval response
if current_in_context_messages and current_in_context_messages[-1].role != "approval":
# No pending approval request - check if this is an idempotent retry
# Check last few messages for a tool return matching the approval's tool_call_ids
# (approved tool return should be recent, but server-side tool calls may come after it)
approval_already_processed = False
recent_messages = current_in_context_messages[-10:] # Only check last 10 messages
for msg in reversed(recent_messages):
if msg.role == "tool" and validate_persisted_tool_call_ids(msg, input_messages[0]):
logger.info(
f"Idempotency check: Found matching tool return in recent history. "
f"tool_returns={msg.tool_returns}, approval_response.approvals={input_messages[0].approvals}"
)
approval_already_processed = True
break
if approval_already_processed:
# Approval already handled, just process follow-up messages if any or manually inject keep-alive message
keep_alive_messages = input_messages[1:] or [
MessageCreate(
role="user",
content=[
TextContent(
text="<system-alert>Automated keep-alive ping. Ignore this message and continue from where you stopped.</system-alert>"
)
],
)
]
new_in_context_messages = await create_input_messages(
input_messages=keep_alive_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
return current_in_context_messages, new_in_context_messages
logger.warn(
f"Cannot process approval response: No tool call is currently awaiting approval. Last message: {current_in_context_messages[-1]}"
)
raise ValueError(
"Cannot process approval response: No tool call is currently awaiting approval. "
"Please send a regular message to interact with the agent."
)
validate_approval_tool_call_ids(current_in_context_messages[-1], input_messages[0])
new_in_context_messages = await create_approval_response_message_from_input(
agent_state=agent_state, input_message=input_messages[0], run_id=run_id
)
if len(input_messages) > 1:
follow_up_messages = await create_input_messages(
input_messages=input_messages[1:], agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
new_in_context_messages.extend(follow_up_messages)
else:
# User is trying to send a regular message
if current_in_context_messages and current_in_context_messages[-1].is_approval_request():
raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id)
# Create a new user message from the input but dont store it yet
new_in_context_messages = await create_input_messages(
input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, run_id=run_id, actor=actor
)
return current_in_context_messages, new_in_context_messages
def serialize_message_history(messages: List[str], context: str) -> str:
"""
Produce an XML document like:
<memory>
<messages>
<message>…</message>
<message>…</message>
</messages>
<context>…</context>
</memory>
"""
root = ET.Element("memory")
msgs_el = ET.SubElement(root, "messages")
for msg in messages:
m = ET.SubElement(msgs_el, "message")
m.text = msg
sum_el = ET.SubElement(root, "context")
sum_el.text = context
# ET.tostring will escape reserved chars for you
return ET.tostring(root, encoding="unicode")
def deserialize_message_history(xml_str: str) -> Tuple[List[str], str]:
"""
Parse the XML back into (messages, context). Raises ValueError if tags are missing.
"""
try:
root = ET.fromstring(xml_str)
except ET.ParseError as e:
raise ValueError(f"Invalid XML: {e}")
msgs_el = root.find("messages")
if msgs_el is None:
raise ValueError("Missing <messages> section")
messages = []
for m in msgs_el.findall("message"):
# .text may be None if empty, so coerce to empty string
messages.append(m.text or "")
sum_el = root.find("context")
if sum_el is None:
raise ValueError("Missing <context> section")
context = sum_el.text or ""
return messages, context
def generate_step_id(uid: Optional[UUID] = None) -> str:
uid = uid or uuid4()
return f"step-{uid}"
def _safe_load_tool_call_str(tool_call_args_str: str) -> dict:
"""Lenient JSON → dict with fallback to eval on assertion failure."""
# Temp hack to gracefully handle parallel tool calling attempt, only take first one
if "}{" in tool_call_args_str:
tool_call_args_str = tool_call_args_str.split("}{", 1)[0] + "}"
try:
tool_args = json.loads(tool_call_args_str)
if not isinstance(tool_args, dict):
# Load it again - this is due to sometimes Anthropic returning weird json @caren
tool_args = json.loads(tool_args)
except json.JSONDecodeError:
logger.error("Failed to JSON decode tool call argument string: %s", tool_call_args_str)
tool_args = {}
return tool_args
def _json_type_matches(value: Any, expected_type: Any) -> bool:
"""Basic JSON Schema type checking for common types.
expected_type can be a string (e.g., "string") or a list (union).
This is intentionally lightweight; deeper validation can be added as needed.
"""
def match_one(v: Any, t: str) -> bool:
if t == "string":
return isinstance(v, str)
if t == "integer":
# bool is subclass of int in Python; exclude
return isinstance(v, int) and not isinstance(v, bool)
if t == "number":
return (isinstance(v, int) and not isinstance(v, bool)) or isinstance(v, float)
if t == "boolean":
return isinstance(v, bool)
if t == "object":
return isinstance(v, dict)
if t == "array":
return isinstance(v, list)
if t == "null":
return v is None
# Fallback: don't over-reject on unknown types
return True
if isinstance(expected_type, list):
return any(match_one(value, t) for t in expected_type)
if isinstance(expected_type, str):
return match_one(value, expected_type)
return True
def _schema_accepts_value(prop_schema: Dict[str, Any], value: Any) -> bool:
"""Check if a value is acceptable for a property schema.
Handles: type, enum, const, anyOf, oneOf (by shallow traversal).
"""
if prop_schema is None:
return True
# const has highest precedence
if "const" in prop_schema:
return value == prop_schema["const"]
# enums
if "enum" in prop_schema:
try:
return value in prop_schema["enum"]
except Exception:
return False
# unions
for union_key in ("anyOf", "oneOf"):
if union_key in prop_schema and isinstance(prop_schema[union_key], list):
for sub in prop_schema[union_key]:
if _schema_accepts_value(sub, value):
return True
return False
# type-based
if "type" in prop_schema:
if not _json_type_matches(value, prop_schema["type"]):
return False
# No strict constraints specified: accept
return True
def merge_and_validate_prefilled_args(tool: "Tool", llm_args: Dict[str, Any], prefilled_args: Dict[str, Any]) -> Dict[str, Any]: # noqa: F821
"""Merge LLM-provided args with prefilled args from tool rules.
- Overlapping keys are replaced by prefilled values (prefilled wins).
- Validates that prefilled keys exist on the tool schema and that values satisfy
basic JSON Schema constraints (type/enum/const/anyOf/oneOf).
- Returns merged args, or raises ValueError on invalid prefilled inputs.
"""
from letta.schemas.tool import Tool # local import to avoid circulars in type hints
assert isinstance(tool, Tool)
schema = (tool.json_schema or {}).get("parameters", {})
props: Dict[str, Any] = schema.get("properties", {}) if isinstance(schema, dict) else {}
errors: list[str] = []
for k, v in prefilled_args.items():
if k not in props:
errors.append(f"Unknown argument '{k}' for tool '{tool.name}'.")
continue
if not _schema_accepts_value(props.get(k), v):
expected = props.get(k, {}).get("type")
errors.append(f"Invalid value for '{k}': {v!r} does not match expected schema type {expected!r}.")
if errors:
raise ValueError("; ".join(errors))
merged = dict(llm_args or {})
merged.update(prefilled_args)
return merged
def _pop_heartbeat(tool_args: dict) -> bool:
hb = tool_args.pop("request_heartbeat", False)
return str(hb).lower() == "true" if isinstance(hb, str) else bool(hb)
def _build_rule_violation_result(tool_name: str, valid: list[str], solver: ToolRulesSolver) -> ToolExecutionResult:
hint_lines = solver.guess_rule_violation(tool_name)
hint_txt = ("\n** Hint: Possible rules that were violated:\n" + "\n".join(f"\t- {h}" for h in hint_lines)) if hint_lines else ""
msg = f"[ToolConstraintError] Cannot call {tool_name}, valid tools include: {valid}.{hint_txt}"
return ToolExecutionResult(status="error", func_return=msg)
def _load_last_function_response(in_context_messages: list[Message]):
"""Load the last function response from message history"""
for msg in reversed(in_context_messages):
if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent):
text_content = msg.content[0].text
try:
response_json = json.loads(text_content)
if response_json.get("message"):
return response_json["message"]
except (json.JSONDecodeError, KeyError):
raise ValueError(f"Invalid JSON format in message: {text_content}")
return None
def _maybe_get_approval_messages(messages: list[Message]) -> Tuple[Message | None, Message | None]:
if len(messages) >= 2:
maybe_approval_request, maybe_approval_response = messages[-2], messages[-1]
if maybe_approval_request.role == "approval" and maybe_approval_response.role == "approval":
return maybe_approval_request, maybe_approval_response
return None, None
def _maybe_get_pending_tool_call_message(messages: list[Message]) -> Message | None:
"""
Only used in the case where hitl is invoked with parallel tool calling,
where agent calls some tools that require approval, and others that don't.
"""
if len(messages) >= 3:
maybe_tool_call_message = messages[-3]
if (
maybe_tool_call_message.role == "assistant"
and maybe_tool_call_message.tool_calls is not None
and len(maybe_tool_call_message.tool_calls) > 0
):
return maybe_tool_call_message
return None