diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 55c312a9..18d03e79 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -12,6 +12,7 @@ from letta.agents.helpers import _create_letta_response, _prepare_in_context_mes from letta.helpers import ToolRulesSolver from letta.helpers.tool_execution_helper import enable_strict_mode from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface +from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface from letta.llm_api.llm_client import LLMClient from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG @@ -125,7 +126,7 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True + self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False ) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. @@ -153,9 +154,16 @@ class LettaAgent(BaseAgent): ) # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED - interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs - ) + if agent_state.llm_config.model_endpoint_type == "anthropic": + interface = AnthropicStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) + elif agent_state.llm_config.model_endpoint_type == "openai": + interface = OpenAIStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py new file mode 100644 index 00000000..5b4fade4 --- /dev/null +++ b/letta/interfaces/openai_streaming_interface.py @@ -0,0 +1,305 @@ +from datetime import datetime, timezone +from typing import Any, AsyncGenerator, Dict, List, Optional + +from openai import AsyncStream +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, PRE_EXECUTION_MESSAGE_ARG +from letta.interfaces.utils import _format_sse_chunk +from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage +from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.schemas.usage import LettaUsageStatistics +from letta.server.rest_api.json_parser import OptimisticJSONParser +from letta.streaming_utils import JSONInnerThoughtsExtractor + + +class OpenAIStreamingInterface: + """ + Encapsulates the logic for streaming responses from OpenAI. + This class handles parsing of partial tokens, pre-execution messages, + and detection of tool call events. + """ + + def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): + self.use_assistant_message = use_assistant_message + self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL + self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG + + self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward + self.function_name_buffer = None + self.function_args_buffer = None + self.function_id_buffer = None + self.last_flushed_function_name = None + + # Buffer to hold function arguments until inner thoughts are complete + self.current_function_arguments = "" + self.current_json_parse_result = {} + + # Premake IDs for database writes + self.letta_assistant_message_id = Message.generate_id() + self.letta_tool_message_id = Message.generate_id() + + # token counters + self.input_tokens = 0 + self.output_tokens = 0 + + self.content_buffer: List[str] = [] + self.tool_call_name: Optional[str] = None + self.tool_call_id: Optional[str] = None + self.reasoning_messages = [] + + def get_reasoning_content(self) -> List[TextContent]: + content = "".join(self.reasoning_messages) + return [TextContent(text=content)] + + def get_tool_call_object(self) -> ToolCall: + """Useful for agent loop""" + return ToolCall( + id=self.letta_tool_message_id, + function=FunctionCall(arguments=self.current_function_arguments, name=self.last_flushed_function_name), + ) + + async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[LettaMessage, None]: + """ + Iterates over the OpenAI stream, yielding SSE events. + It also collects tokens and detects if a tool call is triggered. + """ + async with stream: + prev_message_type = None + message_index = 0 + async for chunk in stream: + # track usage + if chunk.usage: + self.input_tokens += len(chunk.usage.prompt_tokens) + self.output_tokens += len(chunk.usage.completion_tokens) + + if chunk.choices: + choice = chunk.choices[0] + message_delta = choice.delta + + if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: + tool_call = message_delta.tool_calls[0] + + if tool_call.function.name: + # If we're waiting for the first key, then we should hold back the name + # ie add it to a buffer instead of returning it as a chunk + if self.function_name_buffer is None: + self.function_name_buffer = tool_call.function.name + else: + self.function_name_buffer += tool_call.function.name + + if tool_call.id: + # Buffer until next time + if self.function_id_buffer is None: + self.function_id_buffer = tool_call.id + else: + self.function_id_buffer += tool_call.id + + if tool_call.function.arguments: + # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments + updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment( + tool_call.function.arguments + ) + + # If we have inner thoughts, we should output them as a chunk + if updates_inner_thoughts: + if prev_message_type and prev_message_type != "reasoning_message": + message_index += 1 + self.reasoning_messages.append(updates_inner_thoughts) + reasoning_message = ReasoningMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + reasoning=updates_inner_thoughts, + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = reasoning_message.message_type + yield reasoning_message + + # Additionally inner thoughts may stream back with a chunk of main JSON + # In that case, since we can only return a chunk at a time, we should buffer it + if updates_main_json: + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If we have main_json, we should output a ToolCallMessage + elif updates_main_json: + + # If there's something in the function_name buffer, we should release it first + # NOTE: we could output it as part of a chunk that has both name and args, + # however the frontend may expect name first, then args, so to be + # safe we'll output name first in a separate chunk + if self.function_name_buffer: + + # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." + if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name: + + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + + else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + self.tool_call_name = str(self.function_name_buffer) + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ), + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + + # Record what the last function name we flushed was + self.last_flushed_function_name = self.function_name_buffer + # Clear the buffer + self.function_name_buffer = None + self.function_id_buffer = None + # Since we're clearing the name buffer, we should store + # any updates to the arguments inside a separate buffer + + # Add any main_json updates to the arguments buffer + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If there was nothing in the name buffer, we can proceed to + # output the arguments chunk as a ToolCallMessage + else: + + # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." + if self.use_assistant_message and ( + self.last_flushed_function_name is not None + and self.last_flushed_function_name == self.assistant_message_tool_name + ): + # do an additional parse on the updates_main_json + if self.function_args_buffer: + updates_main_json = self.function_args_buffer + updates_main_json + self.function_args_buffer = None + + # Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix + match_str = '{"' + self.assistant_message_tool_kwarg + '":"' + if updates_main_json == match_str: + updates_main_json = None + + else: + # Some hardcoding to strip off the trailing "}" + if updates_main_json in ["}", '"}']: + updates_main_json = None + if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"': + updates_main_json = updates_main_json[:-1] + + if not updates_main_json: + # early exit to turn into content mode + continue + + # There may be a buffer from a previous chunk, for example + # if the previous chunk had arguments but we needed to flush name + if self.function_args_buffer: + # In this case, we should release the buffer + new data at once + combined_chunk = self.function_args_buffer + updates_main_json + + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 + assistant_message = AssistantMessage( + id=self.letta_assistant_message_id, + date=datetime.now(timezone.utc), + content=combined_chunk, + otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + ) + prev_message_type = assistant_message.message_type + yield assistant_message + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + # clear buffer + self.function_args_buffer = None + self.function_id_buffer = None + + else: + # If there's no buffer to clear, just output a new chunk with new data + # TODO: THIS IS HORRIBLE + # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER + # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) + + if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( + self.assistant_message_tool_kwarg + ) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg): + new_content = parsed_args.get(self.assistant_message_tool_kwarg) + prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "") + # TODO: Assumes consistent state and that prev_content is subset of new_content + diff = new_content.replace(prev_content, "", 1) + self.current_json_parse_result = parsed_args + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 + assistant_message = AssistantMessage( + id=self.letta_assistant_message_id, + date=datetime.now(timezone.utc), + content=diff, + # name=name, + otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + ) + prev_message_type = assistant_message.message_type + yield assistant_message + + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + # clear buffers + self.function_id_buffer = None + else: + + # There may be a buffer from a previous chunk, for example + # if the previous chunk had arguments but we needed to flush name + if self.function_args_buffer: + # In this case, we should release the buffer + new data at once + combined_chunk = self.function_args_buffer + updates_main_json + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + # clear buffer + self.function_args_buffer = None + self.function_id_buffer = None + else: + # If there's no buffer to clear, just output a new chunk with new data + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + self.function_id_buffer = None diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index c3cf94f0..24226996 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -694,9 +694,9 @@ async def send_message_streaming( agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" - model_compatible = agent.llm_config.model_endpoint_type == "anthropic" + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] - if agent_eligible and feature_enabled and model_compatible and request.stream_tokens: + if agent_eligible and feature_enabled and model_compatible: experimental_agent = LettaAgent( agent_id=agent_id, message_manager=server.message_manager, @@ -707,7 +707,9 @@ async def send_message_streaming( ) result = StreamingResponse( - experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message), + experimental_agent.step_stream( + request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens + ), media_type="text/event-stream", ) else: