fix summarizer
This commit is contained in:
@@ -10,9 +10,9 @@ import openai
|
||||
from .system import get_heartbeat, get_login_event, package_function_response, package_summarize_message, get_initial_boot_messages
|
||||
from .memory import CoreMemory as Memory, summarize_messages
|
||||
from .openai_tools import acompletions_with_backoff as acreate
|
||||
from .utils import get_local_time, parse_json, united_diff, printd
|
||||
from .utils import get_local_time, parse_json, united_diff, printd, count_tokens
|
||||
from .constants import \
|
||||
FIRST_MESSAGE_ATTEMPTS, MESSAGE_SUMMARY_CUTOFF_FRAC, MAX_PAUSE_HEARTBEATS, \
|
||||
FIRST_MESSAGE_ATTEMPTS, MAX_PAUSE_HEARTBEATS, \
|
||||
MESSAGE_CHATGPT_FUNCTION_MODEL, MESSAGE_CHATGPT_FUNCTION_SYSTEM_MESSAGE, MESSAGE_SUMMARY_WARNING_TOKENS, \
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
|
||||
|
||||
@@ -539,7 +539,14 @@ class AgentAsync(object):
|
||||
|
||||
async def summarize_messages_inplace(self, cutoff=None):
|
||||
if cutoff is None:
|
||||
cutoff = round((len(self.messages) - 1) * MESSAGE_SUMMARY_CUTOFF_FRAC) # by default, trim the first 50% of messages
|
||||
tokens_so_far = 0 # Smart cutoff -- just below the max.
|
||||
cutoff = len(self.messages) - 1
|
||||
for m in reversed(self.messages):
|
||||
tokens_so_far += count_tokens(str(m), self.model)
|
||||
if tokens_so_far >= MESSAGE_SUMMARY_WARNING_TOKENS*0.2:
|
||||
break
|
||||
cutoff -= 1
|
||||
cutoff = min(len(self.messages) - 3, cutoff) # Always keep the last two messages too
|
||||
|
||||
# Try to make an assistant message come after the cutoff
|
||||
try:
|
||||
|
||||
@@ -12,7 +12,6 @@ STARTUP_QUOTES = [
|
||||
INITIAL_BOOT_MESSAGE_SEND_MESSAGE_FIRST_MSG = STARTUP_QUOTES[2]
|
||||
|
||||
# Constants to do with summarization / conversation length window
|
||||
MESSAGE_SUMMARY_CUTOFF_FRAC = 0.5
|
||||
MESSAGE_SUMMARY_WARNING_TOKENS = 7000 # the number of tokens consumed in a call before a system warning goes to the agent
|
||||
MESSAGE_SUMMARY_WARNING_STR = f"Warning: the conversation history will soon reach its maximum length and be trimmed. Make sure to save any important information from the conversation to your memory before it is removed."
|
||||
|
||||
|
||||
@@ -4,7 +4,8 @@ import re
|
||||
import faiss
|
||||
import numpy as np
|
||||
|
||||
from .utils import cosine_similarity, get_local_time, printd
|
||||
from .constants import MESSAGE_SUMMARY_WARNING_TOKENS
|
||||
from .utils import cosine_similarity, get_local_time, printd, count_tokens
|
||||
from .prompts.gpt_summarize import SYSTEM as SUMMARY_PROMPT_SYSTEM
|
||||
from .openai_tools import acompletions_with_backoff as acreate, async_get_embedding_with_backoff
|
||||
|
||||
@@ -105,6 +106,11 @@ async def summarize_messages(
|
||||
|
||||
summary_prompt = SUMMARY_PROMPT_SYSTEM
|
||||
summary_input = str(message_sequence_to_summarize)
|
||||
summary_input_tkns = count_tokens(summary_input, model)
|
||||
if summary_input_tkns > MESSAGE_SUMMARY_WARNING_TOKENS:
|
||||
trunc_ratio = (MESSAGE_SUMMARY_WARNING_TOKENS / summary_input_tkns) * 0.8 # For good measure...
|
||||
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
summary_input = str([await summarize_messages(model, message_sequence_to_summarize[:cutoff])] + message_sequence_to_summarize[cutoff:])
|
||||
message_sequence = [
|
||||
{"role": "system", "content": summary_prompt},
|
||||
{"role": "user", "content": summary_input},
|
||||
|
||||
@@ -54,7 +54,7 @@ class InMemoryStateManager(PersistenceManager):
|
||||
|
||||
def trim_messages(self, num):
|
||||
# printd(f"InMemoryStateManager.trim_messages")
|
||||
self.messages = self.messages[num:]
|
||||
self.messages = [self.messages[0]] + self.messages[num:]
|
||||
|
||||
def prepend_to_messages(self, added_messages):
|
||||
# first tag with timestamps
|
||||
|
||||
@@ -6,7 +6,11 @@ import json
|
||||
import pytz
|
||||
import os
|
||||
import faiss
|
||||
import tiktoken
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
return len(encoding.encode(s))
|
||||
|
||||
# DEBUG = True
|
||||
DEBUG = False
|
||||
|
||||
Reference in New Issue
Block a user