chore: migrate package name to letta (#1775)
Co-authored-by: Charles Packer <packercharles@gmail.com> Co-authored-by: Shubham Naik <shubham.naik10@gmail.com> Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
69
letta/server/rest_api/utils.py
Normal file
69
letta/server/rest_api/utils.py
Normal file
@@ -0,0 +1,69 @@
|
||||
import json
|
||||
import traceback
|
||||
from enum import Enum
|
||||
from typing import AsyncGenerator, Union
|
||||
|
||||
from pydantic import BaseModel
|
||||
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
# from letta.orm.user import User
|
||||
# from letta.orm.utilities import get_db_session
|
||||
|
||||
SSE_PREFIX = "data: "
|
||||
SSE_SUFFIX = "\n\n"
|
||||
SSE_FINISH_MSG = "[DONE]" # mimic openai
|
||||
SSE_ARTIFICIAL_DELAY = 0.1
|
||||
|
||||
|
||||
def sse_formatter(data: Union[dict, str]) -> str:
|
||||
"""Prefix with 'data: ', and always include double newlines"""
|
||||
assert type(data) in [dict, str], f"Expected type dict or str, got type {type(data)}"
|
||||
data_str = json.dumps(data, separators=(",", ":")) if isinstance(data, dict) else data
|
||||
return f"data: {data_str}\n\n"
|
||||
|
||||
|
||||
async def sse_async_generator(generator: AsyncGenerator, finish_message=True):
|
||||
"""
|
||||
Wraps a generator for use in Server-Sent Events (SSE), handling errors and ensuring a completion message.
|
||||
|
||||
Args:
|
||||
- generator: An asynchronous generator yielding data chunks.
|
||||
|
||||
Yields:
|
||||
- Formatted Server-Sent Event strings.
|
||||
"""
|
||||
try:
|
||||
async for chunk in generator:
|
||||
# yield f"data: {json.dumps(chunk)}\n\n"
|
||||
if isinstance(chunk, BaseModel):
|
||||
chunk = chunk.model_dump()
|
||||
elif isinstance(chunk, Enum):
|
||||
chunk = str(chunk.value)
|
||||
elif not isinstance(chunk, dict):
|
||||
chunk = str(chunk)
|
||||
yield sse_formatter(chunk)
|
||||
|
||||
except Exception as e:
|
||||
print("stream decoder hit error:", e)
|
||||
print(traceback.print_stack())
|
||||
yield sse_formatter({"error": "stream decoder encountered an error"})
|
||||
|
||||
finally:
|
||||
if finish_message:
|
||||
# Signal that the stream is complete
|
||||
yield sse_formatter(SSE_FINISH_MSG)
|
||||
|
||||
|
||||
# TODO: why does this double up the interface?
|
||||
def get_letta_server() -> SyncServer:
|
||||
# Check if a global server is already instantiated
|
||||
from letta.server.rest_api.app import server
|
||||
|
||||
# assert isinstance(server, SyncServer)
|
||||
return server
|
||||
|
||||
|
||||
def get_current_interface() -> StreamingServerInterface:
|
||||
return StreamingServerInterface
|
||||
Reference in New Issue
Block a user