feat: Add voice-compatible chat completions endpoint (#774)

This commit is contained in:
Matthew Zhou
2025-01-27 12:25:05 -10:00
committed by GitHub
parent bfe2c6e8f5
commit b6773ea7ff
11 changed files with 996 additions and 34 deletions

View File

@@ -1,18 +1,22 @@
import json
from typing import Generator
from typing import Generator, Union, get_args
import httpx
from httpx_sse import SSEError, connect_sse
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.constants import OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
from letta.errors import LLMError
from letta.log import get_logger
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage, ToolReturnMessage
from letta.schemas.letta_response import LettaStreamingResponse
from letta.schemas.usage import LettaUsageStatistics
logger = get_logger(__name__)
def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingResponse, None, None]:
def _sse_post(url: str, data: dict, headers: dict) -> Generator[Union[LettaStreamingResponse, ChatCompletionChunk], None, None]:
with httpx.Client() as client:
with connect_sse(client, method="POST", url=url, json=data, headers=headers) as event_source:
@@ -20,22 +24,26 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
# Inspect for errors before iterating (see https://github.com/florimondmanca/httpx-sse/pull/12)
if not event_source.response.is_success:
# handle errors
from letta.utils import printd
pass
printd("Caught error before iterating SSE request:", vars(event_source.response))
printd(event_source.response.read())
logger.warning("Caught error before iterating SSE request:", vars(event_source.response))
logger.warning(event_source.response.read().decode("utf-8"))
try:
response_bytes = event_source.response.read()
response_dict = json.loads(response_bytes.decode("utf-8"))
error_message = response_dict["error"]["message"]
# e.g.: This model's maximum context length is 8192 tokens. However, your messages resulted in 8198 tokens (7450 in the messages, 748 in the functions). Please reduce the length of the messages or functions.
if OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in error_message:
raise LLMError(error_message)
if (
"error" in response_dict
and "message" in response_dict["error"]
and OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING in response_dict["error"]["message"]
):
logger.error(response_dict["error"]["message"])
raise LLMError(response_dict["error"]["message"])
except LLMError:
raise
except:
print(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
logger.error(f"Failed to parse SSE message, throwing SSE HTTP error up the stack")
event_source.response.raise_for_status()
try:
@@ -58,33 +66,34 @@ def _sse_post(url: str, data: dict, headers: dict) -> Generator[LettaStreamingRe
yield ToolReturnMessage(**chunk_data)
elif "step_count" in chunk_data:
yield LettaUsageStatistics(**chunk_data)
elif chunk_data.get("object") == get_args(ChatCompletionChunk.__annotations__["object"])[0]:
yield ChatCompletionChunk(**chunk_data) # Add your processing logic for chat chunks here
else:
raise ValueError(f"Unknown message type in chunk_data: {chunk_data}")
except SSEError as e:
print("Caught an error while iterating the SSE stream:", str(e))
logger.error("Caught an error while iterating the SSE stream:", str(e))
if "application/json" in str(e): # Check if the error is because of JSON response
# TODO figure out a better way to catch the error other than re-trying with a POST
response = client.post(url=url, json=data, headers=headers) # Make the request again to get the JSON response
if response.headers["Content-Type"].startswith("application/json"):
error_details = response.json() # Parse the JSON to get the error message
print("Request:", vars(response.request))
print("POST Error:", error_details)
print("Original SSE Error:", str(e))
logger.error("Request:", vars(response.request))
logger.error("POST Error:", error_details)
logger.error("Original SSE Error:", str(e))
else:
print("Failed to retrieve JSON error message via retry.")
logger.error("Failed to retrieve JSON error message via retry.")
else:
print("SSEError not related to 'application/json' content type.")
logger.error("SSEError not related to 'application/json' content type.")
# Optionally re-raise the exception if you need to propagate it
raise e
except Exception as e:
if event_source.response.request is not None:
print("HTTP Request:", vars(event_source.response.request))
logger.error("HTTP Request:", vars(event_source.response.request))
if event_source.response is not None:
print("HTTP Status:", event_source.response.status_code)
print("HTTP Headers:", event_source.response.headers)
# print("HTTP Body:", event_source.response.text)
print("Exception message:", str(e))
logger.error("HTTP Status:", event_source.response.status_code)
logger.error("HTTP Headers:", event_source.response.headers)
logger.error("Exception message:", str(e))
raise e

View File

@@ -116,7 +116,7 @@ class MessageDelta(BaseModel):
content: Optional[str] = None
tool_calls: Optional[List[ToolCallDelta]] = None
# role: Optional[str] = None
role: Optional[str] = None
function_call: Optional[FunctionCallDelta] = None # Deprecated
@@ -132,7 +132,7 @@ class ChatCompletionChunkResponse(BaseModel):
id: str
choices: List[ChunkChoice]
created: datetime.datetime
created: Union[datetime.datetime, str]
model: str
# system_fingerprint: str # docs say this is mandatory, but in reality API returns None
system_fingerprint: Optional[str] = None

View File

@@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
from starlette.middleware.cors import CORSMiddleware
from letta.__init__ import __version__
from letta.constants import ADMIN_PREFIX, API_PREFIX
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
from letta.log import get_logger
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
@@ -22,6 +22,7 @@ from letta.server.constants import REST_DEFAULT_PORT
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
@@ -241,6 +242,9 @@ def create_application() -> "FastAPI":
app.include_router(users_router, prefix=ADMIN_PREFIX)
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
# openai
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
# /api/auth endpoints
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -0,0 +1,256 @@
import asyncio
from collections import deque
from datetime import datetime
from typing import AsyncGenerator, Optional, Union
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import LettaMessage
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
logger = get_logger(__name__)
class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""
Provides an asynchronous streaming mechanism for LLM output. Internally
maintains a queue of chunks that can be consumed via an async generator.
Key Behaviors:
- process_chunk: Accepts ChatCompletionChunkResponse objects (e.g. from an
OpenAI-like streaming API), potentially transforms them to a partial
text response, and enqueues them.
- get_generator: Returns an async generator that yields messages or status
markers as they become available.
- step_complete, step_yield: End streaming for the current step or entirely,
depending on the multi_step setting.
- function_message, internal_monologue: Handle LLM “function calls” and
“reasoning” messages for non-streaming contexts.
"""
FINISH_REASON_STR = "stop"
ASSISTANT_STR = "assistant"
def __init__(
self,
multi_step: bool = True,
timeout: int = 150,
# The following are placeholders for potential expansions; they
# remain if you need to differentiate between actual "assistant messages"
# vs. tool calls. By default, they are set for the "send_message" tool usage.
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
inner_thoughts_in_kwargs: bool = True,
inner_thoughts_kwarg: str = INNER_THOUGHTS_KWARG,
):
self.streaming_mode = True
# Parsing state for incremental function-call data
self.current_function_name = ""
self.current_function_arguments = []
# Internal chunk buffer and event for async notification
self._chunks = deque()
self._event = asyncio.Event()
self._active = True
# Whether or not the stream should remain open across multiple steps
self.multi_step = multi_step
# Timing / debug parameters
self.timeout = timeout
# These are placeholders to handle specialized
# assistant message logic or storing inner thoughts.
self.assistant_message_tool_name = assistant_message_tool_name
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
self.inner_thoughts_kwarg = inner_thoughts_kwarg
async def _create_generator(
self,
) -> AsyncGenerator[Union[LettaMessage, MessageStreamStatus], None]:
"""
An asynchronous generator that yields queued items as they arrive.
Ends when _active is set to False or when timing out.
"""
while self._active:
try:
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
except asyncio.TimeoutError:
break
while self._chunks:
yield self._chunks.popleft()
self._event.clear()
def get_generator(self) -> AsyncGenerator:
"""
Provide the async generator interface. Will raise StopIteration
if the stream is inactive.
"""
if not self._active:
raise StopIteration("The stream is not active.")
return self._create_generator()
def _push_to_buffer(
self,
item: ChatCompletionChunk,
):
"""
Add an item (a LettaMessage, status marker, or partial chunk)
to the queue and signal waiting consumers.
"""
if not self._active:
raise RuntimeError("Attempted to push to an inactive stream.")
self._chunks.append(item)
self._event.set()
def stream_start(self) -> None:
"""Initialize or reset the streaming state for a new request."""
self._active = True
self._chunks.clear()
self._event.clear()
self._reset_parsing_state()
def stream_end(self) -> None:
"""
Clean up after the current streaming session. Typically called when the
request is done or the data source has signaled it has no more data.
"""
self._reset_parsing_state()
def step_complete(self) -> None:
"""
Indicate that one step of multi-step generation is done.
If multi_step=False, the stream is closed immediately.
"""
if not self.multi_step:
self._active = False
self._event.set() # Ensure waiting generators can finalize
self._reset_parsing_state()
def step_yield(self) -> None:
"""
Explicitly end the stream in a multi-step scenario, typically
called when the entire chain of steps is complete.
"""
self._active = False
self._event.set()
@staticmethod
def clear() -> None:
"""No-op retained for interface compatibility."""
return
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime) -> None:
"""
Called externally with a ChatCompletionChunkResponse. Transforms
it if necessary, then enqueues partial messages for streaming back.
"""
processed_chunk = self._process_chunk_to_openai_style(chunk)
if processed_chunk is not None:
self._push_to_buffer(processed_chunk)
def user_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle user messages. Here, it's a no-op, but included if your
pipeline needs to respond to user messages distinctly.
"""
return
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle LLM reasoning or internal monologue. Example usage: if you want
to capture chain-of-thought for debugging in a non-streaming scenario.
"""
return
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle direct assistant messages. This class primarily handles them
as function calls, so it's a no-op by default.
"""
return
def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
"""
Handle function-related log messages, typically of the form:
It's a no-op by default.
"""
return
def _process_chunk_to_openai_style(self, chunk: ChatCompletionChunkResponse) -> Optional[ChatCompletionChunk]:
"""
Optionally transform an inbound OpenAI-style chunk so that partial
content (especially from a 'send_message' tool) is exposed as text
deltas in 'content'. Otherwise, pass through or yield finish reasons.
"""
choice = chunk.choices[0]
delta = choice.delta
# If there's direct content, we usually let it stream as-is
if delta.content is not None:
# TODO: Eventually use all of the native OpenAI objects
return ChatCompletionChunk(**chunk.model_dump(exclude_none=True))
# If there's a function call, accumulate its name/args. If it's a known
# text-producing function (like send_message), stream partial text.
if delta.tool_calls:
tool_call = delta.tool_calls[0]
if tool_call.function.name:
self.current_function_name += tool_call.function.name
if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments)
# Only parse arguments for "send_message" to stream partial text
if self.current_function_name.strip() == self.assistant_message_tool_name:
combined_args = "".join(self.current_function_arguments)
parsed_args = OptimisticJSONParser().parse(combined_args)
# If we can see a "message" field, return it as partial content
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
finish_reason=None,
)
],
)
# If there's a finish reason, pass that along
if choice.finish_reason is not None:
return ChatCompletionChunk(
id=chunk.id,
object=chunk.object,
created=chunk.created.timestamp(),
model=chunk.model,
choices=[
Choice(
index=choice.index,
delta=ChoiceDelta(),
finish_reason=self.FINISH_REASON_STR,
)
],
)
return None
def _reset_parsing_state(self) -> None:
"""Clears internal buffers for function call name/args."""
self.current_function_name = ""
self.current_function_arguments = []

