fix: fix poison state from bad approval response (#5979)
* fix: detect and fail on malformed approval responses * fix: guard against None approvals in utils.py * fix: add extra warning * fix: stop silent drops in deserialize_approvals * fix: patch v3 stream error handling to prevent sending end_turn after an error occurs, and ensures stop_reason is always set when an error occurs * fix: Prevents infinite client hangs by ensuring a terminal event is ALWAYS sent * fix: Ensures terminal events are sent even if inner stream generator fails to send them
This commit is contained in:
committed by
Caren Thomas
parent
4acda9c80f
commit
363a5c1f92
@@ -307,6 +307,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Error during agent stream: {e}", exc_info=True)
|
||||
|
||||
# Set stop_reason if not already set
|
||||
if self.stop_reason is None:
|
||||
# Classify error type
|
||||
if isinstance(e, LLMError):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error.value)
|
||||
else:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
|
||||
if first_chunk:
|
||||
# Raise if no chunks sent yet (response not started, can return error status code)
|
||||
raise
|
||||
@@ -321,24 +330,48 @@ class LettaAgentV3(LettaAgentV2):
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_chunk)}\n\n"
|
||||
|
||||
if run_id:
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
self.response_messages_for_metadata, # Use separate accumulator to preserve all messages
|
||||
use_assistant_message=False, # NOTE: set to false
|
||||
reverse=False,
|
||||
# text_is_assistant_message=(self.agent_state.agent_type == AgentType.react_agent),
|
||||
text_is_assistant_message=True,
|
||||
)
|
||||
result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
if self.job_update_metadata is None:
|
||||
self.job_update_metadata = {}
|
||||
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
||||
# Return immediately - don't fall through to finish chunks
|
||||
# This prevents sending end_turn finish chunks after an error
|
||||
return
|
||||
|
||||
await self._request_checkpoint_finish(
|
||||
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
|
||||
)
|
||||
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
|
||||
yield f"data: {finish_chunk}\n\n"
|
||||
# Cleanup and finalize (only runs if no exception occurred)
|
||||
try:
|
||||
if run_id:
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
self.response_messages_for_metadata, # Use separate accumulator to preserve all messages
|
||||
use_assistant_message=False, # NOTE: set to false
|
||||
reverse=False,
|
||||
# text_is_assistant_message=(self.agent_state.agent_type == AgentType.react_agent),
|
||||
text_is_assistant_message=True,
|
||||
)
|
||||
result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
if self.job_update_metadata is None:
|
||||
self.job_update_metadata = {}
|
||||
self.job_update_metadata["result"] = result.model_dump(mode="json")
|
||||
|
||||
await self._request_checkpoint_finish(
|
||||
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
|
||||
)
|
||||
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
|
||||
yield f"data: {finish_chunk}\n\n"
|
||||
except Exception as cleanup_error:
|
||||
# Error during cleanup/finalization - ensure we still send a terminal event
|
||||
self.logger.error(f"Error during stream cleanup: {cleanup_error}", exc_info=True)
|
||||
|
||||
# Set stop_reason if not already set
|
||||
if self.stop_reason is None:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
|
||||
# Send error event
|
||||
error_chunk = {
|
||||
"error": {
|
||||
"type": "cleanup_error",
|
||||
"message": "An error occurred during stream finalization.",
|
||||
"detail": str(cleanup_error),
|
||||
}
|
||||
}
|
||||
yield f"event: error\ndata: {json.dumps(error_chunk)}\n\n"
|
||||
# Note: we don't send finish chunks here since we already errored
|
||||
|
||||
@trace_method
|
||||
async def _step(
|
||||
@@ -434,6 +467,18 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if approval_response.approvals:
|
||||
tool_returns = [r for r in approval_response.approvals if isinstance(r, ToolReturn)]
|
||||
|
||||
# Validate that the approval response contains meaningful data
|
||||
# If all three lists are empty, this is a malformed approval response
|
||||
if not tool_calls and not tool_call_denials and not tool_returns:
|
||||
self.logger.error(
|
||||
f"Invalid approval response: approval_response.approvals is {approval_response.approvals} "
|
||||
f"but no tool calls, denials, or returns were extracted. "
|
||||
f"This likely indicates a corrupted or malformed approval payload."
|
||||
)
|
||||
self.should_continue = False
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
|
||||
return
|
||||
|
||||
step_id = approval_request.step_id
|
||||
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
|
||||
else:
|
||||
|
||||
@@ -254,14 +254,33 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu
|
||||
return []
|
||||
|
||||
approvals = []
|
||||
for item in data:
|
||||
if "type" in item and item.get("type") == MessageReturnType.approval:
|
||||
approval_return = ApprovalReturn(**item)
|
||||
approvals.append(approval_return)
|
||||
elif "status" in item:
|
||||
tool_return = ToolReturn(**item)
|
||||
approvals.append(tool_return)
|
||||
else:
|
||||
for idx, item in enumerate(data):
|
||||
try:
|
||||
# Check for ApprovalReturn (has type="approval")
|
||||
if "type" in item and item.get("type") == MessageReturnType.approval:
|
||||
approval_return = ApprovalReturn(**item)
|
||||
approvals.append(approval_return)
|
||||
# Check for ToolReturn (has status field)
|
||||
elif "status" in item:
|
||||
# Handle field name variations (tool_return vs func_response)
|
||||
if "tool_return" in item and "func_response" not in item:
|
||||
# Client SDK uses "tool_return", internal uses "func_response"
|
||||
item = {**item, "func_response": item["tool_return"]}
|
||||
tool_return = ToolReturn(**item)
|
||||
approvals.append(tool_return)
|
||||
else:
|
||||
# Unknown format - log warning with diagnostic info
|
||||
# Truncate large fields for logging
|
||||
item_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in item.items()}
|
||||
logger.warning(
|
||||
f"deserialize_approvals: Skipping unrecognized approval item at index {idx}. "
|
||||
f"Item preview: {item_preview}. Expected 'type=approval' or 'status' field."
|
||||
)
|
||||
continue
|
||||
except Exception as e:
|
||||
# Log validation errors but continue processing other items
|
||||
item_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in item.items()}
|
||||
logger.warning(f"deserialize_approvals: Failed to deserialize approval item at index {idx}: {e}. Item preview: {item_preview}")
|
||||
continue
|
||||
|
||||
return approvals
|
||||
|
||||
@@ -215,6 +215,9 @@ async def create_background_stream_processor(
|
||||
actor: Optional actor for run status updates
|
||||
"""
|
||||
stop_reason = None
|
||||
saw_done = False
|
||||
saw_error = False
|
||||
|
||||
if writer is None:
|
||||
writer = RedisSSEStreamWriter(redis_client)
|
||||
await writer.start()
|
||||
@@ -227,7 +230,14 @@ async def create_background_stream_processor(
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
|
||||
is_done = saw_done or saw_error
|
||||
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
|
||||
@@ -235,7 +245,7 @@ async def create_background_stream_processor(
|
||||
break
|
||||
|
||||
try:
|
||||
# sorry for this
|
||||
# Extract stop_reason from stop_reason chunks
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
@@ -243,40 +253,89 @@ async def create_background_stream_processor(
|
||||
except:
|
||||
pass
|
||||
|
||||
# Stream ended naturally - check if we got a proper terminal
|
||||
if not saw_done and not saw_error:
|
||||
# Stream ended without terminal event - synthesize one
|
||||
logger.warning(
|
||||
f"Stream for run {run_id} ended without terminal event (no [DONE] or event:error). "
|
||||
f"Last stop_reason seen: {stop_reason}. Synthesizing terminal."
|
||||
)
|
||||
if stop_reason:
|
||||
# We have a stop_reason, send [DONE]
|
||||
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
|
||||
saw_done = True
|
||||
else:
|
||||
# No stop_reason and no terminal - this is an error condition
|
||||
error_chunk = {"error": "Stream ended unexpectedly without stop_reason", "code": "STREAM_INCOMPLETE"}
|
||||
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=False)
|
||||
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
|
||||
saw_error = True
|
||||
saw_done = True
|
||||
# Set a default stop_reason so run status can be mapped in finally
|
||||
stop_reason = StopReasonType.error.value
|
||||
|
||||
except RunCancelledException as e:
|
||||
# Handle cancellation gracefully - don't write error chunk, cancellation event was already sent
|
||||
logger.info(f"Stream processing stopped due to cancellation for run {run_id}")
|
||||
# The cancellation event was already yielded by cancellation_aware_stream_wrapper
|
||||
# Write [DONE] marker to properly close the stream for clients reading from Redis
|
||||
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
|
||||
saw_done = True
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream for run {run_id}: {e}")
|
||||
# Write error chunk
|
||||
# error_chunk = {"error": {"message": str(e)}}
|
||||
# Mark run_id terminal state
|
||||
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
|
||||
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=False)
|
||||
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
|
||||
saw_error = True
|
||||
saw_done = True
|
||||
|
||||
# Mark run as failed immediately
|
||||
if run_manager and actor:
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
|
||||
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
|
||||
finally:
|
||||
if should_stop_writer:
|
||||
await writer.stop()
|
||||
if run_manager and actor:
|
||||
if stop_reason == "cancelled":
|
||||
|
||||
# Update run status if not already set (e.g., by exception handler)
|
||||
if run_manager and actor and stop_reason:
|
||||
# Map stop_reason to run status
|
||||
# Error states -> failed
|
||||
if stop_reason in [
|
||||
StopReasonType.error.value,
|
||||
StopReasonType.llm_api_error.value,
|
||||
StopReasonType.invalid_tool_call.value,
|
||||
StopReasonType.invalid_llm_response.value,
|
||||
StopReasonType.no_tool_call.value,
|
||||
]:
|
||||
run_status = RunStatus.failed
|
||||
# Cancelled state
|
||||
elif stop_reason == StopReasonType.cancelled.value:
|
||||
run_status = RunStatus.cancelled
|
||||
# Success states -> completed
|
||||
elif stop_reason in [
|
||||
StopReasonType.end_turn.value,
|
||||
StopReasonType.max_steps.value,
|
||||
StopReasonType.tool_rule.value,
|
||||
StopReasonType.requires_approval.value,
|
||||
]:
|
||||
run_status = RunStatus.completed
|
||||
else:
|
||||
# Unknown stop_reason - default to completed but log warning
|
||||
logger.warning(f"Unknown stop_reason '{stop_reason}' for run {run_id}, defaulting to completed")
|
||||
run_status = RunStatus.completed
|
||||
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=run_status, stop_reason=stop_reason or StopReasonType.end_turn.value),
|
||||
actor=actor,
|
||||
)
|
||||
# Only update if we saw a clean terminal (don't overwrite failed status set in except block)
|
||||
if not saw_error or run_status != RunStatus.completed:
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=run_status, stop_reason=stop_reason),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
|
||||
async def redis_sse_stream_generator(
|
||||
|
||||
@@ -186,6 +186,14 @@ def create_approval_response_message_from_input(
|
||||
)
|
||||
return maybe_tool_return
|
||||
|
||||
# Guard against None approvals - treat as empty list to avoid TypeError
|
||||
approvals_list = input_message.approvals or []
|
||||
if input_message.approvals is None:
|
||||
logger.warning(
|
||||
"ApprovalCreate.approvals is None; treating as empty list (approval_request_id=%s)",
|
||||
getattr(input_message, "approval_request_id", None),
|
||||
)
|
||||
|
||||
return [
|
||||
Message(
|
||||
role=MessageRole.approval,
|
||||
@@ -194,7 +202,7 @@ def create_approval_response_message_from_input(
|
||||
approval_request_id=input_message.approval_request_id,
|
||||
approve=input_message.approve,
|
||||
denial_reason=input_message.reason,
|
||||
approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals],
|
||||
approvals=[maybe_convert_tool_return_message(approval) for approval in approvals_list],
|
||||
run_id=run_id,
|
||||
group_id=input_message.group_id
|
||||
if input_message.group_id
|
||||
|
||||
@@ -301,6 +301,8 @@ class StreamingService:
|
||||
run_update_metadata = None
|
||||
stop_reason = None
|
||||
error_data = None
|
||||
saw_done = False
|
||||
saw_error = False
|
||||
|
||||
try:
|
||||
stream = agent_loop.stream(
|
||||
@@ -314,8 +316,24 @@ class StreamingService:
|
||||
)
|
||||
|
||||
async for chunk in stream:
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
yield chunk
|
||||
|
||||
# Stream completed - check if we got a terminal event
|
||||
if not saw_done and not saw_error:
|
||||
# Stream ended without terminal - synthesize one
|
||||
logger.warning(
|
||||
f"Stream for run {run_id} ended without terminal event. "
|
||||
f"Agent stop_reason: {agent_loop.stop_reason}. Synthesizing [DONE]."
|
||||
)
|
||||
yield "data: [DONE]\n\n"
|
||||
saw_done = True
|
||||
|
||||
# set run status after successful completion
|
||||
if agent_loop.stop_reason.stop_reason.value == "cancelled":
|
||||
run_status = RunStatus.cancelled
|
||||
|
||||
Reference in New Issue
Block a user