fix: persist streaming errors in run metadata (#8062)
This commit is contained in:
committed by
Caren Thomas
parent
c5c633285b
commit
d5decc2a27
@@ -87,9 +87,13 @@ class LettaLLMStreamAdapter(LettaLLMAdapter):
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
|
||||
# Process the stream and yield chunks immediately for TTFT
|
||||
async for chunk in self.interface.process(stream): # TODO: add ttft span
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
# Wrap in error handling to convert provider errors to common LLMError types
|
||||
try:
|
||||
async for chunk in self.interface.process(stream): # TODO: add ttft span
|
||||
# Yield each chunk immediately as it arrives
|
||||
yield chunk
|
||||
except Exception as e:
|
||||
raise self.llm_client.handle_llm_error(e)
|
||||
|
||||
# After streaming completes, extract the accumulated data
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
@@ -4,7 +4,9 @@ import asyncio
|
||||
import json
|
||||
import time
|
||||
from collections import defaultdict
|
||||
from typing import AsyncIterator, Dict, List, Optional
|
||||
from collections.abc import AsyncGenerator, AsyncIterator
|
||||
from contextlib import aclosing
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from letta.data_sources.redis_client import AsyncRedisClient
|
||||
from letta.log import get_logger
|
||||
@@ -194,7 +196,7 @@ class RedisSSEStreamWriter:
|
||||
|
||||
|
||||
async def create_background_stream_processor(
|
||||
stream_generator,
|
||||
stream_generator: AsyncGenerator[str | bytes | tuple[str | bytes, int], None],
|
||||
redis_client: AsyncRedisClient,
|
||||
run_id: str,
|
||||
writer: Optional[RedisSSEStreamWriter] = None,
|
||||
@@ -218,6 +220,7 @@ async def create_background_stream_processor(
|
||||
stop_reason = None
|
||||
saw_done = False
|
||||
saw_error = False
|
||||
error_metadata = None
|
||||
|
||||
if writer is None:
|
||||
writer = RedisSSEStreamWriter(redis_client)
|
||||
@@ -227,32 +230,52 @@ async def create_background_stream_processor(
|
||||
should_stop_writer = False
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
# Always close the upstream async generator so its `finally` blocks run.
|
||||
# (e.g., stream adapters may persist terminal error metadata on close)
|
||||
async with aclosing(stream_generator):
|
||||
async for chunk in stream_generator:
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
# Track terminal events
|
||||
if isinstance(chunk, str):
|
||||
if "data: [DONE]" in chunk:
|
||||
saw_done = True
|
||||
if "event: error" in chunk:
|
||||
saw_error = True
|
||||
|
||||
is_done = saw_done or saw_error
|
||||
# Best-effort extraction of the error payload so we can persist it on the run.
|
||||
# Chunk format is typically: "event: error\ndata: {json}\n\n"
|
||||
if saw_error and error_metadata is None:
|
||||
try:
|
||||
# Grab the first `data:` line after `event: error`
|
||||
for line in chunk.splitlines():
|
||||
if line.startswith("data: "):
|
||||
maybe_json = line[len("data: ") :].strip()
|
||||
if maybe_json and maybe_json[0] in "[{":
|
||||
error_metadata = {"error": json.loads(maybe_json)}
|
||||
else:
|
||||
error_metadata = {"error": {"message": maybe_json}}
|
||||
break
|
||||
except Exception:
|
||||
# Don't let parsing failures interfere with streaming
|
||||
error_metadata = {"error": {"message": "Failed to parse error payload from stream."}}
|
||||
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
is_done = saw_done or saw_error
|
||||
|
||||
if is_done:
|
||||
break
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
|
||||
try:
|
||||
# Extract stop_reason from stop_reason chunks
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
pass
|
||||
if is_done:
|
||||
break
|
||||
|
||||
try:
|
||||
# Extract stop_reason from stop_reason chunks
|
||||
maybe_json_chunk = chunk.split("data: ")[1]
|
||||
maybe_stop_reason = json.loads(maybe_json_chunk) if maybe_json_chunk and maybe_json_chunk[0] == "{" else None
|
||||
if maybe_stop_reason and maybe_stop_reason.get("message_type") == "stop_reason":
|
||||
stop_reason = maybe_stop_reason.get("stop_reason")
|
||||
except:
|
||||
pass
|
||||
|
||||
# Stream ended naturally - check if we got a proper terminal
|
||||
if not saw_done and not saw_error:
|
||||
@@ -357,11 +380,11 @@ async def create_background_stream_processor(
|
||||
logger.warning(f"Unknown stop_reason '{final_stop_reason}' for run {run_id}, defaulting to completed")
|
||||
run_status = RunStatus.completed
|
||||
|
||||
await run_manager.update_run_by_id_async(
|
||||
run_id=run_id,
|
||||
update=RunUpdate(status=run_status, stop_reason=final_stop_reason),
|
||||
actor=actor,
|
||||
)
|
||||
update_kwargs = {"status": run_status, "stop_reason": final_stop_reason}
|
||||
if run_status == RunStatus.failed and error_metadata is not None:
|
||||
update_kwargs["metadata"] = error_metadata
|
||||
|
||||
await run_manager.update_run_by_id_async(run_id=run_id, update=RunUpdate(**update_kwargs), actor=actor)
|
||||
|
||||
# Belt-and-suspenders: always append a terminal [DONE] chunk to ensure clients terminate
|
||||
# Even if a previous chunk set `complete`, an extra [DONE] is harmless and ensures SDKs that
|
||||
|
||||
@@ -353,9 +353,16 @@ class RunManager:
|
||||
logger.warning(f"Run {run_id} completed without a completed_at timestamp")
|
||||
update.completed_at = get_utc_time().replace(tzinfo=None)
|
||||
|
||||
# Update job attributes with only the fields that were explicitly set
|
||||
# Update run attributes with only the fields that were explicitly set
|
||||
update_data = update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Merge metadata updates instead of overwriting.
|
||||
# This is important for streaming/background flows where different components update
|
||||
# different parts of metadata (e.g., run_type set at creation, error payload set at terminal).
|
||||
if "metadata_" in update_data and isinstance(update_data["metadata_"], dict):
|
||||
existing_metadata = run.metadata_ if isinstance(run.metadata_, dict) else {}
|
||||
update_data["metadata_"] = {**existing_metadata, **update_data["metadata_"]}
|
||||
|
||||
# Automatically update the completion timestamp if status is set to 'completed'
|
||||
for key, value in update_data.items():
|
||||
# Ensure completed_at is timezone-naive for database compatibility
|
||||
@@ -616,9 +623,7 @@ class RunManager:
|
||||
# Cancellation should be idempotent: if a run is already terminated, treat this as a no-op.
|
||||
# This commonly happens when a run finishes between client request and server handling.
|
||||
if run.stop_reason and run.stop_reason not in [StopReasonType.requires_approval]:
|
||||
logger.debug(
|
||||
f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}"
|
||||
)
|
||||
logger.debug(f"Run {run_id} cannot be cancelled because it is already terminated with stop reason: {run.stop_reason.value}")
|
||||
return
|
||||
|
||||
# Check if agent is waiting for approval by examining the last message
|
||||
|
||||
@@ -0,0 +1,50 @@
|
||||
import anthropic
|
||||
import httpx
|
||||
import pytest
|
||||
|
||||
from letta.adapters.letta_llm_stream_adapter import LettaLLMStreamAdapter
|
||||
from letta.errors import LLMServerError
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_letta_llm_stream_adapter_converts_anthropic_streaming_api_status_error(monkeypatch):
|
||||
"""Regression: provider APIStatusError raised *during* streaming iteration should be converted via handle_llm_error."""
|
||||
|
||||
request = httpx.Request("POST", "https://api.anthropic.com/v1/messages")
|
||||
response = httpx.Response(status_code=500, request=request)
|
||||
body = {
|
||||
"type": "error",
|
||||
"error": {"details": None, "type": "api_error", "message": "Internal server error"},
|
||||
"request_id": "req_011CWSBmrUwW5xdcqjfkUFS4",
|
||||
}
|
||||
|
||||
class FakeAsyncStream:
|
||||
"""Mimics anthropic.AsyncStream enough for AnthropicStreamingInterface (async cm + async iterator)."""
|
||||
|
||||
async def __aenter__(self):
|
||||
return self
|
||||
|
||||
async def __aexit__(self, exc_type, exc, tb):
|
||||
return None
|
||||
|
||||
def __aiter__(self):
|
||||
return self
|
||||
|
||||
async def __anext__(self):
|
||||
raise anthropic.APIStatusError("INTERNAL_SERVER_ERROR", response=response, body=body)
|
||||
|
||||
async def fake_stream_async(self, request_data: dict, llm_config: LLMConfig):
|
||||
return FakeAsyncStream()
|
||||
|
||||
monkeypatch.setattr(AnthropicClient, "stream_async", fake_stream_async, raising=True)
|
||||
|
||||
llm_client = AnthropicClient()
|
||||
llm_config = LLMConfig(model="claude-sonnet-4-5-20250929", model_endpoint_type="anthropic", context_window=200000)
|
||||
adapter = LettaLLMStreamAdapter(llm_client=llm_client, llm_config=llm_config)
|
||||
|
||||
gen = adapter.invoke_llm(request_data={}, messages=[], tools=[], use_assistant_message=True)
|
||||
with pytest.raises(LLMServerError):
|
||||
async for _ in gen:
|
||||
pass
|
||||
@@ -218,16 +218,18 @@ async def test_update_run_metadata_persistence(server: SyncServer, sarah_agent,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify metadata was properly updated
|
||||
# Verify metadata was properly updated (metadata should merge, not overwrite)
|
||||
assert updated_run.status == RunStatus.failed
|
||||
assert updated_run.stop_reason == StopReasonType.llm_api_error
|
||||
assert updated_run.metadata == error_data
|
||||
assert updated_run.metadata["type"] == "test"
|
||||
assert updated_run.metadata["initial"] == "value"
|
||||
assert "error" in updated_run.metadata
|
||||
assert updated_run.metadata["error"]["type"] == "llm_timeout"
|
||||
|
||||
# Fetch the run again to ensure it's persisted in DB
|
||||
fetched_run = await server.run_manager.get_run_by_id(created_run.id, actor=default_user)
|
||||
assert fetched_run.metadata == error_data
|
||||
assert fetched_run.metadata["type"] == "test"
|
||||
assert fetched_run.metadata["initial"] == "value"
|
||||
assert "error" in fetched_run.metadata
|
||||
assert fetched_run.metadata["error"]["type"] == "llm_timeout"
|
||||
|
||||
|
||||
Reference in New Issue
Block a user