723 lines
34 KiB
Python
723 lines
34 KiB
Python
from datetime import datetime
|
|
from multiprocessing import Value
|
|
from pickletools import pyunicode
|
|
from typing import List, Literal, Optional
|
|
|
|
from httpx import AsyncClient
|
|
|
|
from letta.errors import LettaInvalidArgumentError
|
|
from letta.helpers.datetime_helpers import get_utc_time
|
|
from letta.log import get_logger
|
|
from letta.log_context import update_log_context
|
|
from letta.orm.agent import Agent as AgentModel
|
|
from letta.orm.errors import NoResultFound
|
|
from letta.orm.message import Message as MessageModel
|
|
from letta.orm.run import Run as RunModel
|
|
from letta.orm.run_metrics import RunMetrics as RunMetricsModel
|
|
from letta.orm.sqlalchemy_base import AccessType
|
|
from letta.orm.step import Step as StepModel
|
|
from letta.otel.tracing import log_event, trace_method
|
|
from letta.schemas.enums import AgentType, ComparisonOperator, MessageRole, PrimitiveType, RunStatus
|
|
from letta.schemas.job import LettaRequestConfig
|
|
from letta.schemas.letta_message import LettaMessage, LettaMessageUnion
|
|
from letta.schemas.letta_response import LettaResponse
|
|
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.schemas.run import Run as PydanticRun, RunUpdate
|
|
from letta.schemas.run_metrics import RunMetrics as PydanticRunMetrics
|
|
from letta.schemas.step import Step as PydanticStep
|
|
from letta.schemas.usage import LettaUsageStatistics, normalize_cache_tokens, normalize_reasoning_tokens
|
|
from letta.schemas.user import User as PydanticUser
|
|
from letta.server.db import db_registry
|
|
from letta.services.agent_manager import AgentManager
|
|
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
|
from letta.services.message_manager import MessageManager
|
|
from letta.services.step_manager import StepManager
|
|
from letta.utils import enforce_types
|
|
from letta.validators import raise_on_invalid_id
|
|
|
|
logger = get_logger(__name__)
|
|
|
|
|
|
class RunManager:
|
|
"""Manager class to handle business logic related to Runs."""
|
|
|
|
def __init__(self):
|
|
"""Initialize the RunManager."""
|
|
self.step_manager = StepManager()
|
|
self.message_manager = MessageManager()
|
|
self.agent_manager = AgentManager()
|
|
|
|
@enforce_types
|
|
async def create_run(self, pydantic_run: PydanticRun, actor: PydanticUser) -> PydanticRun:
|
|
"""Create a new run."""
|
|
async with db_registry.async_session() as session:
|
|
# Get agent_id from the pydantic object
|
|
agent_id = pydantic_run.agent_id
|
|
|
|
# Verify agent exists before creating the run
|
|
await validate_agent_exists_async(session, agent_id, actor)
|
|
organization_id = actor.organization_id
|
|
|
|
run_data = pydantic_run.model_dump(exclude_none=True)
|
|
# Handle metadata field mapping (Pydantic uses 'metadata', ORM uses 'metadata_')
|
|
if "metadata" in run_data:
|
|
run_data["metadata_"] = run_data.pop("metadata")
|
|
|
|
run = RunModel(**run_data)
|
|
run.organization_id = organization_id
|
|
|
|
# Get the project_id from the agent
|
|
agent = await session.get(AgentModel, agent_id)
|
|
project_id = agent.project_id if agent else None
|
|
run.project_id = project_id
|
|
|
|
run = await run.create_async(session, actor=actor, no_commit=True, no_refresh=True)
|
|
|
|
update_log_context(run_id=run.id)
|
|
|
|
# Create run metrics with start timestamp
|
|
import time
|
|
|
|
metrics = RunMetricsModel(
|
|
id=run.id,
|
|
organization_id=organization_id,
|
|
agent_id=agent_id,
|
|
project_id=project_id,
|
|
run_start_ns=int(time.time() * 1e9), # Current time in nanoseconds
|
|
num_steps=0, # Initialize to 0
|
|
)
|
|
await metrics.create_async(session)
|
|
await session.commit()
|
|
|
|
return run.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_by_id(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
|
"""Get a run by its ID."""
|
|
update_log_context(run_id=run_id)
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
return run.to_pydantic()
|
|
|
|
@enforce_types
|
|
async def get_run_with_status(self, run_id: str, actor: PydanticUser) -> PydanticRun:
|
|
"""Get a run by its ID and update status from Lettuce if applicable."""
|
|
update_log_context(run_id=run_id)
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
|
|
use_lettuce = run.metadata and run.metadata.get("lettuce")
|
|
if use_lettuce and run.status not in [RunStatus.completed, RunStatus.failed, RunStatus.cancelled]:
|
|
try:
|
|
from letta.services.lettuce import LettuceClient
|
|
|
|
lettuce_client = await LettuceClient.create()
|
|
status = await lettuce_client.get_status(run_id=run_id)
|
|
|
|
# Map the status to our enum
|
|
if status == "RUNNING":
|
|
run.status = RunStatus.running
|
|
elif status == "COMPLETED":
|
|
run.status = RunStatus.completed
|
|
elif status == "FAILED":
|
|
run.status = RunStatus.failed
|
|
elif status == "CANCELLED":
|
|
run.status = RunStatus.cancelled
|
|
except Exception as e:
|
|
logger.error(f"Failed to get status from Lettuce for run {run_id}: {str(e)}")
|
|
# Return run with current status from DB if Lettuce fails
|
|
|
|
return run
|
|
|
|
@enforce_types
|
|
async def list_runs(
|
|
self,
|
|
actor: PydanticUser,
|
|
agent_id: Optional[str] = None,
|
|
agent_ids: Optional[List[str]] = None,
|
|
statuses: Optional[List[RunStatus]] = None,
|
|
limit: Optional[int] = 50,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
ascending: bool = False,
|
|
stop_reason: Optional[str] = None,
|
|
background: Optional[bool] = None,
|
|
template_family: Optional[str] = None,
|
|
step_count: Optional[int] = None,
|
|
step_count_operator: ComparisonOperator = ComparisonOperator.EQ,
|
|
tools_used: Optional[List[str]] = None,
|
|
project_id: Optional[str] = None,
|
|
order_by: Literal["created_at", "duration"] = "created_at",
|
|
duration_percentile: Optional[int] = None,
|
|
duration_filter: Optional[dict] = None,
|
|
start_date: Optional[datetime] = None,
|
|
end_date: Optional[datetime] = None,
|
|
) -> List[PydanticRun]:
|
|
"""List runs with filtering options."""
|
|
async with db_registry.async_session() as session:
|
|
from sqlalchemy import func, or_, select
|
|
|
|
# Always join with run_metrics to get duration data
|
|
query = (
|
|
select(RunModel, RunMetricsModel.run_ns)
|
|
.outerjoin(RunMetricsModel, RunModel.id == RunMetricsModel.id)
|
|
.filter(RunModel.organization_id == actor.organization_id)
|
|
)
|
|
|
|
# Filter by project_id if provided
|
|
if project_id:
|
|
query = query.filter(RunModel.project_id == project_id)
|
|
|
|
# Handle agent filtering
|
|
if agent_id:
|
|
agent_ids = [agent_id]
|
|
if agent_ids:
|
|
query = query.filter(RunModel.agent_id.in_(agent_ids))
|
|
|
|
# Filter by status
|
|
if statuses:
|
|
query = query.filter(RunModel.status.in_(statuses))
|
|
|
|
# Filter by stop reason
|
|
if stop_reason:
|
|
query = query.filter(RunModel.stop_reason == stop_reason)
|
|
|
|
# Filter by background
|
|
if background is not None:
|
|
query = query.filter(RunModel.background == background)
|
|
|
|
# Filter by template_family (base_template_id)
|
|
if template_family:
|
|
query = query.filter(RunModel.base_template_id == template_family)
|
|
|
|
# Filter by date range
|
|
if start_date:
|
|
query = query.filter(RunModel.created_at >= start_date)
|
|
if end_date:
|
|
query = query.filter(RunModel.created_at <= end_date)
|
|
|
|
# Filter by step_count with the specified operator
|
|
if step_count is not None:
|
|
if step_count_operator == ComparisonOperator.EQ:
|
|
query = query.filter(RunMetricsModel.num_steps == step_count)
|
|
elif step_count_operator == ComparisonOperator.GTE:
|
|
query = query.filter(RunMetricsModel.num_steps >= step_count)
|
|
elif step_count_operator == ComparisonOperator.LTE:
|
|
query = query.filter(RunMetricsModel.num_steps <= step_count)
|
|
|
|
# Filter by tools used ids
|
|
if tools_used:
|
|
from sqlalchemy import String, cast as sa_cast, type_coerce
|
|
from sqlalchemy.dialects.postgresql import ARRAY, JSONB
|
|
|
|
# Use ?| operator to check if any tool_id exists in the array (OR logic)
|
|
jsonb_tools = sa_cast(RunMetricsModel.tools_used, JSONB)
|
|
tools_array = type_coerce(tools_used, ARRAY(String))
|
|
query = query.filter(jsonb_tools.op("?|")(tools_array))
|
|
|
|
# Ensure run_ns is not null when working with duration
|
|
if order_by == "duration" or duration_percentile is not None or duration_filter is not None:
|
|
query = query.filter(RunMetricsModel.run_ns.isnot(None))
|
|
|
|
# Apply duration filter if requested
|
|
if duration_filter is not None:
|
|
duration_value = duration_filter.get("value") if isinstance(duration_filter, dict) else duration_filter.value
|
|
duration_operator = duration_filter.get("operator") if isinstance(duration_filter, dict) else duration_filter.operator
|
|
|
|
if duration_operator == "gt":
|
|
query = query.filter(RunMetricsModel.run_ns > duration_value)
|
|
elif duration_operator == "lt":
|
|
query = query.filter(RunMetricsModel.run_ns < duration_value)
|
|
elif duration_operator == "eq":
|
|
query = query.filter(RunMetricsModel.run_ns == duration_value)
|
|
|
|
# Apply duration percentile filter if requested
|
|
if duration_percentile is not None:
|
|
# Calculate the percentile threshold
|
|
percentile_query = (
|
|
select(func.percentile_cont(duration_percentile / 100.0).within_group(RunMetricsModel.run_ns))
|
|
.select_from(RunMetricsModel)
|
|
.join(RunModel, RunModel.id == RunMetricsModel.id)
|
|
.filter(RunModel.organization_id == actor.organization_id)
|
|
.filter(RunMetricsModel.run_ns.isnot(None))
|
|
)
|
|
|
|
# Apply same filters to percentile calculation
|
|
if project_id:
|
|
percentile_query = percentile_query.filter(RunModel.project_id == project_id)
|
|
if agent_ids:
|
|
percentile_query = percentile_query.filter(RunModel.agent_id.in_(agent_ids))
|
|
if statuses:
|
|
percentile_query = percentile_query.filter(RunModel.status.in_(statuses))
|
|
|
|
# Execute percentile query
|
|
percentile_result = await session.execute(percentile_query)
|
|
percentile_threshold = percentile_result.scalar()
|
|
|
|
# Filter by percentile threshold (runs slower than the percentile)
|
|
if percentile_threshold is not None:
|
|
query = query.filter(RunMetricsModel.run_ns >= percentile_threshold)
|
|
|
|
# Apply sorting based on order_by
|
|
if order_by == "duration":
|
|
# Sort by duration
|
|
if ascending:
|
|
query = query.order_by(RunMetricsModel.run_ns.asc())
|
|
else:
|
|
query = query.order_by(RunMetricsModel.run_ns.desc())
|
|
else:
|
|
# Apply pagination for created_at ordering
|
|
from letta.services.helpers.run_manager_helper import _apply_pagination_async
|
|
|
|
query = await _apply_pagination_async(query, before, after, session, ascending=ascending)
|
|
|
|
# Apply limit (always enforce a maximum to prevent unbounded queries)
|
|
# If no limit specified, default to 100; enforce maximum of 1000
|
|
effective_limit = limit if limit is not None else 100
|
|
effective_limit = min(effective_limit, 1000)
|
|
query = query.limit(effective_limit)
|
|
|
|
result = await session.execute(query)
|
|
rows = result.all()
|
|
|
|
# Populate total_duration_ns from run_metrics.run_ns
|
|
pydantic_runs = []
|
|
for row in rows:
|
|
run_model = row[0]
|
|
run_ns = row[1]
|
|
|
|
pydantic_run = run_model.to_pydantic()
|
|
if run_ns is not None:
|
|
pydantic_run.total_duration_ns = run_ns
|
|
|
|
pydantic_runs.append(pydantic_run)
|
|
|
|
return pydantic_runs
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def delete_run(self, run_id: str, actor: PydanticUser) -> None:
|
|
"""Delete a run by its ID."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
await run.hard_delete_async(db_session=session, actor=actor)
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
@trace_method
|
|
async def update_run_by_id_async(
|
|
self, run_id: str, update: RunUpdate, actor: PydanticUser, refresh_result_messages: bool = True
|
|
) -> PydanticRun:
|
|
"""Update a run using a RunUpdate object."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
|
|
# Check if this is a terminal update and whether we should dispatch a callback
|
|
needs_callback = False
|
|
callback_url = None
|
|
not_completed_before = not bool(run.completed_at)
|
|
is_terminal_update = update.status in {RunStatus.completed, RunStatus.failed, RunStatus.cancelled}
|
|
if is_terminal_update and not_completed_before and run.callback_url:
|
|
needs_callback = True
|
|
callback_url = run.callback_url
|
|
|
|
# validate run lifecycle (only log the errors)
|
|
if run.status in {RunStatus.completed}:
|
|
if update.status not in {RunStatus.cancelled}:
|
|
# a completed run can only be marked as cancelled
|
|
logger.error(
|
|
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
|
|
)
|
|
if update.stop_reason not in {StopReasonType.requires_approval}:
|
|
# a completed run can only be cancelled if the stop reason is requires approval
|
|
logger.error(
|
|
f"Run {run_id} is already completed with stop reason {run.stop_reason}, but is being marked as {update.status} with stop reason {update.stop_reason}"
|
|
)
|
|
if run.status in {RunStatus.failed, RunStatus.cancelled}:
|
|
logger.error(
|
|
f"Run {run_id} is already in a terminal state {run.status} with stop reason {run.stop_reason}, but is being updated with data {update.model_dump()}"
|
|
)
|
|
|
|
# Housekeeping only when the run is actually completing
|
|
if not_completed_before and is_terminal_update:
|
|
if not update.stop_reason:
|
|
logger.error(f"Run {run_id} completed without a stop reason")
|
|
if not update.completed_at:
|
|
logger.warning(f"Run {run_id} completed without a completed_at timestamp")
|
|
update.completed_at = get_utc_time().replace(tzinfo=None)
|
|
|
|
# Update job attributes with only the fields that were explicitly set
|
|
update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
|
|
|
# Automatically update the completion timestamp if status is set to 'completed'
|
|
for key, value in update_data.items():
|
|
# Ensure completed_at is timezone-naive for database compatibility
|
|
if key == "completed_at" and value is not None and hasattr(value, "replace"):
|
|
value = value.replace(tzinfo=None)
|
|
setattr(run, key, value)
|
|
|
|
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
final_metadata = run.metadata_
|
|
pydantic_run = run.to_pydantic()
|
|
|
|
await session.commit()
|
|
|
|
# Update agent's last_stop_reason when run completes
|
|
# Do this after run update is committed to database
|
|
if is_terminal_update and update.stop_reason:
|
|
try:
|
|
from letta.schemas.agent import UpdateAgent
|
|
|
|
await self.agent_manager.update_agent_async(
|
|
agent_id=pydantic_run.agent_id,
|
|
agent_update=UpdateAgent(last_stop_reason=update.stop_reason),
|
|
actor=actor,
|
|
)
|
|
except Exception as e:
|
|
logger.error(f"Failed to update agent's last_stop_reason for run {run_id}: {e}")
|
|
|
|
# update run metrics table
|
|
num_steps = len(await self.step_manager.list_steps_async(run_id=run_id, actor=actor))
|
|
|
|
# Collect tools used from run messages
|
|
tools_used = set()
|
|
messages = await self.message_manager.list_messages(actor=actor, run_id=run_id)
|
|
for message in messages:
|
|
if message.tool_calls:
|
|
for tool_call in message.tool_calls:
|
|
if hasattr(tool_call, "function") and hasattr(tool_call.function, "name"):
|
|
# Get tool ID from tool name
|
|
from letta.services.tool_manager import ToolManager
|
|
|
|
tool_manager = ToolManager()
|
|
tool_name = tool_call.function.name
|
|
tool_id = await tool_manager.get_tool_id_by_name_async(tool_name, actor)
|
|
if tool_id:
|
|
tools_used.add(tool_id)
|
|
|
|
async with db_registry.async_session() as session:
|
|
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
# Calculate runtime if run is completing
|
|
if is_terminal_update:
|
|
# Use total_duration_ns from RunUpdate if provided
|
|
# Otherwise fall back to system time
|
|
if update.total_duration_ns is not None:
|
|
metrics.run_ns = update.total_duration_ns
|
|
elif metrics.run_start_ns:
|
|
import time
|
|
|
|
current_ns = int(time.time() * 1e9)
|
|
metrics.run_ns = current_ns - metrics.run_start_ns
|
|
metrics.num_steps = num_steps
|
|
metrics.tools_used = list(tools_used) if tools_used else None
|
|
await metrics.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
await session.commit()
|
|
|
|
# Dispatch callback outside of database session if needed
|
|
if needs_callback:
|
|
if refresh_result_messages:
|
|
result = LettaResponse(
|
|
messages=await self.get_run_messages(run_id=run_id, actor=actor),
|
|
stop_reason=LettaStopReason(stop_reason=pydantic_run.stop_reason),
|
|
usage=await self.get_run_usage(run_id=run_id, actor=actor),
|
|
)
|
|
final_metadata["result"] = result.model_dump()
|
|
callback_info = {
|
|
"run_id": run_id,
|
|
"callback_url": callback_url,
|
|
"status": update.status,
|
|
"completed_at": get_utc_time().replace(tzinfo=None),
|
|
"metadata": final_metadata,
|
|
}
|
|
callback_result = await self._dispatch_callback_async(callback_info)
|
|
|
|
# Update callback status in a separate transaction
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
run.callback_sent_at = callback_result["callback_sent_at"]
|
|
run.callback_status_code = callback_result.get("callback_status_code")
|
|
run.callback_error = callback_result.get("callback_error")
|
|
pydantic_run = run.to_pydantic()
|
|
await run.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
|
await session.commit()
|
|
|
|
return pydantic_run
|
|
|
|
@trace_method
|
|
async def _dispatch_callback_async(self, callback_info: dict) -> dict:
|
|
"""
|
|
POST a standard JSON payload to callback_url and return callback status asynchronously.
|
|
"""
|
|
payload = {
|
|
"run_id": callback_info["run_id"],
|
|
"status": callback_info["status"],
|
|
"completed_at": callback_info["completed_at"].isoformat() if callback_info["completed_at"] else None,
|
|
"metadata": callback_info["metadata"],
|
|
}
|
|
|
|
callback_sent_at = get_utc_time().replace(tzinfo=None)
|
|
result = {"callback_sent_at": callback_sent_at}
|
|
|
|
try:
|
|
async with AsyncClient() as client:
|
|
log_event("POST callback dispatched", payload)
|
|
resp = await client.post(callback_info["callback_url"], json=payload, timeout=5.0)
|
|
log_event("POST callback finished")
|
|
result["callback_status_code"] = resp.status_code
|
|
except Exception as e:
|
|
error_message = f"Failed to dispatch callback for run {callback_info['run_id']} to {callback_info['callback_url']}: {e!r}"
|
|
logger.error(error_message)
|
|
result["callback_error"] = error_message
|
|
# Continue silently - callback failures should not affect run completion
|
|
|
|
return result
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_usage(self, run_id: str, actor: PydanticUser) -> LettaUsageStatistics:
|
|
"""Get usage statistics for a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
steps = await self.step_manager.list_steps_async(run_id=run_id, actor=actor)
|
|
total_usage = LettaUsageStatistics()
|
|
for step in steps:
|
|
total_usage.prompt_tokens += step.prompt_tokens
|
|
total_usage.completion_tokens += step.completion_tokens
|
|
total_usage.total_tokens += step.total_tokens
|
|
total_usage.step_count += 1
|
|
|
|
# Aggregate cache and reasoning tokens from detailed breakdowns using normalized helpers
|
|
# Handle None defaults: only set if we have data, accumulate if already set
|
|
cached_input, cache_write = normalize_cache_tokens(step.prompt_tokens_details)
|
|
if cached_input > 0 or total_usage.cached_input_tokens is not None:
|
|
total_usage.cached_input_tokens = (total_usage.cached_input_tokens or 0) + cached_input
|
|
if cache_write > 0 or total_usage.cache_write_tokens is not None:
|
|
total_usage.cache_write_tokens = (total_usage.cache_write_tokens or 0) + cache_write
|
|
reasoning = normalize_reasoning_tokens(step.completion_tokens_details)
|
|
if reasoning > 0 or total_usage.reasoning_tokens is not None:
|
|
total_usage.reasoning_tokens = (total_usage.reasoning_tokens or 0) + reasoning
|
|
|
|
return total_usage
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_messages(
|
|
self,
|
|
run_id: str,
|
|
actor: PydanticUser,
|
|
limit: Optional[int] = 100,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
order: Literal["asc", "desc"] = "asc",
|
|
) -> List[LettaMessage]:
|
|
"""Get the result of a run."""
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
request_config = run.request_config
|
|
agent = await self.agent_manager.get_agent_by_id_async(agent_id=run.agent_id, actor=actor, include_relationships=[])
|
|
text_is_assistant_message = agent.agent_type == AgentType.letta_v1_agent
|
|
|
|
messages = await self.message_manager.list_messages(
|
|
actor=actor,
|
|
run_id=run_id,
|
|
limit=limit,
|
|
before=before,
|
|
after=after,
|
|
ascending=(order == "asc"),
|
|
)
|
|
letta_messages = PydanticMessage.to_letta_messages_from_list(
|
|
messages, reverse=(order != "asc"), text_is_assistant_message=text_is_assistant_message
|
|
)
|
|
|
|
if request_config and request_config.include_return_message_types:
|
|
include_return_message_types_set = set(request_config.include_return_message_types)
|
|
letta_messages = [msg for msg in letta_messages if msg.message_type in include_return_message_types_set]
|
|
|
|
return letta_messages
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_request_config(self, run_id: str, actor: PydanticUser) -> Optional[LettaRequestConfig]:
|
|
"""Get the letta request config from a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
pydantic_run = run.to_pydantic()
|
|
return pydantic_run.request_config
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_metrics_async(self, run_id: str, actor: PydanticUser) -> PydanticRunMetrics:
|
|
"""Get metrics for a run."""
|
|
async with db_registry.async_session() as session:
|
|
metrics = await RunMetricsModel.read_async(db_session=session, identifier=run_id, actor=actor)
|
|
return metrics.to_pydantic()
|
|
|
|
@enforce_types
|
|
@raise_on_invalid_id(param_name="run_id", expected_prefix=PrimitiveType.RUN)
|
|
async def get_run_steps(
|
|
self,
|
|
run_id: str,
|
|
actor: PydanticUser,
|
|
limit: Optional[int] = 100,
|
|
before: Optional[str] = None,
|
|
after: Optional[str] = None,
|
|
ascending: bool = False,
|
|
) -> List[PydanticStep]:
|
|
"""Get steps for a run."""
|
|
async with db_registry.async_session() as session:
|
|
run = await RunModel.read_async(db_session=session, identifier=run_id, actor=actor, access_type=AccessType.ORGANIZATION)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
|
|
steps = await self.step_manager.list_steps_async(
|
|
actor=actor, run_id=run_id, limit=limit, before=before, after=after, order="asc" if ascending else "desc"
|
|
)
|
|
return steps
|
|
|
|
@enforce_types
|
|
async def cancel_run(self, actor: PydanticUser, agent_id: Optional[str] = None, run_id: Optional[str] = None) -> None:
|
|
"""Cancel a run."""
|
|
|
|
# make sure run_id and agent_id are not both None
|
|
if not run_id:
|
|
# get the last agent run
|
|
if not agent_id:
|
|
raise ValueError("Agent ID is required to cancel a run by ID")
|
|
logger.warning("Cannot find run associated with agent to cancel in redis, fetching from db.")
|
|
run_ids = await self.list_runs(
|
|
actor=actor,
|
|
ascending=False,
|
|
agent_id=agent_id,
|
|
)
|
|
run_ids = [run.id for run in run_ids]
|
|
else:
|
|
# get the agent
|
|
run = await self.get_run_by_id(run_id=run_id, actor=actor)
|
|
if not run:
|
|
raise NoResultFound(f"Run with id {run_id} not found")
|
|
agent_id = run.agent_id
|
|
|
|
logger.debug(f"Cancelling run {run_id} for agent {agent_id}")
|
|
|
|
# check if run can be cancelled (cannot cancel a completed, failed, or cancelled run)
|
|
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
|
|
logger.error(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
|
|
raise LettaInvalidArgumentError(
|
|
f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}"
|
|
)
|
|
|
|
# Check if agent is waiting for approval by examining the last message
|
|
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
|
current_in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor)
|
|
was_pending_approval = current_in_context_messages and current_in_context_messages[-1].is_approval_request()
|
|
|
|
# cancel the run
|
|
# NOTE: this should update the agent's last stop reason to cancelled
|
|
run = await self.update_run_by_id_async(
|
|
run_id=run_id, update=RunUpdate(status=RunStatus.cancelled, stop_reason=StopReasonType.cancelled), actor=actor
|
|
)
|
|
|
|
# cleanup the agent's state
|
|
# if was pending approval, we need to cleanup the approval state
|
|
if was_pending_approval:
|
|
logger.debug(f"Agent was waiting for approval, adding denial messages for run {run_id}")
|
|
approval_request_message = current_in_context_messages[-1]
|
|
|
|
# Ensure the approval request has tool calls to deny
|
|
if approval_request_message.tool_calls:
|
|
from letta.constants import TOOL_CALL_DENIAL_ON_CANCEL
|
|
from letta.schemas.letta_message import ApprovalReturn
|
|
from letta.schemas.message import ApprovalCreate
|
|
from letta.server.rest_api.utils import (
|
|
create_approval_response_message_from_input,
|
|
create_tool_message_from_returns,
|
|
create_tool_returns_for_denials,
|
|
)
|
|
|
|
# Create denials for ALL pending tool calls
|
|
denials = [
|
|
ApprovalReturn(
|
|
tool_call_id=tool_call.id,
|
|
approve=False,
|
|
reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
|
)
|
|
for tool_call in approval_request_message.tool_calls
|
|
]
|
|
|
|
# Create an ApprovalCreate input with the denials
|
|
approval_input = ApprovalCreate(
|
|
approvals=denials,
|
|
approval_request_id=approval_request_message.id,
|
|
)
|
|
|
|
# Use the standard function to create properly formatted approval response messages
|
|
approval_response_messages = create_approval_response_message_from_input(
|
|
agent_state=agent_state,
|
|
input_message=approval_input,
|
|
run_id=run_id,
|
|
)
|
|
|
|
# Create tool returns for ALL denied tool calls using shared helper
|
|
# This handles all pending tool calls at once since they all have the same denial reason
|
|
tool_returns = create_tool_returns_for_denials(
|
|
tool_calls=approval_request_message.tool_calls, # ALL pending tool calls
|
|
denial_reason=TOOL_CALL_DENIAL_ON_CANCEL,
|
|
timezone=agent_state.timezone,
|
|
)
|
|
|
|
# Create tool message with all denial returns using shared helper
|
|
tool_message = create_tool_message_from_returns(
|
|
agent_id=agent_state.id,
|
|
model=agent_state.llm_config.model,
|
|
tool_returns=tool_returns,
|
|
run_id=run_id,
|
|
)
|
|
|
|
# Combine approval response and tool messages
|
|
new_messages = approval_response_messages + [tool_message]
|
|
|
|
# Checkpoint the new messages
|
|
from letta.agents.agent_loop import AgentLoop
|
|
|
|
agent_loop = AgentLoop.load(agent_state=agent_state, actor=actor)
|
|
new_in_context_messages = current_in_context_messages + new_messages
|
|
await agent_loop._checkpoint_messages(
|
|
run_id=run_id,
|
|
step_id=approval_request_message.step_id,
|
|
new_messages=new_messages,
|
|
in_context_messages=new_in_context_messages,
|
|
)
|
|
|
|
# persisted_messages = await self.message_manager.create_many_messages_async(
|
|
# pydantic_msgs=new_messages,
|
|
# actor=actor,
|
|
# run_id=run_id,
|
|
# )
|
|
# logger.debug(f"Persisted {len(persisted_messages)} messages (approval + tool returns)")
|
|
|
|
## Update the agent's message_ids to include the new messages (approval + tool message)
|
|
# agent_state.message_ids = agent_state.message_ids + [m.id for m in persisted_messages]
|
|
# await self.agent_manager.update_message_ids_async(agent_id=agent_state.id, message_ids=agent_state.message_ids, actor=actor)
|
|
|
|
logger.debug(
|
|
f"Inserted approval response with {len(denials)} denials and tool return message for cancelled run {run_id}. "
|
|
f"Approval request message ID: {approval_request_message.id}"
|
|
)
|
|
else:
|
|
logger.warning(
|
|
f"Last message is an approval request but has no tool_calls. "
|
|
f"Message ID: {approval_request_message.id}, Run ID: {run_id}"
|
|
)
|
|
|
|
return run
|