from datetime import datetime from multiprocessing import Value from pickletools import pyunicode from typing import List, Literal, Optional from httpx import AsyncClient from letta.data_sources.redis_client import get_redis_client from letta.errors import LettaInvalidArgumentError from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger from letta.log_context import update_log_context from letta.orm.agent import Agent as AgentModel from letta.orm.errors import NoResultFound from letta.orm.message import Message as MessageModel from letta.orm.run import Run as RunModel from letta.orm.run_metrics import RunMetrics as RunMetricsModel from letta.orm.sqlalchemy_base import AccessType from letta.orm.step import Step as StepModel from letta.otel.tracing import log_event, trace_method from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, PrimitiveType, RunStatus from letta.schemas.job import LettaRequestConfig from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType from letta.schemas.message import Message as PydanticMessage from letta.schemas.run import Run as PydanticRun, RunUpdate from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics from letta.schemas.step import Step as PydanticStep from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens from letta.schemas.user import User as PydanticUser from letta.server.db import db_registry from letta.services.agent_manager import AgentManager from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.services.message_manager import MessageManager from letta.services.step_manager import StepManager from letta.utils import enforce_types from letta.validators import raise_on_invalid_id logger = get_logger(__name__) class RunManager: """Manager class to handle business logic related to Runs.""" def __init__(self): """Initialize the RunManager.""" self.step_manager = StepManager() self.message_manager = MessageManager() self.agent_manager = AgentManager() @enforce_types async def create_run(self, pydantic_run: PydanticRun, actor: PydanticUser) -> PydanticRun: """Create a new run.""" async with db_registry.async_session() as session: # Get agent_id from the pydantic object agent_id = pydantic_run.agent_id # Verify agent exists before creating the run await validate_agent_exists_async(session, agent_id, actor) organization_id = actor.organization_id run_data = pydantic_run.model_dump(exclude_none=True) # Handle metadata field mapping (Pydantic uses 'metadata', ORM uses 'metadata_') if "metadata" in run_data: run_data["metadata_"] = run_data.pop("metadata") run = RunModel(**run_data) run.organization_id = organization_id # Get the project_id from the agent agent = await session.get(AgentModel, agent_id) project_id = agent.project_id if agent else None run.project_id = project_id run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True) update_log_context(run_id=run.id) # Create run metrics with start timestamp import time metrics = RunMetricsModel( id=run.id, organization_id=organization_id, agent_id=agent_id, project_id=project_id, run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds num_steps=0, # Initialize to 0 ) await metrics.create_async(session) # context manager now handles commits # await session.commit() return run.to_pydantic() @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun: """Get a run by its ID.""" update_log_context(run_id=run_id) async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) if not run: raise NoResultFound(f"Run with id {run_id} not found") return run.to_pydantic() @enforce_types async def get_run_with_status(self, run_id: str, actor: PydanticUser) -> PydanticRun: """Get a run by its ID and update status from Lettuce if applicable.""" update_log_context(run_id=run_id) run = await self.get_run_by_id(run_id=run_id, actor=actor) use_lettuce = run.metadata and run.metadata.get("lettuce") if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]: try: from letta.services.lettuce import LettuceClient lettuce_client = await LettuceClient.create() status = await lettuce_client.get_status(run_id=run_id) # Map the status to our enum if status == "RUNNING": run.status = RunStatus.running elif status == "COMPLETED": run.status = RunStatus.completed elif status == "FAILED": run.status = RunStatus.failed elif status == "CANCELLED": run.status = RunStatus.cancelled except Exception as e: logger.error(f"Failed to get status from Lettuce for run {run_id}: {str(e)}") # Return run with current status from DB if Lettuce fails return run @enforce_types async def list_runs( self, actor: PydanticUser, run_id: Optional[str] = None, agent_id: Optional[str] = None, agent_ids: Optional[List[str]] = None, statuses: Optional[List[RunStatus]] = None, limit: Optional[int] = 50, before: Optional[str] = None, after: Optional[str] = None, ascending: bool = False, stop_reason: Optional[str] = None, background: Optional[bool] = None, template_family: Optional[str] = None, step_count: Optional[int] = None, step_count_operator: ComparisonOperator = ComparisonOperator.EQ, tools_used: Optional[List[str]] = None, project_id: Optional[str] = None, conversation_id: Optional[str] = None, order_by: Literal["created_at", "duration"] = "created_at", duration_percentile: Optional[int] = None, duration_filter: Optional[dict] = None, start_date: Optional[datetime] = None, end_date: Optional[datetime] = None, ) -> List[PydanticRun]: """List runs with filtering options.""" async with db_registry.async_session() as session: from sqlalchemy import func, or_, select # Always join with run_metrics to get duration data query = ( select(RunModel, RunMetricsModel.run_ns) .outerjoin(RunMetricsModel, RunModel.id == RunMetricsModel.id) .filter(RunModel.organization_id == actor.organization_id) ) # Filter by project_id if provided if project_id: query = query.filter(RunModel.project_id == project_id) if run_id: query = query.filter(RunModel.id == run_id) # Handle agent filtering if agent_id: agent_ids = [agent_id] if agent_ids: query = query.filter(RunModel.agent_id.in_(agent_ids)) # Filter by status if statuses: query = query.filter(RunModel.status.in_(statuses)) # Filter by stop reason if stop_reason: query = query.filter(RunModel.stop_reason == stop_reason) # Filter by background if background is not None: query = query.filter(RunModel.background == background) # Filter by conversation_id if conversation_id is not None: query = query.filter(RunModel.conversation_id == conversation_id) # Filter by template_family (base_template_id) if template_family: query = query.filter(RunModel.base_template_id == template_family) # Filter by date range if start_date: query = query.filter(RunModel.created_at >= start_date) if end_date: query = query.filter(RunModel.created_at <= end_date) # Filter by step_count with the specified operator if step_count is not None: if step_count_operator == ComparisonOperator.EQ: query = query.filter(RunMetricsModel.num_steps == step_count) elif step_count_operator == ComparisonOperator.GTE: query = query.filter(RunMetricsModel.num_steps >= step_count) elif step_count_operator == ComparisonOperator.LTE: query = query.filter(RunMetricsModel.num_steps <= step_count) # Filter by tools used ids if tools_used: from sqlalchemy import String, cast as sa_cast, type_coerce from sqlalchemy.dialects.postgresql import ARRAY, JSONB # Use ?| operator to check if any tool_id exists in the array (OR logic) jsonb_tools = sa_cast(RunMetricsModel.tools_used, JSONB) tools_array = type_coerce(tools_used, ARRAY(String)) query = query.filter(jsonb_tools.op("?|")(tools_array)) # Ensure run_ns is not null when working with duration if order_by == "duration" or duration_percentile is not None or duration_filter is not None: query = query.filter(RunMetricsModel.run_ns.isnot(None)) # Apply duration filter if requested if duration_filter is not None: duration_value = duration_filter.get("value") if isinstance(duration_filter, dict) else duration_filter.value duration_operator = duration_filter.get("operator") if isinstance(duration_filter, dict) else duration_filter.operator if duration_operator == "gt": query = query.filter(RunMetricsModel.run_ns > duration_value) elif duration_operator == "lt": query = query.filter(RunMetricsModel.run_ns < duration_value) elif duration_operator == "eq": query = query.filter(RunMetricsModel.run_ns == duration_value) # Apply duration percentile filter if requested if duration_percentile is not None: # Calculate the percentile threshold percentile_query = ( select(func.percentile_cont(duration_percentile / 100.0).within_group(RunMetricsModel.run_ns)) .select_from(RunMetricsModel) .join(RunModel, RunModel.id == RunMetricsModel.id) .filter(RunModel.organization_id == actor.organization_id) .filter(RunMetricsModel.run_ns.isnot(None)) ) # Apply same filters to percentile calculation if project_id: percentile_query = percentile_query.filter(RunModel.project_id == project_id) if agent_ids: percentile_query = percentile_query.filter(RunModel.agent_id.in_(agent_ids)) if statuses: percentile_query = percentile_query.filter(RunModel.status.in_(statuses)) # Execute percentile query percentile_result = await session.execute(percentile_query) percentile_threshold = percentile_result.scalar() # Filter by percentile threshold (runs slower than the percentile) if percentile_threshold is not None: query = query.filter(RunMetricsModel.run_ns >= percentile_threshold) # Apply sorting based on order_by if order_by == "duration": # Sort by duration if ascending: query = query.order_by(RunMetricsModel.run_ns.asc()) else: query = query.order_by(RunMetricsModel.run_ns.desc()) else: # Apply pagination for created_at ordering from letta.services.helpers.run_manager_helper import _apply_pagination_async query = await _apply_pagination_async(query, before, after, session, ascending=ascending) # Apply limit (always enforce a maximum to prevent unbounded queries) # If no limit specified, default to 100; enforce maximum of 1000 effective_limit = limit if limit is not None else 100 effective_limit = min(effective_limit, 1000) query = query.limit(effective_limit) result = await session.execute(query) rows = result.all() # Populate total_duration_ns from run_metrics.run_ns pydantic_runs = [] for row in rows: run_model = row[0] run_ns = row[1] pydantic_run = run_model.to_pydantic() if run_ns is not None: pydantic_run.total_duration_ns = run_ns pydantic_runs.append(pydantic_run) return pydantic_runs @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def delete_run(self, run_id: str, actor: PydanticUser) -> None: """Delete a run by its ID.""" async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) if not run: raise NoResultFound(f"Run with id {run_id} not found") await run.hard_delete_async(db_session=session, actor=actor) @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) @trace_method async def update_run_by_id_async( self, run_id: str, update: RunUpdate, actor: PydanticUser, refresh_result_messages: bool = True, conversation_id: Optional[str] = None, ) -> PydanticRun: """Update a run using a RunUpdate object.""" async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor) # Check if this is a terminal update and whether we should dispatch a callback needs_callback = False callback_url = None not_completed_before = not bool(run.completed_at) is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled} if is_terminal_update and not_completed_before and run.callback_url: needs_callback = True callback_url = run.callback_url # validate run lifecycle (only log the errors) if run.status in {RunStatus.completed}: if update.status not in {RunStatus.cancelled}: # a completed run can only be marked as cancelled logger.error( f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}" ) if update.stop_reason not in {StopReasonType.requires_approval}: # a completed run can only be cancelled if the stop reason is requires approval logger.error( f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}" ) if run.status in {RunStatus.failed, RunStatus.cancelled}: logger.error( f"Run {run_id} is already in a terminal state {run.status} with stop reason {run.stop_reason}, but is being updated with data {update.model_dump()}" ) # Housekeeping only when the run is actually completing if not_completed_before and is_terminal_update: if not update.stop_reason: logger.error(f"Run {run_id} completed without a stop reason") if not update.completed_at: logger.warning(f"Run {run_id} completed without a completed_at timestamp") update.completed_at = get_utc_time().replace(tzinfo=None) # Update run attributes with only the fields that were explicitly set update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) # Merge metadata updates instead of overwriting. # This is important for streaming/background flows where different components update # different parts of metadata (e.g., run_type set at creation, error payload set at terminal). if "metadata_" in update_data and isinstance(update_data["metadata_"], dict): existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {} update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]} # Automatically update the completion timestamp if status is set to 'completed' for key, value in update_data.items(): # Ensure completed_at is timezone-naive for database compatibility if key == "completed_at" and value is not None and hasattr(value, "replace"): value = value.replace(tzinfo=None) setattr(run, key, value) await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) final_metadata = run.metadata_ pydantic_run = run.to_pydantic() # context manager now handles commits # await session.commit() # Release conversation lock if conversation_id was provided if is_terminal_update and conversation_id: try: redis_client = await get_redis_client() await redis_client.release_conversation_lock(conversation_id) except Exception as lock_error: logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {lock_error}") # Update agent's last_stop_reason when run completes # Do this after run update is committed to database if is_terminal_update and update.stop_reason: try: from letta.schemas.agent import UpdateAgent await self.agent_manager.update_agent_async( agent_id=pydantic_run.agent_id, agent_update=UpdateAgent(last_stop_reason=update.stop_reason), actor=actor, ) except Exception as e: logger.error(f"Failed to update agent's last_stop_reason for run {run_id}: {e}") # update run metrics table num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor)) # Collect tools used from run messages tools_used = set() messages = await self.message_manager.list_messages(actor=actor, run_id=run_id) for message in messages: if message.tool_calls: for tool_call in message.tool_calls: if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"): # Get tool ID from tool name from letta.services.tool_manager import ToolManager tool_manager = ToolManager() tool_name = tool_call.function.name tool_id = await tool_manager.get_tool_id_by_name_async(tool_name, actor) if tool_id: tools_used.add(tool_id) async with db_registry.async_session() as session: metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor) # Calculate runtime if run is completing if is_terminal_update: # Use total_duration_ns from RunUpdate if provided # Otherwise fall back to system time if update.total_duration_ns is not None: metrics.run_ns = update.total_duration_ns elif metrics.run_start_ns: import time current_ns = int(time.time() * 1e9) metrics.run_ns = current_ns - metrics.run_start_ns metrics.num_steps = num_steps metrics.tools_used = list(tools_used) if tools_used else None await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) # context manager now handles commits # await session.commit() # Dispatch callback outside of database session if needed if needs_callback: if refresh_result_messages: result = LettaResponse( messages=await self.get_run_messages(run_id=run_id, actor=actor), stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason), usage=await self.get_run_usage(run_id=run_id, actor=actor), ) final_metadata["result"] = result.model_dump() callback_info = { "run_id": run_id, "callback_url": callback_url, "status": update.status, "completed_at": get_utc_time().replace(tzinfo=None), "metadata": final_metadata, } callback_result = await self._dispatch_callback_async(callback_info) # Update callback status in a separate transaction async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor) run.callback_sent_at = callback_result["callback_sent_at"] run.callback_status_code = callback_result.get("callback_status_code") run.callback_error = callback_result.get("callback_error") pydantic_run = run.to_pydantic() await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True) # context manager now handles commits # await session.commit() return pydantic_run @trace_method async def _dispatch_callback_async(self, callback_info: dict) -> dict: """ POST a standard JSON payload to callback_url and return callback status asynchronously. """ payload = { "run_id": callback_info["run_id"], "status": callback_info["status"], "completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None, "metadata": callback_info["metadata"], } callback_sent_at = get_utc_time().replace(tzinfo=None) result = {"callback_sent_at": callback_sent_at} try: async with AsyncClient() as client: log_event("POST callback dispatched", payload) resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0) log_event("POST callback finished") result["callback_status_code"] = resp.status_code except Exception as e: error_message = f"Failed to dispatch callback for run {callback_info['run_id']} to {callback_info['callback_url']}: {e!r}" logger.error(error_message) result["callback_error"] = error_message # Continue silently - callback failures should not affect run completion return result @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics: """Get usage statistics for a run.""" async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) if not run: raise NoResultFound(f"Run with id {run_id} not found") steps = await self.step_manager.list_steps_async(run_id=run_id, actor=actor) total_usage = LettaUsageStatistics() for step in steps: total_usage.prompt_tokens += step.prompt_tokens total_usage.completion_tokens += step.completion_tokens total_usage.total_tokens += step.total_tokens total_usage.step_count += 1 # Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers # Handle None defaults: only set if we have data, accumulate if already set cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details) if cached_input > 0 or total_usage.cached_input_tokens is not None: total_usage.cached_input_tokens = (total_usage.cached_input_tokens or 0) + cached_input if cache_write > 0 or total_usage.cache_write_tokens is not None: total_usage.cache_write_tokens = (total_usage.cache_write_tokens or 0) + cache_write reasoning = normalize_reasoning_tokens(step.completion_tokens_details) if reasoning > 0 or total_usage.reasoning_tokens is not None: total_usage.reasoning_tokens = (total_usage.reasoning_tokens or 0) + reasoning return total_usage @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_messages( self, run_id: str, actor: PydanticUser, limit: Optional[int] = 100, before: Optional[str] = None, after: Optional[str] = None, order: Literal["asc", "desc"] = "asc", ) -> List[LettaMessage]: """Get the result of a run.""" run = await self.get_run_by_id(run_id=run_id, actor=actor) request_config = run.request_config agent = await self.agent_manager.get_agent_by_id_async(agent_id=run.agent_id, actor=actor, include_relationships=[]) text_is_assistant_message = agent.agent_type == AgentType.letta_v1_agent messages = await self.message_manager.list_messages( actor=actor, run_id=run_id, limit=limit, before=before, after=after, ascending=(order == "asc"), ) letta_messages = PydanticMessage.to_letta_messages_from_list( messages, reverse=(order != "asc"), text_is_assistant_message=text_is_assistant_message ) if request_config and request_config.include_return_message_types: include_return_message_types_set = set(request_config.include_return_message_types) letta_messages = [msg for msg in letta_messages if msg.message_type in include_return_message_types_set] return letta_messages @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]: """Get the letta request config from a run.""" async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) if not run: raise NoResultFound(f"Run with id {run_id} not found") pydantic_run = run.to_pydantic() return pydantic_run.request_config @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics: """Get metrics for a run.""" async with db_registry.async_session() as session: metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor) return metrics.to_pydantic() @enforce_types @raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN) async def get_run_steps( self, run_id: str, actor: PydanticUser, limit: Optional[int] = 100, before: Optional[str] = None, after: Optional[str] = None, ascending: bool = False, ) -> List[PydanticStep]: """Get steps for a run.""" async with db_registry.async_session() as session: run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION) if not run: raise NoResultFound(f"Run with id {run_id} not found") steps = await self.step_manager.list_steps_async( actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc" ) return steps @enforce_types async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None: """Cancel a run.""" # make sure run_id and agent_id are not both None if not run_id: # get the last agent run if not agent_id: raise ValueError("Agent ID is required to cancel a run by ID") logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.") run_ids = await self.list_runs( actor=actor, ascending=False, agent_id=agent_id, ) run_ids = [run.id for run in run_ids] else: # get the agent run = await self.get_run_by_id(run_id=run_id, actor=actor) if not run: raise NoResultFound(f"Run with id {run_id} not found") agent_id = run.agent_id logger.debug(f"Cancelling run {run_id} for agent {agent_id}") # Cancellation should be idempotent: if a run is already terminated, treat this as a no-op. # This commonly happens when a run finishes between client request and server handling. if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]: logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}") return # Check if agent is waiting for approval by examining the last message agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request() # cancel the run # NOTE: this should update the agent's last stop reason to cancelled run = await self.update_run_by_id_async( run_id=run_id, update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled), actor=actor, conversation_id=run.conversation_id, ) # cleanup the agent's state # if was pending approval, we need to cleanup the approval state if was_pending_approval: logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}") approval_request_message = current_in_context_messages[-1] # Find ALL pending tool calls (both requiring approval and not requiring approval) # The assistant message may have tool calls that didn't require approval all_pending_tool_calls = [] if approval_request_message.tool_calls: all_pending_tool_calls.extend(approval_request_message.tool_calls) # Check if there's an assistant message before the approval request with additional tool calls if len(current_in_context_messages) >= 2: potential_assistant_msg = current_in_context_messages[-2] if potential_assistant_msg.role == MessageRole.assistant and potential_assistant_msg.tool_calls: # Add any tool calls from the assistant message that aren't already in the approval request approval_tool_call_ids = ( {tc.id for tc in approval_request_message.tool_calls} if approval_request_message.tool_calls else set() ) for tool_call in potential_assistant_msg.tool_calls: if tool_call.id not in approval_tool_call_ids: all_pending_tool_calls.append(tool_call) # Ensure we have tool calls to deny if all_pending_tool_calls: from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL from letta.schemas.letta_message import ApprovalReturn from letta.schemas.message import ApprovalCreate from letta.server.rest_api.utils import ( create_approval_response_message_from_input, create_tool_message_from_returns, create_tool_returns_for_denials, ) # Create denials for ALL pending tool calls (including those that didn't require approval) denials = ( [ ApprovalReturn( tool_call_id=tool_call.id, approve=False, reason=TOOL_CALL_DENIAL_ON_CANCEL, ) for tool_call in approval_request_message.tool_calls ] if approval_request_message.tool_calls else [] ) # Create an ApprovalCreate input with the denials approval_input = ApprovalCreate( approvals=denials, approval_request_id=approval_request_message.id, ) # Use the standard function to create properly formatted approval response messages approval_response_messages = await create_approval_response_message_from_input( agent_state=agent_state, input_message=approval_input, run_id=run_id, ) # Create tool returns for ALL denied tool calls using shared helper # This includes both tool calls requiring approval AND those that didn't tool_returns = create_tool_returns_for_denials( tool_calls=all_pending_tool_calls, denial_reason=TOOL_CALL_DENIAL_ON_CANCEL, timezone=agent_state.timezone, ) # Create tool message with all denial returns using shared helper tool_message = create_tool_message_from_returns( agent_id=agent_state.id, model=agent_state.llm_config.model, tool_returns=tool_returns, run_id=run_id, ) # Combine approval response and tool messages new_messages = approval_response_messages + [tool_message] # Checkpoint the new messages from letta.agents.agent_loop import AgentLoop agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor) new_in_context_messages = current_in_context_messages + new_messages await agent_loop._checkpoint_messages( run_id=run_id, step_id=approval_request_message.step_id, new_messages=new_messages, in_context_messages=new_in_context_messages, ) # persisted_messages = await self.message_manager.create_many_messages_async( # pydantic_msgs=new_messages, # actor=actor, # run_id=run_id, # ) # logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)") ## Update the agent's message_ids to include the new messages (approval + tool message) # agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages] # await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor) logger.debug( f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. " f"Approval request message ID: {approval_request_message.id}" ) else: logger.warning( f"Last message is an approval request but has no tool_calls. " f"Message ID: {approval_request_message.id}, Run ID: {run_id}" ) return run