fix: fix poison state from bad approval response (#5979)

* fix: detect and fail on malformed approval responses

* fix: guard against None approvals in utils.py

* fix: add extra warning

* fix: stop silent drops in deserialize_approvals

* fix: patch v3 stream error handling to prevent sending end_turn after an error occurs, and ensures stop_reason is always set when an error occurs

* fix: Prevents infinite client hangs by ensuring a terminal event is ALWAYS sent

* fix:  Ensures terminal events are sent even if inner stream generator fails to
  send them
This commit is contained in:
Charles Packer
2025-11-06 20:53:00 -08:00
committed by Caren Thomas
parent 4acda9c80f
commit 363a5c1f92
5 changed files with 189 additions and 40 deletions

View File

@@ -307,6 +307,15 @@ class LettaAgentV3(LettaAgentV2):
except Exception as e:
self.logger.warning(f"Error during agent stream: {e}", exc_info=True)
# Set stop_reason if not already set
if self.stop_reason is None:
# Classify error type
if isinstance(e, LLMError):
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.llm_api_error.value)
else:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
if first_chunk:
# Raise if no chunks sent yet (response not started, can return error status code)
raise
@@ -321,24 +330,48 @@ class LettaAgentV3(LettaAgentV2):
}
yield f"event: error\ndata: {json.dumps(error_chunk)}\n\n"
if run_id:
letta_messages = Message.to_letta_messages_from_list(
self.response_messages_for_metadata, # Use separate accumulator to preserve all messages
use_assistant_message=False, # NOTE: set to false
reverse=False,
# text_is_assistant_message=(self.agent_state.agent_type == AgentType.react_agent),
text_is_assistant_message=True,
)
result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
if self.job_update_metadata is None:
self.job_update_metadata = {}
self.job_update_metadata["result"] = result.model_dump(mode="json")
# Return immediately - don't fall through to finish chunks
# This prevents sending end_turn finish chunks after an error
return
await self._request_checkpoint_finish(
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
)
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
yield f"data: {finish_chunk}\n\n"
# Cleanup and finalize (only runs if no exception occurred)
try:
if run_id:
letta_messages = Message.to_letta_messages_from_list(
self.response_messages_for_metadata, # Use separate accumulator to preserve all messages
use_assistant_message=False, # NOTE: set to false
reverse=False,
# text_is_assistant_message=(self.agent_state.agent_type == AgentType.react_agent),
text_is_assistant_message=True,
)
result = LettaResponse(messages=letta_messages, stop_reason=self.stop_reason, usage=self.usage)
if self.job_update_metadata is None:
self.job_update_metadata = {}
self.job_update_metadata["result"] = result.model_dump(mode="json")
await self._request_checkpoint_finish(
request_span=request_span, request_start_timestamp_ns=request_start_timestamp_ns, run_id=run_id
)
for finish_chunk in self.get_finish_chunks_for_stream(self.usage, self.stop_reason):
yield f"data: {finish_chunk}\n\n"
except Exception as cleanup_error:
# Error during cleanup/finalization - ensure we still send a terminal event
self.logger.error(f"Error during stream cleanup: {cleanup_error}", exc_info=True)
# Set stop_reason if not already set
if self.stop_reason is None:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
# Send error event
error_chunk = {
"error": {
"type": "cleanup_error",
"message": "An error occurred during stream finalization.",
"detail": str(cleanup_error),
}
}
yield f"event: error\ndata: {json.dumps(error_chunk)}\n\n"
# Note: we don't send finish chunks here since we already errored
@trace_method
async def _step(
@@ -434,6 +467,18 @@ class LettaAgentV3(LettaAgentV2):
if approval_response.approvals:
tool_returns = [r for r in approval_response.approvals if isinstance(r, ToolReturn)]
# Validate that the approval response contains meaningful data
# If all three lists are empty, this is a malformed approval response
if not tool_calls and not tool_call_denials and not tool_returns:
self.logger.error(
f"Invalid approval response: approval_response.approvals is {approval_response.approvals} "
f"but no tool calls, denials, or returns were extracted. "
f"This likely indicates a corrupted or malformed approval payload."
)
self.should_continue = False
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.invalid_tool_call.value)
return
step_id = approval_request.step_id
step_metrics = await self.step_manager.get_step_metrics_async(step_id=step_id, actor=self.actor)
else:

View File

