fix: persist streaming errors in run metadata (#8062)

This commit is contained in:
Sarah Wooders
2025-12-26 16:12:35 -07:00
committed by Caren Thomas
parent c5c633285b
commit d5decc2a27
5 changed files with 122 additions and 38 deletions

View File

@@ -87,9 +87,13 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
raise self.llm_client.handle_llm_error(e)
# Process the stream and yield chunks immediately for TTFT
async for chunk in self.interface.process(stream): # TODO: add ttft span
# Yield each chunk immediately as it arrives
yield chunk
# Wrap in error handling to convert provider errors to common LLMError types
try:
async for chunk in self.interface.process(stream): # TODO: add ttft span
# Yield each chunk immediately as it arrives
yield chunk
except Exception as e:
raise self.llm_client.handle_llm_error(e)
# After streaming completes, extract the accumulated data
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()

View File

@@ -4,7 +4,9 @@ import asyncio
import json
import time
from collections import defaultdict
from typing import AsyncIterator, Dict, List, Optional
from collections.abc import AsyncGenerator, AsyncIterator
from contextlib import aclosing
from typing import Dict, List, Optional
from letta.data_sources.redis_client import AsyncRedisClient
from letta.log import get_logger
@@ -194,7 +196,7 @@ class RedisSSEStreamWriter:
async def create_background_stream_processor(
stream_generator,
stream_generator: AsyncGenerator[str | bytes | tuple[str | bytes, int], None],
redis_client: AsyncRedisClient,
run_id: str,
writer: Optional[RedisSSEStreamWriter] = None,
@@ -218,6 +220,7 @@ async def create_background_stream_processor(
stop_reason = None
saw_done = False
saw_error = False
error_metadata = None
if writer is None:
writer = RedisSSEStreamWriter(redis_client)
@@ -227,32 +230,52 @@ async def create_background_stream_processor(
should_stop_writer = False
try:
async for chunk in stream_generator:
if isinstance(chunk, tuple):
chunk = chunk[0]
# Always close the upstream async generator so its `finally` blocks run.
# (e.g., stream adapters may persist terminal error metadata on close)
async with aclosing(stream_generator):
async for chunk in stream_generator:
if isinstance(chunk, tuple):
chunk = chunk[0]
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
# Track terminal events
if isinstance(chunk, str):
if "data: [DONE]" in chunk:
saw_done = True
if "event: error" in chunk:
saw_error = True
is_done = saw_done or saw_error
# Best-effort extraction of the error payload so we can persist it on the run.
# Chunk format is typically: "event: error\ndata: {json}\n\n"
if saw_error and error_metadata is None:
try:
# Grab the first `data:` line after `event: error`
for line in chunk.splitlines():
if line.startswith("data: "):
maybe_json = line[len("data: ") :].strip()
if maybe_json and maybe_json[0] in "[{":
error_metadata = {"error": json.loads(maybe_json)}
else:
error_metadata = {"error": {"message": maybe_json}}
break
except Exception:
# Don't let parsing failures interfere with streaming
error_metadata = {"error": {"message": "Failed to parse error payload from stream."}}
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
is_done = saw_done or saw_error
if is_done:
break
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
try:
# Extract stop_reason from stop_reason chunks
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
stop_reason = maybe_stop_reason.get("stop_reason")
except:
pass
if is_done:
break
try:
# Extract stop_reason from stop_reason chunks
maybe_json_chunk = chunk.split("data: ")[1]
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
stop_reason = maybe_stop_reason.get("stop_reason")
except:
pass
# Stream ended naturally - check if we got a proper terminal
if not saw_done and not saw_error:
@@ -357,11 +380,11 @@ async def create_background_stream_processor(
logger.warning(f"Unknown stop_reason '{final_stop_reason}' for run {run_id}, defaulting to completed")
run_status = RunStatus.completed
await run_manager.update_run_by_id_async(
run_id=run_id,
update=RunUpdate(status=run_status, stop_reason=final_stop_reason),
actor=actor,
)
update_kwargs = {"status": run_status, "stop_reason": final_stop_reason}
if run_status == RunStatus.failed and error_metadata is not None:
update_kwargs["metadata"] = error_metadata
await run_manager.update_run_by_id_async(run_id=run_id, update=RunUpdate(**update_kwargs), actor=actor)
# Belt-and-suspenders: always append a terminal [DONE] chunk to ensure clients terminate
# Even if a previous chunk set `complete`, an extra [DONE] is harmless and ensures SDKs that

View File

@@ -353,9 +353,16 @@ class RunManager:
logger.warning(f"Run {run_id} completed without a completed_at timestamp")
update.completed_at = get_utc_time().replace(tzinfo=None)
# Update job attributes with only the fields that were explicitly set
# Update run attributes with only the fields that were explicitly set
update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
# Merge metadata updates instead of overwriting.
# This is important for streaming/background flows where different components update
# different parts of metadata (e.g., run_type set at creation, error payload set at terminal).
if "metadata_" in update_data and isinstance(update_data["metadata_"], dict):
existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {}
update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]}
# Automatically update the completion timestamp if status is set to 'completed'
for key, value in update_data.items():
# Ensure completed_at is timezone-naive for database compatibility
@@ -616,9 +623,7 @@ class RunManager:
# Cancellation should be idempotent: if a run is already terminated, treat this as a no-op.
# This commonly happens when a run finishes between client request and server handling.
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
logger.debug(
f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}"
)
logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
return
# Check if agent is waiting for approval by examining the last message

View File

@@ -0,0 +1,50 @@
import anthropic
import httpx
import pytest
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
from letta.errors import LLMServerError
from letta.llm_api.anthropic_client import AnthropicClient
from letta.schemas.llm_config import LLMConfig
@pytest.mark.asyncio
async def test_letta_llm_stream_adapter_converts_anthropic_streaming_api_status_error(monkeypatch):
"""Regression: provider APIStatusError raised *during* streaming iteration should be converted via handle_llm_error."""
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
response = httpx.Response(status_code=500, request=request)
body = {
"type": "error",
"error": {"details": None, "type": "api_error", "message": "Internal server error"},
"request_id": "req_011CWSBmrUwW5xdcqjfkUFS4",
}
class FakeAsyncStream:
"""Mimics anthropic.AsyncStream enough for AnthropicStreamingInterface (async cm + async iterator)."""
async def __aenter__(self):
return self
async def __aexit__(self, exc_type, exc, tb):
return None
def __aiter__(self):
return self
async def __anext__(self):
raise anthropic.APIStatusError("INTERNAL_SERVER_ERROR", response=response, body=body)
async def fake_stream_async(self, request_data: dict, llm_config: LLMConfig):
return FakeAsyncStream()
monkeypatch.setattr(AnthropicClient, "stream_async", fake_stream_async, raising=True)
llm_client = AnthropicClient()
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
with pytest.raises(LLMServerError):
async for _ in gen:
pass

View File

@@ -218,16 +218,18 @@ async def test_update_run_metadata_persistence(server: SyncServer, sarah_agent,
actor=default_user,
)
# Verify metadata was properly updated
# Verify metadata was properly updated (metadata should merge, not overwrite)
assert updated_run.status == RunStatus.failed
assert updated_run.stop_reason == StopReasonType.llm_api_error
assert updated_run.metadata == error_data
assert updated_run.metadata["type"] == "test"
assert updated_run.metadata["initial"] == "value"
assert "error" in updated_run.metadata
assert updated_run.metadata["error"]["type"] == "llm_timeout"
# Fetch the run again to ensure it's persisted in DB
fetched_run = await server.run_manager.get_run_by_id(created_run.id, actor=default_user)
assert fetched_run.metadata == error_data
assert fetched_run.metadata["type"] == "test"
assert fetched_run.metadata["initial"] == "value"
assert "error" in fetched_run.metadata
assert fetched_run.metadata["error"]["type"] == "llm_timeout"