feat: add gating around agent run cancellation (#3446)
This commit is contained in:
@@ -705,28 +705,32 @@ async def send_message(
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "together", "google_ai", "google_vertex", "bedrock"]
|
||||
|
||||
# Create a new run for execution tracking
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
if settings.track_agent_run:
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
actor=actor,
|
||||
)
|
||||
else:
|
||||
run = None
|
||||
|
||||
job_update_metadata = None
|
||||
# TODO (cliandy): clean this up
|
||||
redis_client = await get_redis_client()
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
|
||||
|
||||
try:
|
||||
if agent_eligible and model_compatible:
|
||||
@@ -741,7 +745,7 @@ async def send_message(
|
||||
job_manager=server.job_manager,
|
||||
actor=actor,
|
||||
group=agent.multi_agent_group,
|
||||
current_run_id=run.id,
|
||||
current_run_id=run.id if run else None,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
@@ -754,7 +758,7 @@ async def send_message(
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
current_run_id=run.id,
|
||||
current_run_id=run.id if run else None,
|
||||
# summarizer settings to be added here
|
||||
summarizer_mode=(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER
|
||||
@@ -790,12 +794,13 @@ async def send_message(
|
||||
job_status = JobStatus.failed
|
||||
raise
|
||||
finally:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
if settings.track_agent_run:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
|
||||
|
||||
# noinspection PyInconsistentReturns
|
||||
@@ -836,29 +841,32 @@ async def send_message_streaming(
|
||||
not_letta_endpoint = agent.llm_config.model_endpoint != LETTA_MODEL_ENDPOINT
|
||||
|
||||
# Create a new job for execution tracking
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message_streaming",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
if settings.track_agent_run:
|
||||
job_status = JobStatus.created
|
||||
run = await server.job_manager.create_job_async(
|
||||
pydantic_job=Run(
|
||||
user_id=actor.id,
|
||||
status=job_status,
|
||||
metadata={
|
||||
"job_type": "send_message_streaming",
|
||||
"agent_id": agent_id,
|
||||
},
|
||||
request_config=LettaRequestConfig(
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_tool_name=request.assistant_message_tool_name,
|
||||
assistant_message_tool_kwarg=request.assistant_message_tool_kwarg,
|
||||
include_return_message_types=request.include_return_message_types,
|
||||
),
|
||||
),
|
||||
),
|
||||
actor=actor,
|
||||
)
|
||||
actor=actor,
|
||||
)
|
||||
else:
|
||||
run = None
|
||||
|
||||
job_update_metadata = None
|
||||
# TODO (cliandy): clean this up
|
||||
redis_client = await get_redis_client()
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id)
|
||||
await redis_client.set(f"{REDIS_RUN_ID_PREFIX}:{agent_id}", run.id if run else None)
|
||||
|
||||
try:
|
||||
if agent_eligible and model_compatible:
|
||||
@@ -875,7 +883,7 @@ async def send_message_streaming(
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
group=agent.multi_agent_group,
|
||||
current_run_id=run.id,
|
||||
current_run_id=run.id if run else None,
|
||||
)
|
||||
else:
|
||||
agent_loop = LettaAgent(
|
||||
@@ -888,7 +896,7 @@ async def send_message_streaming(
|
||||
actor=actor,
|
||||
step_manager=server.step_manager,
|
||||
telemetry_manager=server.telemetry_manager if settings.llm_api_logging else NoopTelemetryManager(),
|
||||
current_run_id=run.id,
|
||||
current_run_id=run.id if run else None,
|
||||
# summarizer settings to be added here
|
||||
summarizer_mode=(
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER
|
||||
@@ -941,12 +949,13 @@ async def send_message_streaming(
|
||||
job_status = JobStatus.failed
|
||||
raise
|
||||
finally:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
if settings.track_agent_run:
|
||||
await server.job_manager.safe_update_job_status_async(
|
||||
job_id=run.id,
|
||||
new_status=job_status,
|
||||
actor=actor,
|
||||
metadata=job_update_metadata,
|
||||
)
|
||||
|
||||
|
||||
@router.post("/{agent_id}/messages/cancel", operation_id="cancel_agent_run")
|
||||
@@ -963,6 +972,8 @@ async def cancel_agent_run(
|
||||
"""
|
||||
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if not settings.track_agent_run:
|
||||
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||
if not run_ids:
|
||||
redis_client = await get_redis_client()
|
||||
run_id = await redis_client.get(f"{REDIS_RUN_ID_PREFIX}:{agent_id}")
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
router = APIRouter(prefix="/jobs", tags=["jobs"])
|
||||
|
||||
@@ -93,6 +94,8 @@ async def cancel_job(
|
||||
agent execution to terminate as soon as possible.
|
||||
"""
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if not settings.track_agent_run:
|
||||
raise HTTPException(status_code=400, detail="Agent run tracking is disabled")
|
||||
|
||||
try:
|
||||
# First check if the job exists and is in a cancellable state
|
||||
|
||||
@@ -235,6 +235,7 @@ class Settings(BaseSettings):
|
||||
track_last_agent_run: bool = Field(default=False, description="Update last agent run metrics")
|
||||
track_errored_messages: bool = Field(default=True, description="Enable tracking for errored messages")
|
||||
track_stop_reason: bool = Field(default=True, description="Enable tracking stop reason on steps.")
|
||||
track_agent_run: bool = Field(default=True, description="Enable tracking agent run with cancellation support")
|
||||
|
||||
# FastAPI Application Settings
|
||||
uvicorn_workers: int = 1
|
||||
|
||||
Reference in New Issue
Block a user