feat: Add voice-compatible chat completions endpoint (#774)
This commit is contained in:
@@ -1,18 +1,22 @@
|
||||
import json
|
||||
from typing import Generator
|
||||
from typing import Generator, Union, get_args
|
||||
|
||||
import httpx
|
||||
from httpx_sse import SSEError, connect_sse
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
from letta.errors import LLMError
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage
|
||||
from letta.schemas.letta_response import LettaStreamingResponse
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
|
||||
|
||||
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
|
||||
|
||||
with httpx.Client() as client:
|
||||
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
|
||||
@@ -20,22 +24,26 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
|
||||
if not event_source.response.is_success:
|
||||
# handle errors
|
||||
from letta.utils import printd
|
||||
pass
|
||||
|
||||
printd("Caught error before iterating SSE request:", vars(event_source.response))
|
||||
printd(event_source.response.read())
|
||||
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
|
||||
logger.warning(event_source.response.read().decode("utf-8"))
|
||||
|
||||
try:
|
||||
response_bytes = event_source.response.read()
|
||||
response_dict = json.loads(response_bytes.decode("utf-8"))
|
||||
error_message = response_dict["error"]["message"]
|
||||
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
|
||||
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
|
||||
raise LLMError(error_message)
|
||||
if (
|
||||
"error" in response_dict
|
||||
and "message" in response_dict["error"]
|
||||
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"]
|
||||
):
|
||||
logger.error(response_dict["error"]["message"])
|
||||
raise LLMError(response_dict["error"]["message"])
|
||||
except LLMError:
|
||||
raise
|
||||
except:
|
||||
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
|
||||
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
|
||||
event_source.response.raise_for_status()
|
||||
|
||||
try:
|
||||
@@ -58,33 +66,34 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
|
||||
yield ToolReturnMessage(**chunk_data)
|
||||
elif "step_count" in chunk_data:
|
||||
yield LettaUsageStatistics(**chunk_data)
|
||||
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
|
||||
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here
|
||||
else:
|
||||
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
|
||||
|
||||
except SSEError as e:
|
||||
print("Caught an error while iterating the SSE stream:", str(e))
|
||||
logger.error("Caught an error while iterating the SSE stream:", str(e))
|
||||
if "application/json" in str(e): # Check if the error is because of JSON response
|
||||
# TODO figure out a better way to catch the error other than re-trying with a POST
|
||||
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
|
||||
if response.headers["Content-Type"].startswith("application/json"):
|
||||
error_details = response.json() # Parse the JSON to get the error message
|
||||
print("Request:", vars(response.request))
|
||||
print("POST Error:", error_details)
|
||||
print("Original SSE Error:", str(e))
|
||||
logger.error("Request:", vars(response.request))
|
||||
logger.error("POST Error:", error_details)
|
||||
logger.error("Original SSE Error:", str(e))
|
||||
else:
|
||||
print("Failed to retrieve JSON error message via retry.")
|
||||
logger.error("Failed to retrieve JSON error message via retry.")
|
||||
else:
|
||||
print("SSEError not related to 'application/json' content type.")
|
||||
logger.error("SSEError not related to 'application/json' content type.")
|
||||
|
||||
# Optionally re-raise the exception if you need to propagate it
|
||||
raise e
|
||||
|
||||
except Exception as e:
|
||||
if event_source.response.request is not None:
|
||||
print("HTTP Request:", vars(event_source.response.request))
|
||||
logger.error("HTTP Request:", vars(event_source.response.request))
|
||||
if event_source.response is not None:
|
||||
print("HTTP Status:", event_source.response.status_code)
|
||||
print("HTTP Headers:", event_source.response.headers)
|
||||
# print("HTTP Body:", event_source.response.text)
|
||||
print("Exception message:", str(e))
|
||||
logger.error("HTTP Status:", event_source.response.status_code)
|
||||
logger.error("HTTP Headers:", event_source.response.headers)
|
||||
logger.error("Exception message:", str(e))
|
||||
raise e
|
||||
|
||||
Reference in New Issue
Block a user