feat: Add voice-compatible chat completions endpoint (#774)

This commit is contained in:
Matthew Zhou
2025-01-27 12:25:05 -10:00
committed by GitHub
parent bfe2c6e8f5
commit b6773ea7ff
11 changed files with 996 additions and 34 deletions

View File

@@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
@@ -22,6 +22,7 @@ from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
@@ -241,6 +242,9 @@ def create_application() -> "FastAPI":
app.include_router(users_router, prefix=ADMIN_PREFIX)
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -0,0 +1,256 @@
import asyncio
from collections import deque
from datetime import datetime
from typing import AsyncGenerator, Optional, Union
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
logger = get_logger(__name__)
class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""
Provides an asynchronous streaming mechanism for LLM output. Internally
maintains a queue of chunks that can be consumed via an async generator.
Key Behaviors:
- process_chunk: Accepts ChatCompletionChunkResponse objects (e.g. from an
OpenAI-like streaming API), potentially transforms them to a partial
text response, and enqueues them.
- get_generator: Returns an async generator that yields messages or status
markers as they become available.
- step_complete, step_yield: End streaming for the current step or entirely,
depending on the multi_step setting.
- function_message, internal_monologue: Handle LLM “function calls” and
“reasoning” messages for non-streaming contexts.
"""
FINISH_REASON_STR = "stop"
ASSISTANT_STR = "assistant"
def __init__(
self,
multi_step: bool = True,
timeout: int = 150,
# The following are placeholders for potential expansions; they
# remain if you need to differentiate between actual "assistant messages"
# vs. tool calls. By default, they are set for the "send_message" tool usage.
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
inner_thoughts_in_kwargs: bool = True,
inner_thoughts_kwarg: str = INNER_THOUGHTS_KWARG,
):
self.streaming_mode = True
# Parsing state for incremental function-call data
self.current_function_name = ""
self.current_function_arguments = []
# Internal chunk buffer and event for async notification
self._chunks = deque()
self._event = asyncio.Event()
self._active = True
# Whether or not the stream should remain open across multiple steps
self.multi_step = multi_step
# Timing / debug parameters
self.timeout = timeout
# These are placeholders to handle specialized
# assistant message logic or storing inner thoughts.
self.assistant_message_tool_name = assistant_message_tool_name
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
self.inner_thoughts_kwarg = inner_thoughts_kwarg
async def _create_generator(
self,
) -> AsyncGenerator[Union[LettaMessage, MessageStreamStatus], None]:
"""
An asynchronous generator that yields queued items as they arrive.
Ends when _active is set to False or when timing out.
"""
while self._active:
try:
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
break
while self._chunks:
yield self._chunks.popleft()
self._event.clear()
def get_generator(self) -> AsyncGenerator:
"""
Provide the async generator interface. Will raise StopIteration
if the stream is inactive.
"""
if not self._active:
raise StopIteration("The stream is not active.")
return self._create_generator()
def _push_to_buffer(
self,
item: ChatCompletionChunk,
):
"""
Add an item (a LettaMessage, status marker, or partial chunk)
to the queue and signal waiting consumers.
"""
if not self._active:
raise RuntimeError("Attempted to push to an inactive stream.")
self._chunks.append(item)
self._event.set()
def stream_start(self) -> None:
"""Initialize or reset the streaming state for a new request."""
self._active = True
self._chunks.clear()
self._event.clear()
self._reset_parsing_state()
def stream_end(self) -> None:
"""
Clean up after the current streaming session. Typically called when the
request is done or the data source has signaled it has no more data.
"""
self._reset_parsing_state()
def step_complete(self) -> None:
"""
Indicate that one step of multi-step generation is done.
If multi_step=False, the stream is closed immediately.
"""
if not self.multi_step:
self._active = False
self._event.set() # Ensure waiting generators can finalize
self._reset_parsing_state()
def step_yield(self) -> None:
"""
Explicitly end the stream in a multi-step scenario, typically
called when the entire chain of steps is complete.
"""
self._active = False
self._event.set()
@staticmethod
def clear() -> None:
"""No-op retained for interface compatibility."""
return
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime) -> None:
"""
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.
"""
processed_chunk = self._process_chunk_to_openai_style(chunk)
if processed_chunk is not None:
self._push_to_buffer(processed_chunk)
def user_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle user messages. Here, it's a no-op, but included if your
pipeline needs to respond to user messages distinctly.
"""
return
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle LLM reasoning or internal monologue. Example usage: if you want
to capture chain-of-thought for debugging in a non-streaming scenario.
"""
return
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle direct assistant messages. This class primarily handles them
as function calls, so it's a no-op by default.
"""
return
def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle function-related log messages, typically of the form:
It's a no-op by default.
"""
return
def _process_chunk_to_openai_style(self, chunk: ChatCompletionChunkResponse) -> Optional[ChatCompletionChunk]:
"""
Optionally transform an inbound OpenAI-style chunk so that partial
content (especially from a 'send_message' tool) is exposed as text
deltas in 'content'. Otherwise, pass through or yield finish reasons.
"""
choice = chunk.choices[0]
delta = choice.delta
# If there's direct content, we usually let it stream as-is
if delta.content is not None:
# TODO: Eventually use all of the native OpenAI objects
return ChatCompletionChunk(**chunk.model_dump(exclude_none=True))
# If there's a function call, accumulate its name/args. If it's a known
# text-producing function (like send_message), stream partial text.
if delta.tool_calls:
tool_call = delta.tool_calls[0]
if tool_call.function.name:
self.current_function_name += tool_call.function.name
if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments)
# Only parse arguments for "send_message" to stream partial text
if self.current_function_name.strip() == self.assistant_message_tool_name:
combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args)
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If there's a finish reason, pass that along
if choice.finish_reason is not None:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(),
finish_reason=self.FINISH_REASON_STR,
)
],
)
return None
def _reset_parsing_state(self) -> None:
"""Clears internal buffers for function call name/args."""
self.current_function_name = ""
self.current_function_arguments = []

View File

@@ -0,0 +1,185 @@
import json
class OptimisticJSONParser:
"""
A JSON parser that attempts to parse a given string using `json.loads`,
and if that fails, it parses as much valid JSON as possible while
allowing extra tokens to remain. Those extra tokens can be retrieved
from `self.last_parse_reminding`. If `strict` is False, the parser
tries to tolerate incomplete strings and incomplete numbers.
"""
def __init__(self, strict=True):
self.strict = strict
self.parsers = {
" ": self.parse_space,
"\r": self.parse_space,
"\n": self.parse_space,
"\t": self.parse_space,
"[": self.parse_array,
"{": self.parse_object,
'"': self.parse_string,
"t": self.parse_true,
"f": self.parse_false,
"n": self.parse_null,
}
# Register number parser for digits and signs
for char in "0123456789.-":
self.parsers[char] = self.parse_number
self.last_parse_reminding = None
self.on_extra_token = self.default_on_extra_token
def default_on_extra_token(self, text, data, reminding):
pass
def parse(self, input_str):
"""
Try to parse the entire `input_str` as JSON. If parsing fails,
attempts a partial parse, storing leftover text in
`self.last_parse_reminding`. A callback (`on_extra_token`) is
triggered if extra tokens remain.
"""
if len(input_str) >= 1:
try:
return json.loads(input_str)
except json.JSONDecodeError as decode_error:
data, reminding = self.parse_any(input_str, decode_error)
self.last_parse_reminding = reminding
if self.on_extra_token and reminding:
self.on_extra_token(input_str, data, reminding)
return data
else:
return json.loads("{}")
def parse_any(self, input_str, decode_error):
"""Determine which parser to use based on the first character."""
if not input_str:
raise decode_error
parser = self.parsers.get(input_str[0])
if parser is None:
raise decode_error
return parser(input_str, decode_error)
def parse_space(self, input_str, decode_error):
"""Strip leading whitespace and parse again."""
return self.parse_any(input_str.strip(), decode_error)
def parse_array(self, input_str, decode_error):
"""Parse a JSON array, returning the list and remaining string."""
# Skip the '['
input_str = input_str[1:]
array_values = []
input_str = input_str.strip()
while input_str:
if input_str[0] == "]":
# Skip the ']'
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
array_values.append(value)
input_str = input_str.strip()
if input_str.startswith(","):
# Skip the ','
input_str = input_str[1:].strip()
return array_values, input_str
def parse_object(self, input_str, decode_error):
"""Parse a JSON object, returning the dict and remaining string."""
# Skip the '{'
input_str = input_str[1:]
obj = {}
input_str = input_str.strip()
while input_str:
if input_str[0] == "}":
# Skip the '}'
input_str = input_str[1:]
break
key, input_str = self.parse_any(input_str, decode_error)
input_str = input_str.strip()
if not input_str or input_str[0] == "}":
obj[key] = None
break
if input_str[0] != ":":
raise decode_error
# Skip ':'
input_str = input_str[1:].strip()
if not input_str or input_str[0] in ",}":
obj[key] = None
if input_str.startswith(","):
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
obj[key] = value
input_str = input_str.strip()
if input_str.startswith(","):
# Skip the ','
input_str = input_str[1:].strip()
return obj, input_str
def parse_string(self, input_str, decode_error):
"""Parse a JSON string, respecting escaped quotes if present."""
end = input_str.find('"', 1)
while end != -1 and input_str[end - 1] == "\\":
end = input_str.find('"', end + 1)
if end == -1:
# Incomplete string
if not self.strict:
return input_str[1:], ""
return json.loads(f'"{input_str[1:]}"'), ""
str_val = input_str[: end + 1]
input_str = input_str[end + 1 :]
if not self.strict:
return str_val[1:-1], input_str
return json.loads(str_val), input_str
def parse_number(self, input_str, decode_error):
"""
Parse a number (int or float). Allows digits, '.', '-', but
doesn't fully validate complex exponents unless they appear
before a non-number character.
"""
idx = 0
while idx < len(input_str) and input_str[idx] in "0123456789.-":
idx += 1
num_str = input_str[:idx]
remainder = input_str[idx:]
# If it's only a sign or just '.', return as-is with empty remainder
if not num_str or num_str in {"-", "."}:
return num_str, ""
try:
if num_str.endswith("."):
num = int(num_str[:-1])
else:
num = float(num_str) if any(c in num_str for c in ".eE") else int(num_str)
except ValueError:
raise decode_error
return num, remainder
def parse_true(self, input_str, decode_error):
"""Parse a 'true' value."""
if input_str.startswith(("t", "T")):
return True, input_str[4:]
raise decode_error
def parse_false(self, input_str, decode_error):
"""Parse a 'false' value."""
if input_str.startswith(("f", "F")):
return False, input_str[5:]
raise decode_error
def parse_null(self, input_str, decode_error):
"""Parse a 'null' value."""
if input_str.startswith("n"):
return None, input_str[4:]
raise decode_error

View File

@@ -0,0 +1,161 @@
import asyncio
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.completion_create_params import CompletionCreateParams
from letta.agent import Agent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.message import MessageCreate
from letta.schemas.openai.chat_completion_response import Message
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
# TODO this belongs in a controller!
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
if TYPE_CHECKING:
from letta.server.server import SyncServer
router = APIRouter(prefix="/v1", tags=["chat_completions"])
logger = get_logger(__name__)
@router.post(
"/chat/completions",
response_model=None,
operation_id="create_chat_completions",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def create_chat_completions(
completion_request: CompletionCreateParams = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
# Validate and process fields
try:
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
except KeyError:
# Handle the case where "messages" is not present in the request
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
except TypeError:
# Handle the case where "messages" is not iterable
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
except Exception as e:
# Catch any other unexpected errors and include the exception message
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
if messages[-1]["role"] != "user":
logger.error(f"The last message does not have a `user` role: {messages}")
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
input_message = messages[-1]
if not isinstance(input_message["content"], str):
logger.error(f"The input message does not have valid content: {input_message}")
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
# Process remaining fields
if not completion_request["stream"]:
raise HTTPException(status_code=400, detail="Must be streaming request: `stream` was set to `False` in the request.")
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
llm_config = letta_agent.agent_state.llm_config
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
model = completion_request.get("model")
if model != llm_config.model:
warning_msg = f"The requested model {model} is different from the model specified in this agent's ({agent_id}) llm_config: \n{llm_config.model_dump_json(indent=4)}"
logger.warning(f"Defaulting to {llm_config.model}...")
logger.warning(warning_msg)
logger.info(f"Received input message: {input_message}")
return await send_message_to_agent_chat_completions(
server=server,
letta_agent=letta_agent,
actor=actor,
messages=[MessageCreate(role=input_message["role"], content=input_message["content"])],
)
async def send_message_to_agent_chat_completions(
server: "SyncServer",
letta_agent: Agent,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> StreamingResponse:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# For streaming response
try:
# TODO: cleanup this logic
llm_config = letta_agent.agent_state.llm_config
# Create a new interface per request
letta_agent.interface = ChatCompletionsStreamingInterface()
streaming_interface = letta_agent.interface
if not isinstance(streaming_interface, ChatCompletionsStreamingInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Allow AssistantMessage is desired by client
streaming_interface.assistant_message_tool_name = assistant_message_tool_name
streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg
# Related to JSON buffer reader
streaming_interface.inner_thoughts_in_kwargs = (
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
)
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
asyncio.create_task(
asyncio.to_thread(
server.send_messages,
actor=actor,
agent_id=letta_agent.agent_state.id,
messages=messages,
interface=streaming_interface,
)
)
# return a stream
return StreamingResponse(
sse_async_generator(
streaming_interface.get_generator(),
usage_task=None,
finish_message=True,
),
media_type="text/event-stream",
)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")