diff --git a/memgpt/agent.py b/memgpt/agent.py index e78e2b48..e15c9d84 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -18,6 +18,7 @@ from .constants import ( MESSAGE_SUMMARY_TRUNC_KEEP_N_LAST, CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT, + LLM_MAX_TOKENS, ) from .errors import LLMError from .functions.functions import load_all_function_sets @@ -605,7 +606,7 @@ class Agent(object): model=self.model, message_sequence=input_message_sequence, functions=self.functions, - context_window=self.config.context_window, + context_window=None if self.config.context_window is None else int(self.config.context_window), ) if self.verify_first_message_correctness(response, require_monologue=self.first_message_verify_mono): break @@ -619,7 +620,7 @@ class Agent(object): model=self.model, message_sequence=input_message_sequence, functions=self.functions, - context_window=self.config.context_window, + context_window=None if self.config.context_window is None else int(self.config.context_window), ) # Step 2: check if LLM wanted to call a function @@ -649,16 +650,28 @@ class Agent(object): # Check the memory pressure and potentially issue a memory pressure warning current_total_tokens = response["usage"]["total_tokens"] active_memory_warning = False - if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window: + # We can't do summarize logic properly if context_window is undefined + if self.config.context_window is None: + # Fallback if for some reason context_window is missing, just set to the default + print(f"WARNING: could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") + print(f"{self.config}") + self.config.context_window = ( + str(LLM_MAX_TOKENS[self.model]) + if (self.model is not None and self.model in LLM_MAX_TOKENS) + else str(LLM_MAX_TOKENS["DEFAULT"]) + ) + if current_total_tokens > MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window): printd( - f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}" + f"WARNING: last response total_tokens ({current_total_tokens}) > {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}" ) # Only deliver the alert if we haven't already (this period) if not self.agent_alerted_about_memory_pressure: active_memory_warning = True self.agent_alerted_about_memory_pressure = True # it's up to the outer loop to handle this else: - printd(f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * self.config.context_window}") + printd( + f"last response total_tokens ({current_total_tokens}) < {MESSAGE_SUMMARY_WARNING_FRAC * int(self.config.context_window)}" + ) self.append_to_messages(all_new_messages) return all_new_messages, heartbeat_request, function_failed, active_memory_warning @@ -729,8 +742,18 @@ class Agent(object): message_sequence_to_summarize = self.messages[1:cutoff] # do NOT get rid of the system message printd(f"Attempting to summarize {len(message_sequence_to_summarize)} messages [1:{cutoff}] of {len(self.messages)}") + # We can't do summarize logic properly if context_window is undefined + if self.config.context_window is None: + # Fallback if for some reason context_window is missing, just set to the default + print(f"WARNING: could not find context_window in config, setting to default {LLM_MAX_TOKENS['DEFAULT']}") + print(f"{self.config}") + self.config.context_window = ( + str(LLM_MAX_TOKENS[self.model]) + if (self.model is not None and self.model in LLM_MAX_TOKENS) + else str(LLM_MAX_TOKENS["DEFAULT"]) + ) summary = summarize_messages( - model=self.model, context_window=self.config.context_window, message_sequence_to_summarize=message_sequence_to_summarize + model=self.model, context_window=int(self.config.context_window), message_sequence_to_summarize=message_sequence_to_summarize ) printd(f"Got summary: {summary}") diff --git a/memgpt/cli/cli.py b/memgpt/cli/cli.py index 21896994..12b02b89 100644 --- a/memgpt/cli/cli.py +++ b/memgpt/cli/cli.py @@ -100,9 +100,9 @@ def run( sys.stdout = original_stdout # overwrite the context_window if specified - if context_window is not None and int(context_window) != config.context_window: + if context_window is not None and int(context_window) != int(config.context_window): typer.secho(f"Warning: Overriding existing context window {config.context_window} with {context_window}", fg=typer.colors.YELLOW) - config.context_window = context_window + config.context_window = str(context_window) # create agent config if agent and AgentConfig.exists(agent): # use existing agent