always cast config.context_window to int before use (#444)

* always cast config.context_window to int before use

* extra code to be super safe if self.config.context_window is somehow None
This commit is contained in:
Charles Packer
2023-11-14 15:12:00 -08:00
committed by GitHub
parent b86d3e8f96
commit 442a0ca8bf
2 changed files with 31 additions and 8 deletions

View File

@@ -18,6 +18,7 @@ from .constants import (
MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
LLM_MAX_TOKENS,
)
from .errors import LLMError
from .functions.functions import load_all_function_sets
@@ -605,7 +606,7 @@ class Agent(object):
model=self.model,
message_sequence=input_message_sequence,
functions=self.functions,
context_window=self.config.context_window,
context_window=None if self.config.context_window is None else int(self.config.context_window),
)
if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono):
break
@@ -619,7 +620,7 @@ class Agent(object):
model=self.model,
message_sequence=input_message_sequence,
functions=self.functions,
context_window=self.config.context_window,
context_window=None if self.config.context_window is None else int(self.config.context_window),
)
# Step 2: check if LLM wanted to call a function
@@ -649,16 +650,28 @@ class Agent(object):
# Check the memory pressure and potentially issue a memory pressure warning
current_total_tokens = response["usage"]["total_tokens"]
active_memory_warning = False
if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window:
# We can't do summarize logic properly if context_window is undefined
if self.config.context_window is None:
# Fallback if for some reason context_window is missing, just set to the default
print(f"WARNING: could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}")
print(f"{self.config}")
self.config.context_window = (
str(LLM_MAX_TOKENS[self.model])
if (self.model is not None and self.model in LLM_MAX_TOKENS)
else str(LLM_MAX_TOKENS["DEFAULT"])
)
if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window):
printd(
f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}"
f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}"
)
# Only deliver the alert if we haven't already (this period)
if not self.agent_alerted_about_memory_pressure:
active_memory_warning = True
self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this
else:
printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}")
printd(
f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}"
)
self.append_to_messages(all_new_messages)
return all_new_messages, heartbeat_request, function_failed, active_memory_warning
@@ -729,8 +742,18 @@ class Agent(object):
message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message
printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}")
# We can't do summarize logic properly if context_window is undefined
if self.config.context_window is None:
# Fallback if for some reason context_window is missing, just set to the default
print(f"WARNING: could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}")
print(f"{self.config}")
self.config.context_window = (
str(LLM_MAX_TOKENS[self.model])
if (self.model is not None and self.model in LLM_MAX_TOKENS)
else str(LLM_MAX_TOKENS["DEFAULT"])
)
summary = summarize_messages(
model=self.model, context_window=self.config.context_window, message_sequence_to_summarize=message_sequence_to_summarize
model=self.model, context_window=int(self.config.context_window), message_sequence_to_summarize=message_sequence_to_summarize
)
printd(f"Got summary: {summary}")

View File

@@ -100,9 +100,9 @@ def run(
sys.stdout = original_stdout
# overwrite the context_window if specified
if context_window is not None and int(context_window) != config.context_window:
if context_window is not None and int(context_window) != int(config.context_window):
typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW)
config.context_window = context_window
config.context_window = str(context_window)
# create agent config
if agent and AgentConfig.exists(agent): # use existing agent