diff --git a/letta/constants.py b/letta/constants.py index dc0a17c0..84fa0a76 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -46,6 +46,12 @@ BASE_TOOLS = [ "archival_memory_search", ] +# The name of the tool used to send message to the user +# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) +# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent) +DEFAULT_MESSAGE_TOOL = "send_message" +DEFAULT_MESSAGE_TOOL_KWARG = "message" + # LOGGER_LOG_LEVEL is use to convert Text to Logging level value for logging mostly for Cli input to setting level LOGGER_LOG_LEVELS = {"CRITICAL": CRITICAL, "ERROR": ERROR, "WARN": WARN, "WARNING": WARNING, "INFO": INFO, "DEBUG": DEBUG, "NOTSET": NOTSET} diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index b690b47b..a6e49d8b 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -2,6 +2,7 @@ from typing import List from pydantic import BaseModel, Field +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.message import MessageCreate @@ -21,3 +22,19 @@ class LettaRequest(BaseModel): default=False, description="Set True to return the raw Message object. Set False to return the Message in the format of the Letta API.", ) + + # Flags to support the use of AssistantMessage message types + + use_assistant_message: bool = Field( + default=False, + description="[Only applicable if return_message_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ) + + assistant_message_function_name: str = Field( + default=DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ) + assistant_message_function_kwarg: str = Field( + default=DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ) diff --git a/letta/schemas/message.py b/letta/schemas/message.py index d3879c0c..70aa9df9 100644 --- a/letta/schemas/message.py +++ b/letta/schemas/message.py @@ -6,11 +6,16 @@ from typing import List, Optional from pydantic import Field, field_validator -from letta.constants import TOOL_CALL_ID_MAX_LEN +from letta.constants import ( + DEFAULT_MESSAGE_TOOL, + DEFAULT_MESSAGE_TOOL_KWARG, + TOOL_CALL_ID_MAX_LEN, +) from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.schemas.enums import MessageRole from letta.schemas.letta_base import LettaBase from letta.schemas.letta_message import ( + AssistantMessage, FunctionCall, FunctionCallMessage, FunctionReturn, @@ -122,7 +127,12 @@ class Message(BaseMessage): json_message["created_at"] = self.created_at.isoformat() return json_message - def to_letta_message(self) -> List[LettaMessage]: + def to_letta_message( + self, + assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, + ) -> List[LettaMessage]: """Convert message object (in DB format) to the style used by the original Letta API""" messages = [] @@ -140,16 +150,33 @@ class Message(BaseMessage): if self.tool_calls is not None: # This is type FunctionCall for tool_call in self.tool_calls: - messages.append( - FunctionCallMessage( - id=self.id, - date=self.created_at, - function_call=FunctionCall( - name=tool_call.function.name, - arguments=tool_call.function.arguments, - ), + # If we're supporting using assistant message, + # then we want to treat certain function calls as a special case + if assistant_message and tool_call.function.name == assistant_message_function_name: + # We need to unpack the actual message contents from the function call + try: + func_args = json.loads(tool_call.function.arguments) + message_string = func_args[DEFAULT_MESSAGE_TOOL_KWARG] + except KeyError: + raise ValueError(f"Function call {tool_call.function.name} missing {DEFAULT_MESSAGE_TOOL_KWARG} argument") + messages.append( + AssistantMessage( + id=self.id, + date=self.created_at, + assistant_message=message_string, + ) + ) + else: + messages.append( + FunctionCallMessage( + id=self.id, + date=self.created_at, + function_call=FunctionCall( + name=tool_call.function.name, + arguments=tool_call.function.arguments, + ), + ) ) - ) elif self.role == MessageRole.tool: # This is type FunctionReturn # Try to interpret the function return, recall that this is how we packaged: diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index 0715b901..b8b06d78 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -1,10 +1,12 @@ import asyncio import json import queue +import warnings from collections import deque from datetime import datetime from typing import AsyncGenerator, Literal, Optional, Union +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.interface import AgentInterface from letta.schemas.enums import MessageStreamStatus from letta.schemas.letta_message import ( @@ -249,7 +251,7 @@ class QueuingInterface(AgentInterface): class FunctionArgumentsStreamHandler: """State machine that can process a stream of""" - def __init__(self, json_key="message"): + def __init__(self, json_key=DEFAULT_MESSAGE_TOOL_KWARG): self.json_key = json_key self.reset() @@ -311,7 +313,13 @@ class StreamingServerInterface(AgentChunkStreamingInterface): should maintain multiple generators and index them with the request ID """ - def __init__(self, multi_step=True): + def __init__( + self, + multi_step=True, + use_assistant_message=False, + assistant_message_function_name=DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, + ): # If streaming mode, ignores base interface calls like .assistant_message, etc self.streaming_mode = False # NOTE: flag for supporting legacy 'stream' flag where send_message is treated specially @@ -321,7 +329,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.streaming_chat_completion_mode_function_name = None # NOTE: sadly need to track state during stream # If chat completion mode, we need a special stream reader to # turn function argument to send_message into a normal text stream - self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler() + self.streaming_chat_completion_json_reader = FunctionArgumentsStreamHandler(json_key=assistant_message_function_kwarg) self._chunks = deque() self._event = asyncio.Event() # Use an event to notify when chunks are available @@ -333,6 +341,11 @@ class StreamingServerInterface(AgentChunkStreamingInterface): self.multi_step_indicator = MessageStreamStatus.done_step self.multi_step_gen_indicator = MessageStreamStatus.done_generation + # Support for AssistantMessage + self.use_assistant_message = use_assistant_message + self.assistant_message_function_name = assistant_message_function_name + self.assistant_message_function_kwarg = assistant_message_function_kwarg + # extra prints self.debug = False self.timeout = 30 @@ -441,7 +454,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): def _process_chunk_to_letta_style( self, chunk: ChatCompletionChunkResponse, message_id: str, message_date: datetime - ) -> Optional[Union[InternalMonologue, FunctionCallMessage]]: + ) -> Optional[Union[InternalMonologue, FunctionCallMessage, AssistantMessage]]: """ Example data from non-streaming response looks like: @@ -461,23 +474,83 @@ class StreamingServerInterface(AgentChunkStreamingInterface): date=message_date, internal_monologue=message_delta.content, ) + + # tool calls elif message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: tool_call = message_delta.tool_calls[0] - tool_call_delta = {} - if tool_call.id: - tool_call_delta["id"] = tool_call.id - if tool_call.function: - if tool_call.function.arguments: - tool_call_delta["arguments"] = tool_call.function.arguments - if tool_call.function.name: - tool_call_delta["name"] = tool_call.function.name + # special case for trapping `send_message` + if self.use_assistant_message and tool_call.function: + + # If we just received a chunk with the message in it, we either enter "send_message" mode, or we do standard FunctionCallMessage passthrough mode + + # Track the function name while streaming + # If we were previously on a 'send_message', we need to 'toggle' into 'content' mode + if tool_call.function.name: + if self.streaming_chat_completion_mode_function_name is None: + self.streaming_chat_completion_mode_function_name = tool_call.function.name + else: + self.streaming_chat_completion_mode_function_name += tool_call.function.name + + # If we get a "hit" on the special keyword we're looking for, we want to skip to the next chunk + # TODO I don't think this handles the function name in multi-pieces problem. Instead, we should probably reset the streaming_chat_completion_mode_function_name when we make this hit? + # if self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name: + if tool_call.function.name == self.assistant_message_function_name: + self.streaming_chat_completion_json_reader.reset() + # early exit to turn into content mode + return None + + # if we're in the middle of parsing a send_message, we'll keep processing the JSON chunks + if ( + tool_call.function.arguments + and self.streaming_chat_completion_mode_function_name == self.assistant_message_function_name + ): + # Strip out any extras tokens + cleaned_func_args = self.streaming_chat_completion_json_reader.process_json_chunk(tool_call.function.arguments) + # In the case that we just have the prefix of something, no message yet, then we should early exit to move to the next chunk + if cleaned_func_args is None: + return None + else: + processed_chunk = AssistantMessage( + id=message_id, + date=message_date, + assistant_message=cleaned_func_args, + ) + + # otherwise we just do a regular passthrough of a FunctionCallDelta via a FunctionCallMessage + else: + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) + + else: + + tool_call_delta = {} + if tool_call.id: + tool_call_delta["id"] = tool_call.id + if tool_call.function: + if tool_call.function.arguments: + tool_call_delta["arguments"] = tool_call.function.arguments + if tool_call.function.name: + tool_call_delta["name"] = tool_call.function.name + + processed_chunk = FunctionCallMessage( + id=message_id, + date=message_date, + function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), + ) - processed_chunk = FunctionCallMessage( - id=message_id, - date=message_date, - function_call=FunctionCallDelta(name=tool_call_delta.get("name"), arguments=tool_call_delta.get("arguments")), - ) elif choice.finish_reason is not None: # skip if there's a finish return None @@ -663,14 +736,32 @@ class StreamingServerInterface(AgentChunkStreamingInterface): else: - processed_chunk = FunctionCallMessage( - id=msg_obj.id, - date=msg_obj.created_at, - function_call=FunctionCall( - name=function_call.function.name, - arguments=function_call.function.arguments, - ), - ) + try: + func_args = json.loads(function_call.function.arguments) + except: + warnings.warn(f"Failed to parse function arguments: {function_call.function.arguments}") + func_args = {} + + if ( + self.use_assistant_message + and function_call.function.name == self.assistant_message_function_name + and self.assistant_message_function_kwarg in func_args + ): + processed_chunk = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + assistant_message=func_args[self.assistant_message_function_kwarg], + ) + else: + processed_chunk = FunctionCallMessage( + id=msg_obj.id, + date=msg_obj.created_at, + function_call=FunctionCall( + name=function_call.function.name, + arguments=function_call.function.arguments, + ), + ) + # processed_chunk = { # "function_call": { # "name": function_call.function.name, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 514db4c0..cf4a8a64 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse, StreamingResponse from starlette.responses import StreamingResponse +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.agent import AgentState, CreateAgent, UpdateAgentState from letta.schemas.enums import MessageRole, MessageStreamStatus from letta.schemas.letta_message import ( @@ -254,6 +255,19 @@ def get_agent_messages( before: Optional[str] = Query(None, description="Message before which to retrieve the returned messages."), limit: int = Query(10, description="Maximum number of messages to retrieve."), msg_object: bool = Query(False, description="If true, returns Message objects. If false, return LettaMessage objects."), + # Flags to support the use of AssistantMessage message types + use_assistant_message: bool = Query( + False, + description="[Only applicable if msg_object is False] If true, returns AssistantMessage objects when the agent calls a designated message tool. If false, return FunctionCallMessage objects for all tool calls.", + ), + assistant_message_function_name: str = Query( + DEFAULT_MESSAGE_TOOL, + description="[Only applicable if use_assistant_message is True] The name of the designated message tool.", + ), + assistant_message_function_kwarg: str = Query( + DEFAULT_MESSAGE_TOOL_KWARG, + description="[Only applicable if use_assistant_message is True] The name of the message argument in the designated message tool.", + ), ): """ Retrieve message history for an agent. @@ -267,6 +281,9 @@ def get_agent_messages( limit=limit, reverse=True, return_message_object=msg_object, + use_assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, ) @@ -310,6 +327,10 @@ async def send_message( stream_steps=request.stream_steps, stream_tokens=request.stream_tokens, return_message_object=request.return_message_object, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_function_name=request.assistant_message_function_name, + assistant_message_function_kwarg=request.assistant_message_function_kwarg, ) @@ -322,12 +343,17 @@ async def send_message_to_agent( message: str, stream_steps: bool, stream_tokens: bool, - return_message_object: bool, # Should be True for Python Client, False for REST API - chat_completion_mode: Optional[bool] = False, - timestamp: Optional[datetime] = None, # related to whether or not we return `LettaMessage`s or `Message`s + return_message_object: bool, # Should be True for Python Client, False for REST API + chat_completion_mode: bool = False, + timestamp: Optional[datetime] = None, + # Support for AssistantMessage + use_assistant_message: bool = False, + assistant_message_function_name: str = DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[StreamingResponse, LettaResponse]: """Split off into a separate function so that it can be imported in the /chat/completion proxy.""" + # TODO: @charles is this the correct way to handle? include_final_message = True @@ -368,6 +394,11 @@ async def send_message_to_agent( # streaming_interface.allow_assistant_message = stream # streaming_interface.function_call_legacy_mode = stream + # Allow AssistantMessage is desired by client + streaming_interface.use_assistant_message = use_assistant_message + streaming_interface.assistant_message_function_name = assistant_message_function_name + streaming_interface.assistant_message_function_kwarg = assistant_message_function_kwarg + # Offload the synchronous message_func to a separate thread streaming_interface.stream_start() task = asyncio.create_task( @@ -408,6 +439,7 @@ async def send_message_to_agent( message_ids = [m.id for m in filtered_stream] message_ids = deduplicate(message_ids) message_objs = [server.get_agent_message(agent_id=agent_id, message_id=m_id) for m_id in message_ids] + message_objs = [m for m in message_objs if m is not None] return LettaResponse(messages=message_objs, usage=usage) else: return LettaResponse(messages=filtered_stream, usage=usage) diff --git a/letta/server/server.py b/letta/server/server.py index 80b4c4f1..454f9881 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1262,6 +1262,9 @@ class SyncServer(Server): order: Optional[str] = "asc", reverse: Optional[bool] = False, return_message_object: bool = True, + use_assistant_message: bool = False, + assistant_message_function_name: str = constants.DEFAULT_MESSAGE_TOOL, + assistant_message_function_kwarg: str = constants.DEFAULT_MESSAGE_TOOL_KWARG, ) -> Union[List[Message], List[LettaMessage]]: if self.ms.get_user(user_id=user_id) is None: raise ValueError(f"User user_id={user_id} does not exist") @@ -1281,9 +1284,25 @@ class SyncServer(Server): if not return_message_object: # If we're GETing messages in reverse, we need to reverse the inner list (generated by to_letta_message) if reverse: - records = [msg for m in records for msg in m.to_letta_message()[::-1]] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + )[::-1] + ] else: - records = [msg for m in records for msg in m.to_letta_message()] + records = [ + msg + for m in records + for msg in m.to_letta_message( + assistant_message=use_assistant_message, + assistant_message_function_name=assistant_message_function_name, + assistant_message_function_kwarg=assistant_message_function_kwarg, + ) + ] return records diff --git a/tests/test_server.py b/tests/test_server.py index 67fa58ad..440e9833 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,16 +1,18 @@ import json import uuid +import warnings import pytest import letta.utils as utils -from letta.constants import BASE_TOOLS +from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG from letta.schemas.enums import MessageRole utils.DEBUG = True from letta.config import LettaConfig from letta.schemas.agent import CreateAgent from letta.schemas.letta_message import ( + AssistantMessage, FunctionCallMessage, FunctionReturn, InternalMonologue, @@ -236,7 +238,14 @@ def test_get_archival_memory(server, user_id, agent_id): assert len(passage_none) == 0 -def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): +def _test_get_messages_letta_format( + server, + user_id, + agent_id, + reverse=False, + # flag that determines whether or not to use AssistantMessage, or just FunctionCallMessage universally + use_assistant_message=False, +): """Reverse is off by default, the GET goes in chronological order""" messages = server.get_agent_recall_cursor( @@ -244,6 +253,8 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): agent_id=agent_id, limit=1000, reverse=reverse, + return_message_object=True, + use_assistant_message=use_assistant_message, ) # messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000) assert all(isinstance(m, Message) for m in messages) @@ -254,6 +265,7 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): limit=1000, reverse=reverse, return_message_object=False, + use_assistant_message=use_assistant_message, ) # letta_messages = server.get_agent_messages(agent_id=agent_id, start=0, count=1000, return_message_object=False) assert all(isinstance(m, LettaMessage) for m in letta_messages) @@ -316,9 +328,30 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] if message.text is not None: assert isinstance(letta_message, InternalMonologue) @@ -341,11 +374,32 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): # If there are multiple tool calls, we should have multiple back to back FunctionCallMessages if message.tool_calls is not None: for tool_call in message.tool_calls: - assert isinstance(letta_message, FunctionCallMessage) - assert tool_call.function.name == letta_message.function_call.name - assert tool_call.function.arguments == letta_message.function_call.arguments - letta_message_index += 1 - letta_message = letta_messages[letta_message_index] + + # Try to parse the tool call args + try: + func_args = json.loads(tool_call.function.arguments) + except: + warnings.warn(f"Function call arguments are not valid JSON: {tool_call.function.arguments}") + func_args = {} + + # If assistant_message is True, we expect FunctionCallMessage to be AssistantMessage if the tool call is the assistant message tool + if ( + use_assistant_message + and tool_call.function.name == DEFAULT_MESSAGE_TOOL + and DEFAULT_MESSAGE_TOOL_KWARG in func_args + ): + assert isinstance(letta_message, AssistantMessage) + assert func_args[DEFAULT_MESSAGE_TOOL_KWARG] == letta_message.assistant_message + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] + + # Otherwise, we expect even a "send_message" tool call to be a FunctionCallMessage + else: + assert isinstance(letta_message, FunctionCallMessage) + assert tool_call.function.name == letta_message.function_call.name + assert tool_call.function.arguments == letta_message.function_call.arguments + letta_message_index += 1 + letta_message = letta_messages[letta_message_index] elif message.role == MessageRole.user: print(f"i={i}, M=user, MM={type(letta_message)}") @@ -374,8 +428,9 @@ def _test_get_messages_letta_format(server, user_id, agent_id, reverse=False): def test_get_messages_letta_format(server, user_id, agent_id): - _test_get_messages_letta_format(server, user_id, agent_id, reverse=False) - _test_get_messages_letta_format(server, user_id, agent_id, reverse=True) + for reverse in [False, True]: + for assistant_message in [False, True]: + _test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse, use_assistant_message=assistant_message) def test_agent_rethink_rewrite_retry(server, user_id, agent_id):