feat: add function IDs to LettaMessage function calls and response (#1909)

This commit is contained in:
Charles Packer
2024-10-19 21:47:48 -07:00
committed by GitHub
parent 11bb2f1437
commit 8a9e6dddd3
6 changed files with 43 additions and 9 deletions

View File

@@ -41,7 +41,7 @@ from letta.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from letta.utils import smart_urljoin
from letta.utils import get_tool_call_id, smart_urljoin
OPENAI_SSE_DONE = "[DONE]"
@@ -174,6 +174,7 @@ def openai_chat_completions_process_stream(
stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
override_tool_call_id: bool = True,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
@@ -244,6 +245,14 @@ def openai_chat_completions_process_stream(
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
if override_tool_call_id:
for choice in chat_completion_chunk.choices:
if choice.delta.tool_calls and len(choice.delta.tool_calls) > 0:
for tool_call in choice.delta.tool_calls:
if tool_call.id is not None:
tool_call.id = get_tool_call_id()
if stream_interface:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.process_chunk(
@@ -290,6 +299,7 @@ def openai_chat_completions_process_stream(
else:
accum_message.content += content_delta
# TODO(charles) make sure this works for parallel tool calling?
if message_delta.tool_calls is not None:
tool_calls_delta = message_delta.tool_calls
@@ -340,7 +350,7 @@ def openai_chat_completions_process_stream(
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
assert all(
[
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
for c in chat_completion_response.choices
]
)