From aeae2e2cfbfff5090696ab5b8dc130c08e349036 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 27 Feb 2025 14:51:48 -0800 Subject: [PATCH] feat: Low Latency Agent (#1157) --- letta/interfaces/__init__.py | 0 ...ai_chat_completions_streaming_interface.py | 109 +++++++ letta/interfaces/utils.py | 11 + letta/low_latency_agent.py | 286 ++++++++++++++++++ letta/server/rest_api/routers/v1/voice.py | 270 ++--------------- letta/server/rest_api/utils.py | 5 +- letta/services/agent_manager.py | 1 + tests/integration_test_chat_completions.py | 5 +- 8 files changed, 429 insertions(+), 258 deletions(-) create mode 100644 letta/interfaces/__init__.py create mode 100644 letta/interfaces/openai_chat_completions_streaming_interface.py create mode 100644 letta/interfaces/utils.py create mode 100644 letta/low_latency_agent.py diff --git a/letta/interfaces/__init__.py b/letta/interfaces/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/letta/interfaces/openai_chat_completions_streaming_interface.py b/letta/interfaces/openai_chat_completions_streaming_interface.py new file mode 100644 index 00000000..6e45667d --- /dev/null +++ b/letta/interfaces/openai_chat_completions_streaming_interface.py @@ -0,0 +1,109 @@ +from typing import Any, AsyncGenerator, Dict, List, Optional + +from openai import AsyncStream +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta + +from letta.constants import PRE_EXECUTION_MESSAGE_ARG +from letta.interfaces.utils import _format_sse_chunk +from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser + + +class OpenAIChatCompletionsStreamingInterface: + """ + Encapsulates the logic for streaming responses from OpenAI. + This class handles parsing of partial tokens, pre-execution messages, + and detection of tool call events. + """ + + def __init__(self, stream_pre_execution_message: bool = True): + self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.stream_pre_execution_message: bool = stream_pre_execution_message + + self.current_parsed_json_result: Dict[str, Any] = {} + self.content_buffer: List[str] = [] + self.tool_call_happened: bool = False + self.finish_reason_stop: bool = False + + self.tool_call_name: Optional[str] = None + self.tool_call_args_str: str = "" + self.tool_call_id: Optional[str] = None + + async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[str, None]: + """ + Iterates over the OpenAI stream, yielding SSE events. + It also collects tokens and detects if a tool call is triggered. + """ + async with stream: + async for chunk in stream: + choice = chunk.choices[0] + delta = choice.delta + finish_reason = choice.finish_reason + + async for sse_chunk in self._process_content(delta, chunk): + yield sse_chunk + + async for sse_chunk in self._process_tool_calls(delta, chunk): + yield sse_chunk + + if self._handle_finish_reason(finish_reason): + break + + async def _process_content(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]: + """Processes regular content tokens and streams them.""" + if delta.content: + self.content_buffer.append(delta.content) + yield _format_sse_chunk(chunk) + + async def _process_tool_calls(self, delta: ChoiceDelta, chunk: ChatCompletionChunk) -> AsyncGenerator[str, None]: + """Handles tool call initiation and streaming of pre-execution messages.""" + if not delta.tool_calls: + return + + tool_call = delta.tool_calls[0] + self._update_tool_call_info(tool_call) + + if self.stream_pre_execution_message and tool_call.function.arguments: + self.tool_call_args_str += tool_call.function.arguments + async for sse_chunk in self._stream_pre_execution_message(chunk, tool_call): + yield sse_chunk + + def _update_tool_call_info(self, tool_call: Any) -> None: + """Updates tool call-related attributes.""" + if tool_call.function.name: + self.tool_call_name = tool_call.function.name + if tool_call.id: + self.tool_call_id = tool_call.id + + async def _stream_pre_execution_message(self, chunk: ChatCompletionChunk, tool_call: Any) -> AsyncGenerator[str, None]: + """Parses and streams pre-execution messages if they have changed.""" + parsed_args = self.optimistic_json_parser.parse(self.tool_call_args_str) + + if parsed_args.get(PRE_EXECUTION_MESSAGE_ARG) and self.current_parsed_json_result.get(PRE_EXECUTION_MESSAGE_ARG) != parsed_args.get( + PRE_EXECUTION_MESSAGE_ARG + ): + if parsed_args != self.current_parsed_json_result: + self.current_parsed_json_result = parsed_args + synthetic_chunk = ChatCompletionChunk( + id=chunk.id, + object=chunk.object, + created=chunk.created, + model=chunk.model, + choices=[ + Choice( + index=0, + delta=ChoiceDelta(content=tool_call.function.arguments, role="assistant"), + finish_reason=None, + ) + ], + ) + yield _format_sse_chunk(synthetic_chunk) + + def _handle_finish_reason(self, finish_reason: Optional[str]) -> bool: + """Handles the finish reason and determines if streaming should stop.""" + if finish_reason == "tool_calls": + self.tool_call_happened = True + return True + if finish_reason == "stop": + self.finish_reason_stop = True + return True + return False diff --git a/letta/interfaces/utils.py b/letta/interfaces/utils.py new file mode 100644 index 00000000..4fa34327 --- /dev/null +++ b/letta/interfaces/utils.py @@ -0,0 +1,11 @@ +import json + +from openai.types.chat import ChatCompletionChunk + + +def _format_sse_error(error_payload: dict) -> str: + return f"data: {json.dumps(error_payload)}\n\n" + + +def _format_sse_chunk(chunk: ChatCompletionChunk) -> str: + return f"data: {chunk.model_dump_json()}\n\n" diff --git a/letta/low_latency_agent.py b/letta/low_latency_agent.py new file mode 100644 index 00000000..4b7e5c82 --- /dev/null +++ b/letta/low_latency_agent.py @@ -0,0 +1,286 @@ +import json +import uuid +from typing import Any, AsyncGenerator, Dict, List + +import openai +from starlette.concurrency import run_in_threadpool + +from letta.constants import NON_USER_MSG_PREFIX +from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.tool_execution_helper import ( + add_pre_execution_message, + enable_strict_mode, + execute_external_tool, + remove_request_heartbeat, +) +from letta.interfaces.openai_chat_completions_streaming_interface import OpenAIChatCompletionsStreamingInterface +from letta.log import get_logger +from letta.orm.enums import ToolType +from letta.schemas.agent import AgentState +from letta.schemas.message import Message, MessageUpdate +from letta.schemas.openai.chat_completion_request import ( + AssistantMessage, + ChatCompletionRequest, + Tool, + ToolCall, + ToolCallFunction, + ToolMessage, + UserMessage, +) +from letta.schemas.user import User +from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser +from letta.server.rest_api.utils import ( + convert_letta_messages_to_openai, + create_assistant_messages_from_openai_response, + create_tool_call_messages_from_openai_response, + create_user_message, +) +from letta.services.agent_manager import AgentManager +from letta.services.helpers.agent_manager_helper import compile_system_message +from letta.services.message_manager import MessageManager +from letta.utils import united_diff + +logger = get_logger(__name__) + + +class LowLatencyAgent: + """ + A function-calling loop for streaming OpenAI responses with tool execution. + This agent: + - Streams partial tokens in real-time for low-latency output. + - Detects tool calls and invokes external tools. + - Gracefully handles OpenAI API failures (429, etc.) and streams errors. + """ + + def __init__( + self, + agent_id: str, + openai_client: openai.AsyncClient, + message_manager: MessageManager, + agent_manager: AgentManager, + actor: User, + ): + self.agent_id = agent_id + self.openai_client = openai_client + + # DB access related fields + self.message_manager = message_manager + self.agent_manager = agent_manager + self.actor = actor + + # Internal conversation state + self.optimistic_json_parser = OptimisticJSONParser(strict=True) + self.current_parsed_json_result: Dict[str, Any] = {} + + async def step(self, input_message: Dict[str, str]) -> AsyncGenerator[str, None]: + """ + Async generator that yields partial tokens as SSE events, handles tool calls, + and streams error messages if OpenAI API failures occur. + """ + agent_state = self.agent_manager.get_agent_by_id(agent_id=self.agent_id, actor=self.actor) + letta_message_db_queue = [create_user_message(input_message=input_message, agent_id=agent_state.id, actor=self.actor)] + in_memory_message_history = [input_message] + + while True: + # Build context and request + openai_messages = self._build_context_window(in_memory_message_history, agent_state) + request = self._build_openai_request(openai_messages, agent_state) + + # Execute the request + stream = await self.openai_client.chat.completions.create(**request.model_dump(exclude_unset=True)) + streaming_interface = OpenAIChatCompletionsStreamingInterface(stream_pre_execution_message=True) + + async for sse in streaming_interface.process(stream): + yield sse + + # Process the AI response (buffered messages, tool execution, etc.) + continue_execution = await self.handle_ai_response( + streaming_interface, agent_state, in_memory_message_history, letta_message_db_queue + ) + + if not continue_execution: + break + + # Persist messages to the database asynchronously + await run_in_threadpool( + self.agent_manager.append_to_in_context_messages, + letta_message_db_queue, + agent_id=agent_state.id, + actor=self.actor, + ) + + yield "data: [DONE]\n\n" + + async def handle_ai_response( + self, + streaming_interface: OpenAIChatCompletionsStreamingInterface, + agent_state: AgentState, + in_memory_message_history: List[Dict[str, Any]], + letta_message_db_queue: List[Any], + ) -> bool: + """ + Handles AI response processing, including buffering messages, detecting tool calls, + executing tools, and deciding whether to continue execution. + + Returns: + bool: True if execution should continue, False if the step loop should terminate. + """ + # Handle assistant message buffering + if streaming_interface.content_buffer: + content = "".join(streaming_interface.content_buffer) + in_memory_message_history.append({"role": "assistant", "content": content}) + + assistant_msgs = create_assistant_messages_from_openai_response( + response_text=content, + agent_id=agent_state.id, + model=agent_state.llm_config.model, + actor=self.actor, + ) + letta_message_db_queue.extend(assistant_msgs) + + # Handle tool execution if a tool call occurred + if streaming_interface.tool_call_happened: + try: + tool_args = json.loads(streaming_interface.tool_call_args_str) + except json.JSONDecodeError: + tool_args = {} + + tool_call_id = streaming_interface.tool_call_id or f"call_{uuid.uuid4().hex[:8]}" + + assistant_tool_call_msg = AssistantMessage( + content=None, + tool_calls=[ + ToolCall( + id=tool_call_id, + function=ToolCallFunction( + name=streaming_interface.tool_call_name, + arguments=streaming_interface.tool_call_args_str, + ), + ) + ], + ) + in_memory_message_history.append(assistant_tool_call_msg.model_dump()) + + tool_result, function_call_success = await self._execute_tool( + tool_name=streaming_interface.tool_call_name, + tool_args=tool_args, + agent_state=agent_state, + ) + + tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id) + in_memory_message_history.append(tool_message.model_dump()) + + heartbeat_user_message = UserMessage( + content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user." + ) + in_memory_message_history.append(heartbeat_user_message.model_dump()) + + tool_call_messages = create_tool_call_messages_from_openai_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=streaming_interface.tool_call_name, + function_arguments=tool_args, + tool_call_id=tool_call_id, + function_call_success=function_call_success, + function_response=tool_result, + actor=self.actor, + add_heartbeat_request_system_message=True, + ) + letta_message_db_queue.extend(tool_call_messages) + + # Continue execution by restarting the loop with updated context + return True + + # Exit the loop if finish_reason_stop or no tool call occurred + return not streaming_interface.finish_reason_stop + + def _build_context_window(self, in_memory_message_history: List[Dict[str, Any]], agent_state: AgentState) -> List[Dict]: + # Build in_context_messages + in_context_messages = self.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=self.actor) + in_context_messages = self._rebuild_memory(in_context_messages=in_context_messages, agent_state=agent_state) + + # Convert Letta messages to OpenAI messages + openai_messages = convert_letta_messages_to_openai(in_context_messages) + openai_messages.extend(in_memory_message_history) + return openai_messages + + def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + if curr_memory_str in curr_system_message.text: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages + + memory_edit_timestamp = get_utc_time() + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + ) + + diff = united_diff(curr_system_message.text, new_system_message_str) + if len(diff) > 0: + logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + new_system_message = self.message_manager.update_message_by_id( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + + # Skip pulling down the agent's memory again to save on a db call + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages + + def _build_openai_request(self, openai_messages: List[Dict], agent_state: AgentState) -> ChatCompletionRequest: + tool_schemas = self._build_tool_schemas(agent_state) + tool_choice = "auto" if tool_schemas else None + + openai_request = ChatCompletionRequest( + model=agent_state.llm_config.model, + messages=openai_messages, + tools=self._build_tool_schemas(agent_state), + tool_choice=tool_choice, + user=self.actor.id, + max_completion_tokens=agent_state.llm_config.max_tokens, + temperature=agent_state.llm_config.temperature, + stream=True, + ) + return openai_request + + def _build_tool_schemas(self, agent_state: AgentState, external_tools_only=True) -> List[Tool]: + if external_tools_only: + tools = [t for t in agent_state.tools if t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}] + else: + tools = agent_state.tools + + # TODO: Customize whether or not to have heartbeats, pre_exec_message, etc. + return [ + Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema)))) + for t in tools + ] + + async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> (str, bool): + """ + Executes a tool and returns (result, success_flag). + """ + target_tool = next((x for x in agent_state.tools if x.name == tool_name), None) + if not target_tool: + return f"Tool not found: {tool_name}", False + + try: + tool_result, _ = execute_external_tool( + agent_state=agent_state, + function_name=tool_name, + function_args=tool_args, + target_letta_tool=target_tool, + actor=self.actor, + allow_agent_state_modifications=False, + ) + return tool_result, True + except Exception as e: + return f"Failed to call tool. Error: {e}", False diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 2e8c54f8..1ecbde00 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -1,42 +1,14 @@ -import json -import uuid from typing import TYPE_CHECKING, Optional import httpx import openai from fastapi import APIRouter, Body, Depends, Header, HTTPException from fastapi.responses import StreamingResponse -from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta from openai.types.chat.completion_create_params import CompletionCreateParams -from starlette.concurrency import run_in_threadpool -from letta.constants import LETTA_TOOL_SET, NON_USER_MSG_PREFIX, PRE_EXECUTION_MESSAGE_ARG -from letta.helpers.tool_execution_helper import ( - add_pre_execution_message, - enable_strict_mode, - execute_external_tool, - remove_request_heartbeat, -) from letta.log import get_logger -from letta.orm.enums import ToolType -from letta.schemas.openai.chat_completion_request import ( - AssistantMessage, - ChatCompletionRequest, - Tool, - ToolCall, - ToolCallFunction, - ToolMessage, - UserMessage, -) -from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser -from letta.server.rest_api.utils import ( - convert_letta_messages_to_openai, - create_assistant_messages_from_openai_response, - create_tool_call_messages_from_openai_response, - create_user_message, - get_letta_server, - get_messages_from_completion_request, -) +from letta.low_latency_agent import LowLatencyAgent +from letta.server.rest_api.utils import get_letta_server, get_messages_from_completion_request from letta.settings import model_settings if TYPE_CHECKING: @@ -72,42 +44,14 @@ async def create_voice_chat_completions( if agent_id is None: raise HTTPException(status_code=400, detail="Must pass agent_id in the 'user' field") - agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) - if agent_state.llm_config.model_endpoint_type != "openai": - raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.") + # agent_state = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + # if agent_state.llm_config.model_endpoint_type != "openai": + # raise HTTPException(status_code=400, detail="Only OpenAI models are supported by this endpoint.") - # Convert Letta messages to OpenAI messages - in_context_messages = server.message_manager.get_messages_by_ids(message_ids=agent_state.message_ids, actor=actor) - openai_messages = convert_letta_messages_to_openai(in_context_messages) - - # Also parse user input from completion_request and append + # Also parse the user's new input input_message = get_messages_from_completion_request(completion_request)[-1] - openai_messages.append(input_message) - # Tools we allow this agent to call - tools = [t for t in agent_state.tools if t.name not in LETTA_TOOL_SET and t.tool_type in {ToolType.EXTERNAL_COMPOSIO, ToolType.CUSTOM}] - - # Initial request - openai_request = ChatCompletionRequest( - model=agent_state.llm_config.model, - messages=openai_messages, - # TODO: This nested thing here is so ugly, need to refactor - tools=( - [ - Tool(type="function", function=enable_strict_mode(add_pre_execution_message(remove_request_heartbeat(t.json_schema)))) - for t in tools - ] - if tools - else None - ), - tool_choice="auto", - user=user_id, - max_completion_tokens=agent_state.llm_config.max_tokens, - temperature=agent_state.llm_config.temperature, - stream=True, - ) - - # Create the OpenAI async client + # Create OpenAI async client client = openai.AsyncClient( api_key=model_settings.openai_api_key, max_retries=0, @@ -122,194 +66,14 @@ async def create_voice_chat_completions( ), ) - # The messages we want to persist to the Letta agent - user_message = create_user_message(input_message=input_message, agent_id=agent_id, actor=actor) - message_db_queue = [user_message] + # Instantiate our LowLatencyAgent + agent = LowLatencyAgent( + agent_id=agent_id, + openai_client=client, + message_manager=server.message_manager, + agent_manager=server.agent_manager, + actor=actor, + ) - async def event_stream(): - """ - A function-calling loop: - - We stream partial tokens. - - If we detect a tool call (finish_reason="tool_calls"), we parse it, - add two messages to the conversation: - (a) assistant message with tool_calls referencing the same ID - (b) a tool message referencing that ID, containing the tool result. - - Re-invoke the OpenAI request with updated conversation, streaming again. - - End when finish_reason="stop" or no more tool calls. - """ - - # We'll keep updating this conversation in a loop - conversation = openai_messages[:] - - while True: - # Make the streaming request to OpenAI - stream = await client.chat.completions.create(**openai_request.model_dump(exclude_unset=True)) - - content_buffer = [] - tool_call_name = None - tool_call_args_str = "" - tool_call_id = None - tool_call_happened = False - finish_reason_stop = False - optimistic_json_parser = OptimisticJSONParser(strict=True) - current_parsed_json_result = {} - - async with stream: - async for chunk in stream: - choice = chunk.choices[0] - delta = choice.delta - finish_reason = choice.finish_reason # "tool_calls", "stop", or None - - if delta.content: - content_buffer.append(delta.content) - yield f"data: {chunk.model_dump_json()}\n\n" - - # CASE B: Partial tool call info - if delta.tool_calls: - # Typically there's only one in delta.tool_calls - tc = delta.tool_calls[0] - if tc.function.name: - tool_call_name = tc.function.name - if tc.function.arguments: - tool_call_args_str += tc.function.arguments - - # See if we can stream out the pre-execution message - parsed_args = optimistic_json_parser.parse(tool_call_args_str) - if parsed_args.get( - PRE_EXECUTION_MESSAGE_ARG - ) and current_parsed_json_result.get( # Ensure key exists and is not None/empty - PRE_EXECUTION_MESSAGE_ARG - ) != parsed_args.get( - PRE_EXECUTION_MESSAGE_ARG - ): - # Only stream if there's something new to stream - # We do this way to avoid hanging JSON at the end of the stream, e.g. '}' - if parsed_args != current_parsed_json_result: - current_parsed_json_result = parsed_args - synthetic_chunk = ChatCompletionChunk( - id=chunk.id, - object=chunk.object, - created=chunk.created, - model=chunk.model, - choices=[ - Choice( - index=choice.index, - delta=ChoiceDelta(content=tc.function.arguments, role="assistant"), - finish_reason=None, - ) - ], - ) - - yield f"data: {synthetic_chunk.model_dump_json()}\n\n" - - # We might generate a unique ID for the tool call - if tc.id: - tool_call_id = tc.id - - # Check finish_reason - if finish_reason == "tool_calls": - tool_call_happened = True - break - elif finish_reason == "stop": - finish_reason_stop = True - break - - if content_buffer: - # We treat that partial text as an assistant message - content = "".join(content_buffer) - conversation.append({"role": "assistant", "content": content}) - - # Create an assistant message here to persist later - assistant_messages = create_assistant_messages_from_openai_response( - response_text=content, agent_id=agent_id, model=agent_state.llm_config.model, actor=actor - ) - message_db_queue.extend(assistant_messages) - - if tool_call_happened: - # Parse the tool call arguments - try: - tool_args = json.loads(tool_call_args_str) - except json.JSONDecodeError: - tool_args = {} - - if not tool_call_id: - # If no tool_call_id given by the model, generate one - tool_call_id = f"call_{uuid.uuid4().hex[:8]}" - - # 1) Insert the "assistant" message with the tool_calls field - # referencing the same tool_call_id - assistant_tool_call_msg = AssistantMessage( - content=None, - tool_calls=[ToolCall(id=tool_call_id, function=ToolCallFunction(name=tool_call_name, arguments=tool_call_args_str))], - ) - - conversation.append(assistant_tool_call_msg.model_dump()) - - # 2) Execute the tool - target_tool = next((x for x in tools if x.name == tool_call_name), None) - if not target_tool: - # Tool not found, handle error - yield f"data: {json.dumps({'error': 'Tool not found', 'tool': tool_call_name})}\n\n" - break - - try: - tool_result, _ = execute_external_tool( - agent_state=agent_state, - function_name=tool_call_name, - function_args=tool_args, - target_letta_tool=target_tool, - actor=actor, - allow_agent_state_modifications=False, - ) - function_call_success = True - except Exception as e: - tool_result = f"Failed to call tool. Error: {e}" - function_call_success = False - - # 3) Insert the "tool" message referencing the same tool_call_id - tool_message = ToolMessage(content=json.dumps({"result": tool_result}), tool_call_id=tool_call_id) - - conversation.append(tool_message.model_dump()) - - # 4) Add a user message prompting the tool call result summarization - heartbeat_user_message = UserMessage( - content=f"{NON_USER_MSG_PREFIX} Tool finished executing. Summarize the result for the user.", - ) - conversation.append(heartbeat_user_message.model_dump()) - - # Now, re-invoke OpenAI with the updated conversation - openai_request.messages = conversation - - # Create a tool call message and append to message_db_queue - tool_call_messages = create_tool_call_messages_from_openai_response( - agent_id=agent_state.id, - model=agent_state.llm_config.model, - function_name=tool_call_name, - function_arguments=tool_args, - tool_call_id=tool_call_id, - function_call_success=function_call_success, - function_response=tool_result, - actor=actor, - add_heartbeat_request_system_message=True, - ) - message_db_queue.extend(tool_call_messages) - - continue # Start the while loop again - - if finish_reason_stop: - break - - # If we reach here, no tool call, no "stop", but we've ended streaming - # Possibly a model error or some other finish reason. We'll just end. - break - - await run_in_threadpool( - server.agent_manager.append_to_in_context_messages, - message_db_queue, - agent_id=agent_id, - actor=actor, - ) - - yield "data: [DONE]\n\n" - - return StreamingResponse(event_stream(), media_type="text/event-stream") + # Return the streaming generator + return StreamingResponse(agent.step(input_message=input_message), media_type="text/event-stream") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 104318fd..b349e32c 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -13,7 +13,7 @@ from openai.types.chat.chat_completion_message_tool_call import Function as Open from openai.types.chat.completion_create_params import CompletionCreateParams from pydantic import BaseModel -from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, REQ_HEARTBEAT_MESSAGE +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, FUNC_FAILED_HEARTBEAT_MESSAGE, REQ_HEARTBEAT_MESSAGE from letta.errors import ContextWindowExceededError, RateLimitExceededError from letta.helpers.datetime_helpers import get_utc_time from letta.log import get_logger @@ -216,9 +216,10 @@ def create_tool_call_messages_from_openai_response( messages.append(tool_message) if add_heartbeat_request_system_message: + text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE heartbeat_system_message = Message( role=MessageRole.user, - content=[TextContent(text=get_heartbeat(REQ_HEARTBEAT_MESSAGE))], + content=[TextContent(text=get_heartbeat(text_content))], organization_id=actor.organization_id, agent_id=agent_id, model=model, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 8d9743ea..ada1f7c1 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -529,6 +529,7 @@ class AgentManager: model=agent_state.llm_config.model, openai_message_dict={"role": "system", "content": new_system_message_str}, ) + # TODO: This seems kind of silly, why not just update the message? message = self.message_manager.create_message(message, actor=actor) message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system) return self._set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor) diff --git a/tests/integration_test_chat_completions.py b/tests/integration_test_chat_completions.py index 495a3ecd..465b322d 100644 --- a/tests/integration_test_chat_completions.py +++ b/tests/integration_test_chat_completions.py @@ -153,7 +153,7 @@ def _assert_valid_chunk(chunk, idx, chunks): @pytest.mark.asyncio -@pytest.mark.parametrize("message", ["Tell me something interesting about bananas."]) +@pytest.mark.parametrize("message", ["What's the weather in SF?"]) @pytest.mark.parametrize("endpoint", ["v1/voice"]) async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint): """Tests chat completion streaming using the Async OpenAI client.""" @@ -163,8 +163,7 @@ async def test_latency(mock_e2b_api_key_none, client, agent, message, endpoint): stream = await async_client.chat.completions.create(**request.model_dump(exclude_none=True)) async with stream: async for chunk in stream: - assert isinstance(chunk, ChatCompletionChunk), f"Unexpected chunk type: {type(chunk)}" - assert chunk.choices, "Each ChatCompletionChunk should have at least one choice." + print(chunk) @pytest.mark.asyncio