fix: fix bug triggered by using ada embeddings (#1915)
This commit is contained in:
@@ -312,11 +312,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
# extra prints
|
||||
self.debug = False
|
||||
self.timeout = 30
|
||||
|
||||
def _reset_inner_thoughts_json_reader(self):
|
||||
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
|
||||
self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True)
|
||||
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
|
||||
self.function_name_buffer = None
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
|
||||
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
|
||||
"""An asynchronous generator that yields chunks as they become available."""
|
||||
while self._active:
|
||||
@@ -376,6 +385,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
|
||||
self._push_to_buffer(self.multi_step_gen_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
|
||||
def step_complete(self):
|
||||
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
|
||||
if not self.multi_step:
|
||||
@@ -386,6 +398,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
# signal that a new step has started in the stream
|
||||
self._push_to_buffer(self.multi_step_indicator)
|
||||
|
||||
# Wipe the inner thoughts buffers
|
||||
self._reset_inner_thoughts_json_reader()
|
||||
|
||||
def step_yield(self):
|
||||
"""If multi_step, this is the true 'stream_end' function."""
|
||||
self._active = False
|
||||
@@ -498,6 +513,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
else:
|
||||
self.function_name_buffer += tool_call.function.name
|
||||
|
||||
if tool_call.id:
|
||||
# Buffer until next time
|
||||
if self.function_id_buffer is None:
|
||||
self.function_id_buffer = tool_call.id
|
||||
else:
|
||||
self.function_id_buffer += tool_call.id
|
||||
|
||||
if tool_call.function.arguments:
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
@@ -518,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
# If we have main_json, we should output a FunctionCallMessage
|
||||
elif updates_main_json:
|
||||
|
||||
# If there's something in the function_name buffer, we should release it first
|
||||
# NOTE: we could output it as part of a chunk that has both name and args,
|
||||
# however the frontend may expect name first, then args, so to be
|
||||
@@ -526,18 +549,23 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None),
|
||||
function_call=FunctionCallDelta(
|
||||
name=self.function_name_buffer,
|
||||
arguments=None,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
# Clear the buffer
|
||||
self.function_name_buffer = None
|
||||
self.function_id_buffer = None
|
||||
# Since we're clearing the name buffer, we should store
|
||||
# any updates to the arguments inside a separate buffer
|
||||
if updates_main_json:
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# Add any main_json updates to the arguments buffer
|
||||
if self.function_args_buffer is None:
|
||||
self.function_args_buffer = updates_main_json
|
||||
else:
|
||||
self.function_args_buffer += updates_main_json
|
||||
|
||||
# If there was nothing in the name buffer, we can proceed to
|
||||
# output the arguments chunk as a FunctionCallMessage
|
||||
@@ -550,17 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=None, arguments=combined_chunk),
|
||||
function_call=FunctionCallDelta(
|
||||
name=None,
|
||||
arguments=combined_chunk,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
# clear buffer
|
||||
self.function_args_buffer = None
|
||||
self.function_id_buffer = None
|
||||
else:
|
||||
# If there's no buffer to clear, just output a new chunk with new data
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=None, arguments=updates_main_json),
|
||||
function_call=FunctionCallDelta(
|
||||
name=None,
|
||||
arguments=updates_main_json,
|
||||
function_call_id=self.function_id_buffer,
|
||||
),
|
||||
)
|
||||
self.function_id_buffer = None
|
||||
|
||||
# # If there's something in the main_json buffer, we should add if to the arguments and release it together
|
||||
# tool_call_delta = {}
|
||||
|
||||
Reference in New Issue
Block a user