feat: add function IDs to LettaMessage function calls and response (#1909)
This commit is contained in:
@@ -503,7 +503,7 @@ class Agent(BaseAgent):
|
||||
def _handle_ai_response(
|
||||
self,
|
||||
response_message: ChatCompletionMessage, # TODO should we eventually move the Message creation outside of this function?
|
||||
override_tool_call_id: bool = True,
|
||||
override_tool_call_id: bool = False,
|
||||
# If we are streaming, we needed to create a Message ID ahead of time,
|
||||
# and now we want to use it in the creation of the Message object
|
||||
# TODO figure out a cleaner way to do this
|
||||
@@ -530,6 +530,7 @@ class Agent(BaseAgent):
|
||||
|
||||
# generate UUID for tool call
|
||||
if override_tool_call_id or response_message.function_call:
|
||||
warnings.warn("Overriding the tool call can result in inconsistent tool call IDs during streaming")
|
||||
tool_call_id = get_tool_call_id() # needs to be a string for JSON
|
||||
response_message.tool_calls[0].id = tool_call_id
|
||||
else:
|
||||
|
||||
@@ -41,7 +41,7 @@ from letta.streaming_interface import (
|
||||
AgentChunkStreamingInterface,
|
||||
AgentRefreshStreamingInterface,
|
||||
)
|
||||
from letta.utils import smart_urljoin
|
||||
from letta.utils import get_tool_call_id, smart_urljoin
|
||||
|
||||
OPENAI_SSE_DONE = "[DONE]"
|
||||
|
||||
@@ -174,6 +174,7 @@ def openai_chat_completions_process_stream(
|
||||
stream_interface: Optional[Union[AgentChunkStreamingInterface, AgentRefreshStreamingInterface]] = None,
|
||||
create_message_id: bool = True,
|
||||
create_message_datetime: bool = True,
|
||||
override_tool_call_id: bool = True,
|
||||
) -> ChatCompletionResponse:
|
||||
"""Process a streaming completion response, and return a ChatCompletionRequest at the end.
|
||||
|
||||
@@ -244,6 +245,14 @@ def openai_chat_completions_process_stream(
|
||||
):
|
||||
assert isinstance(chat_completion_chunk, ChatCompletionChunkResponse), type(chat_completion_chunk)
|
||||
|
||||
# NOTE: this assumes that the tool call ID will only appear in one of the chunks during the stream
|
||||
if override_tool_call_id:
|
||||
for choice in chat_completion_chunk.choices:
|
||||
if choice.delta.tool_calls and len(choice.delta.tool_calls) > 0:
|
||||
for tool_call in choice.delta.tool_calls:
|
||||
if tool_call.id is not None:
|
||||
tool_call.id = get_tool_call_id()
|
||||
|
||||
if stream_interface:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.process_chunk(
|
||||
@@ -290,6 +299,7 @@ def openai_chat_completions_process_stream(
|
||||
else:
|
||||
accum_message.content += content_delta
|
||||
|
||||
# TODO(charles) make sure this works for parallel tool calling?
|
||||
if message_delta.tool_calls is not None:
|
||||
tool_calls_delta = message_delta.tool_calls
|
||||
|
||||
@@ -340,7 +350,7 @@ def openai_chat_completions_process_stream(
|
||||
assert all([c.finish_reason != TEMP_STREAM_FINISH_REASON for c in chat_completion_response.choices])
|
||||
assert all(
|
||||
[
|
||||
all([tc != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
|
||||
all([tc.id != TEMP_STREAM_TOOL_CALL_ID for tc in c.message.tool_calls]) if c.message.tool_calls else True
|
||||
for c in chat_completion_response.choices
|
||||
]
|
||||
)
|
||||
|
||||
@@ -78,12 +78,14 @@ class FunctionCall(BaseModel):
|
||||
|
||||
name: str
|
||||
arguments: str
|
||||
function_call_id: str
|
||||
|
||||
|
||||
class FunctionCallDelta(BaseModel):
|
||||
|
||||
name: Optional[str]
|
||||
arguments: Optional[str]
|
||||
function_call_id: Optional[str]
|
||||
|
||||
# NOTE: this is a workaround to exclude None values from the JSON dump,
|
||||
# since the OpenAI style of returning chunks doesn't include keys with null values
|
||||
@@ -129,10 +131,10 @@ class FunctionCallMessage(LettaMessage):
|
||||
@classmethod
|
||||
def validate_function_call(cls, v):
|
||||
if isinstance(v, dict):
|
||||
if "name" in v and "arguments" in v:
|
||||
return FunctionCall(name=v["name"], arguments=v["arguments"])
|
||||
elif "name" in v or "arguments" in v:
|
||||
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"))
|
||||
if "name" in v and "arguments" in v and "function_call_id" in v:
|
||||
return FunctionCall(name=v["name"], arguments=v["arguments"], function_call_id=v["function_call_id"])
|
||||
elif "name" in v or "arguments" in v or "function_call_id" in v:
|
||||
return FunctionCallDelta(name=v.get("name"), arguments=v.get("arguments"), function_call_id=v.get("function_call_id"))
|
||||
else:
|
||||
raise ValueError("function_call must contain either 'name' or 'arguments'")
|
||||
return v
|
||||
@@ -147,11 +149,13 @@ class FunctionReturn(LettaMessage):
|
||||
status (Literal["success", "error"]): The status of the function call
|
||||
id (str): The ID of the message
|
||||
date (datetime): The date the message was created in ISO format
|
||||
function_call_id (str): A unique identifier for the function call that generated this message
|
||||
"""
|
||||
|
||||
message_type: Literal["function_return"] = "function_return"
|
||||
function_return: str
|
||||
status: Literal["success", "error"]
|
||||
function_call_id: str
|
||||
|
||||
|
||||
# Legacy Letta API had an additional type "assistant_message" and the "function_call" was a formatted string
|
||||
|
||||
@@ -178,6 +178,7 @@ class Message(BaseMessage):
|
||||
function_call=FunctionCall(
|
||||
name=tool_call.function.name,
|
||||
arguments=tool_call.function.arguments,
|
||||
function_call_id=tool_call.id,
|
||||
),
|
||||
)
|
||||
)
|
||||
@@ -203,6 +204,7 @@ class Message(BaseMessage):
|
||||
raise ValueError(f"Invalid status: {status}")
|
||||
except json.JSONDecodeError:
|
||||
raise ValueError(f"Failed to decode function return: {self.text}")
|
||||
assert self.tool_call_id is not None
|
||||
messages.append(
|
||||
# TODO make sure this is what the API returns
|
||||
# function_return may not match exactly...
|
||||
@@ -211,6 +213,7 @@ class Message(BaseMessage):
|
||||
date=self.created_at,
|
||||
function_return=self.text,
|
||||
status=status_enum,
|
||||
function_call_id=self.tool_call_id,
|
||||
)
|
||||
)
|
||||
elif self.role == MessageRole.user:
|
||||
|
||||
@@ -531,7 +531,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
||||
function_call=FunctionCallDelta(
|
||||
name=tool_call_delta.get("name"),
|
||||
arguments=tool_call_delta.get("arguments"),
|
||||
function_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
)
|
||||
|
||||
else:
|
||||
@@ -548,7 +552,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
||||
function_call=FunctionCallDelta(
|
||||
name=tool_call_delta.get("name"),
|
||||
arguments=tool_call_delta.get("arguments"),
|
||||
function_call_id=tool_call_delta.get("id"),
|
||||
),
|
||||
)
|
||||
|
||||
elif choice.finish_reason is not None:
|
||||
@@ -759,6 +767,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
function_call=FunctionCall(
|
||||
name=function_call.function.name,
|
||||
arguments=function_call.function.arguments,
|
||||
function_call_id=function_call.id,
|
||||
),
|
||||
)
|
||||
|
||||
@@ -786,21 +795,25 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
elif msg.startswith("Success: "):
|
||||
msg = msg.replace("Success: ", "")
|
||||
# new_message = {"function_return": msg, "status": "success"}
|
||||
assert msg_obj.tool_call_id is not None
|
||||
new_message = FunctionReturn(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_return=msg,
|
||||
status="success",
|
||||
function_call_id=msg_obj.tool_call_id,
|
||||
)
|
||||
|
||||
elif msg.startswith("Error: "):
|
||||
msg = msg.replace("Error: ", "")
|
||||
# new_message = {"function_return": msg, "status": "error"}
|
||||
assert msg_obj.tool_call_id is not None
|
||||
new_message = FunctionReturn(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_return=msg,
|
||||
status="error",
|
||||
function_call_id=msg_obj.tool_call_id,
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
@@ -488,6 +488,9 @@ def is_utc_datetime(dt: datetime) -> bool:
|
||||
|
||||
|
||||
def get_tool_call_id() -> str:
|
||||
# TODO(sarah) make this a slug-style string?
|
||||
# e.g. OpenAI: "call_xlIfzR1HqAW7xJPa3ExJSg3C"
|
||||
# or similar to agents: "call-xlIfzR1HqAW7xJPa3ExJSg3C"
|
||||
return str(uuid.uuid4())[:TOOL_CALL_ID_MAX_LEN]
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user