feat: add gating around agent run cancellation (#3446)

This commit is contained in:
cthomas
2025-07-20 23:46:26 -07:00
committed by GitHub
parent f2ce39b5c7
commit c85dca4ec8
3 changed files with 67 additions and 52 deletions

View File

@@ -705,28 +705,32 @@ async def send_message(
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
# Create a new run for execution tracking
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
if settings.track_agent_run:
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
),
),
),
actor=actor,
)
actor=actor,
)
else:
run = None
job_update_metadata = None
# TODO (cliandy): clean this up
redis_client = await get_redis_client()
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
try:
if agent_eligible and model_compatible:
@@ -741,7 +745,7 @@ async def send_message(
job_manager=server.job_manager,
actor=actor,
group=agent.multi_agent_group,
current_run_id=run.id,
current_run_id=run.id if run else None,
)
else:
agent_loop = LettaAgent(
@@ -754,7 +758,7 @@ async def send_message(
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
current_run_id=run.id if run else None,
# summarizer settings to be added here
summarizer_mode=(
SummarizationMode.STATIC_MESSAGE_BUFFER
@@ -790,12 +794,13 @@ async def send_message(
job_status = JobStatus.failed
raise
finally:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
metadata=job_update_metadata,
)
if settings.track_agent_run:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
metadata=job_update_metadata,
)
# noinspection PyInconsistentReturns
@@ -836,29 +841,32 @@ async def send_message_streaming(
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
# Create a new job for execution tracking
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message_streaming",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
if settings.track_agent_run:
job_status = JobStatus.created
run = await server.job_manager.create_job_async(
pydantic_job=Run(
user_id=actor.id,
status=job_status,
metadata={
"job_type": "send_message_streaming",
"agent_id": agent_id,
},
request_config=LettaRequestConfig(
use_assistant_message=request.use_assistant_message,
assistant_message_tool_name=request.assistant_message_tool_name,
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
include_return_message_types=request.include_return_message_types,
),
),
),
actor=actor,
)
actor=actor,
)
else:
run = None
job_update_metadata = None
# TODO (cliandy): clean this up
redis_client = await get_redis_client()
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
try:
if agent_eligible and model_compatible:
@@ -875,7 +883,7 @@ async def send_message_streaming(
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
group=agent.multi_agent_group,
current_run_id=run.id,
current_run_id=run.id if run else None,
)
else:
agent_loop = LettaAgent(
@@ -888,7 +896,7 @@ async def send_message_streaming(
actor=actor,
step_manager=server.step_manager,
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
current_run_id=run.id,
current_run_id=run.id if run else None,
# summarizer settings to be added here
summarizer_mode=(
SummarizationMode.STATIC_MESSAGE_BUFFER
@@ -941,12 +949,13 @@ async def send_message_streaming(
job_status = JobStatus.failed
raise
finally:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
metadata=job_update_metadata,
)
if settings.track_agent_run:
await server.job_manager.safe_update_job_status_async(
job_id=run.id,
new_status=job_status,
actor=actor,
metadata=job_update_metadata,
)
@router.post("/{agent_id}/messages/cancel", operation_id="cancel_agent_run")
@@ -963,6 +972,8 @@ async def cancel_agent_run(
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
if not settings.track_agent_run:
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
if not run_ids:
redis_client = await get_redis_client()
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")

View File

@@ -7,6 +7,7 @@ from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
from letta.settings import settings
router = APIRouter(prefix="/jobs", tags=["jobs"])
@@ -93,6 +94,8 @@ async def cancel_job(
agent execution to terminate as soon as possible.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
if not settings.track_agent_run:
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
try:
# First check if the job exists and is in a cancellable state

View File

@@ -235,6 +235,7 @@ class Settings(BaseSettings):
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
track_agent_run: bool = Field(default=True, description="Enable tracking agent run with cancellation support")
# FastAPI Application Settings
uvicorn_workers: int = 1