View File

@@ -0,0 +1,185 @@
import json
class OptimisticJSONParser:
"""
A JSON parser that attempts to parse a given string using `json.loads`,
and if that fails, it parses as much valid JSON as possible while
allowing extra tokens to remain. Those extra tokens can be retrieved
from `self.last_parse_reminding`. If `strict` is False, the parser
tries to tolerate incomplete strings and incomplete numbers.
"""
def __init__(self, strict=True):
self.strict = strict
self.parsers = {
" ": self.parse_space,
"\r": self.parse_space,
"\n": self.parse_space,
"\t": self.parse_space,
"[": self.parse_array,
"{": self.parse_object,
'"': self.parse_string,
"t": self.parse_true,
"f": self.parse_false,
"n": self.parse_null,
}
# Register number parser for digits and signs
for char in "0123456789.-":
self.parsers[char] = self.parse_number
self.last_parse_reminding = None
self.on_extra_token = self.default_on_extra_token
def default_on_extra_token(self, text, data, reminding):
pass
def parse(self, input_str):
"""
Try to parse the entire `input_str` as JSON. If parsing fails,
attempts a partial parse, storing leftover text in
`self.last_parse_reminding`. A callback (`on_extra_token`) is
triggered if extra tokens remain.
"""
if len(input_str) >= 1:
try:
return json.loads(input_str)
except json.JSONDecodeError as decode_error:
data, reminding = self.parse_any(input_str, decode_error)
self.last_parse_reminding = reminding
if self.on_extra_token and reminding:
self.on_extra_token(input_str, data, reminding)
return data
else:
return json.loads("{}")
def parse_any(self, input_str, decode_error):
"""Determine which parser to use based on the first character."""
if not input_str:
raise decode_error
parser = self.parsers.get(input_str[0])
if parser is None:
raise decode_error
return parser(input_str, decode_error)
def parse_space(self, input_str, decode_error):
"""Strip leading whitespace and parse again."""
return self.parse_any(input_str.strip(), decode_error)
def parse_array(self, input_str, decode_error):
"""Parse a JSON array, returning the list and remaining string."""
# Skip the '['
input_str = input_str[1:]
array_values = []
input_str = input_str.strip()
while input_str:
if input_str[0] == "]":
# Skip the ']'
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
array_values.append(value)
input_str = input_str.strip()
if input_str.startswith(","):
# Skip the ','
input_str = input_str[1:].strip()
return array_values, input_str
def parse_object(self, input_str, decode_error):
"""Parse a JSON object, returning the dict and remaining string."""
# Skip the '{'
input_str = input_str[1:]
obj = {}
input_str = input_str.strip()
while input_str:
if input_str[0] == "}":
# Skip the '}'
input_str = input_str[1:]
break
key, input_str = self.parse_any(input_str, decode_error)
input_str = input_str.strip()
if not input_str or input_str[0] == "}":
obj[key] = None
break
if input_str[0] != ":":
raise decode_error
# Skip ':'
input_str = input_str[1:].strip()
if not input_str or input_str[0] in ",}":
obj[key] = None
if input_str.startswith(","):
input_str = input_str[1:]
break
value, input_str = self.parse_any(input_str, decode_error)
obj[key] = value
input_str = input_str.strip()
if input_str.startswith(","):
# Skip the ','
input_str = input_str[1:].strip()
return obj, input_str
def parse_string(self, input_str, decode_error):
"""Parse a JSON string, respecting escaped quotes if present."""
end = input_str.find('"', 1)
while end != -1 and input_str[end - 1] == "\\":
end = input_str.find('"', end + 1)
if end == -1:
# Incomplete string
if not self.strict:
return input_str[1:], ""
return json.loads(f'"{input_str[1:]}"'), ""
str_val = input_str[: end + 1]
input_str = input_str[end + 1 :]
if not self.strict:
return str_val[1:-1], input_str
return json.loads(str_val), input_str
def parse_number(self, input_str, decode_error):
"""
Parse a number (int or float). Allows digits, '.', '-', but
doesn't fully validate complex exponents unless they appear
before a non-number character.
"""
idx = 0
while idx < len(input_str) and input_str[idx] in "0123456789.-":
idx += 1
num_str = input_str[:idx]
remainder = input_str[idx:]
# If it's only a sign or just '.', return as-is with empty remainder
if not num_str or num_str in {"-", "."}:
return num_str, ""
try:
if num_str.endswith("."):
num = int(num_str[:-1])
else:
num = float(num_str) if any(c in num_str for c in ".eE") else int(num_str)
except ValueError:
raise decode_error
return num, remainder
def parse_true(self, input_str, decode_error):
"""Parse a 'true' value."""
if input_str.startswith(("t", "T")):
return True, input_str[4:]
raise decode_error
def parse_false(self, input_str, decode_error):
"""Parse a 'false' value."""
if input_str.startswith(("f", "F")):
return False, input_str[5:]
raise decode_error
def parse_null(self, input_str, decode_error):
"""Parse a 'null' value."""
if input_str.startswith("n"):
return None, input_str[4:]
raise decode_error

View File

@@ -0,0 +1,161 @@
import asyncio
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast
from fastapi import APIRouter, Body, Depends, Header, HTTPException
from fastapi.responses import StreamingResponse
from openai.types.chat import ChatCompletionMessageParam
from openai.types.chat.completion_create_params import CompletionCreateParams
from letta.agent import Agent
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.log import get_logger
from letta.schemas.message import MessageCreate
from letta.schemas.openai.chat_completion_response import Message
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
# TODO this belongs in a controller!
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
if TYPE_CHECKING:
from letta.server.server import SyncServer
router = APIRouter(prefix="/v1", tags=["chat_completions"])
logger = get_logger(__name__)
@router.post(
"/chat/completions",
response_model=None,
operation_id="create_chat_completions",
responses={
200: {
"description": "Successful response",
"content": {
"text/event-stream": {"description": "Server-Sent Events stream"},
},
}
},
)
async def create_chat_completions(
completion_request: CompletionCreateParams = Body(...),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
# Validate and process fields
try:
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
except KeyError:
# Handle the case where "messages" is not present in the request
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
except TypeError:
# Handle the case where "messages" is not iterable
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
except Exception as e:
# Catch any other unexpected errors and include the exception message
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
if messages[-1]["role"] != "user":
logger.error(f"The last message does not have a `user` role: {messages}")
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
input_message = messages[-1]
if not isinstance(input_message["content"], str):
logger.error(f"The input message does not have valid content: {input_message}")
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
# Process remaining fields
if not completion_request["stream"]:
raise HTTPException(status_code=400, detail="Must be streaming request: `stream` was set to `False` in the request.")
actor = server.user_manager.get_user_or_default(user_id=user_id)
agent_id = str(completion_request.get("user", None))
if agent_id is None:
error_msg = "Must pass agent_id in the 'user' field"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
llm_config = letta_agent.agent_state.llm_config
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}"
logger.error(error_msg)
raise HTTPException(status_code=400, detail=error_msg)
model = completion_request.get("model")
if model != llm_config.model:
warning_msg = f"The requested model {model} is different from the model specified in this agent's ({agent_id}) llm_config: \n{llm_config.model_dump_json(indent=4)}"
logger.warning(f"Defaulting to {llm_config.model}...")
logger.warning(warning_msg)
logger.info(f"Received input message: {input_message}")
return await send_message_to_agent_chat_completions(
server=server,
letta_agent=letta_agent,
actor=actor,
messages=[MessageCreate(role=input_message["role"], content=input_message["content"])],
)
async def send_message_to_agent_chat_completions(
server: "SyncServer",
letta_agent: Agent,
actor: User,
messages: Union[List[Message], List[MessageCreate]],
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
) -> StreamingResponse:
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
# For streaming response
try:
# TODO: cleanup this logic
llm_config = letta_agent.agent_state.llm_config
# Create a new interface per request
letta_agent.interface = ChatCompletionsStreamingInterface()
streaming_interface = letta_agent.interface
if not isinstance(streaming_interface, ChatCompletionsStreamingInterface):
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
# Allow AssistantMessage is desired by client
streaming_interface.assistant_message_tool_name = assistant_message_tool_name
streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg
# Related to JSON buffer reader
streaming_interface.inner_thoughts_in_kwargs = (
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
)
# Offload the synchronous message_func to a separate thread
streaming_interface.stream_start()
asyncio.create_task(
asyncio.to_thread(
server.send_messages,
actor=actor,
agent_id=letta_agent.agent_state.id,
messages=messages,
interface=streaming_interface,
)
)
# return a stream
return StreamingResponse(
sse_async_generator(
streaming_interface.get_generator(),
usage_task=None,
finish_message=True,
),
media_type="text/event-stream",
)
except HTTPException:
raise
except Exception as e:
print(e)
import traceback
traceback.print_exc()
raise HTTPException(status_code=500, detail=f"{e}")

