diff --git a/memgpt/autogen/examples/agent_groupchat.py b/memgpt/autogen/examples/agent_groupchat.py index 2fd97646..e4d8224e 100644 --- a/memgpt/autogen/examples/agent_groupchat.py +++ b/memgpt/autogen/examples/agent_groupchat.py @@ -54,9 +54,9 @@ else: coder = create_autogen_memgpt_agent( "MemGPT_coder", persona_description="I am a 10x engineer, trained in Python. I was the first engineer at Uber (which I make sure to tell everyone I work with).", - user_description="You are participating in a group chat with a user and a product manager (PM).", + user_description=f"You are participating in a group chat with a user ({user_proxy.name}) and a product manager ({pm.name}).", # extra options - interface_kwargs={"debug": True}, + # interface_kwargs={"debug": True}, ) # Initialize the group chat between the user and two LLM agents (PM and coder) diff --git a/memgpt/autogen/memgpt_agent.py b/memgpt/autogen/memgpt_agent.py index c8ac7416..68e1b5df 100644 --- a/memgpt/autogen/memgpt_agent.py +++ b/memgpt/autogen/memgpt_agent.py @@ -69,15 +69,69 @@ def create_autogen_memgpt_agent( class MemGPTAgent(ConversableAgent): - def __init__(self, name: str, agent: AgentAsync, skip_verify=False): + def __init__( + self, + name: str, + agent: AgentAsync, + skip_verify=False, + concat_other_agent_messages=False, + ): super().__init__(name) self.agent = agent self.skip_verify = skip_verify + self.concat_other_agent_messages = concat_other_agent_messages self.register_reply( [Agent, None], MemGPTAgent._a_generate_reply_for_user_message ) self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message) + def format_other_agent_message(self, msg): + if "name" in msg: + user_message = f"{msg['name']}: {msg['content']}" + else: + user_message = msg["content"] + return user_message + + def find_last_user_message(self): + last_user_message = None + for msg in self.agent.messages: + if msg["role"] == "user": + last_user_message = msg["content"] + return last_user_message + + def find_new_messages(self, entire_message_list): + """Extract the subset of messages that's actually new""" + + if len(self.agent.messages) <= 1: + # if len == 1, it's only the system message, so everything must be new + return entire_message_list + + # Find where the last message was in the message history + last_seen_message = self.find_last_user_message() + # print( + # f"XXX there are {len(entire_message_list)} total messages, the last seen message was:\n{last_seen_message}" + # ) + new_message_idx = 0 + for i, msg in enumerate(entire_message_list): + user_message = system.package_user_message( + self.format_other_agent_message(msg) + ) + # Once we see the "final message" in the entire message list, this is where the history stops + if self.concat_other_agent_messages: + # Check if the message is inside + # FIXME hacky, doesn't handle repeat message scenarios + if self.format_other_agent_message(msg) in last_seen_message: + new_message_idx = i + 1 + else: + if user_message == last_seen_message: + new_message_idx = i + 1 + # print(f"the new message index is {new_message_idx}") + + # New messages + # TODO handle index error + new_messages = entire_message_list[new_message_idx:] + return new_messages + def _generate_reply_for_user_message( self, messages: Optional[List[Dict]] = None, @@ -101,33 +155,45 @@ class MemGPTAgent(ConversableAgent): # print(f"a_gen_reply messages:\n{messages}") self.agent.interface.reset_message_list() - for msg in messages: - if "name" in msg: - user_message_raw = f"{msg['name']}: {msg['content']}" - else: - user_message_raw = msg["content"] - user_message = system.package_user_message(user_message_raw) - while True: - ( - new_messages, - heartbeat_request, - function_failed, - token_warning, - ) = await self.agent.step( - user_message, first_message=False, skip_verify=self.skip_verify + new_messages = self.find_new_messages(messages) + if len(new_messages) > 1: + if self.concat_other_agent_messages: + # Combine all the other messages into one message + user_message = "\n".join( + [self.format_other_agent_message(m) for m in new_messages] ) - # ret.extend(new_messages) - # Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control - if token_warning: - user_message = system.get_token_limit_warning() - elif function_failed: - user_message = system.get_heartbeat( - constants.FUNC_FAILED_HEARTBEAT_MESSAGE - ) - elif heartbeat_request: - user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) - else: - break + else: + # Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step() + self.agent.messages.extend(new_messages[:-1]) + user_message = new_messages[-1] + else: + user_message = new_messages[0] + + # Package the user message + user_message = system.package_user_message(user_message) + + # Send a single message into MemGPT + while True: + ( + new_messages, + heartbeat_request, + function_failed, + token_warning, + ) = await self.agent.step( + user_message, first_message=False, skip_verify=self.skip_verify + ) + # ret.extend(new_messages) + # Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control + if token_warning: + user_message = system.get_token_limit_warning() + elif function_failed: + user_message = system.get_heartbeat( + constants.FUNC_FAILED_HEARTBEAT_MESSAGE + ) + elif heartbeat_request: + user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE) + else: + break # Pass back to AutoGen the pretty-printed calls MemGPT made to the interface pretty_ret = MemGPTAgent.pretty_concat(self.agent.interface.message_list)