feat: add OpenAI streaming interface for new agent loop (#2191)
This commit is contained in:
@@ -12,6 +12,7 @@ from letta.agents.helpers import _create_letta_response, _prepare_in_context_mes
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface
|
||||
from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
@@ -125,7 +126,7 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
@trace_method
|
||||
async def step_stream(
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True
|
||||
self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False
|
||||
) -> AsyncGenerator[str, None]:
|
||||
"""
|
||||
Main streaming loop that yields partial tokens.
|
||||
@@ -153,9 +154,16 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
# TODO: THIS IS INCREDIBLY UGLY
|
||||
# TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs
|
||||
)
|
||||
if agent_state.llm_config.model_endpoint_type == "anthropic":
|
||||
interface = AnthropicStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
elif agent_state.llm_config.model_endpoint_type == "openai":
|
||||
interface = OpenAIStreamingInterface(
|
||||
use_assistant_message=use_assistant_message,
|
||||
put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs,
|
||||
)
|
||||
async for chunk in interface.process(stream):
|
||||
yield f"data: {chunk.model_dump_json()}\n\n"
|
||||
|
||||
|
||||
305
letta/interfaces/openai_streaming_interface.py
Normal file
305
letta/interfaces/openai_streaming_interface.py
Normal file
@@ -0,0 +1,305 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.interfaces.utils import _format_sse_chunk
|
||||
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_utils import JSONInnerThoughtsExtractor
|
||||
|
||||
|
||||
class OpenAIStreamingInterface:
|
||||
"""
|
||||
Encapsulates the logic for streaming responses from OpenAI.
|
||||
This class handles parsing of partial tokens, pre-execution messages,
|
||||
and detection of tool call events.
|
||||
"""
|
||||
|
||||
def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
|
||||
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
|
||||
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
self.last_flushed_function_name = None
|
||||
|
||||
# Buffer to hold function arguments until inner thoughts are complete
|
||||
self.current_function_arguments = ""
|
||||
self.current_json_parse_result = {}
|
||||
|
||||
# Premake IDs for database writes
|
||||
self.letta_assistant_message_id = Message.generate_id()
|
||||
self.letta_tool_message_id = Message.generate_id()
|
||||
|
||||
# token counters
|
||||
self.input_tokens = 0
|
||||
self.output_tokens = 0
|
||||
|
||||
self.content_buffer: List[str] = []
|
||||
self.tool_call_name: Optional[str] = None
|
||||
self.tool_call_id: Optional[str] = None
|
||||
self.reasoning_messages = []
|
||||
|
||||
def get_reasoning_content(self) -> List[TextContent]:
|
||||
content = "".join(self.reasoning_messages)
|
||||
return [TextContent(text=content)]
|
||||
|
||||
def get_tool_call_object(self) -> ToolCall:
|
||||
"""Useful for agent loop"""
|
||||
return ToolCall(
|
||||
id=self.letta_tool_message_id,
|
||||
function=FunctionCall(arguments=self.current_function_arguments, name=self.last_flushed_function_name),
|
||||
)
|
||||
|
||||
async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[LettaMessage, None]:
|
||||
"""
|
||||
Iterates over the OpenAI stream, yielding SSE events.
|
||||
It also collects tokens and detects if a tool call is triggered.
|
||||
"""
|
||||
async with stream:
|
||||
prev_message_type = None
|
||||
message_index = 0
|
||||
async for chunk in stream:
|
||||
# track usage
|
||||
if chunk.usage:
|
||||
self.input_tokens += len(chunk.usage.prompt_tokens)
|
||||
self.output_tokens += len(chunk.usage.completion_tokens)
|
||||
|
||||
if chunk.choices:
|
||||
choice = chunk.choices[0]
|
||||
message_delta = choice.delta
|
||||
|
||||
if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
if tool_call.function.name:
|
||||
# If we're waiting for the first key, then we should hold back the name
|
||||
# ie add it to a buffer instead of returning it as a chunk
|
||||
if self.function_name_buffer is None:
|
||||
self.function_name_buffer = tool_call.function.name
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
self.current_function_arguments += tool_call.function.arguments
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(
|
||||
tool_call.function.arguments
|
||||
)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
if prev_message_type and prev_message_type != "reasoning_message":
|
||||
message_index += 1
|
||||
self.reasoning_messages.append(updates_inner_thoughts)
|
||||
reasoning_message = ReasoningMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
reasoning=updates_inner_thoughts,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = reasoning_message.message_type
|
||||
yield reasoning_message
|
||||
|
||||
# Additionally inner thoughts may stream back with a chunk of main JSON
|
||||
# In that case, since we can only return a chunk at a time, we should buffer it
|
||||
if updates_main_json:
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If we have main_json, we should output a ToolCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
# safe we'll output name first in a separate chunk
|
||||
if self.function_name_buffer:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name:
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
|
||||
else:
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
self.tool_call_name = str(self.function_name_buffer)
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
|
||||
# Record what the last function name we flushed was
|
||||
self.last_flushed_function_name = self.function_name_buffer
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a ToolCallMessage
|
||||
else:
|
||||
|
||||
# use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..."
|
||||
if self.use_assistant_message and (
|
||||
self.last_flushed_function_name is not None
|
||||
and self.last_flushed_function_name == self.assistant_message_tool_name
|
||||
):
|
||||
# do an additional parse on the updates_main_json
|
||||
if self.function_args_buffer:
|
||||
updates_main_json = self.function_args_buffer + updates_main_json
|
||||
self.function_args_buffer = None
|
||||
|
||||
# Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix
|
||||
match_str = '{"' + self.assistant_message_tool_kwarg + '":"'
|
||||
if updates_main_json == match_str:
|
||||
updates_main_json = None
|
||||
|
||||
else:
|
||||
# Some hardcoding to strip off the trailing "}"
|
||||
if updates_main_json in ["}", '"}']:
|
||||
updates_main_json = None
|
||||
if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"':
|
||||
updates_main_json = updates_main_json[:-1]
|
||||
|
||||
if not updates_main_json:
|
||||
# early exit to turn into content mode
|
||||
continue
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=combined_chunk,
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
# TODO: THIS IS HORRIBLE
|
||||
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
|
||||
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
|
||||
parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments)
|
||||
|
||||
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
|
||||
self.assistant_message_tool_kwarg
|
||||
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
|
||||
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
|
||||
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
|
||||
# TODO: Assumes consistent state and that prev_content is subset of new_content
|
||||
diff = new_content.replace(prev_content, "", 1)
|
||||
self.current_json_parse_result = parsed_args
|
||||
if prev_message_type and prev_message_type != "assistant_message":
|
||||
message_index += 1
|
||||
assistant_message = AssistantMessage(
|
||||
id=self.letta_assistant_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
content=diff,
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index),
|
||||
)
|
||||
prev_message_type = assistant_message.message_type
|
||||
yield assistant_message
|
||||
|
||||
# Store the ID of the tool call so allow skipping the corresponding response
|
||||
if self.function_id_buffer:
|
||||
self.prev_assistant_message_id = self.function_id_buffer
|
||||
# clear buffers
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
|
||||
# There may be a buffer from a previous chunk, for example
|
||||
# if the previous chunk had arguments but we needed to flush name
|
||||
if self.function_args_buffer:
|
||||
# In this case, we should release the buffer + new data at once
|
||||
combined_chunk = self.function_args_buffer + updates_main_json
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=combined_chunk,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
if prev_message_type and prev_message_type != "tool_call_message":
|
||||
message_index += 1
|
||||
tool_call_msg = ToolCallMessage(
|
||||
id=self.letta_tool_message_id,
|
||||
date=datetime.now(timezone.utc),
|
||||
tool_call=ToolCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
tool_call_id=self.function_id_buffer,
|
||||
),
|
||||
# name=name,
|
||||
otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index),
|
||||
)
|
||||
prev_message_type = tool_call_msg.message_type
|
||||
yield tool_call_msg
|
||||
self.function_id_buffer = None
|
||||
@@ -694,9 +694,9 @@ async def send_message_streaming(
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false"
|
||||
feature_enabled = settings.use_experimental or experimental_header.lower() == "true"
|
||||
model_compatible = agent.llm_config.model_endpoint_type == "anthropic"
|
||||
model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"]
|
||||
|
||||
if agent_eligible and feature_enabled and model_compatible and request.stream_tokens:
|
||||
if agent_eligible and feature_enabled and model_compatible:
|
||||
experimental_agent = LettaAgent(
|
||||
agent_id=agent_id,
|
||||
message_manager=server.message_manager,
|
||||
@@ -707,7 +707,9 @@ async def send_message_streaming(
|
||||
)
|
||||
|
||||
result = StreamingResponse(
|
||||
experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message),
|
||||
experimental_agent.step_stream(
|
||||
request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens
|
||||
),
|
||||
media_type="text/event-stream",
|
||||
)
|
||||
else:
|
||||
|
||||
Reference in New Issue
Block a user