feat: Add voice-compatible chat completions endpoint (#774)
This commit is contained in:
@@ -12,7 +12,7 @@ from starlette.middleware.base import BaseHTTPMiddleware
|
||||
from starlette.middleware.cors import CORSMiddleware
|
||||
|
||||
from letta.__init__ import __version__
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
@@ -22,6 +22,7 @@ from letta.server.constants import REST_DEFAULT_PORT
|
||||
# NOTE(charles): these are extra routes that are not part of v1 but we still need to mount to pass tests
|
||||
from letta.server.rest_api.auth.index import setup_auth_router # TODO: probably remove right?
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
from letta.server.rest_api.routers.openai.chat_completions.chat_completions import router as openai_chat_completions_router
|
||||
|
||||
# from letta.orm.utilities import get_db_session # TODO(ethan) reenable once we merge ORM
|
||||
from letta.server.rest_api.routers.v1 import ROUTERS as v1_routes
|
||||
@@ -241,6 +242,9 @@ def create_application() -> "FastAPI":
|
||||
app.include_router(users_router, prefix=ADMIN_PREFIX)
|
||||
app.include_router(organizations_router, prefix=ADMIN_PREFIX)
|
||||
|
||||
# openai
|
||||
app.include_router(openai_chat_completions_router, prefix=OPENAI_API_PREFIX)
|
||||
|
||||
# /api/auth endpoints
|
||||
app.include_router(setup_auth_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
|
||||
256
letta/server/rest_api/chat_completions_interface.py
Normal file
256
letta/server/rest_api/chat_completions_interface.py
Normal file
@@ -0,0 +1,256 @@
|
||||
import asyncio
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Optional, Union
|
||||
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
|
||||
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
|
||||
from letta.streaming_interface import AgentChunkStreamingInterface
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
|
||||
"""
|
||||
Provides an asynchronous streaming mechanism for LLM output. Internally
|
||||
maintains a queue of chunks that can be consumed via an async generator.
|
||||
|
||||
Key Behaviors:
|
||||
- process_chunk: Accepts ChatCompletionChunkResponse objects (e.g. from an
|
||||
OpenAI-like streaming API), potentially transforms them to a partial
|
||||
text response, and enqueues them.
|
||||
- get_generator: Returns an async generator that yields messages or status
|
||||
markers as they become available.
|
||||
- step_complete, step_yield: End streaming for the current step or entirely,
|
||||
depending on the multi_step setting.
|
||||
- function_message, internal_monologue: Handle LLM “function calls” and
|
||||
“reasoning” messages for non-streaming contexts.
|
||||
"""
|
||||
|
||||
FINISH_REASON_STR = "stop"
|
||||
ASSISTANT_STR = "assistant"
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
multi_step: bool = True,
|
||||
timeout: int = 150,
|
||||
# The following are placeholders for potential expansions; they
|
||||
# remain if you need to differentiate between actual "assistant messages"
|
||||
# vs. tool calls. By default, they are set for the "send_message" tool usage.
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
inner_thoughts_in_kwargs: bool = True,
|
||||
inner_thoughts_kwarg: str = INNER_THOUGHTS_KWARG,
|
||||
):
|
||||
self.streaming_mode = True
|
||||
|
||||
# Parsing state for incremental function-call data
|
||||
self.current_function_name = ""
|
||||
self.current_function_arguments = []
|
||||
|
||||
# Internal chunk buffer and event for async notification
|
||||
self._chunks = deque()
|
||||
self._event = asyncio.Event()
|
||||
self._active = True
|
||||
|
||||
# Whether or not the stream should remain open across multiple steps
|
||||
self.multi_step = multi_step
|
||||
|
||||
# Timing / debug parameters
|
||||
self.timeout = timeout
|
||||
|
||||
# These are placeholders to handle specialized
|
||||
# assistant message logic or storing inner thoughts.
|
||||
self.assistant_message_tool_name = assistant_message_tool_name
|
||||
self.assistant_message_tool_kwarg = assistant_message_tool_kwarg
|
||||
self.inner_thoughts_in_kwargs = inner_thoughts_in_kwargs
|
||||
self.inner_thoughts_kwarg = inner_thoughts_kwarg
|
||||
|
||||
async def _create_generator(
|
||||
self,
|
||||
) -> AsyncGenerator[Union[LettaMessage, MessageStreamStatus], None]:
|
||||
"""
|
||||
An asynchronous generator that yields queued items as they arrive.
|
||||
Ends when _active is set to False or when timing out.
|
||||
"""
|
||||
while self._active:
|
||||
try:
|
||||
await asyncio.wait_for(self._event.wait(), timeout=self.timeout)
|
||||
except asyncio.TimeoutError:
|
||||
break
|
||||
|
||||
while self._chunks:
|
||||
yield self._chunks.popleft()
|
||||
|
||||
self._event.clear()
|
||||
|
||||
def get_generator(self) -> AsyncGenerator:
|
||||
"""
|
||||
Provide the async generator interface. Will raise StopIteration
|
||||
if the stream is inactive.
|
||||
"""
|
||||
if not self._active:
|
||||
raise StopIteration("The stream is not active.")
|
||||
return self._create_generator()
|
||||
|
||||
def _push_to_buffer(
|
||||
self,
|
||||
item: ChatCompletionChunk,
|
||||
):
|
||||
"""
|
||||
Add an item (a LettaMessage, status marker, or partial chunk)
|
||||
to the queue and signal waiting consumers.
|
||||
"""
|
||||
if not self._active:
|
||||
raise RuntimeError("Attempted to push to an inactive stream.")
|
||||
self._chunks.append(item)
|
||||
self._event.set()
|
||||
|
||||
def stream_start(self) -> None:
|
||||
"""Initialize or reset the streaming state for a new request."""
|
||||
self._active = True
|
||||
self._chunks.clear()
|
||||
self._event.clear()
|
||||
self._reset_parsing_state()
|
||||
|
||||
def stream_end(self) -> None:
|
||||
"""
|
||||
Clean up after the current streaming session. Typically called when the
|
||||
request is done or the data source has signaled it has no more data.
|
||||
"""
|
||||
self._reset_parsing_state()
|
||||
|
||||
def step_complete(self) -> None:
|
||||
"""
|
||||
Indicate that one step of multi-step generation is done.
|
||||
If multi_step=False, the stream is closed immediately.
|
||||
"""
|
||||
if not self.multi_step:
|
||||
self._active = False
|
||||
self._event.set() # Ensure waiting generators can finalize
|
||||
self._reset_parsing_state()
|
||||
|
||||
def step_yield(self) -> None:
|
||||
"""
|
||||
Explicitly end the stream in a multi-step scenario, typically
|
||||
called when the entire chain of steps is complete.
|
||||
"""
|
||||
self._active = False
|
||||
self._event.set()
|
||||
|
||||
@staticmethod
|
||||
def clear() -> None:
|
||||
"""No-op retained for interface compatibility."""
|
||||
return
|
||||
|
||||
def process_chunk(self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime) -> None:
|
||||
"""
|
||||
Called externally with a ChatCompletionChunkResponse. Transforms
|
||||
it if necessary, then enqueues partial messages for streaming back.
|
||||
"""
|
||||
processed_chunk = self._process_chunk_to_openai_style(chunk)
|
||||
if processed_chunk is not None:
|
||||
self._push_to_buffer(processed_chunk)
|
||||
|
||||
def user_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""
|
||||
Handle user messages. Here, it's a no-op, but included if your
|
||||
pipeline needs to respond to user messages distinctly.
|
||||
"""
|
||||
return
|
||||
|
||||
def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""
|
||||
Handle LLM reasoning or internal monologue. Example usage: if you want
|
||||
to capture chain-of-thought for debugging in a non-streaming scenario.
|
||||
"""
|
||||
return
|
||||
|
||||
def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""
|
||||
Handle direct assistant messages. This class primarily handles them
|
||||
as function calls, so it's a no-op by default.
|
||||
"""
|
||||
return
|
||||
|
||||
def function_message(self, msg: str, msg_obj: Optional[Message] = None) -> None:
|
||||
"""
|
||||
Handle function-related log messages, typically of the form:
|
||||
It's a no-op by default.
|
||||
"""
|
||||
return
|
||||
|
||||
def _process_chunk_to_openai_style(self, chunk: ChatCompletionChunkResponse) -> Optional[ChatCompletionChunk]:
|
||||
"""
|
||||
Optionally transform an inbound OpenAI-style chunk so that partial
|
||||
content (especially from a 'send_message' tool) is exposed as text
|
||||
deltas in 'content'. Otherwise, pass through or yield finish reasons.
|
||||
"""
|
||||
choice = chunk.choices[0]
|
||||
delta = choice.delta
|
||||
|
||||
# If there's direct content, we usually let it stream as-is
|
||||
if delta.content is not None:
|
||||
# TODO: Eventually use all of the native OpenAI objects
|
||||
return ChatCompletionChunk(**chunk.model_dump(exclude_none=True))
|
||||
|
||||
# If there's a function call, accumulate its name/args. If it's a known
|
||||
# text-producing function (like send_message), stream partial text.
|
||||
if delta.tool_calls:
|
||||
tool_call = delta.tool_calls[0]
|
||||
if tool_call.function.name:
|
||||
self.current_function_name += tool_call.function.name
|
||||
if tool_call.function.arguments:
|
||||
self.current_function_arguments.append(tool_call.function.arguments)
|
||||
|
||||
# Only parse arguments for "send_message" to stream partial text
|
||||
if self.current_function_name.strip() == self.assistant_message_tool_name:
|
||||
combined_args = "".join(self.current_function_arguments)
|
||||
parsed_args = OptimisticJSONParser().parse(combined_args)
|
||||
|
||||
# If we can see a "message" field, return it as partial content
|
||||
if self.assistant_message_tool_kwarg in parsed_args and parsed_args[self.assistant_message_tool_kwarg]:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(content=self.current_function_arguments[-1], role=self.ASSISTANT_STR),
|
||||
finish_reason=None,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
# If there's a finish reason, pass that along
|
||||
if choice.finish_reason is not None:
|
||||
return ChatCompletionChunk(
|
||||
id=chunk.id,
|
||||
object=chunk.object,
|
||||
created=chunk.created.timestamp(),
|
||||
model=chunk.model,
|
||||
choices=[
|
||||
Choice(
|
||||
index=choice.index,
|
||||
delta=ChoiceDelta(),
|
||||
finish_reason=self.FINISH_REASON_STR,
|
||||
)
|
||||
],
|
||||
)
|
||||
|
||||
return None
|
||||
|
||||
def _reset_parsing_state(self) -> None:
|
||||
"""Clears internal buffers for function call name/args."""
|
||||
self.current_function_name = ""
|
||||
self.current_function_arguments = []
|
||||
185
letta/server/rest_api/optimistic_json_parser.py
Normal file
185
letta/server/rest_api/optimistic_json_parser.py
Normal file
@@ -0,0 +1,185 @@
|
||||
import json
|
||||
|
||||
|
||||
class OptimisticJSONParser:
|
||||
"""
|
||||
A JSON parser that attempts to parse a given string using `json.loads`,
|
||||
and if that fails, it parses as much valid JSON as possible while
|
||||
allowing extra tokens to remain. Those extra tokens can be retrieved
|
||||
from `self.last_parse_reminding`. If `strict` is False, the parser
|
||||
tries to tolerate incomplete strings and incomplete numbers.
|
||||
"""
|
||||
|
||||
def __init__(self, strict=True):
|
||||
self.strict = strict
|
||||
self.parsers = {
|
||||
" ": self.parse_space,
|
||||
"\r": self.parse_space,
|
||||
"\n": self.parse_space,
|
||||
"\t": self.parse_space,
|
||||
"[": self.parse_array,
|
||||
"{": self.parse_object,
|
||||
'"': self.parse_string,
|
||||
"t": self.parse_true,
|
||||
"f": self.parse_false,
|
||||
"n": self.parse_null,
|
||||
}
|
||||
# Register number parser for digits and signs
|
||||
for char in "0123456789.-":
|
||||
self.parsers[char] = self.parse_number
|
||||
|
||||
self.last_parse_reminding = None
|
||||
self.on_extra_token = self.default_on_extra_token
|
||||
|
||||
def default_on_extra_token(self, text, data, reminding):
|
||||
pass
|
||||
|
||||
def parse(self, input_str):
|
||||
"""
|
||||
Try to parse the entire `input_str` as JSON. If parsing fails,
|
||||
attempts a partial parse, storing leftover text in
|
||||
`self.last_parse_reminding`. A callback (`on_extra_token`) is
|
||||
triggered if extra tokens remain.
|
||||
"""
|
||||
if len(input_str) >= 1:
|
||||
try:
|
||||
return json.loads(input_str)
|
||||
except json.JSONDecodeError as decode_error:
|
||||
data, reminding = self.parse_any(input_str, decode_error)
|
||||
self.last_parse_reminding = reminding
|
||||
if self.on_extra_token and reminding:
|
||||
self.on_extra_token(input_str, data, reminding)
|
||||
return data
|
||||
else:
|
||||
return json.loads("{}")
|
||||
|
||||
def parse_any(self, input_str, decode_error):
|
||||
"""Determine which parser to use based on the first character."""
|
||||
if not input_str:
|
||||
raise decode_error
|
||||
parser = self.parsers.get(input_str[0])
|
||||
if parser is None:
|
||||
raise decode_error
|
||||
return parser(input_str, decode_error)
|
||||
|
||||
def parse_space(self, input_str, decode_error):
|
||||
"""Strip leading whitespace and parse again."""
|
||||
return self.parse_any(input_str.strip(), decode_error)
|
||||
|
||||
def parse_array(self, input_str, decode_error):
|
||||
"""Parse a JSON array, returning the list and remaining string."""
|
||||
# Skip the '['
|
||||
input_str = input_str[1:]
|
||||
array_values = []
|
||||
input_str = input_str.strip()
|
||||
while input_str:
|
||||
if input_str[0] == "]":
|
||||
# Skip the ']'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
array_values.append(value)
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
# Skip the ','
|
||||
input_str = input_str[1:].strip()
|
||||
return array_values, input_str
|
||||
|
||||
def parse_object(self, input_str, decode_error):
|
||||
"""Parse a JSON object, returning the dict and remaining string."""
|
||||
# Skip the '{'
|
||||
input_str = input_str[1:]
|
||||
obj = {}
|
||||
input_str = input_str.strip()
|
||||
while input_str:
|
||||
if input_str[0] == "}":
|
||||
# Skip the '}'
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
key, input_str = self.parse_any(input_str, decode_error)
|
||||
input_str = input_str.strip()
|
||||
|
||||
if not input_str or input_str[0] == "}":
|
||||
obj[key] = None
|
||||
break
|
||||
if input_str[0] != ":":
|
||||
raise decode_error
|
||||
|
||||
# Skip ':'
|
||||
input_str = input_str[1:].strip()
|
||||
if not input_str or input_str[0] in ",}":
|
||||
obj[key] = None
|
||||
if input_str.startswith(","):
|
||||
input_str = input_str[1:]
|
||||
break
|
||||
|
||||
value, input_str = self.parse_any(input_str, decode_error)
|
||||
obj[key] = value
|
||||
input_str = input_str.strip()
|
||||
if input_str.startswith(","):
|
||||
# Skip the ','
|
||||
input_str = input_str[1:].strip()
|
||||
return obj, input_str
|
||||
|
||||
def parse_string(self, input_str, decode_error):
|
||||
"""Parse a JSON string, respecting escaped quotes if present."""
|
||||
end = input_str.find('"', 1)
|
||||
while end != -1 and input_str[end - 1] == "\\":
|
||||
end = input_str.find('"', end + 1)
|
||||
|
||||
if end == -1:
|
||||
# Incomplete string
|
||||
if not self.strict:
|
||||
return input_str[1:], ""
|
||||
return json.loads(f'"{input_str[1:]}"'), ""
|
||||
|
||||
str_val = input_str[: end + 1]
|
||||
input_str = input_str[end + 1 :]
|
||||
if not self.strict:
|
||||
return str_val[1:-1], input_str
|
||||
return json.loads(str_val), input_str
|
||||
|
||||
def parse_number(self, input_str, decode_error):
|
||||
"""
|
||||
Parse a number (int or float). Allows digits, '.', '-', but
|
||||
doesn't fully validate complex exponents unless they appear
|
||||
before a non-number character.
|
||||
"""
|
||||
idx = 0
|
||||
while idx < len(input_str) and input_str[idx] in "0123456789.-":
|
||||
idx += 1
|
||||
|
||||
num_str = input_str[:idx]
|
||||
remainder = input_str[idx:]
|
||||
|
||||
# If it's only a sign or just '.', return as-is with empty remainder
|
||||
if not num_str or num_str in {"-", "."}:
|
||||
return num_str, ""
|
||||
|
||||
try:
|
||||
if num_str.endswith("."):
|
||||
num = int(num_str[:-1])
|
||||
else:
|
||||
num = float(num_str) if any(c in num_str for c in ".eE") else int(num_str)
|
||||
except ValueError:
|
||||
raise decode_error
|
||||
|
||||
return num, remainder
|
||||
|
||||
def parse_true(self, input_str, decode_error):
|
||||
"""Parse a 'true' value."""
|
||||
if input_str.startswith(("t", "T")):
|
||||
return True, input_str[4:]
|
||||
raise decode_error
|
||||
|
||||
def parse_false(self, input_str, decode_error):
|
||||
"""Parse a 'false' value."""
|
||||
if input_str.startswith(("f", "F")):
|
||||
return False, input_str[5:]
|
||||
raise decode_error
|
||||
|
||||
def parse_null(self, input_str, decode_error):
|
||||
"""Parse a 'null' value."""
|
||||
if input_str.startswith("n"):
|
||||
return None, input_str[4:]
|
||||
raise decode_error
|
||||
@@ -0,0 +1,161 @@
|
||||
import asyncio
|
||||
from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast
|
||||
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException
|
||||
from fastapi.responses import StreamingResponse
|
||||
from openai.types.chat import ChatCompletionMessageParam
|
||||
from openai.types.chat.completion_create_params import CompletionCreateParams
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import Message
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface
|
||||
|
||||
# TODO this belongs in a controller!
|
||||
from letta.server.rest_api.utils import get_letta_server, sse_async_generator
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.server.server import SyncServer
|
||||
|
||||
router = APIRouter(prefix="/v1", tags=["chat_completions"])
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
@router.post(
|
||||
"/chat/completions",
|
||||
response_model=None,
|
||||
operation_id="create_chat_completions",
|
||||
responses={
|
||||
200: {
|
||||
"description": "Successful response",
|
||||
"content": {
|
||||
"text/event-stream": {"description": "Server-Sent Events stream"},
|
||||
},
|
||||
}
|
||||
},
|
||||
)
|
||||
async def create_chat_completions(
|
||||
completion_request: CompletionCreateParams = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
# Validate and process fields
|
||||
try:
|
||||
messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"]))
|
||||
except KeyError:
|
||||
# Handle the case where "messages" is not present in the request
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.")
|
||||
except TypeError:
|
||||
# Handle the case where "messages" is not iterable
|
||||
raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.")
|
||||
except Exception as e:
|
||||
# Catch any other unexpected errors and include the exception message
|
||||
raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}")
|
||||
|
||||
if messages[-1]["role"] != "user":
|
||||
logger.error(f"The last message does not have a `user` role: {messages}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'")
|
||||
|
||||
input_message = messages[-1]
|
||||
if not isinstance(input_message["content"], str):
|
||||
logger.error(f"The input message does not have valid content: {input_message}")
|
||||
raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'")
|
||||
|
||||
# Process remaining fields
|
||||
if not completion_request["stream"]:
|
||||
raise HTTPException(status_code=400, detail="Must be streaming request: `stream` was set to `False` in the request.")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_id = str(completion_request.get("user", None))
|
||||
if agent_id is None:
|
||||
error_msg = "Must pass agent_id in the 'user' field"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
letta_agent = server.load_agent(agent_id=agent_id, actor=actor)
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
if llm_config.model_endpoint_type != "openai" or "inference.memgpt.ai" in llm_config.model_endpoint:
|
||||
error_msg = f"You can only use models with type 'openai' for chat completions. This agent {agent_id} has llm_config: \n{llm_config.model_dump_json(indent=4)}"
|
||||
logger.error(error_msg)
|
||||
raise HTTPException(status_code=400, detail=error_msg)
|
||||
|
||||
model = completion_request.get("model")
|
||||
if model != llm_config.model:
|
||||
warning_msg = f"The requested model {model} is different from the model specified in this agent's ({agent_id}) llm_config: \n{llm_config.model_dump_json(indent=4)}"
|
||||
logger.warning(f"Defaulting to {llm_config.model}...")
|
||||
logger.warning(warning_msg)
|
||||
|
||||
logger.info(f"Received input message: {input_message}")
|
||||
|
||||
return await send_message_to_agent_chat_completions(
|
||||
server=server,
|
||||
letta_agent=letta_agent,
|
||||
actor=actor,
|
||||
messages=[MessageCreate(role=input_message["role"], content=input_message["content"])],
|
||||
)
|
||||
|
||||
|
||||
async def send_message_to_agent_chat_completions(
|
||||
server: "SyncServer",
|
||||
letta_agent: Agent,
|
||||
actor: User,
|
||||
messages: Union[List[Message], List[MessageCreate]],
|
||||
assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
) -> StreamingResponse:
|
||||
"""Split off into a separate function so that it can be imported in the /chat/completion proxy."""
|
||||
# For streaming response
|
||||
try:
|
||||
# TODO: cleanup this logic
|
||||
llm_config = letta_agent.agent_state.llm_config
|
||||
|
||||
# Create a new interface per request
|
||||
letta_agent.interface = ChatCompletionsStreamingInterface()
|
||||
streaming_interface = letta_agent.interface
|
||||
if not isinstance(streaming_interface, ChatCompletionsStreamingInterface):
|
||||
raise ValueError(f"Agent has wrong type of interface: {type(streaming_interface)}")
|
||||
|
||||
# Allow AssistantMessage is desired by client
|
||||
streaming_interface.assistant_message_tool_name = assistant_message_tool_name
|
||||
streaming_interface.assistant_message_tool_kwarg = assistant_message_tool_kwarg
|
||||
|
||||
# Related to JSON buffer reader
|
||||
streaming_interface.inner_thoughts_in_kwargs = (
|
||||
llm_config.put_inner_thoughts_in_kwargs if llm_config.put_inner_thoughts_in_kwargs is not None else False
|
||||
)
|
||||
|
||||
# Offload the synchronous message_func to a separate thread
|
||||
streaming_interface.stream_start()
|
||||
asyncio.create_task(
|
||||
asyncio.to_thread(
|
||||
server.send_messages,
|
||||
actor=actor,
|
||||
agent_id=letta_agent.agent_state.id,
|
||||
messages=messages,
|
||||
interface=streaming_interface,
|
||||
)
|
||||
)
|
||||
|
||||
# return a stream
|
||||
return StreamingResponse(
|
||||
sse_async_generator(
|
||||
streaming_interface.get_generator(),
|
||||
usage_task=None,
|
||||
finish_message=True,
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
print(e)
|
||||
import traceback
|
||||
|
||||
traceback.print_exc()
|
||||
raise HTTPException(status_code=500, detail=f"{e}")
|
||||
Reference in New Issue
Block a user