feat: add back support for using AssistantMessage subtype of LettaMessage (#1812)

This commit is contained in:
Charles Packer
2024-10-04 15:36:33 -07:00
committed by GitHub
parent e0442bd658
commit b17246a3b0
7 changed files with 300 additions and 53 deletions

View File

@@ -1,10 +1,12 @@
import asyncio
import json
import queue
import warnings
from collections import deque
from datetime import datetime
from typing import AsyncGenerator, Literal, Optional, Union
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.interface import AgentInterface
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import (
@@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface):
class FunctionArgumentsStreamHandler:
"""State machine that can process a stream of"""
def __init__(self, json_key="message"):
def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG):
self.json_key = json_key
self.reset()
@@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
should maintain multiple generators and index them with the request ID
"""
def __init__(self, multi_step=True):
def __init__(
self,
multi_step=True,
use_assistant_message=False,
assistant_message_function_name=DEFAULT_MESSAGE_TOOL,
assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG,
):
# If streaming mode, ignores base interface calls like .assistant_message, etc
self.streaming_mode = False
# NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially
@@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream
# If chat completion mode, we need a special stream reader to
# turn function argument to send_message into a normal text stream
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler()
self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg)
self._chunks = deque()
self._event = asyncio.Event() # Use an event to notify when chunks are available
@@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
self.multi_step_indicator = MessageStreamStatus.done_step
self.multi_step_gen_indicator = MessageStreamStatus.done_generation
# Support for AssistantMessage
self.use_assistant_message = use_assistant_message
self.assistant_message_function_name = assistant_message_function_name
self.assistant_message_function_kwarg = assistant_message_function_kwarg
# extra prints
self.debug = False
self.timeout = 30
@@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
def _process_chunk_to_letta_style(
self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime
) -> Optional[Union[InternalMonologue, FunctionCallMessage]]:
) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]:
"""
Example data from non-streaming response looks like:
@@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
date=message_date,
internal_monologue=message_delta.content,
)
# tool calls
elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0:
tool_call = message_delta.tool_calls[0]
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
# special case for trapping `send_message`
if self.use_assistant_message and tool_call.function:
# If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode
# Track the function name while streaming
# If we were previously on a 'send_message', we need to 'toggle' into 'content' mode
if tool_call.function.name:
if self.streaming_chat_completion_mode_function_name is None:
self.streaming_chat_completion_mode_function_name = tool_call.function.name
else:
self.streaming_chat_completion_mode_function_name += tool_call.function.name
# If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk
# TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit?
# if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name:
if tool_call.function.name == self.assistant_message_function_name:
self.streaming_chat_completion_json_reader.reset()
# early exit to turn into content mode
return None
# if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks
if (
tool_call.function.arguments
and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name
):
# Strip out any extras tokens
cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments)
# In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk
if cleaned_func_args is None:
return None
else:
processed_chunk = AssistantMessage(
id=message_id,
date=message_date,
assistant_message=cleaned_func_args,
)
# otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage
else:
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
else:
tool_call_delta = {}
if tool_call.id:
tool_call_delta["id"] = tool_call.id
if tool_call.function:
if tool_call.function.arguments:
tool_call_delta["arguments"] = tool_call.function.arguments
if tool_call.function.name:
tool_call_delta["name"] = tool_call.function.name
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
processed_chunk = FunctionCallMessage(
id=message_id,
date=message_date,
function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")),
)
elif choice.finish_reason is not None:
# skip if there's a finish
return None
@@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface):
else:
processed_chunk = FunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
),
)
try:
func_args = json.loads(function_call.function.arguments)
except:
warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}")
func_args = {}
if (
self.use_assistant_message
and function_call.function.name == self.assistant_message_function_name
and self.assistant_message_function_kwarg in func_args
):
processed_chunk = AssistantMessage(
id=msg_obj.id,
date=msg_obj.created_at,
assistant_message=func_args[self.assistant_message_function_kwarg],
)
else:
processed_chunk = FunctionCallMessage(
id=msg_obj.id,
date=msg_obj.created_at,
function_call=FunctionCall(
name=function_call.function.name,
arguments=function_call.function.arguments,
),
)
# processed_chunk = {
# "function_call": {
# "name": function_call.function.name,