* feat(core): add image support in tool returns [LET-7140] Enable tool_return to support both string and ImageContent content parts, matching the pattern used for user message inputs. This allows tools executed client-side to return images back to the agent. Changes: - Add LettaToolReturnContentUnion type for text/image content parts - Update ToolReturn schema to accept Union[str, List[content parts]] - Update converters for each provider: - OpenAI Chat Completions: placeholder text for images - OpenAI Responses API: full image support - Anthropic: full image support with base64 - Google: placeholder text for images - Add resolve_tool_return_images() for URL-to-base64 conversion - Make create_approval_response_message_from_input() async 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(core): support images in Google tool returns as sibling parts Following the gemini-cli pattern: images in tool returns are sent as sibling inlineData parts alongside the functionResponse, rather than inside it. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * test(core): add integration tests for multi-modal tool returns [LET-7140] Tests verify that: - Models with image support (Anthropic, OpenAI Responses API) can see images in tool returns and identify the secret text - Models without image support (Chat Completions) get placeholder text and cannot see the actual image content - Tool returns with images persist correctly in the database Uses secret.png test image containing hidden text "FIREBRAWL" that models must identify to pass the test. Also fixes misleading comment about Anthropic only supporting base64 images - they support URLs too, we just pre-resolve for consistency. 🐾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * refactor: simplify tool return image support implementation Reduce code verbosity while maintaining all functionality: - Extract _resolve_url_to_base64() helper in message_helper.py (eliminates duplication) - Add _get_text_from_part() helper for text extraction - Add _get_base64_image_data() helper for image data extraction - Add _tool_return_to_google_parts() to simplify Google implementation - Add _image_dict_to_data_url() for OpenAI Responses format - Use walrus operator and list comprehensions where appropriate - Add integration_test_multi_modal_tool_returns.py to CI workflow Net change: -120 lines while preserving all features and test coverage. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(tests): improve prompt for multi-modal tool return tests Make prompts more direct to reduce LLM flakiness: - Simplify tool description: "Retrieves a secret image with hidden text. Call this function to get the image." - Change user prompt from verbose request to direct command: "Call the get_secret_image function now." - Apply to both test methods This reduces ambiguity and makes tool calling more reliable across different LLM models. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix bugs * test(core): add google_ai/gemini-2.0-flash-exp to multi-modal tests Add Gemini model to test coverage for multi-modal tool returns. Google AI already supports images in tool returns via sibling inlineData parts. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> * fix(ui): handle multi-modal tool_return type in frontend components Convert Union<string, LettaToolReturnContentUnion[]> to string for display: - ViewRunDetails: Convert array to '[Image here]' placeholder - ToolCallMessageComponent: Convert array to '[Image here]' placeholder Fixes TypeScript errors in web, desktop-ui, and docker-ui type-checks. 👾 Generated with [Letta Code](https://letta.com) Co-Authored-By: Letta <noreply@letta.com> --------- Co-authored-by: Letta <noreply@letta.com> Co-authored-by: Caren Thomas <carenthomas@gmail.com>
781 lines
37 KiB
Python
781 lines
37 KiB
Python
from datetime import datetime
|
|
from multiprocessing import Value
|
|
from pickletools import pyunicode
|
|
from typing import List, Literal, Optional
|
|
|
|
from httpx import AsyncClient
|
|
|
|
from letta.data_sources.redis_client import get_redis_client
|
|
from letta.errors import LettaInvalidArgumentError
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.log import get_logger
|
|
from letta.log_context import update_log_context
|
|
from letta.orm.agent import Agent as AgentModel
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.message import Message as MessageModel
|
|
from letta.orm.run import Run as RunModel
|
|
from letta.orm.run_metrics import RunMetrics as RunMetricsModel
|
|
from letta.orm.sqlalchemy_base import AccessType
|
|
from letta.orm.step import Step as StepModel
|
|
from letta.otel.tracing import log_event, trace_method
|
|
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, PrimitiveType, RunStatus
|
|
from letta.schemas.job import LettaRequestConfig
|
|
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
|
from letta.schemas.letta_response import LettaResponse
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
|
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
|
|
from letta.schemas.step import Step as PydanticStep
|
|
from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.step_manager import StepManager
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RunManager:
|
|
"""Manager class to handle business logic related to Runs."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the RunManager."""
|
|
self.step_manager = StepManager()
|
|
self.message_manager = MessageManager()
|
|
self.agent_manager = AgentManager()
|
|
|
|
@enforce_types
|
|
async def create_run(self, pydantic_run: PydanticRun, actor: PydanticUser) -> PydanticRun:
|
|
"""Create a new run."""
|
|
async with db_registry.async_session() as session:
|
|
# Get agent_id from the pydantic object
|
|
agent_id = pydantic_run.agent_id
|
|
|
|
# Verify agent exists before creating the run
|
|
await validate_agent_exists_async(session, agent_id, actor)
|
|
organization_id = actor.organization_id
|
|
|
|
run_data = pydantic_run.model_dump(exclude_none=True)
|
|
# Handle metadata field mapping (Pydantic uses 'metadata', ORM uses 'metadata_')
|
|
if "metadata" in run_data:
|
|
run_data["metadata_"] = run_data.pop("metadata")
|
|
|
|
run = RunModel(**run_data)
|
|
run.organization_id = organization_id
|
|
|
|
# Get the project_id from the agent
|
|
agent = await session.get(AgentModel, agent_id)
|
|
project_id = agent.project_id if agent else None
|
|
run.project_id = project_id
|
|
|
|
run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
|
|
|
update_log_context(run_id=run.id)
|
|
|
|
# Create run metrics with start timestamp
|
|
import time
|
|
|
|
metrics = RunMetricsModel(
|
|
id=run.id,
|
|
organization_id=organization_id,
|
|
agent_id=agent_id,
|
|
project_id=project_id,
|
|
run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds
|
|
num_steps=0, # Initialize to 0
|
|
)
|
|
await metrics.create_async(session)
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
return run.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
|
"""Get a run by its ID."""
|
|
update_log_context(run_id=run_id)
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
return run.to_pydantic()
|
|
|
|
@enforce_types
|
|
async def get_run_with_status(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
|
"""Get a run by its ID and update status from Lettuce if applicable."""
|
|
update_log_context(run_id=run_id)
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
|
|
use_lettuce = run.metadata and run.metadata.get("lettuce")
|
|
if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]:
|
|
try:
|
|
from letta.services.lettuce import LettuceClient
|
|
|
|
lettuce_client = await LettuceClient.create()
|
|
status = await lettuce_client.get_status(run_id=run_id)
|
|
|
|
# Map the status to our enum
|
|
if status == "RUNNING":
|
|
run.status = RunStatus.running
|
|
elif status == "COMPLETED":
|
|
run.status = RunStatus.completed
|
|
elif status == "FAILED":
|
|
run.status = RunStatus.failed
|
|
elif status == "CANCELLED":
|
|
run.status = RunStatus.cancelled
|
|
except Exception as e:
|
|
logger.error(f"Failed to get status from Lettuce for run {run_id}: {str(e)}")
|
|
# Return run with current status from DB if Lettuce fails
|
|
|
|
return run
|
|
|
|
@enforce_types
|
|
async def list_runs(
|
|
self,
|
|
actor: PydanticUser,
|
|
run_id: Optional[str] = None,
|
|
agent_id: Optional[str] = None,
|
|
agent_ids: Optional[List[str]] = None,
|
|
statuses: Optional[List[RunStatus]] = None,
|
|
limit: Optional[int] = 50,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
ascending: bool = False,
|
|
stop_reason: Optional[str] = None,
|
|
background: Optional[bool] = None,
|
|
template_family: Optional[str] = None,
|
|
step_count: Optional[int] = None,
|
|
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
|
|
tools_used: Optional[List[str]] = None,
|
|
project_id: Optional[str] = None,
|
|
conversation_id: Optional[str] = None,
|
|
order_by: Literal["created_at", "duration"] = "created_at",
|
|
duration_percentile: Optional[int] = None,
|
|
duration_filter: Optional[dict] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[PydanticRun]:
|
|
"""List runs with filtering options."""
|
|
async with db_registry.async_session() as session:
|
|
from sqlalchemy import func, or_, select
|
|
|
|
# Always join with run_metrics to get duration data
|
|
query = (
|
|
select(RunModel, RunMetricsModel.run_ns)
|
|
.outerjoin(RunMetricsModel, RunModel.id == RunMetricsModel.id)
|
|
.filter(RunModel.organization_id == actor.organization_id)
|
|
)
|
|
|
|
# Filter by project_id if provided
|
|
if project_id:
|
|
query = query.filter(RunModel.project_id == project_id)
|
|
|
|
if run_id:
|
|
query = query.filter(RunModel.id == run_id)
|
|
|
|
# Handle agent filtering
|
|
if agent_id:
|
|
agent_ids = [agent_id]
|
|
if agent_ids:
|
|
query = query.filter(RunModel.agent_id.in_(agent_ids))
|
|
|
|
# Filter by status
|
|
if statuses:
|
|
query = query.filter(RunModel.status.in_(statuses))
|
|
|
|
# Filter by stop reason
|
|
if stop_reason:
|
|
query = query.filter(RunModel.stop_reason == stop_reason)
|
|
|
|
# Filter by background
|
|
if background is not None:
|
|
query = query.filter(RunModel.background == background)
|
|
|
|
# Filter by conversation_id
|
|
if conversation_id is not None:
|
|
query = query.filter(RunModel.conversation_id == conversation_id)
|
|
|
|
# Filter by template_family (base_template_id)
|
|
if template_family:
|
|
query = query.filter(RunModel.base_template_id == template_family)
|
|
|
|
# Filter by date range
|
|
if start_date:
|
|
query = query.filter(RunModel.created_at >= start_date)
|
|
if end_date:
|
|
query = query.filter(RunModel.created_at <= end_date)
|
|
|
|
# Filter by step_count with the specified operator
|
|
if step_count is not None:
|
|
if step_count_operator == ComparisonOperator.EQ:
|
|
query = query.filter(RunMetricsModel.num_steps == step_count)
|
|
elif step_count_operator == ComparisonOperator.GTE:
|
|
query = query.filter(RunMetricsModel.num_steps >= step_count)
|
|
elif step_count_operator == ComparisonOperator.LTE:
|
|
query = query.filter(RunMetricsModel.num_steps <= step_count)
|
|
|
|
# Filter by tools used ids
|
|
if tools_used:
|
|
from sqlalchemy import String, cast as sa_cast, type_coerce
|
|
from sqlalchemy.dialects.postgresql import ARRAY, JSONB
|
|
|
|
# Use ?| operator to check if any tool_id exists in the array (OR logic)
|
|
jsonb_tools = sa_cast(RunMetricsModel.tools_used, JSONB)
|
|
tools_array = type_coerce(tools_used, ARRAY(String))
|
|
query = query.filter(jsonb_tools.op("?|")(tools_array))
|
|
|
|
# Ensure run_ns is not null when working with duration
|
|
if order_by == "duration" or duration_percentile is not None or duration_filter is not None:
|
|
query = query.filter(RunMetricsModel.run_ns.isnot(None))
|
|
|
|
# Apply duration filter if requested
|
|
if duration_filter is not None:
|
|
duration_value = duration_filter.get("value") if isinstance(duration_filter, dict) else duration_filter.value
|
|
duration_operator = duration_filter.get("operator") if isinstance(duration_filter, dict) else duration_filter.operator
|
|
|
|
if duration_operator == "gt":
|
|
query = query.filter(RunMetricsModel.run_ns > duration_value)
|
|
elif duration_operator == "lt":
|
|
query = query.filter(RunMetricsModel.run_ns < duration_value)
|
|
elif duration_operator == "eq":
|
|
query = query.filter(RunMetricsModel.run_ns == duration_value)
|
|
|
|
# Apply duration percentile filter if requested
|
|
if duration_percentile is not None:
|
|
# Calculate the percentile threshold
|
|
percentile_query = (
|
|
select(func.percentile_cont(duration_percentile / 100.0).within_group(RunMetricsModel.run_ns))
|
|
.select_from(RunMetricsModel)
|
|
.join(RunModel, RunModel.id == RunMetricsModel.id)
|
|
.filter(RunModel.organization_id == actor.organization_id)
|
|
.filter(RunMetricsModel.run_ns.isnot(None))
|
|
)
|
|
|
|
# Apply same filters to percentile calculation
|
|
if project_id:
|
|
percentile_query = percentile_query.filter(RunModel.project_id == project_id)
|
|
if agent_ids:
|
|
percentile_query = percentile_query.filter(RunModel.agent_id.in_(agent_ids))
|
|
if statuses:
|
|
percentile_query = percentile_query.filter(RunModel.status.in_(statuses))
|
|
|
|
# Execute percentile query
|
|
percentile_result = await session.execute(percentile_query)
|
|
percentile_threshold = percentile_result.scalar()
|
|
|
|
# Filter by percentile threshold (runs slower than the percentile)
|
|
if percentile_threshold is not None:
|
|
query = query.filter(RunMetricsModel.run_ns >= percentile_threshold)
|
|
|
|
# Apply sorting based on order_by
|
|
if order_by == "duration":
|
|
# Sort by duration
|
|
if ascending:
|
|
query = query.order_by(RunMetricsModel.run_ns.asc())
|
|
else:
|
|
query = query.order_by(RunMetricsModel.run_ns.desc())
|
|
else:
|
|
# Apply pagination for created_at ordering
|
|
from letta.services.helpers.run_manager_helper import _apply_pagination_async
|
|
|
|
query = await _apply_pagination_async(query, before, after, session, ascending=ascending)
|
|
|
|
# Apply limit (always enforce a maximum to prevent unbounded queries)
|
|
# If no limit specified, default to 100; enforce maximum of 1000
|
|
effective_limit = limit if limit is not None else 100
|
|
effective_limit = min(effective_limit, 1000)
|
|
query = query.limit(effective_limit)
|
|
|
|
result = await session.execute(query)
|
|
rows = result.all()
|
|
|
|
# Populate total_duration_ns from run_metrics.run_ns
|
|
pydantic_runs = []
|
|
for row in rows:
|
|
run_model = row[0]
|
|
run_ns = row[1]
|
|
|
|
pydantic_run = run_model.to_pydantic()
|
|
if run_ns is not None:
|
|
pydantic_run.total_duration_ns = run_ns
|
|
|
|
pydantic_runs.append(pydantic_run)
|
|
|
|
return pydantic_runs
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def delete_run(self, run_id: str, actor: PydanticUser) -> None:
|
|
"""Delete a run by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
await run.hard_delete_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
@trace_method
|
|
async def update_run_by_id_async(
|
|
self,
|
|
run_id: str,
|
|
update: RunUpdate,
|
|
actor: PydanticUser,
|
|
refresh_result_messages: bool = True,
|
|
conversation_id: Optional[str] = None,
|
|
) -> PydanticRun:
|
|
"""Update a run using a RunUpdate object."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
|
|
# Check if this is a terminal update and whether we should dispatch a callback
|
|
needs_callback = False
|
|
callback_url = None
|
|
not_completed_before = not bool(run.completed_at)
|
|
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled}
|
|
if is_terminal_update and not_completed_before and run.callback_url:
|
|
needs_callback = True
|
|
callback_url = run.callback_url
|
|
|
|
# validate run lifecycle (only log the errors)
|
|
if run.status in {RunStatus.completed}:
|
|
if update.status not in {RunStatus.cancelled}:
|
|
# a completed run can only be marked as cancelled
|
|
logger.error(
|
|
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
|
|
)
|
|
if update.stop_reason not in {StopReasonType.requires_approval}:
|
|
# a completed run can only be cancelled if the stop reason is requires approval
|
|
logger.error(
|
|
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
|
|
)
|
|
if run.status in {RunStatus.failed, RunStatus.cancelled}:
|
|
logger.error(
|
|
f"Run {run_id} is already in a terminal state {run.status} with stop reason {run.stop_reason}, but is being updated with data {update.model_dump()}"
|
|
)
|
|
|
|
# Housekeeping only when the run is actually completing
|
|
if not_completed_before and is_terminal_update:
|
|
if not update.stop_reason:
|
|
logger.error(f"Run {run_id} completed without a stop reason")
|
|
if not update.completed_at:
|
|
logger.warning(f"Run {run_id} completed without a completed_at timestamp")
|
|
update.completed_at = get_utc_time().replace(tzinfo=None)
|
|
|
|
# Update run attributes with only the fields that were explicitly set
|
|
update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
|
|
# Merge metadata updates instead of overwriting.
|
|
# This is important for streaming/background flows where different components update
|
|
# different parts of metadata (e.g., run_type set at creation, error payload set at terminal).
|
|
if "metadata_" in update_data and isinstance(update_data["metadata_"], dict):
|
|
existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {}
|
|
update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]}
|
|
|
|
# Automatically update the completion timestamp if status is set to 'completed'
|
|
for key, value in update_data.items():
|
|
# Ensure completed_at is timezone-naive for database compatibility
|
|
if key == "completed_at" and value is not None and hasattr(value, "replace"):
|
|
value = value.replace(tzinfo=None)
|
|
setattr(run, key, value)
|
|
|
|
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
final_metadata = run.metadata_
|
|
pydantic_run = run.to_pydantic()
|
|
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
# Release conversation lock if conversation_id was provided
|
|
if is_terminal_update and conversation_id:
|
|
try:
|
|
redis_client = await get_redis_client()
|
|
await redis_client.release_conversation_lock(conversation_id)
|
|
except Exception as lock_error:
|
|
logger.warning(f"Failed to release conversation lock for conversation {conversation_id}: {lock_error}")
|
|
|
|
# Update agent's last_stop_reason when run completes
|
|
# Do this after run update is committed to database
|
|
if is_terminal_update and update.stop_reason:
|
|
try:
|
|
from letta.schemas.agent import UpdateAgent
|
|
|
|
await self.agent_manager.update_agent_async(
|
|
agent_id=pydantic_run.agent_id,
|
|
agent_update=UpdateAgent(last_stop_reason=update.stop_reason),
|
|
actor=actor,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to update agent's last_stop_reason for run {run_id}: {e}")
|
|
|
|
# update run metrics table
|
|
num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor))
|
|
|
|
# Collect tools used from run messages
|
|
tools_used = set()
|
|
messages = await self.message_manager.list_messages(actor=actor, run_id=run_id)
|
|
for message in messages:
|
|
if message.tool_calls:
|
|
for tool_call in message.tool_calls:
|
|
if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"):
|
|
# Get tool ID from tool name
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
tool_manager = ToolManager()
|
|
tool_name = tool_call.function.name
|
|
tool_id = await tool_manager.get_tool_id_by_name_async(tool_name, actor)
|
|
if tool_id:
|
|
tools_used.add(tool_id)
|
|
|
|
async with db_registry.async_session() as session:
|
|
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
# Calculate runtime if run is completing
|
|
if is_terminal_update:
|
|
# Use total_duration_ns from RunUpdate if provided
|
|
# Otherwise fall back to system time
|
|
if update.total_duration_ns is not None:
|
|
metrics.run_ns = update.total_duration_ns
|
|
elif metrics.run_start_ns:
|
|
import time
|
|
|
|
current_ns = int(time.time() * 1e9)
|
|
metrics.run_ns = current_ns - metrics.run_start_ns
|
|
metrics.num_steps = num_steps
|
|
metrics.tools_used = list(tools_used) if tools_used else None
|
|
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
# Dispatch callback outside of database session if needed
|
|
if needs_callback:
|
|
if refresh_result_messages:
|
|
result = LettaResponse(
|
|
messages=await self.get_run_messages(run_id=run_id, actor=actor),
|
|
stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason),
|
|
usage=await self.get_run_usage(run_id=run_id, actor=actor),
|
|
)
|
|
final_metadata["result"] = result.model_dump()
|
|
callback_info = {
|
|
"run_id": run_id,
|
|
"callback_url": callback_url,
|
|
"status": update.status,
|
|
"completed_at": get_utc_time().replace(tzinfo=None),
|
|
"metadata": final_metadata,
|
|
}
|
|
callback_result = await self._dispatch_callback_async(callback_info)
|
|
|
|
# Update callback status in a separate transaction
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
run.callback_sent_at = callback_result["callback_sent_at"]
|
|
run.callback_status_code = callback_result.get("callback_status_code")
|
|
run.callback_error = callback_result.get("callback_error")
|
|
pydantic_run = run.to_pydantic()
|
|
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
# context manager now handles commits
|
|
# await session.commit()
|
|
|
|
return pydantic_run
|
|
|
|
@trace_method
|
|
async def _dispatch_callback_async(self, callback_info: dict) -> dict:
|
|
"""
|
|
POST a standard JSON payload to callback_url and return callback status asynchronously.
|
|
"""
|
|
payload = {
|
|
"run_id": callback_info["run_id"],
|
|
"status": callback_info["status"],
|
|
"completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None,
|
|
"metadata": callback_info["metadata"],
|
|
}
|
|
|
|
callback_sent_at = get_utc_time().replace(tzinfo=None)
|
|
result = {"callback_sent_at": callback_sent_at}
|
|
|
|
try:
|
|
async with AsyncClient() as client:
|
|
log_event("POST callback dispatched", payload)
|
|
resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0)
|
|
log_event("POST callback finished")
|
|
result["callback_status_code"] = resp.status_code
|
|
except Exception as e:
|
|
error_message = f"Failed to dispatch callback for run {callback_info['run_id']} to {callback_info['callback_url']}: {e!r}"
|
|
logger.error(error_message)
|
|
result["callback_error"] = error_message
|
|
# Continue silently - callback failures should not affect run completion
|
|
|
|
return result
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics:
|
|
"""Get usage statistics for a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
steps = await self.step_manager.list_steps_async(run_id=run_id, actor=actor)
|
|
total_usage = LettaUsageStatistics()
|
|
for step in steps:
|
|
total_usage.prompt_tokens += step.prompt_tokens
|
|
total_usage.completion_tokens += step.completion_tokens
|
|
total_usage.total_tokens += step.total_tokens
|
|
total_usage.step_count += 1
|
|
|
|
# Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers
|
|
# Handle None defaults: only set if we have data, accumulate if already set
|
|
cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details)
|
|
if cached_input > 0 or total_usage.cached_input_tokens is not None:
|
|
total_usage.cached_input_tokens = (total_usage.cached_input_tokens or 0) + cached_input
|
|
if cache_write > 0 or total_usage.cache_write_tokens is not None:
|
|
total_usage.cache_write_tokens = (total_usage.cache_write_tokens or 0) + cache_write
|
|
reasoning = normalize_reasoning_tokens(step.completion_tokens_details)
|
|
if reasoning > 0 or total_usage.reasoning_tokens is not None:
|
|
total_usage.reasoning_tokens = (total_usage.reasoning_tokens or 0) + reasoning
|
|
|
|
return total_usage
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_messages(
|
|
self,
|
|
run_id: str,
|
|
actor: PydanticUser,
|
|
limit: Optional[int] = 100,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
order: Literal["asc", "desc"] = "asc",
|
|
) -> List[LettaMessage]:
|
|
"""Get the result of a run."""
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
request_config = run.request_config
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=run.agent_id, actor=actor, include_relationships=[])
|
|
text_is_assistant_message = agent.agent_type == AgentType.letta_v1_agent
|
|
|
|
messages = await self.message_manager.list_messages(
|
|
actor=actor,
|
|
run_id=run_id,
|
|
limit=limit,
|
|
before=before,
|
|
after=after,
|
|
ascending=(order == "asc"),
|
|
)
|
|
letta_messages = PydanticMessage.to_letta_messages_from_list(
|
|
messages, reverse=(order != "asc"), text_is_assistant_message=text_is_assistant_message
|
|
)
|
|
|
|
if request_config and request_config.include_return_message_types:
|
|
include_return_message_types_set = set(request_config.include_return_message_types)
|
|
letta_messages = [msg for msg in letta_messages if msg.message_type in include_return_message_types_set]
|
|
|
|
return letta_messages
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]:
|
|
"""Get the letta request config from a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
pydantic_run = run.to_pydantic()
|
|
return pydantic_run.request_config
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
|
|
"""Get metrics for a run."""
|
|
async with db_registry.async_session() as session:
|
|
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
return metrics.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_steps(
|
|
self,
|
|
run_id: str,
|
|
actor: PydanticUser,
|
|
limit: Optional[int] = 100,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
ascending: bool = False,
|
|
) -> List[PydanticStep]:
|
|
"""Get steps for a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
steps = await self.step_manager.list_steps_async(
|
|
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
|
|
)
|
|
return steps
|
|
|
|
@enforce_types
|
|
async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None:
|
|
"""Cancel a run."""
|
|
|
|
# make sure run_id and agent_id are not both None
|
|
if not run_id:
|
|
# get the last agent run
|
|
if not agent_id:
|
|
raise ValueError("Agent ID is required to cancel a run by ID")
|
|
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
|
|
run_ids = await self.list_runs(
|
|
actor=actor,
|
|
ascending=False,
|
|
agent_id=agent_id,
|
|
)
|
|
run_ids = [run.id for run in run_ids]
|
|
else:
|
|
# get the agent
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
agent_id = run.agent_id
|
|
|
|
logger.debug(f"Cancelling run {run_id} for agent {agent_id}")
|
|
|
|
# Cancellation should be idempotent: if a run is already terminated, treat this as a no-op.
|
|
# This commonly happens when a run finishes between client request and server handling.
|
|
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
|
|
logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
|
|
return
|
|
|
|
# Check if agent is waiting for approval by examining the last message
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
|
was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request()
|
|
|
|
# cancel the run
|
|
# NOTE: this should update the agent's last stop reason to cancelled
|
|
run = await self.update_run_by_id_async(
|
|
run_id=run_id,
|
|
update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled),
|
|
actor=actor,
|
|
conversation_id=run.conversation_id,
|
|
)
|
|
|
|
# cleanup the agent's state
|
|
# if was pending approval, we need to cleanup the approval state
|
|
if was_pending_approval:
|
|
logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}")
|
|
approval_request_message = current_in_context_messages[-1]
|
|
|
|
# Find ALL pending tool calls (both requiring approval and not requiring approval)
|
|
# The assistant message may have tool calls that didn't require approval
|
|
all_pending_tool_calls = []
|
|
if approval_request_message.tool_calls:
|
|
all_pending_tool_calls.extend(approval_request_message.tool_calls)
|
|
|
|
# Check if there's an assistant message before the approval request with additional tool calls
|
|
if len(current_in_context_messages) >= 2:
|
|
potential_assistant_msg = current_in_context_messages[-2]
|
|
if potential_assistant_msg.role == MessageRole.assistant and potential_assistant_msg.tool_calls:
|
|
# Add any tool calls from the assistant message that aren't already in the approval request
|
|
approval_tool_call_ids = (
|
|
{tc.id for tc in approval_request_message.tool_calls} if approval_request_message.tool_calls else set()
|
|
)
|
|
for tool_call in potential_assistant_msg.tool_calls:
|
|
if tool_call.id not in approval_tool_call_ids:
|
|
all_pending_tool_calls.append(tool_call)
|
|
|
|
# Ensure we have tool calls to deny
|
|
if all_pending_tool_calls:
|
|
from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL
|
|
from letta.schemas.letta_message import ApprovalReturn
|
|
from letta.schemas.message import ApprovalCreate
|
|
from letta.server.rest_api.utils import (
|
|
create_approval_response_message_from_input,
|
|
create_tool_message_from_returns,
|
|
create_tool_returns_for_denials,
|
|
)
|
|
|
|
# Create denials for ALL pending tool calls (including those that didn't require approval)
|
|
denials = (
|
|
[
|
|
ApprovalReturn(
|
|
tool_call_id=tool_call.id,
|
|
approve=False,
|
|
reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
|
)
|
|
for tool_call in approval_request_message.tool_calls
|
|
]
|
|
if approval_request_message.tool_calls
|
|
else []
|
|
)
|
|
|
|
# Create an ApprovalCreate input with the denials
|
|
approval_input = ApprovalCreate(
|
|
approvals=denials,
|
|
approval_request_id=approval_request_message.id,
|
|
)
|
|
|
|
# Use the standard function to create properly formatted approval response messages
|
|
approval_response_messages = await create_approval_response_message_from_input(
|
|
agent_state=agent_state,
|
|
input_message=approval_input,
|
|
run_id=run_id,
|
|
)
|
|
|
|
# Create tool returns for ALL denied tool calls using shared helper
|
|
# This includes both tool calls requiring approval AND those that didn't
|
|
tool_returns = create_tool_returns_for_denials(
|
|
tool_calls=all_pending_tool_calls,
|
|
denial_reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
|
timezone=agent_state.timezone,
|
|
)
|
|
|
|
# Create tool message with all denial returns using shared helper
|
|
tool_message = create_tool_message_from_returns(
|
|
agent_id=agent_state.id,
|
|
model=agent_state.llm_config.model,
|
|
tool_returns=tool_returns,
|
|
run_id=run_id,
|
|
)
|
|
|
|
# Combine approval response and tool messages
|
|
new_messages = approval_response_messages + [tool_message]
|
|
|
|
# Checkpoint the new messages
|
|
from letta.agents.agent_loop import AgentLoop
|
|
|
|
agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor)
|
|
new_in_context_messages = current_in_context_messages + new_messages
|
|
await agent_loop._checkpoint_messages(
|
|
run_id=run_id,
|
|
step_id=approval_request_message.step_id,
|
|
new_messages=new_messages,
|
|
in_context_messages=new_in_context_messages,
|
|
)
|
|
|
|
# persisted_messages = await self.message_manager.create_many_messages_async(
|
|
# pydantic_msgs=new_messages,
|
|
# actor=actor,
|
|
# run_id=run_id,
|
|
# )
|
|
# logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
|
|
|
|
## Update the agent's message_ids to include the new messages (approval + tool message)
|
|
# agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
|
|
# await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
|
|
|
|
logger.debug(
|
|
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
|
|
f"Approval request message ID: {approval_request_message.id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Last message is an approval request but has no tool_calls. "
|
|
f"Message ID: {approval_request_message.id}, Run ID: {run_id}"
|
|
)
|
|
|
|
return run
|