feat: add function IDs to LettaMessage function calls and response (#1909)

This commit is contained in:
Charles Packer
2024-10-19 21:47:48 -07:00
committed by GitHub
parent 11bb2f1437
commit 8a9e6dddd3
6 changed files with 43 additions and 9 deletions

View File

@@ -503,7 +503,7 @@ class Agent(BaseAgent):
def _handle_ai_response(
self,
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
override_tool_call_id: bool = True,
override_tool_call_id: bool = False,
# If we are streaming, we needed to create a Message ID ahead of time,
# and now we want to use it in the creation of the Message object
# TODO figure out a cleaner way to do this
@@ -530,6 +530,7 @@ class Agent(BaseAgent):
# generate UUID for tool call
if override_tool_call_id or response_message.function_call:
warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming")
tool_call_id = get_tool_call_id() # needs to be a string for JSON
response_message.tool_calls[0].id = tool_call_id
else:

View File

@@ -41,7 +41,7 @@ from letta.streaming_interface import (
AgentChunkStreamingInterface,
AgentRefreshStreamingInterface,
)
from letta.utils import smart_urljoin
from letta.utils import get_tool_call_id, smart_urljoin
OPENAI_SSE_DONE = "[DONE]"
@@ -174,6 +174,7 @@ def openai_chat_completions_process_stream(
stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
create_message_id: bool = True,
create_message_datetime: bool = True,
override_tool_call_id: bool = True,
) -> ChatCompletionResponse:
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
@@ -244,6 +245,14 @@ def openai_chat_completions_process_stream(
):
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
if override_tool_call_id:
for choice in chat_completion_chunk.choices:
if choice.delta.tool_calls and len(choice.delta.tool_calls) > 0:
for tool_call in choice.delta.tool_calls:
if tool_call.id is not None:
tool_call.id = get_tool_call_id()
if stream_interface:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.process_chunk(
@@ -290,6 +299,7 @@ def openai_chat_completions_process_stream(
else:
accum_message.content += content_delta
# TODO(charles) make sure this works for parallel tool calling?
if message_delta.tool_calls is not None:
tool_calls_delta = message_delta.tool_calls
@@ -340,7 +350,7 @@ def openai_chat_completions_process_stream(
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
assert all(
[
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
for c in chat_completion_response.choices
]
)

View File

@@ -78,12 +78,14 @@ class FunctionCall(BaseModel):
name: str
arguments: str
function_call_id: str
class FunctionCallDelta(BaseModel):
name: Optional[str]
arguments: Optional[str]
function_call_id: Optional[str]
# NOTE: this is a workaround to exclude None values from the JSON dump,
# since the OpenAI style of returning chunks doesn't include keys with null values
@@ -129,10 +131,10 @@ class FunctionCallMessage(LettaMessage):
@classmethod
def validate_function_call(cls, v):
if isinstance(v, dict):
if "name" in v and "arguments" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"])
elif "name" in v or "arguments" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"))
if "name" in v and "arguments" in v and "function_call_id" in v:
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
elif "name" in v or "arguments" in v or "function_call_id" in v:
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
else:
raise ValueError("function_call must contain either 'name' or 'arguments'")
return v
@@ -147,11 +149,13 @@ class FunctionReturn(LettaMessage):
status (Literal["success", "error"]): The status of the function call
id (str): The ID of the message
date (datetime): The date the message was created in ISO format
function_call_id (str): A unique identifier for the function call that generated this message
"""
message_type: Literal["function_return"] = "function_return"
function_return: str
status: Literal["success", "error"]
function_call_id: str
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string

View File

@@ -178,6 +178,7 @@ class Message(BaseMessage):
function_call=FunctionCall(
name=tool_call.function.name,
arguments=tool_call.function.arguments,
function_call_id=tool_call.id,
),
)
)
@@ -203,6 +204,7 @@ class Message(BaseMessage):
raise ValueError(f"Invalid status: {status}")
except json.JSONDecodeError:
raise ValueError(f"Failed to decode function return: {self.text}")
assert self.tool_call_id is not None
messages.append(
# TODO make sure this is what the API returns
# function_return may not match exactly...
@@ -211,6 +213,7 @@ class Message(BaseMessage):
date=self.created_at,
function_return=self.text,
status=status_enum,
function_call_id=self.tool_call_id,
)
)
elif self.role == MessageRole.user:

View File

@@ -531,7 +531,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)
else:
@@ -548,7 +552,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
function_call=FunctionCallDelta(
name=tool_call_delta.get("name"),
arguments=tool_call_delta.get("arguments"),
function_call_id=tool_call_delta.get("id"),
),
)
elif choice.finish_reason is not None:
@@ -759,6 +767,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
function_call_id=function_call.id,
),
)
@@ -786,21 +795,25 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
elif msg.startswith("Success: "):
msg = msg.replace("Success: ", "")
# new_message = {"function_return": msg, "status": "success"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="success",
function_call_id=msg_obj.tool_call_id,
)
elif msg.startswith("Error: "):
msg = msg.replace("Error: ", "")
# new_message = {"function_return": msg, "status": "error"}
assert msg_obj.tool_call_id is not None
new_message = FunctionReturn(
id=msg_obj.id,
date=msg_obj.created_at,
function_return=msg,
status="error",
function_call_id=msg_obj.tool_call_id,
)
else:

View File

@@ -488,6 +488,9 @@ def is_utc_datetime(dt: datetime) -> bool:
def get_tool_call_id() -> str:
# TODO(sarah) make this a slug-style string?
# e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C"
# or similar to agents: "call-xlIfzR1HqAW7xJPa3ExJSg3C"
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]