From 86d52c4cdfc35dfcef5d899f4f7bc5b48d989799 Mon Sep 17 00:00:00 2001 From: Vivian Fang Date: Sun, 15 Oct 2023 21:07:45 -0700 Subject: [PATCH] fix summarizer --- memgpt/agent.py | 13 ++++++++++--- memgpt/constants.py | 1 - memgpt/memory.py | 8 +++++++- memgpt/persistence_manager.py | 2 +- memgpt/utils.py | 4 ++++ 5 files changed, 22 insertions(+), 6 deletions(-) diff --git a/memgpt/agent.py b/memgpt/agent.py index a64c29ec..8c3a5209 100644 --- a/memgpt/agent.py +++ b/memgpt/agent.py @@ -10,9 +10,9 @@ import openai from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages from .memory import CoreMemory as Memory, summarize_messages from .openai_tools import acompletions_with_backoff as acreate -from .utils import get_local_time, parse_json, united_diff, printd +from .utils import get_local_time, parse_json, united_diff, printd, count_tokens from .constants import \ - FIRST_MESSAGE_ATTEMPTS, MESSAGE_SUMMARY_CUTOFF_FRAC, MAX_PAUSE_HEARTBEATS, \ + FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \ MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \ CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT @@ -539,7 +539,14 @@ class AgentAsync(object): async def summarize_messages_inplace(self, cutoff=None): if cutoff is None: - cutoff = round((len(self.messages) - 1) * MESSAGE_SUMMARY_CUTOFF_FRAC) # by default, trim the first 50% of messages + tokens_so_far = 0 # Smart cutoff -- just below the max. + cutoff = len(self.messages) - 1 + for m in reversed(self.messages): + tokens_so_far += count_tokens(str(m), self.model) + if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2: + break + cutoff -= 1 + cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too # Try to make an assistant message come after the cutoff try: diff --git a/memgpt/constants.py b/memgpt/constants.py index 184847cd..33924e47 100644 --- a/memgpt/constants.py +++ b/memgpt/constants.py @@ -12,7 +12,6 @@ STARTUP_QUOTES = [ INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2] # Constants to do with summarization / conversation length window -MESSAGE_SUMMARY_CUTOFF_FRAC = 0.5 MESSAGE_SUMMARY_WARNING_TOKENS = 7000 # the number of tokens consumed in a call before a system warning goes to the agent MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed." diff --git a/memgpt/memory.py b/memgpt/memory.py index fb064959..c36dabdb 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -4,7 +4,8 @@ import re import faiss import numpy as np -from .utils import cosine_similarity, get_local_time, printd +from .constants import MESSAGE_SUMMARY_WARNING_TOKENS +from .utils import cosine_similarity, get_local_time, printd, count_tokens from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM from .openai_tools import acompletions_with_backoff as acreate, async_get_embedding_with_backoff @@ -105,6 +106,11 @@ async def summarize_messages( summary_prompt = SUMMARY_PROMPT_SYSTEM summary_input = str(message_sequence_to_summarize) + summary_input_tkns = count_tokens(summary_input, model) + if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS: + trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure... + cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) + summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:]) message_sequence = [ {"role": "system", "content": summary_prompt}, {"role": "user", "content": summary_input}, diff --git a/memgpt/persistence_manager.py b/memgpt/persistence_manager.py index 575741b3..84a92fe1 100644 --- a/memgpt/persistence_manager.py +++ b/memgpt/persistence_manager.py @@ -54,7 +54,7 @@ class InMemoryStateManager(PersistenceManager): def trim_messages(self, num): # printd(f"InMemoryStateManager.trim_messages") - self.messages = self.messages[num:] + self.messages = [self.messages[0]] + self.messages[num:] def prepend_to_messages(self, added_messages): # first tag with timestamps diff --git a/memgpt/utils.py b/memgpt/utils.py index e008cd93..a67b45f1 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -6,7 +6,11 @@ import json import pytz import os import faiss +import tiktoken +def count_tokens(s: str, model: str = "gpt-4") -> int: + encoding = tiktoken.encoding_for_model(model) + return len(encoding.encode(s)) # DEBUG = True DEBUG = False