diff --git a/memgpt/agent.py b/memgpt/agent.py index 86295e92..4db48180 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -430,9 +430,6 @@ class Agent(object): response_message.tool_calls = [response_message.tool_calls[0]] assert response_message.tool_calls is not None and len(response_message.tool_calls) > 0 - # The content if then internal monologue, not chat - self.interface.internal_monologue(response_message.content) - # generate UUID for tool call if override_tool_call_id or response_message.function_call: tool_call_id = get_tool_call_id() # needs to be a string for JSON @@ -456,6 +453,9 @@ class Agent(object): ) # extend conversation with assistant's reply printd(f"Function call message: {messages[-1]}") + # The content if then internal monologue, not chat + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) + # Step 3: call the function # Note: the JSON response may not always be valid; be sure to handle errors @@ -483,7 +483,7 @@ class Agent(object): }, ) ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}") + self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) return messages, False, True # force a heartbeat to allow agent to handle error # Failure case 2: function name is OK, but function args are bad JSON @@ -506,7 +506,7 @@ class Agent(object): }, ) ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}") + self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) return messages, False, True # force a heartbeat to allow agent to handle error # (Still parsing function args) @@ -519,7 +519,9 @@ class Agent(object): heartbeat_request = False # Failure case 3: function failed during execution - self.interface.function_message(f"Running {function_name}({function_args})") + # NOTE: the msg_obj associated with the "Running " message is the prior assistant message, not the function/tool role message + # this is because the function/tool role message is only created once the function/tool has executed/returned + self.interface.function_message(f"Running {function_name}({function_args})", msg_obj=messages[-1]) try: spec = inspect.getfullargspec(function_to_call).annotations @@ -562,12 +564,12 @@ class Agent(object): }, ) ) # extend conversation with function response - self.interface.function_message(f"Error: {error_msg}") + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) + self.interface.function_message(f"Error: {error_msg}", msg_obj=messages[-1]) return messages, False, True # force a heartbeat to allow agent to handle error # If no failures happened along the way: ... # Step 4: send the info on the function call and function response to GPT - self.interface.function_message(f"Success: {function_response_string}") messages.append( Message.dict_to_message( agent_id=self.agent_state.id, @@ -581,10 +583,11 @@ class Agent(object): }, ) ) # extend conversation with function response + self.interface.function_message(f"Ran {function_name}({function_args})", msg_obj=messages[-1]) + self.interface.function_message(f"Success: {function_response_string}", msg_obj=messages[-1]) else: # Standard non-function reply - self.interface.internal_monologue(response_message.content) messages.append( Message.dict_to_message( agent_id=self.agent_state.id, @@ -593,6 +596,7 @@ class Agent(object): openai_message_dict=response_message.model_dump(), ) ) # extend conversation with assistant's reply + self.interface.internal_monologue(response_message.content, msg_obj=messages[-1]) heartbeat_request = False function_failed = False @@ -604,7 +608,8 @@ class Agent(object): first_message: bool = False, first_message_retry_limit: int = FIRST_MESSAGE_ATTEMPTS, skip_verify: bool = False, - ) -> Tuple[List[dict], bool, bool, bool]: + return_dicts: bool = True, # if True, return dicts, if False, return Message objects + ) -> Tuple[List[Union[dict, Message]], bool, bool, bool]: """Top-level event message handler for the MemGPT agent""" try: @@ -617,7 +622,6 @@ class Agent(object): else: raise ValueError(f"Bad type for user_message: {type(user_message)}") - self.interface.user_message(user_message_text) packed_user_message = {"role": "user", "content": user_message_text} # Special handling for AutoGen messages with 'name' field try: @@ -639,6 +643,7 @@ class Agent(object): model=self.model, openai_message_dict=packed_user_message, ) + self.interface.user_message(user_message_text, msg_obj=packed_user_message_obj) input_message_sequence = self.messages + [packed_user_message] # Alternatively, the requestor can send an empty user message @@ -729,8 +734,8 @@ class Agent(object): ) self._append_to_messages(all_new_messages) - all_new_messages_dicts = [msg.to_openai_dict() for msg in all_new_messages] - return all_new_messages_dicts, heartbeat_request, function_failed, active_memory_warning, response.usage.completion_tokens + messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages + return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage.completion_tokens except Exception as e: printd(f"step() failed\nuser_message = {user_message}\nerror = {e}") @@ -741,7 +746,7 @@ class Agent(object): self.summarize_messages_inplace() # Try step again - return self.step(user_message, first_message=first_message) + return self.step(user_message, first_message=first_message, return_dicts=return_dicts) else: printd(f"step() failed with an unrecognized exception: '{str(e)}'") raise e diff --git a/memgpt/autogen/interface.py b/memgpt/autogen/interface.py index 2b1b8706..504df9c8 100644 --- a/memgpt/autogen/interface.py +++ b/memgpt/autogen/interface.py @@ -1,8 +1,10 @@ import json import re +from typing import Optional from colorama import Fore, Style, init +from memgpt.data_types import Message from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT init(autoreset=True) @@ -64,7 +66,7 @@ class AutoGenInterface(object): """Clears the buffer. Call before every agent.step() when using MemGPT+AutoGen""" self.message_list = [] - def internal_monologue(self, msg): + def internal_monologue(self, msg: str, msg_obj: Optional[Message]): # NOTE: never gets appended if self.debug: print(f"inner thoughts :: {msg}") @@ -74,14 +76,14 @@ class AutoGenInterface(object): message = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {msg}{Style.RESET_ALL}" if self.fancy else f"[MemGPT agent's inner thoughts] {msg}" print(message) - def assistant_message(self, msg): + def assistant_message(self, msg: str, msg_obj: Optional[Message]): # NOTE: gets appended if self.debug: print(f"assistant :: {msg}") # message = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{msg}{Style.RESET_ALL}" if self.fancy else msg self.message_list.append(msg) - def memory_message(self, msg): + def memory_message(self, msg: str): # NOTE: never gets appended if self.debug: print(f"memory :: {msg}") @@ -90,7 +92,7 @@ class AutoGenInterface(object): ) print(message) - def system_message(self, msg): + def system_message(self, msg: str): # NOTE: gets appended if self.debug: print(f"system :: {msg}") @@ -98,7 +100,7 @@ class AutoGenInterface(object): print(message) self.message_list.append(msg) - def user_message(self, msg, raw=False): + def user_message(self, msg: str, msg_obj: Optional[Message], raw=False): if self.debug: print(f"user :: {msg}") if not self.show_user_message: @@ -136,7 +138,7 @@ class AutoGenInterface(object): # TODO should we ever be appending this? self.message_list.append(message) - def function_message(self, msg): + def function_message(self, msg: str, msg_obj: Optional[Message]): if self.debug: print(f"function :: {msg}") if not self.show_function_outputs: diff --git a/memgpt/functions/function_sets/base.py b/memgpt/functions/function_sets/base.py index a5974f88..a3df806d 100644 --- a/memgpt/functions/function_sets/base.py +++ b/memgpt/functions/function_sets/base.py @@ -4,13 +4,14 @@ import json import math from memgpt.constants import MAX_PAUSE_HEARTBEATS, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE, JSON_ENSURE_ASCII +from memgpt.agent import Agent ### Functions / tools the agent can use # All functions should return a response string (or None) # If the function fails, throw an exception -def send_message(self, message: str) -> Optional[str]: +def send_message(self: Agent, message: str) -> Optional[str]: """ Sends a message to the human user. @@ -20,7 +21,8 @@ def send_message(self, message: str) -> Optional[str]: Returns: Optional[str]: None is always returned as this function does not produce a response. """ - self.interface.assistant_message(message) + # FIXME passing of msg_obj here is a hack, unclear if guaranteed to be the correct reference + self.interface.assistant_message(message, msg_obj=self._messages[-1]) return None @@ -36,7 +38,7 @@ Returns: """ -def pause_heartbeats(self, minutes: int) -> Optional[str]: +def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]: minutes = min(MAX_PAUSE_HEARTBEATS, minutes) # Record the current time @@ -50,7 +52,7 @@ def pause_heartbeats(self, minutes: int) -> Optional[str]: pause_heartbeats.__doc__ = pause_heartbeats_docstring -def core_memory_append(self, name: str, content: str) -> Optional[str]: +def core_memory_append(self: Agent, name: str, content: str) -> Optional[str]: """ Append to the contents of core memory. @@ -66,7 +68,7 @@ def core_memory_append(self, name: str, content: str) -> Optional[str]: return None -def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]: +def core_memory_replace(self: Agent, name: str, old_content: str, new_content: str) -> Optional[str]: """ Replace the contents of core memory. To delete memories, use an empty string for new_content. @@ -83,7 +85,7 @@ def core_memory_replace(self, name: str, old_content: str, new_content: str) -> return None -def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[str]: +def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using case-insensitive string matching. @@ -112,7 +114,7 @@ def conversation_search(self, query: str, page: Optional[int] = 0) -> Optional[s return results_str -def conversation_search_date(self, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: +def conversation_search_date(self: Agent, start_date: str, end_date: str, page: Optional[int] = 0) -> Optional[str]: """ Search prior conversation history using a date range. @@ -142,7 +144,7 @@ def conversation_search_date(self, start_date: str, end_date: str, page: Optiona return results_str -def archival_memory_insert(self, content: str) -> Optional[str]: +def archival_memory_insert(self: Agent, content: str) -> Optional[str]: """ Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later. @@ -156,7 +158,7 @@ def archival_memory_insert(self, content: str) -> Optional[str]: return None -def archival_memory_search(self, query: str, page: Optional[int] = 0) -> Optional[str]: +def archival_memory_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]: """ Search archival memory using semantic (embedding-based) search. diff --git a/memgpt/interface.py b/memgpt/interface.py index cbaec27a..edfd86a2 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -1,11 +1,13 @@ from abc import ABC, abstractmethod import json import re +from typing import List, Optional from colorama import Fore, Style, init from memgpt.utils import printd from memgpt.constants import CLI_WARNING_PREFIX, JSON_LOADS_STRICT +from memgpt.data_types import Message init(autoreset=True) @@ -16,25 +18,28 @@ STRIP_UI = False class AgentInterface(ABC): - """Interfaces handle MemGPT-related events (observer pattern)""" + """Interfaces handle MemGPT-related events (observer pattern) + + The 'msg' args provides the scoped message, and the optional Message arg can provide additional metadata. + """ @abstractmethod - def user_message(self, msg): + def user_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT receives a user message""" raise NotImplementedError @abstractmethod - def internal_monologue(self, msg): + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT generates some internal monologue""" raise NotImplementedError @abstractmethod - def assistant_message(self, msg): + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT uses send_message""" raise NotImplementedError @abstractmethod - def function_message(self, msg): + def function_message(self, msg: str, msg_obj: Optional[Message] = None): """MemGPT calls a function""" raise NotImplementedError @@ -58,14 +63,14 @@ class CLIInterface(AgentInterface): """Basic interface for dumping agent events to the command-line""" @staticmethod - def important_message(msg): + def important_message(msg: str): fstr = f"{Fore.MAGENTA}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) @staticmethod - def warning_message(msg): + def warning_message(msg: str): fstr = f"{Fore.RED}{Style.BRIGHT}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" @@ -73,7 +78,7 @@ class CLIInterface(AgentInterface): print(fstr.format(msg=msg)) @staticmethod - def internal_monologue(msg): + def internal_monologue(msg: str, msg_obj: Optional[Message] = None): # ANSI escape code for italic is '\x1B[3m' fstr = f"\x1B[3m{Fore.LIGHTBLACK_EX}💭 {{msg}}{Style.RESET_ALL}" if STRIP_UI: @@ -81,28 +86,28 @@ class CLIInterface(AgentInterface): print(fstr.format(msg=msg)) @staticmethod - def assistant_message(msg): + def assistant_message(msg: str, msg_obj: Optional[Message] = None): fstr = f"{Fore.YELLOW}{Style.BRIGHT}🤖 {Fore.YELLOW}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) @staticmethod - def memory_message(msg): + def memory_message(msg: str, msg_obj: Optional[Message] = None): fstr = f"{Fore.LIGHTMAGENTA_EX}{Style.BRIGHT}🧠 {Fore.LIGHTMAGENTA_EX}{{msg}}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) @staticmethod - def system_message(msg): + def system_message(msg: str, msg_obj: Optional[Message] = None): fstr = f"{Fore.MAGENTA}{Style.BRIGHT}🖥️ [system] {Fore.MAGENTA}{msg}{Style.RESET_ALL}" if STRIP_UI: fstr = "{msg}" print(fstr.format(msg=msg)) @staticmethod - def user_message(msg, raw=False, dump=False, debug=DEBUG): + def user_message(msg: str, msg_obj: Optional[Message] = None, raw: bool = False, dump: bool = False, debug: bool = DEBUG): def print_user_message(icon, msg, printf=print): if STRIP_UI: printf(f"{icon} {msg}") @@ -148,7 +153,8 @@ class CLIInterface(AgentInterface): printd_user_message("🧑", msg_json) @staticmethod - def function_message(msg, debug=DEBUG): + def function_message(msg: str, msg_obj: Optional[Message] = None, debug: bool = DEBUG): + def print_function_message(icon, msg, color=Fore.RED, printf=print): if STRIP_UI: printf(f"⚡{icon} [function] {msg}") @@ -166,6 +172,9 @@ class CLIInterface(AgentInterface): printd_function_message("🟢", msg) elif msg.startswith("Error: "): printd_function_message("🔴", msg) + elif msg.startswith("Ran "): + # NOTE: ignore 'ran' messages that come post-execution + return elif msg.startswith("Running "): if debug: printd_function_message("", msg) @@ -230,7 +239,10 @@ class CLIInterface(AgentInterface): printd_function_message("", msg) @staticmethod - def print_messages(message_sequence, dump=False): + def print_messages(message_sequence: List[Message], dump=False): + # rewrite to dict format + message_sequence = [msg.to_openai_dict() for msg in message_sequence] + idx = len(message_sequence) for msg in message_sequence: if dump: @@ -270,7 +282,10 @@ class CLIInterface(AgentInterface): print(f"Unknown role: {content}") @staticmethod - def print_messages_simple(message_sequence): + def print_messages_simple(message_sequence: List[Message]): + # rewrite to dict format + message_sequence = [msg.to_openai_dict() for msg in message_sequence] + for msg in message_sequence: role = msg["role"] content = msg["content"] @@ -285,7 +300,10 @@ class CLIInterface(AgentInterface): print(f"Unknown role: {content}") @staticmethod - def print_messages_raw(message_sequence): + def print_messages_raw(message_sequence: List[Message]): + # rewrite to dict format + message_sequence = [msg.to_openai_dict() for msg in message_sequence] + for msg in message_sequence: print(msg) diff --git a/memgpt/main.py b/memgpt/main.py index 010af6af..54253185 100644 --- a/memgpt/main.py +++ b/memgpt/main.py @@ -155,13 +155,13 @@ def run_agent_loop(memgpt_agent, config: MemGPTConfig, first, ms: MetadataStore, command = user_input.strip().split() amount = int(command[1]) if len(command) > 1 and command[1].isdigit() else 0 if amount == 0: - interface.print_messages(memgpt_agent.messages, dump=True) + interface.print_messages(memgpt_agent._messages, dump=True) else: - interface.print_messages(memgpt_agent.messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) + interface.print_messages(memgpt_agent._messages[-min(amount, len(memgpt_agent.messages)) :], dump=True) continue elif user_input.lower() == "/dumpraw": - interface.print_messages_raw(memgpt_agent.messages) + interface.print_messages_raw(memgpt_agent._messages) continue elif user_input.lower() == "/memory": diff --git a/memgpt/server/rest_api/interface.py b/memgpt/server/rest_api/interface.py index c9abfa28..eb8f635a 100644 --- a/memgpt/server/rest_api/interface.py +++ b/memgpt/server/rest_api/interface.py @@ -1,10 +1,12 @@ import asyncio import queue from datetime import datetime +from typing import Optional import pytz from memgpt.interface import AgentInterface +from memgpt.data_types import Message class QueuingInterface(AgentInterface): @@ -38,7 +40,8 @@ class QueuingInterface(AgentInterface): message = self.buffer.get() if message == "STOP": break - yield message | {"date": datetime.now(tz=pytz.utc).isoformat()} + # yield message | {"date": datetime.now(tz=pytz.utc).isoformat()} + yield message else: await asyncio.sleep(0.1) # Small sleep to prevent a busy loop @@ -51,38 +54,73 @@ class QueuingInterface(AgentInterface): self.buffer.put({"internal_error": error}) self.buffer.put("STOP") - def user_message(self, msg: str): + def user_message(self, msg: str, msg_obj: Optional[Message] = None): """Handle reception of a user message""" + assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" - def internal_monologue(self, msg: str) -> None: + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent's internal monologue""" + assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) - self.buffer.put({"internal_monologue": msg}) - def assistant_message(self, msg: str) -> None: + new_message = {"internal_monologue": msg} + + # add extra metadata + if msg_obj is not None: + new_message["id"] = str(msg_obj.id) + new_message["date"] = msg_obj.created_at.isoformat() + + self.buffer.put(new_message) + + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None) -> None: """Handle the agent sending a message""" + assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" if self.debug: print(msg) - self.buffer.put({"assistant_message": msg}) - def function_message(self, msg: str) -> None: + new_message = {"assistant_message": msg} + + # add extra metadata + if msg_obj is not None: + new_message["id"] = str(msg_obj.id) + new_message["date"] = msg_obj.created_at.isoformat() + + self.buffer.put(new_message) + + def function_message(self, msg: str, msg_obj: Optional[Message] = None, include_ran_messages: bool = False) -> None: """Handle the agent calling a function""" + # TODO handle 'function' messages that indicate the start of a function call + assert msg_obj is not None, "QueuingInterface requires msg_obj references for metadata" + if self.debug: print(msg) if msg.startswith("Running "): msg = msg.replace("Running ", "") - self.buffer.put({"function_call": msg}) + new_message = {"function_call": msg} + + elif msg.startswith("Ran "): + if not include_ran_messages: + return + msg = msg.replace("Ran ", "Function call returned: ") + new_message = {"function_call": msg} elif msg.startswith("Success: "): msg = msg.replace("Success: ", "") - self.buffer.put({"function_return": msg, "status": "success"}) + new_message = {"function_return": msg, "status": "success"} elif msg.startswith("Error: "): msg = msg.replace("Error: ", "") - self.buffer.put({"function_return": msg, "status": "error"}) + new_message = {"function_return": msg, "status": "error"} else: # NOTE: generic, should not happen - self.buffer.put({"function_message": msg}) + new_message = {"function_message": msg} + + # add extra metadata + if msg_obj is not None: + new_message["id"] = str(msg_obj.id) + new_message["date"] = msg_obj.created_at.isoformat() + + self.buffer.put(new_message) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index bf739208..13c0c965 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -336,7 +336,10 @@ class SyncServer(LockingServer): counter = 0 while True: new_messages, heartbeat_request, function_failed, token_warning, tokens_accumulated = memgpt_agent.step( - next_input_message, first_message=False, skip_verify=no_verify + next_input_message, + first_message=False, + skip_verify=no_verify, + return_dicts=False, ) counter += 1