Files
letta-server/letta/services/run_manager.py
Charles Packer 2fc592e0b6 feat(core): add image support in tool returns [LET-7140] (#8985)
* feat(core): add image support in tool returns [LET-7140]

Enable tool_return to support both string and ImageContent content parts,
matching the pattern used for user message inputs. This allows tools
executed client-side to return images back to the agent.

Changes:
- Add LettaToolReturnContentUnion type for text/image content parts
- Update ToolReturn schema to accept Union[str, List[content parts]]
- Update converters for each provider:
  - OpenAI Chat Completions: placeholder text for images
  - OpenAI Responses API: full image support
  - Anthropic: full image support with base64
  - Google: placeholder text for images
- Add resolve_tool_return_images() for URL-to-base64 conversion
- Make create_approval_response_message_from_input() async

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(core): support images in Google tool returns as sibling parts

Following the gemini-cli pattern: images in tool returns are sent as
sibling inlineData parts alongside the functionResponse, rather than
inside it.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* test(core): add integration tests for multi-modal tool returns [LET-7140]

Tests verify that:
- Models with image support (Anthropic, OpenAI Responses API) can see
  images in tool returns and identify the secret text
- Models without image support (Chat Completions) get placeholder text
  and cannot see the actual image content
- Tool returns with images persist correctly in the database

Uses secret.png test image containing hidden text "FIREBRAWL" that
models must identify to pass the test.

Also fixes misleading comment about Anthropic only supporting base64
images - they support URLs too, we just pre-resolve for consistency.

🐾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* refactor: simplify tool return image support implementation

Reduce code verbosity while maintaining all functionality:
- Extract _resolve_url_to_base64() helper in message_helper.py (eliminates duplication)
- Add _get_text_from_part() helper for text extraction
- Add _get_base64_image_data() helper for image data extraction
- Add _tool_return_to_google_parts() to simplify Google implementation
- Add _image_dict_to_data_url() for OpenAI Responses format
- Use walrus operator and list comprehensions where appropriate
- Add integration_test_multi_modal_tool_returns.py to CI workflow

Net change: -120 lines while preserving all features and test coverage.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(tests): improve prompt for multi-modal tool return tests

Make prompts more direct to reduce LLM flakiness:
- Simplify tool description: "Retrieves a secret image with hidden text. Call this function to get the image."
- Change user prompt from verbose request to direct command: "Call the get_secret_image function now."
- Apply to both test methods

This reduces ambiguity and makes tool calling more reliable across different LLM models.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix bugs

* test(core): add google_ai/gemini-2.0-flash-exp to multi-modal tests

Add Gemini model to test coverage for multi-modal tool returns. Google AI already supports images in tool returns via sibling inlineData parts.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

* fix(ui): handle multi-modal tool_return type in frontend components

Convert Union<string, LettaToolReturnContentUnion[]> to string for display:
- ViewRunDetails: Convert array to '[Image here]' placeholder
- ToolCallMessageComponent: Convert array to '[Image here]' placeholder

Fixes TypeScript errors in web, desktop-ui, and docker-ui type-checks.

👾 Generated with [Letta Code](https://letta.com)

Co-Authored-By: Letta <noreply@letta.com>

---------

Co-authored-by: Letta <noreply@letta.com>
Co-authored-by: Caren Thomas <carenthomas@gmail.com>
2026-01-29 12:43:53 -08:00

781 lines
37 KiB
Python

from datetime import datetime
from multiprocessing import Value
from pickletools import pyunicode
from typing import List, Literal, Optional
from httpx import AsyncClient
from letta.data_sources.redis_client import get_redis_client
from letta.errors import LettaInvalidArgumentError
from letta.helpers.datetime_helpers import get_utc_time
from letta.log import get_logger
from letta.log_context import update_log_context
from letta.orm.agent import Agent as AgentModel
from letta.orm.errors import NoResultFound
from letta.orm.message import Message as MessageModel
from letta.orm.run import Run as RunModel
from letta.orm.run_metrics import RunMetrics as RunMetricsModel
from letta.orm.sqlalchemy_base import AccessType
from letta.orm.step import Step as StepModel
from letta.otel.tracing import log_event, trace_method
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, PrimitiveType, RunStatus
from letta.schemas.job import LettaRequestConfig
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.run import Run as PydanticRun, RunUpdate
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
from letta.schemas.step import Step as PydanticStep
from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens
from letta.schemas.user import User as PydanticUser
from letta.server.db import db_registry
from letta.services.agent_manager import AgentManager
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.services.message_manager import MessageManager
from letta.services.step_manager import StepManager
from letta.utils import enforce_types
from letta.validators import raise_on_invalid_id
logger = get_logger(__name__)
class RunManager:
"""Manager class to handle business logic related to Runs."""
def __init__(self):
"""Initialize the RunManager."""
self.step_manager = StepManager()
self.message_manager = MessageManager()
self.agent_manager = AgentManager()
@enforce_types
async def create_run(self, pydantic_run: PydanticRun, actor: PydanticUser) -> PydanticRun:
"""Create a new run."""
async with db_registry.async_session() as session:
# Get agent_id from the pydantic object
agent_id = pydantic_run.agent_id
# Verify agent exists before creating the run
await validate_agent_exists_async(session, agent_id, actor)
organization_id = actor.organization_id
run_data = pydantic_run.model_dump(exclude_none=True)
# Handle metadata field mapping (Pydantic uses 'metadata', ORM uses 'metadata_')
if "metadata" in run_data:
run_data["metadata_"] = run_data.pop("metadata")
run = RunModel(**run_data)
run.organization_id = organization_id
# Get the project_id from the agent
agent = await session.get(AgentModel, agent_id)
project_id = agent.project_id if agent else None
run.project_id = project_id
run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True)
update_log_context(run_id=run.id)
# Create run metrics with start timestamp
import time
metrics = RunMetricsModel(
id=run.id,
organization_id=organization_id,
agent_id=agent_id,
project_id=project_id,
run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds
num_steps=0, # Initialize to 0
)
await metrics.create_async(session)
# context manager now handles commits
# await session.commit()
return run.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun:
"""Get a run by its ID."""
update_log_context(run_id=run_id)
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
return run.to_pydantic()
@enforce_types
async def get_run_with_status(self, run_id: str, actor: PydanticUser) -> PydanticRun:
"""Get a run by its ID and update status from Lettuce if applicable."""
update_log_context(run_id=run_id)
run = await self.get_run_by_id(run_id=run_id, actor=actor)
use_lettuce = run.metadata and run.metadata.get("lettuce")
if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]:
try:
from letta.services.lettuce import LettuceClient
lettuce_client = await LettuceClient.create()
status = await lettuce_client.get_status(run_id=run_id)
# Map the status to our enum
if status == "RUNNING":
run.status = RunStatus.running
elif status == "COMPLETED":
run.status = RunStatus.completed
elif status == "FAILED":
run.status = RunStatus.failed
elif status == "CANCELLED":
run.status = RunStatus.cancelled
except Exception as e:
logger.error(f"Failed to get status from Lettuce for run {run_id}: {str(e)}")
# Return run with current status from DB if Lettuce fails
return run
@enforce_types
async def list_runs(
self,
actor: PydanticUser,
run_id: Optional[str] = None,
agent_id: Optional[str] = None,
agent_ids: Optional[List[str]] = None,
statuses: Optional[List[RunStatus]] = None,
limit: Optional[int] = 50,
before: Optional[str] = None,
after: Optional[str] = None,
ascending: bool = False,
stop_reason: Optional[str] = None,
background: Optional[bool] = None,
template_family: Optional[str] = None,
step_count: Optional[int] = None,
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
tools_used: Optional[List[str]] = None,
project_id: Optional[str] = None,
conversation_id: Optional[str] = None,
order_by: Literal["created_at", "duration"] = "created_at",
duration_percentile: Optional[int] = None,
duration_filter: Optional[dict] = None,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
) -> List[PydanticRun]:
"""List runs with filtering options."""
async with db_registry.async_session() as session:
from sqlalchemy import func, or_, select
# Always join with run_metrics to get duration data
query = (
select(RunModel, RunMetricsModel.run_ns)
.outerjoin(RunMetricsModel, RunModel.id == RunMetricsModel.id)
.filter(RunModel.organization_id == actor.organization_id)
)
# Filter by project_id if provided
if project_id:
query = query.filter(RunModel.project_id == project_id)
if run_id:
query = query.filter(RunModel.id == run_id)
# Handle agent filtering
if agent_id:
agent_ids = [agent_id]
if agent_ids:
query = query.filter(RunModel.agent_id.in_(agent_ids))
# Filter by status
if statuses:
query = query.filter(RunModel.status.in_(statuses))
# Filter by stop reason
if stop_reason:
query = query.filter(RunModel.stop_reason == stop_reason)
# Filter by background
if background is not None:
query = query.filter(RunModel.background == background)
# Filter by conversation_id
if conversation_id is not None:
query = query.filter(RunModel.conversation_id == conversation_id)
# Filter by template_family (base_template_id)
if template_family:
query = query.filter(RunModel.base_template_id == template_family)
# Filter by date range
if start_date:
query = query.filter(RunModel.created_at >= start_date)
if end_date:
query = query.filter(RunModel.created_at <= end_date)
# Filter by step_count with the specified operator
if step_count is not None:
if step_count_operator == ComparisonOperator.EQ:
query = query.filter(RunMetricsModel.num_steps == step_count)
elif step_count_operator == ComparisonOperator.GTE:
query = query.filter(RunMetricsModel.num_steps >= step_count)
elif step_count_operator == ComparisonOperator.LTE:
query = query.filter(RunMetricsModel.num_steps <= step_count)
# Filter by tools used ids
if tools_used:
from sqlalchemy import String, cast as sa_cast, type_coerce
from sqlalchemy.dialects.postgresql import ARRAY, JSONB
# Use ?| operator to check if any tool_id exists in the array (OR logic)
jsonb_tools = sa_cast(RunMetricsModel.tools_used, JSONB)
tools_array = type_coerce(tools_used, ARRAY(String))
query = query.filter(jsonb_tools.op("?|")(tools_array))
# Ensure run_ns is not null when working with duration
if order_by == "duration" or duration_percentile is not None or duration_filter is not None:
query = query.filter(RunMetricsModel.run_ns.isnot(None))
# Apply duration filter if requested
if duration_filter is not None:
duration_value = duration_filter.get("value") if isinstance(duration_filter, dict) else duration_filter.value
duration_operator = duration_filter.get("operator") if isinstance(duration_filter, dict) else duration_filter.operator
if duration_operator == "gt":
query = query.filter(RunMetricsModel.run_ns > duration_value)
elif duration_operator == "lt":
query = query.filter(RunMetricsModel.run_ns < duration_value)
elif duration_operator == "eq":
query = query.filter(RunMetricsModel.run_ns == duration_value)
# Apply duration percentile filter if requested
if duration_percentile is not None:
# Calculate the percentile threshold
percentile_query = (
select(func.percentile_cont(duration_percentile / 100.0).within_group(RunMetricsModel.run_ns))
.select_from(RunMetricsModel)
.join(RunModel, RunModel.id == RunMetricsModel.id)
.filter(RunModel.organization_id == actor.organization_id)
.filter(RunMetricsModel.run_ns.isnot(None))
)
# Apply same filters to percentile calculation
if project_id:
percentile_query = percentile_query.filter(RunModel.project_id == project_id)
if agent_ids:
percentile_query = percentile_query.filter(RunModel.agent_id.in_(agent_ids))
if statuses:
percentile_query = percentile_query.filter(RunModel.status.in_(statuses))
# Execute percentile query
percentile_result = await session.execute(percentile_query)
percentile_threshold = percentile_result.scalar()
# Filter by percentile threshold (runs slower than the percentile)
if percentile_threshold is not None:
query = query.filter(RunMetricsModel.run_ns >= percentile_threshold)
# Apply sorting based on order_by
if order_by == "duration":
# Sort by duration
if ascending:
query = query.order_by(RunMetricsModel.run_ns.asc())
else:
query = query.order_by(RunMetricsModel.run_ns.desc())
else:
# Apply pagination for created_at ordering
from letta.services.helpers.run_manager_helper import _apply_pagination_async
query = await _apply_pagination_async(query, before, after, session, ascending=ascending)
# Apply limit (always enforce a maximum to prevent unbounded queries)
# If no limit specified, default to 100; enforce maximum of 1000
effective_limit = limit if limit is not None else 100
effective_limit = min(effective_limit, 1000)
query = query.limit(effective_limit)
result = await session.execute(query)
rows = result.all()
# Populate total_duration_ns from run_metrics.run_ns
pydantic_runs = []
for row in rows:
run_model = row[0]
run_ns = row[1]
pydantic_run = run_model.to_pydantic()
if run_ns is not None:
pydantic_run.total_duration_ns = run_ns
pydantic_runs.append(pydantic_run)
return pydantic_runs
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def delete_run(self, run_id: str, actor: PydanticUser) -> None:
"""Delete a run by its ID."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
await run.hard_delete_async(db_session=session, actor=actor)
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
@trace_method
async def update_run_by_id_async(
self,
run_id: str,
update: RunUpdate,
actor: PydanticUser,
refresh_result_messages: bool = True,
conversation_id: Optional[str] = None,
) -> PydanticRun:
"""Update a run using a RunUpdate object."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
# Check if this is a terminal update and whether we should dispatch a callback
needs_callback = False
callback_url = None
not_completed_before = not bool(run.completed_at)
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled}
if is_terminal_update and not_completed_before and run.callback_url:
needs_callback = True
callback_url = run.callback_url
# validate run lifecycle (only log the errors)
if run.status in {RunStatus.completed}:
if update.status not in {RunStatus.cancelled}:
# a completed run can only be marked as cancelled
logger.error(
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
)
if update.stop_reason not in {StopReasonType.requires_approval}:
# a completed run can only be cancelled if the stop reason is requires approval
logger.error(
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
)
if run.status in {RunStatus.failed, RunStatus.cancelled}:
logger.error(
f"Run {run_id} is already in a terminal state {run.status} with stop reason {run.stop_reason}, but is being updated with data {update.model_dump()}"
)
# Housekeeping only when the run is actually completing
if not_completed_before and is_terminal_update:
if not update.stop_reason:
logger.error(f"Run {run_id} completed without a stop reason")
if not update.completed_at:
logger.warning(f"Run {run_id} completed without a completed_at timestamp")
update.completed_at = get_utc_time().replace(tzinfo=None)
# Update run attributes with only the fields that were explicitly set
update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Merge metadata updates instead of overwriting.
# This is important for streaming/background flows where different components update
# different parts of metadata (e.g., run_type set at creation, error payload set at terminal).
if "metadata_" in update_data and isinstance(update_data["metadata_"], dict):
existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {}
update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]}
# Automatically update the completion timestamp if status is set to 'completed'
for key, value in update_data.items():
# Ensure completed_at is timezone-naive for database compatibility
if key == "completed_at" and value is not None and hasattr(value, "replace"):
value = value.replace(tzinfo=None)
setattr(run, key, value)
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
final_metadata = run.metadata_
pydantic_run = run.to_pydantic()
# context manager now handles commits
# await session.commit()
# Release conversation lock if conversation_id was provided
if is_terminal_update and conversation_id:
try:
redis_client = await get_redis_client()
await redis_client.release_conversation_lock(conversation_id)
except Exception as lock_error:
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {lock_error}")
# Update agent's last_stop_reason when run completes
# Do this after run update is committed to database
if is_terminal_update and update.stop_reason:
try:
from letta.schemas.agent import UpdateAgent
await self.agent_manager.update_agent_async(
agent_id=pydantic_run.agent_id,
agent_update=UpdateAgent(last_stop_reason=update.stop_reason),
actor=actor,
)
except Exception as e:
logger.error(f"Failed to update agent's last_stop_reason for run {run_id}: {e}")
# update run metrics table
num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor))
# Collect tools used from run messages
tools_used = set()
messages = await self.message_manager.list_messages(actor=actor, run_id=run_id)
for message in messages:
if message.tool_calls:
for tool_call in message.tool_calls:
if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"):
# Get tool ID from tool name
from letta.services.tool_manager import ToolManager
tool_manager = ToolManager()
tool_name = tool_call.function.name
tool_id = await tool_manager.get_tool_id_by_name_async(tool_name, actor)
if tool_id:
tools_used.add(tool_id)
async with db_registry.async_session() as session:
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
# Calculate runtime if run is completing
if is_terminal_update:
# Use total_duration_ns from RunUpdate if provided
# Otherwise fall back to system time
if update.total_duration_ns is not None:
metrics.run_ns = update.total_duration_ns
elif metrics.run_start_ns:
import time
current_ns = int(time.time() * 1e9)
metrics.run_ns = current_ns - metrics.run_start_ns
metrics.num_steps = num_steps
metrics.tools_used = list(tools_used) if tools_used else None
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
# context manager now handles commits
# await session.commit()
# Dispatch callback outside of database session if needed
if needs_callback:
if refresh_result_messages:
result = LettaResponse(
messages=await self.get_run_messages(run_id=run_id, actor=actor),
stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason),
usage=await self.get_run_usage(run_id=run_id, actor=actor),
)
final_metadata["result"] = result.model_dump()
callback_info = {
"run_id": run_id,
"callback_url": callback_url,
"status": update.status,
"completed_at": get_utc_time().replace(tzinfo=None),
"metadata": final_metadata,
}
callback_result = await self._dispatch_callback_async(callback_info)
# Update callback status in a separate transaction
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
run.callback_sent_at = callback_result["callback_sent_at"]
run.callback_status_code = callback_result.get("callback_status_code")
run.callback_error = callback_result.get("callback_error")
pydantic_run = run.to_pydantic()
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
# context manager now handles commits
# await session.commit()
return pydantic_run
@trace_method
async def _dispatch_callback_async(self, callback_info: dict) -> dict:
"""
POST a standard JSON payload to callback_url and return callback status asynchronously.
"""
payload = {
"run_id": callback_info["run_id"],
"status": callback_info["status"],
"completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None,
"metadata": callback_info["metadata"],
}
callback_sent_at = get_utc_time().replace(tzinfo=None)
result = {"callback_sent_at": callback_sent_at}
try:
async with AsyncClient() as client:
log_event("POST callback dispatched", payload)
resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0)
log_event("POST callback finished")
result["callback_status_code"] = resp.status_code
except Exception as e:
error_message = f"Failed to dispatch callback for run {callback_info['run_id']} to {callback_info['callback_url']}: {e!r}"
logger.error(error_message)
result["callback_error"] = error_message
# Continue silently - callback failures should not affect run completion
return result
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics:
"""Get usage statistics for a run."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
steps = await self.step_manager.list_steps_async(run_id=run_id, actor=actor)
total_usage = LettaUsageStatistics()
for step in steps:
total_usage.prompt_tokens += step.prompt_tokens
total_usage.completion_tokens += step.completion_tokens
total_usage.total_tokens += step.total_tokens
total_usage.step_count += 1
# Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers
# Handle None defaults: only set if we have data, accumulate if already set
cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details)
if cached_input > 0 or total_usage.cached_input_tokens is not None:
total_usage.cached_input_tokens = (total_usage.cached_input_tokens or 0) + cached_input
if cache_write > 0 or total_usage.cache_write_tokens is not None:
total_usage.cache_write_tokens = (total_usage.cache_write_tokens or 0) + cache_write
reasoning = normalize_reasoning_tokens(step.completion_tokens_details)
if reasoning > 0 or total_usage.reasoning_tokens is not None:
total_usage.reasoning_tokens = (total_usage.reasoning_tokens or 0) + reasoning
return total_usage
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_messages(
self,
run_id: str,
actor: PydanticUser,
limit: Optional[int] = 100,
before: Optional[str] = None,
after: Optional[str] = None,
order: Literal["asc", "desc"] = "asc",
) -> List[LettaMessage]:
"""Get the result of a run."""
run = await self.get_run_by_id(run_id=run_id, actor=actor)
request_config = run.request_config
agent = await self.agent_manager.get_agent_by_id_async(agent_id=run.agent_id, actor=actor, include_relationships=[])
text_is_assistant_message = agent.agent_type == AgentType.letta_v1_agent
messages = await self.message_manager.list_messages(
actor=actor,
run_id=run_id,
limit=limit,
before=before,
after=after,
ascending=(order == "asc"),
)
letta_messages = PydanticMessage.to_letta_messages_from_list(
messages, reverse=(order != "asc"), text_is_assistant_message=text_is_assistant_message
)
if request_config and request_config.include_return_message_types:
include_return_message_types_set = set(request_config.include_return_message_types)
letta_messages = [msg for msg in letta_messages if msg.message_type in include_return_message_types_set]
return letta_messages
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]:
"""Get the letta request config from a run."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
pydantic_run = run.to_pydantic()
return pydantic_run.request_config
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
"""Get metrics for a run."""
async with db_registry.async_session() as session:
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
return metrics.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
async def get_run_steps(
self,
run_id: str,
actor: PydanticUser,
limit: Optional[int] = 100,
before: Optional[str] = None,
after: Optional[str] = None,
ascending: bool = False,
) -> List[PydanticStep]:
"""Get steps for a run."""
async with db_registry.async_session() as session:
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
steps = await self.step_manager.list_steps_async(
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
)
return steps
@enforce_types
async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None:
"""Cancel a run."""
# make sure run_id and agent_id are not both None
if not run_id:
# get the last agent run
if not agent_id:
raise ValueError("Agent ID is required to cancel a run by ID")
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
run_ids = await self.list_runs(
actor=actor,
ascending=False,
agent_id=agent_id,
)
run_ids = [run.id for run in run_ids]
else:
# get the agent
run = await self.get_run_by_id(run_id=run_id, actor=actor)
if not run:
raise NoResultFound(f"Run with id {run_id} not found")
agent_id = run.agent_id
logger.debug(f"Cancelling run {run_id} for agent {agent_id}")
# Cancellation should be idempotent: if a run is already terminated, treat this as a no-op.
# This commonly happens when a run finishes between client request and server handling.
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
return
# Check if agent is waiting for approval by examining the last message
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request()
# cancel the run
# NOTE: this should update the agent's last stop reason to cancelled
run = await self.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled),
actor=actor,
conversation_id=run.conversation_id,
)
# cleanup the agent's state
# if was pending approval, we need to cleanup the approval state
if was_pending_approval:
logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}")
approval_request_message = current_in_context_messages[-1]
# Find ALL pending tool calls (both requiring approval and not requiring approval)
# The assistant message may have tool calls that didn't require approval
all_pending_tool_calls = []
if approval_request_message.tool_calls:
all_pending_tool_calls.extend(approval_request_message.tool_calls)
# Check if there's an assistant message before the approval request with additional tool calls
if len(current_in_context_messages) >= 2:
potential_assistant_msg = current_in_context_messages[-2]
if potential_assistant_msg.role == MessageRole.assistant and potential_assistant_msg.tool_calls:
# Add any tool calls from the assistant message that aren't already in the approval request
approval_tool_call_ids = (
{tc.id for tc in approval_request_message.tool_calls} if approval_request_message.tool_calls else set()
)
for tool_call in potential_assistant_msg.tool_calls:
if tool_call.id not in approval_tool_call_ids:
all_pending_tool_calls.append(tool_call)
# Ensure we have tool calls to deny
if all_pending_tool_calls:
from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL
from letta.schemas.letta_message import ApprovalReturn
from letta.schemas.message import ApprovalCreate
from letta.server.rest_api.utils import (
create_approval_response_message_from_input,
create_tool_message_from_returns,
create_tool_returns_for_denials,
)
# Create denials for ALL pending tool calls (including those that didn't require approval)
denials = (
[
ApprovalReturn(
tool_call_id=tool_call.id,
approve=False,
reason=TOOL_CALL_DENIAL_ON_CANCEL,
)
for tool_call in approval_request_message.tool_calls
]
if approval_request_message.tool_calls
else []
)
# Create an ApprovalCreate input with the denials
approval_input = ApprovalCreate(
approvals=denials,
approval_request_id=approval_request_message.id,
)
# Use the standard function to create properly formatted approval response messages
approval_response_messages = await create_approval_response_message_from_input(
agent_state=agent_state,
input_message=approval_input,
run_id=run_id,
)
# Create tool returns for ALL denied tool calls using shared helper
# This includes both tool calls requiring approval AND those that didn't
tool_returns = create_tool_returns_for_denials(
tool_calls=all_pending_tool_calls,
denial_reason=TOOL_CALL_DENIAL_ON_CANCEL,
timezone=agent_state.timezone,
)
# Create tool message with all denial returns using shared helper
tool_message = create_tool_message_from_returns(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
tool_returns=tool_returns,
run_id=run_id,
)
# Combine approval response and tool messages
new_messages = approval_response_messages + [tool_message]
# Checkpoint the new messages
from letta.agents.agent_loop import AgentLoop
agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor)
new_in_context_messages = current_in_context_messages + new_messages
await agent_loop._checkpoint_messages(
run_id=run_id,
step_id=approval_request_message.step_id,
new_messages=new_messages,
in_context_messages=new_in_context_messages,
)
# persisted_messages = await self.message_manager.create_many_messages_async(
# pydantic_msgs=new_messages,
# actor=actor,
# run_id=run_id,
# )
# logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
## Update the agent's message_ids to include the new messages (approval + tool message)
# agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
# await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
logger.debug(
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
f"Approval request message ID: {approval_request_message.id}"
)
else:
logger.warning(
f"Last message is an approval request but has no tool_calls. "
f"Message ID: {approval_request_message.id}, Run ID: {run_id}"
)
return run