From 8b52b26c7ddbec2bf2002f68a9fc372b5ae12e14 Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Sat, 27 Jan 2024 17:07:45 -0800 Subject: [PATCH] fix: patch `tool_call` bug in summarizer (#935) --- memgpt/agent.py | 21 +++++++++++++++++++-- memgpt/interface.py | 9 +++++++++ 2 files changed, 28 insertions(+), 2 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index 879ae918..24432080 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -685,7 +685,7 @@ class Agent(object): printd(f"step() failed with an unrecognized exception: '{str(e)}'") raise e - def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True): + def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True): assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})" # Start at index 1 (past the system message), @@ -696,9 +696,20 @@ class Agent(object): desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC) candidate_messages_to_summarize = self.messages[1:] token_counts = token_counts[1:] + if preserve_last_N_messages: candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST] + + # if disallow_tool_as_first: + # # We have to make sure that a "tool" call is not sitting at the front (after system message), + # # otherwise we'll get an error from OpenAI (if using the OpenAI API) + # while len(candidate_messages_to_summarize) > 0: + # if candidate_messages_to_summarize[0]["role"] in ["tool", "function"]: + # candidate_messages_to_summarize.pop(0) + # else: + # break + printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}") printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}") printd(f"token_counts={token_counts}") @@ -734,8 +745,14 @@ class Agent(object): except IndexError: pass + # Make sure the cutoff isn't on a 'tool' or 'function' + if disallow_tool_as_first: + while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages): + printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...") + cutoff += 1 + message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message - if len(message_sequence_to_summarize) == 1: + if len(message_sequence_to_summarize) <= 1: # This prevents a potential infinite loop of summarizing the same message over and over raise LLMError( f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]" diff --git a/memgpt/interface.py b/memgpt/interface.py index bf57d2f1..f1fdd39c 100644 --- a/memgpt/interface.py +++ b/memgpt/interface.py @@ -253,12 +253,21 @@ class CLIInterface(AgentInterface): args = json.loads(msg["function_call"].get("arguments")) CLIInterface.assistant_message(args.get("message")) # assistant_message(content) + elif msg.get("tool_calls"): + if content is not None: + CLIInterface.internal_monologue(content) + function_obj = msg["tool_calls"][0].get("function") + if function_obj: + args = json.loads(function_obj.get("arguments")) + CLIInterface.assistant_message(args.get("message")) else: CLIInterface.internal_monologue(content) elif role == "user": CLIInterface.user_message(content, dump=dump) elif role == "function": CLIInterface.function_message(content, debug=dump) + elif role == "tool": + CLIInterface.function_message(content, debug=dump) else: print(f"Unknown role: {content}")