View File

@@ -62,6 +62,7 @@ from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
from letta.server.rest_api.interface import StreamingServerInterface
from letta.server.rest_api.utils import sse_async_generator
from letta.services.agent_manager import AgentManager
@@ -719,7 +720,7 @@ class SyncServer(Server):
# whether or not to wrap user and system message as MemGPT-style stringified JSON
wrap_user_message: bool = True,
wrap_system_message: bool = True,
interface: Union[AgentInterface, None] = None, # needed to getting responses
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed to getting responses
metadata: Optional[dict] = None, # Pass through metadata to interface
) -> LettaUsageStatistics:
"""Send a list of messages to the agent
@@ -735,7 +736,7 @@ class SyncServer(Server):
for message in messages:
assert isinstance(message, MessageCreate)
# If wrapping is eanbled, wrap with metadata before placing content inside the Message object
# If wrapping is enabled, wrap with metadata before placing content inside the Message object
if message.role == MessageRole.user and wrap_user_message:
message.content = system.package_user_message(user_message=message.content)
elif message.role == MessageRole.system and wrap_system_message:

View File

@@ -0,0 +1,105 @@
import os
import threading
import time
import uuid
import pytest
from dotenv import load_dotenv
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta import RESTClient, create_client
from letta.client.streaming import _sse_post
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_request import ChatCompletionRequest, UserMessage
from letta.schemas.usage import LettaUsageStatistics
def run_server():
load_dotenv()
# _reset_config()
from letta.server.rest_api.app import start_server
print("Starting server...")
start_server(debug=True)
@pytest.fixture(
scope="module",
)
def client():
# get URL from enviornment
server_url = os.getenv("LETTA_SERVER_URL")
if server_url is None:
# run server in thread
server_url = "http://localhost:8283"
print("Starting server thread")
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
# create user via admin client
client = create_client(base_url=server_url, token=None) # This yields control back to the test function
client.set_default_llm_config(LLMConfig.default_config("gpt-4o-mini"))
client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai"))
yield client
# Fixture for test agent
@pytest.fixture(scope="module")
def agent_state(client: RESTClient):
agent_state = client.create_agent(name=f"test_client_{str(uuid.uuid4())}")
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
def test_voice_streaming(mock_e2b_api_key_none, client: RESTClient, agent_state: AgentState):
"""
Test voice streaming for chat completions using the streaming API.
This test ensures the SSE (Server-Sent Events) response from the voice streaming endpoint
adheres to the expected structure and contains valid data for each type of chunk.
"""
# Prepare the chat completion request with streaming enabled
request = ChatCompletionRequest(
model="gpt-4o-mini",
messages=[UserMessage(content="Tell me something interesting about bananas.")],
user=agent_state.id,
stream=True,
)
# Perform a POST request to the voice/chat/completions endpoint and collect the streaming response
response = _sse_post(
f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers
)
# Convert the streaming response into a list of chunks for processing
chunks = list(response)
for idx, chunk in enumerate(chunks):
if isinstance(chunk, ChatCompletionChunk):
# Assert that the chunk has at least one choice (a response from the model)
assert len(chunk.choices) > 0, "Each ChatCompletionChunk should have at least one choice."
elif isinstance(chunk, LettaUsageStatistics):
# Assert that the usage statistics contain valid token counts
assert chunk.completion_tokens > 0, "Completion tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.prompt_tokens > 0, "Prompt tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.total_tokens > 0, "Total tokens should be greater than 0 in LettaUsageStatistics."
assert chunk.step_count == 1, "Step count in LettaUsageStatistics should always be 1 for a single request."
elif isinstance(chunk, MessageStreamStatus):
# Assert that the stream ends with a 'done' status
assert chunk == MessageStreamStatus.done, "The last chunk should indicate the stream has completed."
assert idx == len(chunks) - 1, "The 'done' status must be the last chunk in the stream."
else:
# Fail the test if an unexpected chunk type is encountered
pytest.fail(f"Unexpected chunk type: {chunk}", pytrace=True)

View File

@@ -0,0 +1,248 @@
import json
from unittest.mock import patch
import pytest
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
@pytest.fixture
def strict_parser():
"""Provides a fresh OptimisticJSONParser instance in strict mode."""
return OptimisticJSONParser(strict=True)
@pytest.fixture
def lenient_parser():
"""Provides a fresh OptimisticJSONParser instance in non-strict mode."""
return OptimisticJSONParser(strict=False)
def test_parse_empty_input(strict_parser):
"""
Test parsing an empty string. Should fall back to parsing "{}".
"""
result = strict_parser.parse("")
assert result == {}, "Empty input should parse as an empty dict."
def test_parse_valid_json(strict_parser):
"""
Test parsing a valid JSON string using the standard json.loads logic.
"""
input_str = '{"name": "John", "age": 30}'
result = strict_parser.parse(input_str)
assert result == {"name": "John", "age": 30}, "Should parse valid JSON correctly."
def test_parse_valid_json_array(strict_parser):
"""
Test parsing a valid JSON array.
"""
input_str = '[1, 2, 3, "four"]'
result = strict_parser.parse(input_str)
assert result == [1, 2, 3, "four"], "Should parse valid JSON array correctly."
def test_parse_partial_json_object(strict_parser):
"""
Test parsing a JSON object with extra trailing characters.
The extra characters should trigger on_extra_token.
"""
input_str = '{"key": "value"} trailing'
with patch.object(strict_parser, "on_extra_token") as mock_callback:
result = strict_parser.parse(input_str)
assert result == {"key": "value"}, "Should parse the JSON part properly."
assert strict_parser.last_parse_reminding.strip() == "trailing", "The leftover reminding should be 'trailing'."
mock_callback.assert_called_once()
def test_parse_partial_json_array(strict_parser):
"""
Test parsing a JSON array with extra tokens.
"""
input_str = "[1, 2, 3] extra_tokens"
result = strict_parser.parse(input_str)
assert result == [1, 2, 3], "Should parse array portion properly."
assert strict_parser.last_parse_reminding.strip() == "extra_tokens", "The leftover reminding should capture extra tokens."
def test_parse_number_cases(strict_parser):
"""
Test various number formats.
"""
# We'll parse them individually to ensure the fallback parser handles them.
test_cases = {
"123": 123,
"-42": -42,
"3.14": 3.14,
"-0.001": -0.001,
"10.": 10, # This should convert to int in our parser.
".5": 0.5 if not strict_parser.strict else ".5",
}
for num_str, expected in test_cases.items():
parsed = strict_parser.parse(num_str)
if num_str == ".5" and strict_parser.strict:
# Strict mode won't parse ".5" directly as a valid float by default
# Our current logic may end up raising or partial-parsing.
# Adjust as necessary based on your actual parser's behavior.
assert parsed == ".5" or parsed == 0.5, "Strict handling of '.5' can vary."
else:
assert parsed == expected, f"Number parsing failed for {num_str}"
def test_parse_boolean_true(strict_parser):
assert strict_parser.parse("true") is True, "Should parse 'true'."
# Check leftover
assert strict_parser.last_parse_reminding == "", "No extra tokens expected."
def test_parse_boolean_false(strict_parser):
assert strict_parser.parse("false") is False, "Should parse 'false'."
def test_parse_null(strict_parser):
assert strict_parser.parse("null") is None, "Should parse 'null'."
@pytest.mark.parametrize("invalid_boolean", ["tru", "fa", "fal", "True", "False"])
def test_parse_invalid_booleans(strict_parser, invalid_boolean):
"""
Test some invalid booleans. The parser tries to parse them as partial if possible.
If it fails, it may raise an exception or parse partially based on the code.
"""
try:
result = strict_parser.parse(invalid_boolean)
# If it doesn't raise, it might parse partially or incorrectly.
# Check leftover or the returned data.
# Adjust your assertions based on actual parser behavior.
assert result in [True, False, invalid_boolean], f"Unexpected parse result for {invalid_boolean}: {result}"
except json.JSONDecodeError:
# This is also a valid outcome for truly invalid strings in strict mode.
pass
def test_parse_string_with_escapes(strict_parser):
"""
Test a string containing escaped quotes.
"""
input_str = r'"This is a \"test\" string"'
result = strict_parser.parse(input_str)
assert result == 'This is a "test" string', "String with escaped quotes should parse correctly."
def test_parse_incomplete_string_strict(strict_parser):
"""
Test how a strict parser handles an incomplete string.
"""
input_str = '"Unfinished string with no end'
try:
strict_parser.parse(input_str)
pytest.fail("Expected an error or partial parse with leftover tokens in strict mode.")
except json.JSONDecodeError:
pass # Strict mode might raise
def test_parse_incomplete_string_lenient(lenient_parser):
"""
In non-strict mode, incomplete strings may be returned as-is.
"""
input_str = '"Unfinished string with no end'
result = lenient_parser.parse(input_str)
assert result == "Unfinished string with no end", "Lenient mode should return the incomplete string without quotes."
def test_parse_incomplete_number_strict(strict_parser):
"""
Test how a strict parser handles an incomplete number, like '-' or '.'.
In strict mode, the parser now raises JSONDecodeError rather than
returning the partial string.
"""
input_str = "-"
with pytest.raises(json.JSONDecodeError):
strict_parser.parse(input_str)
def test_object_with_missing_colon(strict_parser):
"""
Test parsing an object missing a colon. Should raise or partially parse.
"""
input_str = '{"key" "value"}'
try:
strict_parser.parse(input_str)
pytest.fail("Parser should raise or handle error with missing colon.")
except json.JSONDecodeError:
pass
def test_object_with_missing_value(strict_parser):
"""
Test parsing an object with a key but no value before a comma or brace.
"""
input_str = '{"key":}'
# Depending on parser logic, "key" might map to None or raise an error.
result = strict_parser.parse(input_str)
# Expect partial parse: {'key': None}
assert result == {"key": None}, "Key without value should map to None."
def test_array_with_trailing_comma(strict_parser):
"""
Test array that might have a trailing comma before closing.
"""
input_str = "[1, 2, 3, ]"
result = strict_parser.parse(input_str)
# The parser does not explicitly handle trailing commas in strict JSON.
# But the fallback logic may allow partial parse. Adjust assertions accordingly.
assert result == [1, 2, 3], "Trailing comma should be handled or partially parsed."
def test_callback_invocation(strict_parser, capsys):
"""
Verify that on_extra_token callback is invoked and prints expected content.
"""
input_str = '{"a":1} leftover'
strict_parser.parse(input_str)
captured = capsys.readouterr().out
assert "Parsed JSON with extra tokens:" in captured, "Callback default_on_extra_token should print a message."
def test_unknown_token(strict_parser):
"""
Test parser behavior when encountering an unknown first character.
Should raise JSONDecodeError in strict mode.
"""
input_str = "@invalid"
with pytest.raises(json.JSONDecodeError):
strict_parser.parse(input_str)
def test_array_nested_objects(lenient_parser):
"""
Test parsing a complex structure with nested arrays/objects.
"""
input_str = '[ {"a":1}, {"b": [2,3]}, 4, "string"] leftover'
result = lenient_parser.parse(input_str)
expected = [{"a": 1}, {"b": [2, 3]}, 4, "string"]
assert result == expected, "Should parse nested arrays/objects correctly."
assert lenient_parser.last_parse_reminding.strip() == "leftover"
def test_multiple_parse_calls(strict_parser):
"""
Test calling parse() multiple times to ensure leftover is reset properly.
"""
input_1 = '{"x":1} trailing1'
input_2 = "[2,3] trailing2"
# First parse
result_1 = strict_parser.parse(input_1)
assert result_1 == {"x": 1}
assert strict_parser.last_parse_reminding.strip() == "trailing1"
# Second parse
result_2 = strict_parser.parse(input_2)
assert result_2 == [2, 3]
assert strict_parser.last_parse_reminding.strip() == "trailing2"

View File

@@ -914,7 +914,7 @@ def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, b
# create agent
agent_state = server.create_agent(
request=CreateAgent(
name="memory_rebuild_test_agent",
name="test_memory_rebuild_count",
tool_ids=[t.id for t in base_tools + base_memory_tools],
memory_blocks=[
CreateBlock(label="human", value="The human's name is Bob."),
@@ -952,18 +952,11 @@ def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, b
num_system_messages, all_messages = count_system_messages_in_recall()
assert num_system_messages == 1, (num_system_messages, all_messages)
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
# At this stage, there should be 2 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()
assert num_system_messages == 2, (num_system_messages, all_messages)
# Run server.load_agent, and make sure that the number of system messages is still 2
server.load_agent(agent_id=agent_state.id, actor=actor)
num_system_messages, all_messages = count_system_messages_in_recall()
assert num_system_messages == 2, (num_system_messages, all_messages)
assert num_system_messages == 1, (num_system_messages, all_messages)
finally:
# cleanup