updated example and patched messages bug
This commit is contained in:
@@ -54,9 +54,9 @@ else:
|
||||
coder = create_autogen_memgpt_agent(
|
||||
"MemGPT_coder",
|
||||
persona_description="I am a 10x engineer, trained in Python. I was the first engineer at Uber (which I make sure to tell everyone I work with).",
|
||||
user_description="You are participating in a group chat with a user and a product manager (PM).",
|
||||
user_description=f"You are participating in a group chat with a user ({user_proxy.name}) and a product manager ({pm.name}).",
|
||||
# extra options
|
||||
interface_kwargs={"debug": True},
|
||||
# interface_kwargs={"debug": True},
|
||||
)
|
||||
|
||||
# Initialize the group chat between the user and two LLM agents (PM and coder)
|
||||
|
||||
@@ -69,15 +69,69 @@ def create_autogen_memgpt_agent(
|
||||
|
||||
|
||||
class MemGPTAgent(ConversableAgent):
|
||||
def __init__(self, name: str, agent: AgentAsync, skip_verify=False):
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
agent: AgentAsync,
|
||||
skip_verify=False,
|
||||
concat_other_agent_messages=False,
|
||||
):
|
||||
super().__init__(name)
|
||||
self.agent = agent
|
||||
self.skip_verify = skip_verify
|
||||
self.concat_other_agent_messages = concat_other_agent_messages
|
||||
self.register_reply(
|
||||
[Agent, None], MemGPTAgent._a_generate_reply_for_user_message
|
||||
)
|
||||
self.register_reply([Agent, None], MemGPTAgent._generate_reply_for_user_message)
|
||||
|
||||
def format_other_agent_message(self, msg):
|
||||
if "name" in msg:
|
||||
user_message = f"{msg['name']}: {msg['content']}"
|
||||
else:
|
||||
user_message = msg["content"]
|
||||
return user_message
|
||||
|
||||
def find_last_user_message(self):
|
||||
last_user_message = None
|
||||
for msg in self.agent.messages:
|
||||
if msg["role"] == "user":
|
||||
last_user_message = msg["content"]
|
||||
return last_user_message
|
||||
|
||||
def find_new_messages(self, entire_message_list):
|
||||
"""Extract the subset of messages that's actually new"""
|
||||
|
||||
if len(self.agent.messages) <= 1:
|
||||
# if len == 1, it's only the system message, so everything must be new
|
||||
return entire_message_list
|
||||
|
||||
# Find where the last message was in the message history
|
||||
last_seen_message = self.find_last_user_message()
|
||||
# print(
|
||||
# f"XXX there are {len(entire_message_list)} total messages, the last seen message was:\n{last_seen_message}"
|
||||
# )
|
||||
new_message_idx = 0
|
||||
for i, msg in enumerate(entire_message_list):
|
||||
user_message = system.package_user_message(
|
||||
self.format_other_agent_message(msg)
|
||||
)
|
||||
# Once we see the "final message" in the entire message list, this is where the history stops
|
||||
if self.concat_other_agent_messages:
|
||||
# Check if the message is inside
|
||||
# FIXME hacky, doesn't handle repeat message scenarios
|
||||
if self.format_other_agent_message(msg) in last_seen_message:
|
||||
new_message_idx = i + 1
|
||||
else:
|
||||
if user_message == last_seen_message:
|
||||
new_message_idx = i + 1
|
||||
# print(f"the new message index is {new_message_idx}")
|
||||
|
||||
# New messages
|
||||
# TODO handle index error
|
||||
new_messages = entire_message_list[new_message_idx:]
|
||||
return new_messages
|
||||
|
||||
def _generate_reply_for_user_message(
|
||||
self,
|
||||
messages: Optional[List[Dict]] = None,
|
||||
@@ -101,33 +155,45 @@ class MemGPTAgent(ConversableAgent):
|
||||
# print(f"a_gen_reply messages:\n{messages}")
|
||||
self.agent.interface.reset_message_list()
|
||||
|
||||
for msg in messages:
|
||||
if "name" in msg:
|
||||
user_message_raw = f"{msg['name']}: {msg['content']}"
|
||||
else:
|
||||
user_message_raw = msg["content"]
|
||||
user_message = system.package_user_message(user_message_raw)
|
||||
while True:
|
||||
(
|
||||
new_messages,
|
||||
heartbeat_request,
|
||||
function_failed,
|
||||
token_warning,
|
||||
) = await self.agent.step(
|
||||
user_message, first_message=False, skip_verify=self.skip_verify
|
||||
new_messages = self.find_new_messages(messages)
|
||||
if len(new_messages) > 1:
|
||||
if self.concat_other_agent_messages:
|
||||
# Combine all the other messages into one message
|
||||
user_message = "\n".join(
|
||||
[self.format_other_agent_message(m) for m in new_messages]
|
||||
)
|
||||
# ret.extend(new_messages)
|
||||
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
elif function_failed:
|
||||
user_message = system.get_heartbeat(
|
||||
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
)
|
||||
elif heartbeat_request:
|
||||
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
||||
else:
|
||||
break
|
||||
else:
|
||||
# Extend the MemGPT message list with multiple 'user' messages, then push the last one with agent.step()
|
||||
self.agent.messages.extend(new_messages[:-1])
|
||||
user_message = new_messages[-1]
|
||||
else:
|
||||
user_message = new_messages[0]
|
||||
|
||||
# Package the user message
|
||||
user_message = system.package_user_message(user_message)
|
||||
|
||||
# Send a single message into MemGPT
|
||||
while True:
|
||||
(
|
||||
new_messages,
|
||||
heartbeat_request,
|
||||
function_failed,
|
||||
token_warning,
|
||||
) = await self.agent.step(
|
||||
user_message, first_message=False, skip_verify=self.skip_verify
|
||||
)
|
||||
# ret.extend(new_messages)
|
||||
# Skip user inputs if there's a memory warning, function execution failed, or the agent asked for control
|
||||
if token_warning:
|
||||
user_message = system.get_token_limit_warning()
|
||||
elif function_failed:
|
||||
user_message = system.get_heartbeat(
|
||||
constants.FUNC_FAILED_HEARTBEAT_MESSAGE
|
||||
)
|
||||
elif heartbeat_request:
|
||||
user_message = system.get_heartbeat(constants.REQ_HEARTBEAT_MESSAGE)
|
||||
else:
|
||||
break
|
||||
|
||||
# Pass back to AutoGen the pretty-printed calls MemGPT made to the interface
|
||||
pretty_ret = MemGPTAgent.pretty_concat(self.agent.interface.message_list)
|
||||
|
||||
Reference in New Issue
Block a user