feat: Patch broken assistant messages for both Haiku and Sonnet (#1252)

This commit is contained in:
Matthew Zhou
2025-03-12 10:19:41 -07:00
committed by GitHub
parent beccaa8939
commit a86c268926
3 changed files with 49 additions and 14 deletions

View File

@@ -54,7 +54,7 @@ DEVELOPMENT_LOGGING = {
"propagate": True, # Let logs bubble up to root
},
"uvicorn": {
"level": "INFO",
"level": "CRITICAL",
"handlers": ["console"],
"propagate": True,
},

View File

@@ -267,3 +267,5 @@ class ChatCompletionsStreamingInterface(AgentChunkStreamingInterface):
"""Clears internal buffers for function call name/args."""
self.current_function_name = ""
self.current_function_arguments = []
self.current_json_parse_result = {}
self._found_message_tool_kwarg = False

View File

@@ -24,6 +24,7 @@ from letta.schemas.letta_message import (
)
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import ChatCompletionChunkResponse
from letta.server.rest_api.optimistic_json_parser import OptimisticJSONParser
from letta.streaming_interface import AgentChunkStreamingInterface
from letta.streaming_utils import FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor
@@ -282,6 +283,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# turn function argument to send_message into a normal text stream
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_tool_kwarg)
# @matt's changes here, adopting new optimistic json parser
self.current_function_arguments = []
self.optimistic_json_parser = OptimisticJSONParser()
self.current_json_parse_result = {}
# Store metadata passed from server
self.metadata = {}
@@ -374,6 +380,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_start(self):
"""Initialize streaming by activating the generator and clearing any old chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_json_parse_result = {}
if not self._active:
self._active = True
@@ -383,6 +391,8 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def stream_end(self):
"""Clean up the stream by deactivating and clearing chunks."""
self.streaming_chat_completion_mode_function_name = None
self.current_function_arguments = []
self.current_json_parse_result = {}
# if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
# self._push_to_buffer(self.multi_step_gen_indicator)
@@ -568,20 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_json_reader.reset()
# early exit to turn into content mode
return None
if tool_call.function.arguments:
self.current_function_arguments.append(tool_call.function.arguments)
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if tool_call.function.arguments and self.streaming_chat_completion_mode_function_name == self.assistant_message_tool_name:
# Strip out any extras tokens
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
if cleaned_func_args is None:
return None
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
else:
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
content=cleaned_func_args,
)
return None
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
else:
@@ -637,6 +654,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# updates_inner_thoughts = ""
# else: # OpenAI
# updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
self.current_function_arguments.append(tool_call.function.arguments)
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk
@@ -731,6 +749,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
@@ -745,11 +764,24 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
content=updates_main_json,
)
# TODO: THIS IS HORRIBLE
# TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER
# TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE
combined_args = "".join(self.current_function_arguments)
parsed_args = self.optimistic_json_parser.parse(combined_args)
if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get(
self.assistant_message_tool_kwarg
) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg):
new_content = parsed_args.get(self.assistant_message_tool_kwarg)
prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "")
# TODO: Assumes consistent state and that prev_content is subset of new_content
diff = new_content.replace(prev_content, "", 1)
self.current_json_parse_result = parsed_args
processed_chunk = AssistantMessage(id=message_id, date=message_date, content=diff)
else:
return None
# Store the ID of the tool call so allow skipping the corresponding response
if self.function_id_buffer:
self.prev_assistant_message_id = self.function_id_buffer
@@ -1018,6 +1050,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
message_date=message_date,
expect_reasoning_content=expect_reasoning_content,
)
if processed_chunk is None:
return