fix: fix bug triggered by using ada embeddings (#1915)

This commit is contained in:
Charles Packer
2024-10-22 16:42:02 -07:00
committed by GitHub
parent 7871eeb9c2
commit a70dea15a4
3 changed files with 177 additions and 12 deletions

View File

@@ -312,11 +312,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
# extra prints
self.debug = False
self.timeout = 30
def _reset_inner_thoughts_json_reader(self):
# A buffer for accumulating function arguments (we want to buffer keys and run checks on each one)
self.function_args_reader = JSONInnerThoughtsExtractor(inner_thoughts_key=self.inner_thoughts_kwarg, wait_for_first_key=True)
# Two buffers used to make sure that the 'name' comes after the inner thoughts stream (if inner_thoughts_in_kwargs)
self.function_name_buffer = None
self.function_args_buffer = None
self.function_id_buffer = None
async def _create_generator(self) -> AsyncGenerator[Union[LettaMessage, LegacyLettaMessage, MessageStreamStatus], None]:
"""An asynchronous generator that yields chunks as they become available."""
while self._active:
@@ -376,6 +385,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if not self.streaming_chat_completion_mode and not self.nonstreaming_legacy_mode:
self._push_to_buffer(self.multi_step_gen_indicator)
# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
def step_complete(self):
"""Signal from the agent that one 'step' finished (step = LLM response + tool execution)"""
if not self.multi_step:
@@ -386,6 +398,9 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# signal that a new step has started in the stream
self._push_to_buffer(self.multi_step_indicator)
# Wipe the inner thoughts buffers
self._reset_inner_thoughts_json_reader()
def step_yield(self):
"""If multi_step, this is the true 'stream_end' function."""
self._active = False
@@ -498,6 +513,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
self.function_name_buffer += tool_call.function.name
if tool_call.id:
# Buffer until next time
if self.function_id_buffer is None:
self.function_id_buffer = tool_call.id
else:
self.function_id_buffer += tool_call.id
if tool_call.function.arguments:
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
@@ -518,6 +540,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# If we have main_json, we should output a FunctionCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
# NOTE: we could output it as part of a chunk that has both name and args,
# however the frontend may expect name first, then args, so to be
@@ -526,18 +549,23 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=self.function_name_buffer, arguments=None),
function_call=FunctionCallDelta(
name=self.function_name_buffer,
arguments=None,
function_call_id=self.function_id_buffer,
),
)
# Clear the buffer
self.function_name_buffer = None
self.function_id_buffer = None
# Since we're clearing the name buffer, we should store
# any updates to the arguments inside a separate buffer
if updates_main_json:
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# Add any main_json updates to the arguments buffer
if self.function_args_buffer is None:
self.function_args_buffer = updates_main_json
else:
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a FunctionCallMessage
@@ -550,17 +578,27 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=combined_chunk),
function_call=FunctionCallDelta(
name=None,
arguments=combined_chunk,
function_call_id=self.function_id_buffer,
),
)
# clear buffer
self.function_args_buffer = None
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=None, arguments=updates_main_json),
function_call=FunctionCallDelta(
name=None,
arguments=updates_main_json,
function_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None
# # If there's something in the main_json buffer, we should add if to the arguments and release it together
# tool_call_delta = {}