feat: Add voice-compatible chat completions endpoint (#774)

This commit is contained in:
Matthew Zhou
2025-01-27 12:25:05 -10:00
committed by GitHub
parent bfe2c6e8f5
commit b6773ea7ff
11 changed files with 996 additions and 34 deletions

View File

@@ -1,18 +1,22 @@
import json
from typing import Generator
from typing import Generator, Union, get_args
import httpx
from httpx_sse import SSEError, connect_sse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.log import get_logger
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics
logger = get_logger(__name__)
def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
with httpx.Client() as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
@@ -20,22 +24,26 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
if not event_source.response.is_success:
# handle errors
from letta.utils import printd
pass
printd("Caught error before iterating SSE request:", vars(event_source.response))
printd(event_source.response.read())
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
logger.warning(event_source.response.read().decode("utf-8"))
try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8"))
error_message = response_dict["error"]["message"]
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
raise LLMError(error_message)
if (
"error" in response_dict
and "message" in response_dict["error"]
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"]
):
logger.error(response_dict["error"]["message"])
raise LLMError(response_dict["error"]["message"])
except LLMError:
raise
except:
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
event_source.response.raise_for_status()
try:
@@ -58,33 +66,34 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield ToolReturnMessage(**chunk_data)
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
except SSEError as e:
print("Caught an error while iterating the SSE stream:", str(e))
logger.error("Caught an error while iterating the SSE stream:", str(e))
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
print("Request:", vars(response.request))
print("POST Error:", error_details)
print("Original SSE Error:", str(e))
logger.error("Request:", vars(response.request))
logger.error("POST Error:", error_details)
logger.error("Original SSE Error:", str(e))
else:
print("Failed to retrieve JSON error message via retry.")
logger.error("Failed to retrieve JSON error message via retry.")
else:
print("SSEError not related to 'application/json' content type.")
logger.error("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
print("HTTP Request:", vars(event_source.response.request))
logger.error("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
print("HTTP Status:", event_source.response.status_code)
print("HTTP Headers:", event_source.response.headers)
# print("HTTP Body:", event_source.response.text)
print("Exception message:", str(e))
logger.error("HTTP Status:", event_source.response.status_code)
logger.error("HTTP Headers:", event_source.response.headers)
logger.error("Exception message:", str(e))
raise e