@@ -254,14 +254,33 @@ def deserialize_approvals(data: Optional[List[Dict]]) -> List[Union[ApprovalRetu
return []
approvals = []
for item in data:
if "type" in item and item.get("type") == MessageReturnType.approval:
approval_return = ApprovalReturn(**item)
approvals.append(approval_return)
elif "status" in item:
tool_return = ToolReturn(**item)
approvals.append(tool_return)
else:
for idx, item in enumerate(data):
try:
# Check for ApprovalReturn (has type="approval")
if "type" in item and item.get("type") == MessageReturnType.approval:
approval_return = ApprovalReturn(**item)
approvals.append(approval_return)
# Check for ToolReturn (has status field)
elif "status" in item:
# Handle field name variations (tool_return vs func_response)
if "tool_return" in item and "func_response" not in item:
# Client SDK uses "tool_return", internal uses "func_response"
item = {**item, "func_response": item["tool_return"]}
tool_return = ToolReturn(**item)
approvals.append(tool_return)
else:
# Unknown format - log warning with diagnostic info
# Truncate large fields for logging
item_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in item.items()}
logger.warning(
f"deserialize_approvals: Skipping unrecognized approval item at index {idx}. "
f"Item preview: {item_preview}. Expected 'type=approval' or 'status' field."
)
continue
except Exception as e:
# Log validation errors but continue processing other items
item_preview = {k: (v[:100] + "..." if isinstance(v, str) and len(v) > 100 else v) for k, v in item.items()}
logger.warning(f"deserialize_approvals: Failed to deserialize approval item at index {idx}: {e}. Item preview: {item_preview}")
continue
return approvals

View File

@@ -215,6 +215,9 @@ async def create_background_stream_processor(
actor: Optional actor for run status updates
"""
stop_reason = None
saw_done = False
saw_error = False
if writer is None:
writer = RedisSSEStreamWriter(redis_client)
await writer.start()
@@ -227,7 +230,14 @@ async def create_background_stream_processor(
if isinstance(chunk, tuple):
chunk = chunk[0]
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
is_done = saw_done or saw_error
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
@@ -235,7 +245,7 @@ async def create_background_stream_processor(
break
try:
# sorry for this
# Extract stop_reason from stop_reason chunks
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
@@ -243,40 +253,89 @@ async def create_background_stream_processor(
except:
pass
# Stream ended naturally - check if we got a proper terminal
if not saw_done and not saw_error:
# Stream ended without terminal event - synthesize one
logger.warning(
f"Stream for run {run_id} ended without terminal event (no [DONE] or event:error). "
f"Last stop_reason seen: {stop_reason}. Synthesizing terminal."
)
if stop_reason:
# We have a stop_reason, send [DONE]
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
saw_done = True
else:
# No stop_reason and no terminal - this is an error condition
error_chunk = {"error": "Stream ended unexpectedly without stop_reason", "code": "STREAM_INCOMPLETE"}
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=False)
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
saw_error = True
saw_done = True
# Set a default stop_reason so run status can be mapped in finally
stop_reason = StopReasonType.error.value
except RunCancelledException as e:
# Handle cancellation gracefully - don't write error chunk, cancellation event was already sent
logger.info(f"Stream processing stopped due to cancellation for run {run_id}")
# The cancellation event was already yielded by cancellation_aware_stream_wrapper
# Write [DONE] marker to properly close the stream for clients reading from Redis
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
saw_done = True
except Exception as e:
logger.error(f"Error processing stream for run {run_id}: {e}")
# Write error chunk
# error_chunk = {"error": {"message": str(e)}}
# Mark run_id terminal state
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=False)
await writer.write_chunk(run_id=run_id, data="data: [DONE]\n\n", is_complete=True)
saw_error = True
saw_done = True
# Mark run as failed immediately
if run_manager and actor:
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=RunStatus.failed, stop_reason=StopReasonType.error.value, metadata={"error": str(e)}),
actor=actor,
)
error_chunk = {"error": str(e), "code": "INTERNAL_SERVER_ERROR"}
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
finally:
if should_stop_writer:
await writer.stop()
if run_manager and actor:
if stop_reason == "cancelled":
# Update run status if not already set (e.g., by exception handler)
if run_manager and actor and stop_reason:
# Map stop_reason to run status
# Error states -> failed
if stop_reason in [
StopReasonType.error.value,
StopReasonType.llm_api_error.value,
StopReasonType.invalid_tool_call.value,
StopReasonType.invalid_llm_response.value,
StopReasonType.no_tool_call.value,
]:
run_status = RunStatus.failed
# Cancelled state
elif stop_reason == StopReasonType.cancelled.value:
run_status = RunStatus.cancelled
# Success states -> completed
elif stop_reason in [
StopReasonType.end_turn.value,
StopReasonType.max_steps.value,
StopReasonType.tool_rule.value,
StopReasonType.requires_approval.value,
]:
run_status = RunStatus.completed
else:
# Unknown stop_reason - default to completed but log warning
logger.warning(f"Unknown stop_reason '{stop_reason}' for run {run_id}, defaulting to completed")
run_status = RunStatus.completed
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=run_status, stop_reason=stop_reason or StopReasonType.end_turn.value),
actor=actor,
)
# Only update if we saw a clean terminal (don't overwrite failed status set in except block)
if not saw_error or run_status != RunStatus.completed:
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=run_status, stop_reason=stop_reason),
actor=actor,
)
async def redis_sse_stream_generator(

View File

@@ -186,6 +186,14 @@ def create_approval_response_message_from_input(
)
return maybe_tool_return
# Guard against None approvals - treat as empty list to avoid TypeError
approvals_list = input_message.approvals or []
if input_message.approvals is None:
logger.warning(
"ApprovalCreate.approvals is None; treating as empty list (approval_request_id=%s)",
getattr(input_message, "approval_request_id", None),
)
return [
Message(
role=MessageRole.approval,
@@ -194,7 +202,7 @@ def create_approval_response_message_from_input(
approval_request_id=input_message.approval_request_id,
approve=input_message.approve,
denial_reason=input_message.reason,
approvals=[maybe_convert_tool_return_message(approval) for approval in input_message.approvals],
approvals=[maybe_convert_tool_return_message(approval) for approval in approvals_list],
run_id=run_id,
group_id=input_message.group_id
if input_message.group_id

View File

@@ -301,6 +301,8 @@ class StreamingService:
run_update_metadata = None
stop_reason = None
error_data = None
saw_done = False
saw_error = False
try:
stream = agent_loop.stream(
@@ -314,8 +316,24 @@ class StreamingService:
)
async for chunk in stream:
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
yield chunk
# Stream completed - check if we got a terminal event
if not saw_done and not saw_error:
# Stream ended without terminal - synthesize one
logger.warning(
f"Stream for run {run_id} ended without terminal event. "
f"Agent stop_reason: {agent_loop.stop_reason}. Synthesizing [DONE]."
)
yield "data: [DONE]\n\n"
saw_done = True
# set run status after successful completion
if agent_loop.stop_reason.stop_reason.value == "cancelled":
run_status = RunStatus.cancelled