feat: add error chunk handling for background mode (#4158)
This commit is contained in:
@@ -74,7 +74,6 @@ class RedisSSEStreamWriter:
|
||||
except asyncio.CancelledError:
|
||||
pass
|
||||
|
||||
# Flush all remaining buffers
|
||||
for run_id in list(self.buffer.keys()):
|
||||
if self.buffer[run_id]:
|
||||
await self._flush_run(run_id)
|
||||
@@ -96,24 +95,20 @@ class RedisSSEStreamWriter:
|
||||
Returns:
|
||||
The sequence ID assigned to this chunk
|
||||
"""
|
||||
# Assign sequence ID
|
||||
seq_id = self.seq_counters[run_id]
|
||||
self.seq_counters[run_id] += 1
|
||||
|
||||
# Add to buffer
|
||||
chunk = {
|
||||
"seq_id": seq_id,
|
||||
"data": data,
|
||||
"timestamp": int(time.time() * 1000),
|
||||
}
|
||||
|
||||
# Mark completion if this is the last chunk
|
||||
if is_complete:
|
||||
chunk["complete"] = "true"
|
||||
|
||||
self.buffer[run_id].append(chunk)
|
||||
|
||||
# Check if we should flush
|
||||
should_flush = (
|
||||
len(self.buffer[run_id]) >= self.flush_size or is_complete or (time.time() - self.last_flush[run_id]) > self.flush_interval
|
||||
)
|
||||
@@ -135,25 +130,20 @@ class RedisSSEStreamWriter:
|
||||
try:
|
||||
client = await self.redis.get_client()
|
||||
|
||||
# Use pipeline for batch writes
|
||||
async with client.pipeline(transaction=False) as pipe:
|
||||
for chunk in chunks:
|
||||
pipe.xadd(stream_key, chunk, maxlen=self.max_stream_length, approximate=True)
|
||||
|
||||
# Set/refresh TTL on the stream
|
||||
pipe.expire(stream_key, self.stream_ttl)
|
||||
|
||||
# Execute all commands in one round trip
|
||||
await pipe.execute()
|
||||
|
||||
self.last_flush[run_id] = time.time()
|
||||
|
||||
# Log successful flush
|
||||
logger.debug(
|
||||
f"Flushed {len(chunks)} chunks to Redis stream {stream_key}, " f"seq_ids {chunks[0]['seq_id']}-{chunks[-1]['seq_id']}"
|
||||
)
|
||||
|
||||
# If this was a completion chunk, clean up tracking
|
||||
if chunks[-1].get("complete") == "true":
|
||||
self._cleanup_run(run_id)
|
||||
|
||||
@@ -224,10 +214,11 @@ async def create_background_stream_processor(
|
||||
|
||||
try:
|
||||
async for chunk in stream_generator:
|
||||
# Check if this is the final chunk
|
||||
is_done = "data: [DONE]" in chunk if isinstance(chunk, str) else False
|
||||
if isinstance(chunk, tuple):
|
||||
chunk = chunk[0]
|
||||
|
||||
is_done = isinstance(chunk, str) and ("data: [DONE]" in chunk or "event: error" in chunk)
|
||||
|
||||
# Write chunk to Redis
|
||||
await writer.write_chunk(run_id=run_id, data=chunk, is_complete=is_done)
|
||||
|
||||
if is_done:
|
||||
@@ -236,8 +227,8 @@ async def create_background_stream_processor(
|
||||
except Exception as e:
|
||||
logger.error(f"Error processing stream for run {run_id}: {e}")
|
||||
# Write error chunk
|
||||
error_chunk = {"message_type": "error", "error": str(e)}
|
||||
await writer.write_chunk(run_id=run_id, data=f"data: {json.dumps(error_chunk)}\n\n", is_complete=True)
|
||||
error_chunk = {"error": {"message": str(e)}}
|
||||
await writer.write_chunk(run_id=run_id, data=f"event: error\ndata: {json.dumps(error_chunk)}\n\n", is_complete=True)
|
||||
finally:
|
||||
if should_stop_writer:
|
||||
await writer.stop()
|
||||
@@ -288,15 +279,15 @@ async def redis_sse_stream_generator(
|
||||
if chunk_seq_id >= cursor_seq_id:
|
||||
data = fields.get("data", "")
|
||||
if not data:
|
||||
logger.debug(f"No data found for chunk {chunk_seq_id} in run {run_id}")
|
||||
continue
|
||||
|
||||
if data == "data: [DONE]\n\n":
|
||||
yield data
|
||||
return
|
||||
|
||||
yield data
|
||||
yielded_any = True
|
||||
|
||||
if fields.get("complete") == "true":
|
||||
return
|
||||
|
||||
last_redis_id = entry_id
|
||||
|
||||
if not yielded_any and len(entries) > 1:
|
||||
|
||||
Reference in New Issue
Block a user