feat: add back support for using AssistantMessage subtype of LettaMessage (#1812)
This commit is contained in:
@@ -1,10 +1,12 @@
|
||||
import asyncio
|
||||
import json
|
||||
import queue
|
||||
import warnings
|
||||
from collections import deque
|
||||
from datetime import datetime
|
||||
from typing import AsyncGenerator, Literal, Optional, Union
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.interface import AgentInterface
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import (
|
||||
@@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface):
|
||||
class FunctionArgumentsStreamHandler:
|
||||
"""State machine that can process a stream of"""
|
||||
|
||||
def __init__(self, json_key="message"):
|
||||
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
|
||||
self.json_key = json_key
|
||||
self.reset()
|
||||
|
||||
@@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
should maintain multiple generators and index them with the request ID
|
||||
"""
|
||||
|
||||
def __init__(self, multi_step=True):
|
||||
def __init__(
|
||||
self,
|
||||
multi_step=True,
|
||||
use_assistant_message=False,
|
||||
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
|
||||
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
):
|
||||
# If streaming mode, ignores base interface calls like .assistant_message, etc
|
||||
self.streaming_mode = False
|
||||
# NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially
|
||||
@@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
|
||||
# If chat completion mode, we need a special stream reader to
|
||||
# turn function argument to send_message into a normal text stream
|
||||
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler()
|
||||
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
|
||||
|
||||
self._chunks = deque()
|
||||
self._event = asyncio.Event() # Use an event to notify when chunks are available
|
||||
@@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
self.multi_step_indicator = MessageStreamStatus.done_step
|
||||
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
|
||||
|
||||
# Support for AssistantMessage
|
||||
self.use_assistant_message = use_assistant_message
|
||||
self.assistant_message_function_name = assistant_message_function_name
|
||||
self.assistant_message_function_kwarg = assistant_message_function_kwarg
|
||||
|
||||
# extra prints
|
||||
self.debug = False
|
||||
self.timeout = 30
|
||||
@@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
def _process_chunk_to_letta_style(
|
||||
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
|
||||
) -> Optional[Union[InternalMonologue, FunctionCallMessage]]:
|
||||
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
|
||||
"""
|
||||
Example data from non-streaming response looks like:
|
||||
|
||||
@@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
date=message_date,
|
||||
internal_monologue=message_delta.content,
|
||||
)
|
||||
|
||||
# tool calls
|
||||
elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
|
||||
tool_call = message_delta.tool_calls[0]
|
||||
|
||||
tool_call_delta = {}
|
||||
if tool_call.id:
|
||||
tool_call_delta["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.arguments:
|
||||
tool_call_delta["arguments"] = tool_call.function.arguments
|
||||
if tool_call.function.name:
|
||||
tool_call_delta["name"] = tool_call.function.name
|
||||
# special case for trapping `send_message`
|
||||
if self.use_assistant_message and tool_call.function:
|
||||
|
||||
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
|
||||
|
||||
# Track the function name while streaming
|
||||
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
|
||||
if tool_call.function.name:
|
||||
if self.streaming_chat_completion_mode_function_name is None:
|
||||
self.streaming_chat_completion_mode_function_name = tool_call.function.name
|
||||
else:
|
||||
self.streaming_chat_completion_mode_function_name += tool_call.function.name
|
||||
|
||||
# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
|
||||
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
|
||||
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
|
||||
if tool_call.function.name == self.assistant_message_function_name:
|
||||
self.streaming_chat_completion_json_reader.reset()
|
||||
# early exit to turn into content mode
|
||||
return None
|
||||
|
||||
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
|
||||
if (
|
||||
tool_call.function.arguments
|
||||
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
|
||||
):
|
||||
# Strip out any extras tokens
|
||||
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
|
||||
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
|
||||
if cleaned_func_args is None:
|
||||
return None
|
||||
else:
|
||||
processed_chunk = AssistantMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
assistant_message=cleaned_func_args,
|
||||
)
|
||||
|
||||
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
|
||||
else:
|
||||
tool_call_delta = {}
|
||||
if tool_call.id:
|
||||
tool_call_delta["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.arguments:
|
||||
tool_call_delta["arguments"] = tool_call.function.arguments
|
||||
if tool_call.function.name:
|
||||
tool_call_delta["name"] = tool_call.function.name
|
||||
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
||||
)
|
||||
|
||||
else:
|
||||
|
||||
tool_call_delta = {}
|
||||
if tool_call.id:
|
||||
tool_call_delta["id"] = tool_call.id
|
||||
if tool_call.function:
|
||||
if tool_call.function.arguments:
|
||||
tool_call_delta["arguments"] = tool_call.function.arguments
|
||||
if tool_call.function.name:
|
||||
tool_call_delta["name"] = tool_call.function.name
|
||||
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
||||
)
|
||||
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=message_id,
|
||||
date=message_date,
|
||||
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
|
||||
)
|
||||
elif choice.finish_reason is not None:
|
||||
# skip if there's a finish
|
||||
return None
|
||||
@@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
|
||||
|
||||
else:
|
||||
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_call=FunctionCall(
|
||||
name=function_call.function.name,
|
||||
arguments=function_call.function.arguments,
|
||||
),
|
||||
)
|
||||
try:
|
||||
func_args = json.loads(function_call.function.arguments)
|
||||
except:
|
||||
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
|
||||
func_args = {}
|
||||
|
||||
if (
|
||||
self.use_assistant_message
|
||||
and function_call.function.name == self.assistant_message_function_name
|
||||
and self.assistant_message_function_kwarg in func_args
|
||||
):
|
||||
processed_chunk = AssistantMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
assistant_message=func_args[self.assistant_message_function_kwarg],
|
||||
)
|
||||
else:
|
||||
processed_chunk = FunctionCallMessage(
|
||||
id=msg_obj.id,
|
||||
date=msg_obj.created_at,
|
||||
function_call=FunctionCall(
|
||||
name=function_call.function.name,
|
||||
arguments=function_call.function.arguments,
|
||||
),
|
||||
)
|
||||
|
||||
# processed_chunk = {
|
||||
# "function_call": {
|
||||
# "name": function_call.function.name,
|
||||
|
||||
Reference in New Issue
Block a user