fix: patch tool_call bug in summarizer (#935)

This commit is contained in:
Charles Packer
2024-01-27 17:07:45 -08:00
committed by GitHub
parent 405b3429a0
commit 8b52b26c7d
2 changed files with 28 additions and 2 deletions

View File

@@ -685,7 +685,7 @@ class Agent(object):
printd(f"step() failed with an unrecognized exception: '{str(e)}'")
raise e
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True):
def summarize_messages_inplace(self, cutoff=None, preserve_last_N_messages=True, disallow_tool_as_first=True):
assert self.messages[0]["role"] == "system", f"self.messages[0] should be system (instead got {self.messages[0]})"
# Start at index 1 (past the system message),
@@ -696,9 +696,20 @@ class Agent(object):
desired_token_count_to_summarize = int(message_buffer_token_count * MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC)
candidate_messages_to_summarize = self.messages[1:]
token_counts = token_counts[1:]
if preserve_last_N_messages:
candidate_messages_to_summarize = candidate_messages_to_summarize[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
token_counts = token_counts[:-MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST]
# if disallow_tool_as_first:
# # We have to make sure that a "tool" call is not sitting at the front (after system message),
# # otherwise we'll get an error from OpenAI (if using the OpenAI API)
# while len(candidate_messages_to_summarize) > 0:
# if candidate_messages_to_summarize[0]["role"] in ["tool", "function"]:
# candidate_messages_to_summarize.pop(0)
# else:
# break
printd(f"MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC={MESSAGE_SUMMARY_TRUNC_TOKEN_FRAC}")
printd(f"MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST={MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST}")
printd(f"token_counts={token_counts}")
@@ -734,8 +745,14 @@ class Agent(object):
except IndexError:
pass
# Make sure the cutoff isn't on a 'tool' or 'function'
if disallow_tool_as_first:
while self.messages[cutoff]["role"] in ["tool", "function"] and cutoff < len(self.messages):
printd(f"Selected cutoff {cutoff} was a 'tool', shifting one...")
cutoff += 1
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
if len(message_sequence_to_summarize) == 1:
if len(message_sequence_to_summarize) <= 1:
# This prevents a potential infinite loop of summarizing the same message over and over
raise LLMError(
f"Summarize error: tried to run summarize, but couldn't find enough messages to compress [len={len(message_sequence_to_summarize)} <= 1]"

View File

@@ -253,12 +253,21 @@ class CLIInterface(AgentInterface):
args = json.loads(msg["function_call"].get("arguments"))
CLIInterface.assistant_message(args.get("message"))
# assistant_message(content)
elif msg.get("tool_calls"):
if content is not None:
CLIInterface.internal_monologue(content)
function_obj = msg["tool_calls"][0].get("function")
if function_obj:
args = json.loads(function_obj.get("arguments"))
CLIInterface.assistant_message(args.get("message"))
else:
CLIInterface.internal_monologue(content)
elif role == "user":
CLIInterface.user_message(content, dump=dump)
elif role == "function":
CLIInterface.function_message(content, debug=dump)
elif role == "tool":
CLIInterface.function_message(content, debug=dump)
else:
print(f"Unknown role: {content}")