diff --git a/letta/agent.py b/letta/agent.py index 15d10de2..61d69343 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -750,7 +750,6 @@ class Agent(BaseAgent): put_inner_thoughts_first: bool = True, ) -> AgentStepResponse: """Runs a single step in the agent loop (generates at most one LLM call)""" - try: # Extract job_id from metadata if present diff --git a/letta/server/rest_api/chat_completions_interface.py b/letta/server/rest_api/chat_completions_interface.py index 753f6460..6a1190a6 100644 --- a/letta/server/rest_api/chat_completions_interface.py +++ b/letta/server/rest_api/chat_completions_interface.py @@ -161,10 +161,7 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface): Called externally with a ChatCompletionChunkResponse. Transforms it if necessary, then enqueues partial messages for streaming back. """ - # print(chunk) processed_chunk = self._process_chunk_to_openai_style(chunk) - # print(processed_chunk) - # print("++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++++") if processed_chunk is not None: self._push_to_buffer(processed_chunk) diff --git a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py index 053ba004..13fd2347 100644 --- a/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py +++ b/letta/server/rest_api/routers/openai/chat_completions/chat_completions.py @@ -1,21 +1,30 @@ import asyncio -from typing import TYPE_CHECKING, Iterable, List, Optional, Union, cast +from typing import TYPE_CHECKING, List, Optional, Union +import httpx +import openai from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi.responses import StreamingResponse -from openai.types.chat import ChatCompletionMessageParam from openai.types.chat.completion_create_params import CompletionCreateParams +from starlette.concurrency import run_in_threadpool from letta.agent import Agent from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.log import get_logger -from letta.schemas.message import MessageCreate -from letta.schemas.openai.chat_completion_response import Message +from letta.schemas.message import Message, MessageCreate from letta.schemas.user import User from letta.server.rest_api.chat_completions_interface import ChatCompletionsStreamingInterface # TODO this belongs in a controller! -from letta.server.rest_api.utils import get_letta_server, sse_async_generator +from letta.server.rest_api.utils import ( + convert_letta_messages_to_openai, + create_assistant_message_from_openai_response, + create_user_message, + get_letta_server, + get_messages_from_completion_request, + sse_async_generator, +) +from letta.settings import model_settings if TYPE_CHECKING: from letta.server.server import SyncServer @@ -25,6 +34,88 @@ router = APIRouter(prefix="/v1", tags=["chat_completions"]) logger = get_logger(__name__) +@router.post( + "/fast/chat/completions", + response_model=None, + operation_id="create_fast_chat_completions", + responses={ + 200: { + "description": "Successful response", + "content": { + "text/event-stream": {"description": "Server-Sent Events stream"}, + }, + } + }, +) +async def create_fast_chat_completions( + completion_request: CompletionCreateParams = Body(...), + server: "SyncServer" = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), +): + # TODO: This is necessary, we need to factor out CompletionCreateParams due to weird behavior + agent_id = str(completion_request.get("user", None)) + if agent_id is None: + error_msg = "Must pass agent_id in the 'user' field" + logger.error(error_msg) + raise HTTPException(status_code=400, detail=error_msg) + model = completion_request.get("model") + + actor = server.user_manager.get_user_or_default(user_id=user_id) + client = openai.AsyncClient( + api_key=model_settings.openai_api_key, + max_retries=0, + http_client=httpx.AsyncClient( + timeout=httpx.Timeout(connect=15.0, read=5.0, write=5.0, pool=5.0), + follow_redirects=True, + limits=httpx.Limits( + max_connections=50, + max_keepalive_connections=50, + keepalive_expiry=120, + ), + ), + ) + + # Magic message manipulating + input_message = get_messages_from_completion_request(completion_request)[-1] + completion_request.pop("messages") + + # Get in context messages + in_context_messages = server.agent_manager.get_in_context_messages(agent_id=agent_id, actor=actor) + openai_dict_in_context_messages = convert_letta_messages_to_openai(in_context_messages) + openai_dict_in_context_messages.append(input_message) + + async def event_stream(): + # TODO: Factor this out into separate interface + response_accumulator = [] + + stream = await client.chat.completions.create(**completion_request, messages=openai_dict_in_context_messages) + + async with stream: + async for chunk in stream: + if chunk.choices and chunk.choices[0].delta.content: + # TODO: This does not support tool calling right now + response_accumulator.append(chunk.choices[0].delta.content) + yield f"data: {chunk.model_dump_json()}\n\n" + + # Construct messages + user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) + assistant_message = create_assistant_message_from_openai_response( + response_text="".join(response_accumulator), agent_id=agent_id, model=str(model), actor=actor + ) + + # Persist both in one synchronous DB call, done in a threadpool + await run_in_threadpool( + server.agent_manager.append_to_in_context_messages, + [user_message, assistant_message], + agent_id=agent_id, + actor=actor, + ) + + yield "data: [DONE]\n\n" + + return StreamingResponse(event_stream(), media_type="text/event-stream") + + @router.post( "/chat/completions", response_model=None, @@ -44,26 +135,8 @@ async def create_chat_completions( user_id: Optional[str] = Header(None, alias="user_id"), ): # Validate and process fields - try: - messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"])) - except KeyError: - # Handle the case where "messages" is not present in the request - raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.") - except TypeError: - # Handle the case where "messages" is not iterable - raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.") - except Exception as e: - # Catch any other unexpected errors and include the exception message - raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}") - - if messages[-1]["role"] != "user": - logger.error(f"The last message does not have a `user` role: {messages}") - raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'") - + messages = get_messages_from_completion_request(completion_request) input_message = messages[-1] - if not isinstance(input_message["content"], str): - logger.error(f"The input message does not have valid content: {input_message}") - raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'") # Process remaining fields if not completion_request["stream"]: diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 6a3ee583..d5bf4520 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -1,23 +1,33 @@ import asyncio import json import os +import uuid import warnings +from datetime import datetime, timezone from enum import Enum -from typing import TYPE_CHECKING, AsyncGenerator, Optional, Union +from typing import TYPE_CHECKING, AsyncGenerator, Dict, Iterable, List, Optional, Union, cast -from fastapi import Header +import pytz +from fastapi import Header, HTTPException +from openai.types.chat import ChatCompletionMessageParam +from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMessageToolCall as OpenAIToolCall +from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction +from openai.types.chat.completion_create_params import CompletionCreateParams from pydantic import BaseModel +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.log import get_logger +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import TextContent +from letta.schemas.message import Message from letta.schemas.usage import LettaUsageStatistics +from letta.schemas.user import User from letta.server.rest_api.interface import StreamingServerInterface if TYPE_CHECKING: from letta.server.server import SyncServer -# from letta.orm.user import User -# from letta.orm.utilities import get_db_session SSE_PREFIX = "data: " SSE_SUFFIX = "\n\n" @@ -128,3 +138,172 @@ def log_error_to_sentry(e): import sentry_sdk sentry_sdk.capture_exception(e) + + +def create_user_message(input_message: dict, agent_id: str, actor: User) -> Message: + """ + Converts a user input message into the internal structured format. + """ + # Generate timestamp in the correct format + now = datetime.now(pytz.timezone("US/Pacific")).strftime("%Y-%m-%d %I:%M:%S %p %Z%z") + + # Format message as structured JSON + structured_message = {"type": "user_message", "message": input_message["content"], "time": now} + + # Construct the Message object + user_message = Message( + id=f"message-{uuid.uuid4()}", + role=MessageRole.user, + content=[TextContent(text=json.dumps(structured_message, indent=2))], # Store structured JSON + organization_id=actor.organization_id, + agent_id=agent_id, + model=None, + tool_calls=None, + tool_call_id=None, + created_at=datetime.now(timezone.utc), + ) + + return user_message + + +def create_assistant_message_from_openai_response( + response_text: str, + agent_id: str, + model: str, + actor: User, +) -> Message: + """ + Converts an OpenAI response into a Message that follows the internal + paradigm where LLM responses are structured as tool calls instead of content. + """ + tool_call_id = str(uuid.uuid4()) + + # Construct the tool call with the assistant's message + tool_call = OpenAIToolCall( + id=tool_call_id, + function=OpenAIFunction( + name=DEFAULT_MESSAGE_TOOL, + arguments='{\n "' + DEFAULT_MESSAGE_TOOL_KWARG + '": ' + f'"{response_text}",\n "request_heartbeat": true\n' + "}", + ), + type="function", + ) + + # Construct the Message object + assistant_message = Message( + id=f"message-{uuid.uuid4()}", + role=MessageRole.assistant, + content=[], + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + tool_calls=[tool_call], + tool_call_id=None, + created_at=datetime.now(timezone.utc), + ) + + return assistant_message + + +def convert_letta_messages_to_openai(messages: List[Message]) -> List[dict]: + """ + Flattens Letta's messages (with system, user, assistant, tool roles, etc.) + into standard OpenAI chat messages (system, user, assistant). + + Transformation rules: + 1. Assistant + send_message tool_call => content = tool_call's "message" + 2. Tool (role=tool) referencing send_message => skip + 3. User messages might store actual text inside JSON => parse that into content + 4. System => pass through as normal + """ + + openai_messages = [] + + for msg in messages: + # 1. Assistant + 'send_message' tool_calls => flatten + if msg.role == MessageRole.assistant and msg.tool_calls: + # Find any 'send_message' tool_calls + send_message_calls = [tc for tc in msg.tool_calls if tc.function.name == "send_message"] + if send_message_calls: + # If we have multiple calls, just pick the first or merge them + # Typically there's only one. + tc = send_message_calls[0] + arguments = json.loads(tc.function.arguments) + # Extract the "message" string + extracted_text = arguments.get("message", "") + + # Create a new content with the extracted text + msg = Message( + id=msg.id, + role=msg.role, + content=[TextContent(text=extracted_text)], + organization_id=msg.organization_id, + agent_id=msg.agent_id, + model=msg.model, + name=msg.name, + tool_calls=None, # no longer needed + tool_call_id=None, + created_at=msg.created_at, + ) + + # 2. If role=tool and it's referencing send_message => skip + if msg.role == MessageRole.tool and msg.name == "send_message": + # Usually 'tool' messages with `send_message` are just status/OK messages + # that OpenAI doesn't need to see. So skip them. + continue + + # 3. User messages might store text in JSON => parse it + if msg.role == MessageRole.user: + # Example: content=[TextContent(text='{"type": "user_message","message":"Hello"}')] + # Attempt to parse JSON and extract "message" + if msg.content and msg.content[0].text.strip().startswith("{"): + try: + parsed = json.loads(msg.content[0].text) + # If there's a "message" field, use that as the content + if "message" in parsed: + actual_user_text = parsed["message"] + msg = Message( + id=msg.id, + role=msg.role, + content=[TextContent(text=actual_user_text)], + organization_id=msg.organization_id, + agent_id=msg.agent_id, + model=msg.model, + name=msg.name, + tool_calls=msg.tool_calls, + tool_call_id=msg.tool_call_id, + created_at=msg.created_at, + ) + except json.JSONDecodeError: + pass # It's not JSON, leave as-is + + # 4. System is left as-is (or any other role that doesn't need special handling) + # + # Finally, convert to dict using your existing method + openai_messages.append(msg.to_openai_dict()) + + return openai_messages + + +def get_messages_from_completion_request(completion_request: CompletionCreateParams) -> List[Dict]: + try: + messages = list(cast(Iterable[ChatCompletionMessageParam], completion_request["messages"])) + except KeyError: + # Handle the case where "messages" is not present in the request + raise HTTPException(status_code=400, detail="The 'messages' field is missing in the request.") + except TypeError: + # Handle the case where "messages" is not iterable + raise HTTPException(status_code=400, detail="The 'messages' field must be an iterable.") + except Exception as e: + # Catch any other unexpected errors and include the exception message + raise HTTPException(status_code=400, detail=f"An error occurred while processing 'messages': {str(e)}") + + if messages[-1]["role"] != "user": + logger.error(f"The last message does not have a `user` role: {messages}") + raise HTTPException(status_code=400, detail="'messages[-1].role' must be a 'user'") + + input_message = messages[-1] + if not isinstance(input_message["content"], str): + logger.error(f"The input message does not have valid content: {input_message}") + raise HTTPException(status_code=400, detail="'messages[-1].content' must be a 'string'") + + return messages diff --git a/tests/integration_test_chat_completions.py b/tests/integration_test_chat_completions.py index ad48075c..4ab3b1d8 100644 --- a/tests/integration_test_chat_completions.py +++ b/tests/integration_test_chat_completions.py @@ -112,13 +112,12 @@ def _assert_valid_chunk(chunk, idx, chunks): @pytest.mark.parametrize("message", ["Tell me something interesting about bananas."]) -def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message): +@pytest.mark.parametrize("endpoint", ["chat/completions", "fast/chat/completions"]) +def test_chat_completions_streaming(mock_e2b_api_key_none, client, agent, message, endpoint): """Tests chat completion streaming via SSE.""" request = _get_chat_request(agent.id, message) - response = _sse_post( - f"{client.base_url}/openai/{client.api_prefix}/chat/completions", request.model_dump(exclude_none=True), client.headers - ) + response = _sse_post(f"{client.base_url}/openai/{client.api_prefix}/{endpoint}", request.model_dump(exclude_none=True), client.headers) try: chunks = list(response)