Files
letta-server/memgpt/server/rest_api/utils.py
2024-08-16 19:53:21 -07:00

53 lines
1.7 KiB
Python

import json
import traceback
from enum import Enum
from typing import AsyncGenerator, Union
from pydantic import BaseModel
from memgpt.constants import JSON_ENSURE_ASCII
SSE_PREFIX = "data: "
SSE_SUFFIX = "\n\n"
SSE_FINISH_MSG = "[DONE]" # mimic openai
SSE_ARTIFICIAL_DELAY = 0.1
def sse_formatter(data: Union[dict, str]) -> str:
"""Prefix with 'data: ', and always include double newlines"""
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
data_str = json.dumps(data, ensure_ascii=JSON_ENSURE_ASCII) if isinstance(data, dict) else data
return f"{SSE_PREFIX}{data_str}{SSE_SUFFIX}"
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
"""
Wraps a generator for use in Server-Sent Events (SSE), handling errors and ensuring a completion message.
Args:
- generator: An asynchronous generator yielding data chunks.
Yields:
- Formatted Server-Sent Event strings.
"""
try:
async for chunk in generator:
# yield f"data: {json.dumps(chunk)}\n\n"
if isinstance(chunk, BaseModel):
chunk = chunk.model_dump()
elif isinstance(chunk, Enum):
chunk = str(chunk.value)
elif not isinstance(chunk, dict):
chunk = str(chunk)
yield sse_formatter(chunk)
except Exception as e:
print("stream decoder hit error:", e)
print(traceback.print_stack())
yield sse_formatter({"error": "stream decoder encountered an error"})
finally:
if finish_message:
# Signal that the stream is complete
yield sse_formatter(SSE_FINISH_MSG)