diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index bb355756..c335c6cb 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -307,15 +307,31 @@ def openai_chat_completions_process_stream( warnings.warn( f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" ) + # force index 0 + # accum_message.tool_calls[0].id = tool_call_delta.id else: accum_message.tool_calls[tool_call_delta.index].id = tool_call_delta.id if tool_call_delta.function is not None: if tool_call_delta.function.name is not None: # TODO assert that we're not overwriting? # TODO += instead of =? - accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + # force index 0 + # accum_message.tool_calls[0].function.name = tool_call_delta.function.name + else: + accum_message.tool_calls[tool_call_delta.index].function.name = tool_call_delta.function.name if tool_call_delta.function.arguments is not None: - accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments + if tool_call_delta.index not in range(len(accum_message.tool_calls)): + warnings.warn( + f"Tool call index out of range ({tool_call_delta.index})\ncurrent tool calls: {accum_message.tool_calls}\ncurrent delta: {tool_call_delta}" + ) + # force index 0 + # accum_message.tool_calls[0].function.arguments += tool_call_delta.function.arguments + else: + accum_message.tool_calls[tool_call_delta.index].function.arguments += tool_call_delta.function.arguments if message_delta.function_call is not None: raise NotImplementedError(f"Old function_call style not support with stream=True")