from abc import ABC, abstractmethod import json import re import sys from typing import List, Optional # from colorama import Fore, Style, init from rich.console import Console from rich.live import Live from rich.markup import escape from rich.style import Style from rich.text import Text from memgpt.utils import printd from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT from memgpt.data_types import Message from memgpt.models.chat_completion_response import ChatCompletionChunkResponse, ChatCompletionResponse from memgpt.interface import AgentInterface, CLIInterface # init(autoreset=True) # DEBUG = True # puts full message outputs in the terminal DEBUG = False # only dumps important messages in the terminal STRIP_UI = False class AgentChunkStreamingInterface(ABC): """Interfaces handle MemGPT-related events (observer pattern) The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. """ @abstractmethod def user_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT receives a user message""" raise NotImplementedError @abstractmethod def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT generates some internal monologue""" raise NotImplementedError @abstractmethod def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT uses send_message""" raise NotImplementedError @abstractmethod def function_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT calls a function""" raise NotImplementedError @abstractmethod def process_chunk(self, chunk: ChatCompletionChunkResponse): """Process a streaming chunk from an OpenAI-compatible server""" raise NotImplementedError @abstractmethod def stream_start(self): """Any setup required before streaming begins""" raise NotImplementedError @abstractmethod def stream_end(self): """Any cleanup required after streaming ends""" raise NotImplementedError class StreamingCLIInterface(AgentChunkStreamingInterface): """Version of the CLI interface that attaches to a stream generator and prints along the way. When a chunk is received, we write the delta to the buffer. If the buffer type has changed, we write out a newline + set the formatting for the new line. The two buffer types are: (1) content (inner thoughts) (2) tool_calls (function calling) NOTE: this assumes that the deltas received in the chunks are in-order, e.g. that once 'content' deltas stop streaming, they won't be received again. See notes on alternative version of the StreamingCLIInterface that does not have this same problem below: An alternative implementation could instead maintain the partial message state, and on each process chunk (1) update the partial message state, (2) refresh/rewrite the state to the screen. """ # CLIInterface is static/stateless nonstreaming_interface = CLIInterface() def __init__(self): """The streaming CLI interface state for determining which buffer is currently being written to""" self.streaming_buffer_type = None def _flush(self): pass def process_chunk(self, chunk: ChatCompletionChunkResponse): assert len(chunk.choices) == 1, chunk message_delta = chunk.choices[0].delta # Starting a new buffer line if not self.streaming_buffer_type: assert not ( message_delta.content is not None and message_delta.tool_calls is not None and len(message_delta.tool_calls) ), f"Error: got both content and tool_calls in message stream\n{message_delta}" if message_delta.content is not None: # Write out the prefix for inner thoughts print("Inner thoughts: ", end="", flush=True) elif message_delta.tool_calls is not None: assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" # Write out the prefix for function calling print("Calling function: ", end="", flush=True) # Potentially switch/flush a buffer line else: pass # Write out the delta if message_delta.content is not None: if self.streaming_buffer_type and self.streaming_buffer_type != "content": print() self.streaming_buffer_type = "content" # Simple, just write out to the buffer print(message_delta.content, end="", flush=True) elif message_delta.tool_calls is not None: if self.streaming_buffer_type and self.streaming_buffer_type != "tool_calls": print() self.streaming_buffer_type = "tool_calls" assert len(message_delta.tool_calls) == 1, f"Error: got more than one tool call in response\n{message_delta}" function_call = message_delta.tool_calls[0].function # Slightly more complex - want to write parameters in a certain way (paren-style) # function_name(function_args) if function_call.name: # NOTE: need to account for closing the brace later print(f"{function_call.name}(", end="", flush=True) if function_call.arguments: print(function_call.arguments, end="", flush=True) def stream_start(self): # should be handled by stream_end(), but just in case self.streaming_buffer_type = None def stream_end(self): if self.streaming_buffer_type is not None: # TODO: should have a separate self.tool_call_open_paren flag if self.streaming_buffer_type == "tool_calls": print(")", end="", flush=True) print() # newline to move the cursor self.streaming_buffer_type = None # reset buffer tracker @staticmethod def important_message(msg: str): StreamingCLIInterface.nonstreaming_interface(msg) @staticmethod def warning_message(msg: str): StreamingCLIInterface.nonstreaming_interface(msg) @staticmethod def internal_monologue(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def assistant_message(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def memory_message(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def system_message(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): StreamingCLIInterface.nonstreaming_interface(msg, msg_obj) @staticmethod def print_messages(message_sequence: List[Message], dump=False): StreamingCLIInterface.nonstreaming_interface(message_sequence, dump) @staticmethod def print_messages_simple(message_sequence: List[Message]): StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) @staticmethod def print_messages_raw(message_sequence: List[Message]): StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) @staticmethod def step_yield(): pass class AgentRefreshStreamingInterface(ABC): """Same as the ChunkStreamingInterface, but The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. """ @abstractmethod def user_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT receives a user message""" raise NotImplementedError @abstractmethod def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT generates some internal monologue""" raise NotImplementedError @abstractmethod def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT uses send_message""" raise NotImplementedError @abstractmethod def function_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT calls a function""" raise NotImplementedError @abstractmethod def process_refresh(self, response: ChatCompletionResponse): """Process a streaming chunk from an OpenAI-compatible server""" raise NotImplementedError @abstractmethod def stream_start(self): """Any setup required before streaming begins""" raise NotImplementedError @abstractmethod def stream_end(self): """Any cleanup required after streaming ends""" raise NotImplementedError @abstractmethod def toggle_streaming(self, on: bool): """Toggle streaming on/off (off = regular CLI interface)""" raise NotImplementedError class StreamingRefreshCLIInterface(AgentRefreshStreamingInterface): """Version of the CLI interface that attaches to a stream generator and refreshes a render of the message at every step. We maintain the partial message state in the interface state, and on each process chunk we: (1) update the partial message state, (2) refresh/rewrite the state to the screen. """ nonstreaming_interface = CLIInterface def __init__(self, fancy: bool = True, separate_send_message: bool = True, disable_inner_mono_call: bool = True): """Initialize the streaming CLI interface state.""" self.console = Console() # Using `Live` with `refresh_per_second` parameter to limit the refresh rate, avoiding excessive updates self.live = Live("", console=self.console, refresh_per_second=10) # self.live.start() # Start the Live display context and keep it running # Use italics / emoji? self.fancy = fancy self.streaming = True self.separate_send_message = separate_send_message self.disable_inner_mono_call = disable_inner_mono_call def toggle_streaming(self, on: bool): self.streaming = on if on: self.separate_send_message = True self.disable_inner_mono_call = True else: self.separate_send_message = False self.disable_inner_mono_call = False def update_output(self, content: str): """Update the displayed output with new content.""" # We use the `Live` object's update mechanism to refresh content without clearing the console if not self.fancy: content = escape(content) self.live.update(self.console.render_str(content), refresh=True) def process_refresh(self, response: ChatCompletionResponse): """Process the response to rewrite the current output buffer.""" if not response.choices: self.update_output("💭 [italic]...[/italic]") return # Early exit if there are no choices choice = response.choices[0] inner_thoughts = choice.message.content if choice.message.content else "" tool_calls = choice.message.tool_calls if choice.message.tool_calls else [] if self.fancy: message_string = f"💭 [italic]{inner_thoughts}[/italic]" if inner_thoughts else "" else: message_string = "[inner thoughts] " + inner_thoughts if inner_thoughts else "" if tool_calls: function_call = tool_calls[0].function function_name = function_call.name # Function name, can be an empty string function_args = function_call.arguments # Function arguments, can be an empty string if message_string: message_string += "\n" # special case here for send_message if self.separate_send_message and function_name == "send_message": try: message = json.loads(function_args)["message"] except: prefix = '{\n "message": "' if len(function_args) < len(prefix): message = "..." elif function_args.startswith(prefix): message = function_args[len(prefix) :] else: message = function_args message_string += f"🤖 [bold yellow]{message}[/bold yellow]" else: message_string += f"{function_name}({function_args})" self.update_output(message_string) def stream_start(self): if self.streaming: print() self.live.start() # Start the Live display context and keep it running self.update_output("💭 [italic]...[/italic]") def stream_end(self): if self.streaming: if self.live.is_started: self.live.stop() print() self.live = Live("", console=self.console, refresh_per_second=10) @staticmethod def important_message(msg: str): StreamingCLIInterface.nonstreaming_interface.important_message(msg) @staticmethod def warning_message(msg: str): StreamingCLIInterface.nonstreaming_interface.warning_message(msg) def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): if self.disable_inner_mono_call: return StreamingCLIInterface.nonstreaming_interface.internal_monologue(msg, msg_obj) def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): if self.separate_send_message: return StreamingCLIInterface.nonstreaming_interface.assistant_message(msg, msg_obj) @staticmethod def memory_message(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface.memory_message(msg, msg_obj) @staticmethod def system_message(msg: str, msg_obj: Optional[Message] = None): StreamingCLIInterface.nonstreaming_interface.system_message(msg, msg_obj) @staticmethod def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): StreamingCLIInterface.nonstreaming_interface.user_message(msg, msg_obj) @staticmethod def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): StreamingCLIInterface.nonstreaming_interface.function_message(msg, msg_obj) @staticmethod def print_messages(message_sequence: List[Message], dump=False): StreamingCLIInterface.nonstreaming_interface.print_messages(message_sequence, dump) @staticmethod def print_messages_simple(message_sequence: List[Message]): StreamingCLIInterface.nonstreaming_interface.print_messages_simple(message_sequence) @staticmethod def print_messages_raw(message_sequence: List[Message]): StreamingCLIInterface.nonstreaming_interface.print_messages_raw(message_sequence) @staticmethod def step_yield(): pass