From d5decc2a271c6f5b0159ab2675c333361d9c5f0e Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Fri, 26 Dec 2025 16:12:35 -0700 Subject: [PATCH] fix: persist streaming errors in run metadata (#8062) --- letta/adapters/letta_llm_stream_adapter.py | 10 ++- letta/server/rest_api/redis_stream_manager.py | 79 ++++++++++++------- letta/services/run_manager.py | 13 ++- ...letta_llm_stream_adapter_error_handling.py | 50 ++++++++++++ tests/managers/test_run_manager.py | 8 +- 5 files changed, 122 insertions(+), 38 deletions(-) create mode 100644 tests/adapters/test_letta_llm_stream_adapter_error_handling.py diff --git a/letta/adapters/letta_llm_stream_adapter.py b/letta/adapters/letta_llm_stream_adapter.py index 0c3c4ae2..1b4d7fe9 100644 --- a/letta/adapters/letta_llm_stream_adapter.py +++ b/letta/adapters/letta_llm_stream_adapter.py @@ -87,9 +87,13 @@ class LettaLLMStreamAdapter(LettaLLMAdapter): raise self.llm_client.handle_llm_error(e) # Process the stream and yield chunks immediately for TTFT - async for chunk in self.interface.process(stream): # TODO: add ttft span - # Yield each chunk immediately as it arrives - yield chunk + # Wrap in error handling to convert provider errors to common LLMError types + try: + async for chunk in self.interface.process(stream): # TODO: add ttft span + # Yield each chunk immediately as it arrives + yield chunk + except Exception as e: + raise self.llm_client.handle_llm_error(e) # After streaming completes, extract the accumulated data self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() diff --git a/letta/server/rest_api/redis_stream_manager.py b/letta/server/rest_api/redis_stream_manager.py index cd1385ff..9fae3710 100644 --- a/letta/server/rest_api/redis_stream_manager.py +++ b/letta/server/rest_api/redis_stream_manager.py @@ -4,7 +4,9 @@ import asyncio import json import time from collections import defaultdict -from typing import AsyncIterator, Dict, List, Optional +from collections.abc import AsyncGenerator, AsyncIterator +from contextlib import aclosing +from typing import Dict, List, Optional from letta.data_sources.redis_client import AsyncRedisClient from letta.log import get_logger @@ -194,7 +196,7 @@ class RedisSSEStreamWriter: async def create_background_stream_processor( - stream_generator, + stream_generator: AsyncGenerator[str | bytes | tuple[str | bytes, int], None], redis_client: AsyncRedisClient, run_id: str, writer: Optional[RedisSSEStreamWriter] = None, @@ -218,6 +220,7 @@ async def create_background_stream_processor( stop_reason = None saw_done = False saw_error = False + error_metadata = None if writer is None: writer = RedisSSEStreamWriter(redis_client) @@ -227,32 +230,52 @@ async def create_background_stream_processor( should_stop_writer = False try: - async for chunk in stream_generator: - if isinstance(chunk, tuple): - chunk = chunk[0] + # Always close the upstream async generator so its `finally` blocks run. + # (e.g., stream adapters may persist terminal error metadata on close) + async with aclosing(stream_generator): + async for chunk in stream_generator: + if isinstance(chunk, tuple): + chunk = chunk[0] - # Track terminal events - if isinstance(chunk, str): - if "data: [DONE]" in chunk: - saw_done = True - if "event: error" in chunk: - saw_error = True + # Track terminal events + if isinstance(chunk, str): + if "data: [DONE]" in chunk: + saw_done = True + if "event: error" in chunk: + saw_error = True - is_done = saw_done or saw_error + # Best-effort extraction of the error payload so we can persist it on the run. + # Chunk format is typically: "event: error\ndata: {json}\n\n" + if saw_error and error_metadata is None: + try: + # Grab the first `data:` line after `event: error` + for line in chunk.splitlines(): + if line.startswith("data: "): + maybe_json = line[len("data: ") :].strip() + if maybe_json and maybe_json[0] in "[{": + error_metadata = {"error": json.loads(maybe_json)} + else: + error_metadata = {"error": {"message": maybe_json}} + break + except Exception: + # Don't let parsing failures interfere with streaming + error_metadata = {"error": {"message": "Failed to parse error payload from stream."}} - await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done) + is_done = saw_done or saw_error - if is_done: - break + await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done) - try: - # Extract stop_reason from stop_reason chunks - maybe_json_chunk = chunk.split("data: ")[1] - maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None - if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason": - stop_reason = maybe_stop_reason.get("stop_reason") - except: - pass + if is_done: + break + + try: + # Extract stop_reason from stop_reason chunks + maybe_json_chunk = chunk.split("data: ")[1] + maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None + if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason": + stop_reason = maybe_stop_reason.get("stop_reason") + except: + pass # Stream ended naturally - check if we got a proper terminal if not saw_done and not saw_error: @@ -357,11 +380,11 @@ async def create_background_stream_processor( logger.warning(f"Unknown stop_reason '{final_stop_reason}' for run {run_id}, defaulting to completed") run_status = RunStatus.completed - await run_manager.update_run_by_id_async( - run_id=run_id, - update=RunUpdate(status=run_status, stop_reason=final_stop_reason), - actor=actor, - ) + update_kwargs = {"status": run_status, "stop_reason": final_stop_reason} + if run_status == RunStatus.failed and error_metadata is not None: + update_kwargs["metadata"] = error_metadata + + await run_manager.update_run_by_id_async(run_id=run_id, update=RunUpdate(**update_kwargs), actor=actor) # Belt-and-suspenders: always append a terminal [DONE] chunk to ensure clients terminate # Even if a previous chunk set `complete`, an extra [DONE] is harmless and ensures SDKs that diff --git a/letta/services/run_manager.py b/letta/services/run_manager.py index 754cfebe..d76e4f2f 100644 --- a/letta/services/run_manager.py +++ b/letta/services/run_manager.py @@ -353,9 +353,16 @@ class RunManager: logger.warning(f"Run {run_id} completed without a completed_at timestamp") update.completed_at = get_utc_time().replace(tzinfo=None) - # Update job attributes with only the fields that were explicitly set + # Update run attributes with only the fields that were explicitly set update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + # Merge metadata updates instead of overwriting. + # This is important for streaming/background flows where different components update + # different parts of metadata (e.g., run_type set at creation, error payload set at terminal). + if "metadata_" in update_data and isinstance(update_data["metadata_"], dict): + existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {} + update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]} + # Automatically update the completion timestamp if status is set to 'completed' for key, value in update_data.items(): # Ensure completed_at is timezone-naive for database compatibility @@ -616,9 +623,7 @@ class RunManager: # Cancellation should be idempotent: if a run is already terminated, treat this as a no-op. # This commonly happens when a run finishes between client request and server handling. if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]: - logger.debug( - f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}" - ) + logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}") return # Check if agent is waiting for approval by examining the last message diff --git a/tests/adapters/test_letta_llm_stream_adapter_error_handling.py b/tests/adapters/test_letta_llm_stream_adapter_error_handling.py new file mode 100644 index 00000000..30cffb2c --- /dev/null +++ b/tests/adapters/test_letta_llm_stream_adapter_error_handling.py @@ -0,0 +1,50 @@ +import anthropic +import httpx +import pytest + +from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter +from letta.errors import LLMServerError +from letta.llm_api.anthropic_client import AnthropicClient +from letta.schemas.llm_config import LLMConfig + + +@pytest.mark.asyncio +async def test_letta_llm_stream_adapter_converts_anthropic_streaming_api_status_error(monkeypatch): + """Regression: provider APIStatusError raised *during* streaming iteration should be converted via handle_llm_error.""" + + request = httpx.Request("POST", "https://api.anthropic.com/v1/messages") + response = httpx.Response(status_code=500, request=request) + body = { + "type": "error", + "error": {"details": None, "type": "api_error", "message": "Internal server error"}, + "request_id": "req_011CWSBmrUwW5xdcqjfkUFS4", + } + + class FakeAsyncStream: + """Mimics anthropic.AsyncStream enough for AnthropicStreamingInterface (async cm + async iterator).""" + + async def __aenter__(self): + return self + + async def __aexit__(self, exc_type, exc, tb): + return None + + def __aiter__(self): + return self + + async def __anext__(self): + raise anthropic.APIStatusError("INTERNAL_SERVER_ERROR", response=response, body=body) + + async def fake_stream_async(self, request_data: dict, llm_config: LLMConfig): + return FakeAsyncStream() + + monkeypatch.setattr(AnthropicClient, "stream_async", fake_stream_async, raising=True) + + llm_client = AnthropicClient() + llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000) + adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config) + + gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True) + with pytest.raises(LLMServerError): + async for _ in gen: + pass diff --git a/tests/managers/test_run_manager.py b/tests/managers/test_run_manager.py index 01c4b236..bde73cfb 100644 --- a/tests/managers/test_run_manager.py +++ b/tests/managers/test_run_manager.py @@ -218,16 +218,18 @@ async def test_update_run_metadata_persistence(server: SyncServer, sarah_agent, actor=default_user, ) - # Verify metadata was properly updated + # Verify metadata was properly updated (metadata should merge, not overwrite) assert updated_run.status == RunStatus.failed assert updated_run.stop_reason == StopReasonType.llm_api_error - assert updated_run.metadata == error_data + assert updated_run.metadata["type"] == "test" + assert updated_run.metadata["initial"] == "value" assert "error" in updated_run.metadata assert updated_run.metadata["error"]["type"] == "llm_timeout" # Fetch the run again to ensure it's persisted in DB fetched_run = await server.run_manager.get_run_by_id(created_run.id, actor=default_user) - assert fetched_run.metadata == error_data + assert fetched_run.metadata["type"] == "test" + assert fetched_run.metadata["initial"] == "value" assert "error" in fetched_run.metadata assert fetched_run.metadata["error"]["type"] == "llm_timeout"