feat: rename function to tool in sdk (#2288)

Co-authored-by: Caren Thomas <caren@caren-mac.local>
This commit is contained in:
cthomas
2024-12-19 12:12:58 -08:00
committed by GitHub
parent 5f2ba44e93
commit 7d5be32a59
16 changed files with 202 additions and 164 deletions

View File

@@ -12,10 +12,10 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
AssistantMessage,
FunctionCall,
FunctionCallDelta,
FunctionCallMessage,
FunctionReturn,
ToolCall,
ToolCallDelta,
ToolCallMessage,
ToolReturnMessage,
InternalMonologue,
LegacyFunctionCallMessage,
LegacyLettaMessage,
@@ -411,7 +411,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def _process_chunk_to_letta_style(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
) -> Optional[Union[InternalMonologue, ToolCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -442,7 +442,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if self.inner_thoughts_in_kwargs:
raise NotImplementedError("inner_thoughts_in_kwargs with use_assistant_message not yet supported")
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard ToolCallMessage passthrough mode
# Track the function name while streaming
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
@@ -474,7 +474,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
assistant_message=cleaned_func_args,
)
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
# otherwise we just do a regular passthrough of a ToolCallDelta via a ToolCallMessage
else:
tool_call_delta = {}
if tool_call.id:
@@ -485,13 +485,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
tool_call_id=tool_call_delta.get("id"),
),
)
@@ -531,7 +531,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
self.function_args_buffer += updates_main_json
# If we have main_json, we should output a FunctionCallMessage
# If we have main_json, we should output a ToolCallMessage
elif updates_main_json:
# If there's something in the function_name buffer, we should release it first
@@ -539,13 +539,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# however the frontend may expect name first, then args, so to be
# safe we'll output name first in a separate chunk
if self.function_name_buffer:
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=self.function_name_buffer,
arguments=None,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
# Clear the buffer
@@ -561,20 +561,20 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.function_args_buffer += updates_main_json
# If there was nothing in the name buffer, we can proceed to
# output the arguments chunk as a FunctionCallMessage
# output the arguments chunk as a ToolCallMessage
else:
# There may be a buffer from a previous chunk, for example
# if the previous chunk had arguments but we needed to flush name
if self.function_args_buffer:
# In this case, we should release the buffer + new data at once
combined_chunk = self.function_args_buffer + updates_main_json
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=None,
arguments=combined_chunk,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
# clear buffer
@@ -582,13 +582,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.function_id_buffer = None
else:
# If there's no buffer to clear, just output a new chunk with new data
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=None,
arguments=updates_main_json,
function_call_id=self.function_id_buffer,
tool_call_id=self.function_id_buffer,
),
)
self.function_id_buffer = None
@@ -608,10 +608,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# # if tool_call.function.name:
# # tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# processed_chunk = ToolCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
else:
@@ -642,10 +642,10 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# if tool_call.function.name:
# tool_call_delta["name"] = tool_call.function.name
# processed_chunk = FunctionCallMessage(
# processed_chunk = ToolCallMessage(
# id=message_id,
# date=message_date,
# function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# tool_call=ToolCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
# )
# elif False and self.inner_thoughts_in_kwargs and tool_call.function:
@@ -682,13 +682,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
# If it does match, start processing the value (stringified-JSON string
# And with each new chunk, output it as a chunk of type InternalMonologue
# If the key doesn't match, then flush the buffer as a single FunctionCallMessage chunk
# If the key doesn't match, then flush the buffer as a single ToolCallMessage chunk
# If we're reading a value
# If we're reading the inner thoughts value, we output chunks of type InternalMonologue
# Otherwise, do simple chunks of FunctionCallMessage
# Otherwise, do simple chunks of ToolCallMessage
else:
@@ -701,13 +701,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(
tool_call=ToolCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
tool_call_id=tool_call_delta.get("id"),
),
)
@@ -911,13 +911,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
assistant_message=func_args[self.assistant_message_tool_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
processed_chunk = ToolCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
tool_call=ToolCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
function_call_id=function_call.id,
tool_call_id=function_call.id,
),
)
@@ -942,24 +942,24 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
tool_return=msg,
status="success",
function_call_id=msg_obj.tool_call_id,
tool_call_id=msg_obj.tool_call_id,
)
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
new_message = ToolReturnMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
tool_return=msg,
status="error",
function_call_id=msg_obj.tool_call_id,
tool_call_id=msg_obj.tool_call_id,
)
else: