feat: refactor summarization and message persistence code [LET-6464] (#6561)
This commit is contained in:
committed by
Caren Thomas
parent
b23722e4a1
commit
bbd52e291c
@@ -50,7 +50,7 @@ async def summarize_via_sliding_window(
|
||||
llm_config: LLMConfig,
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
new_messages: List[Message],
|
||||
# new_messages: List[Message],
|
||||
) -> Tuple[str, List[Message]]:
|
||||
"""
|
||||
If the total tokens is greater than the context window limit (or force=True),
|
||||
@@ -68,53 +68,42 @@ async def summarize_via_sliding_window(
|
||||
- The list of message IDs to keep in-context
|
||||
"""
|
||||
system_prompt = in_context_messages[0]
|
||||
all_in_context_messages = in_context_messages + new_messages
|
||||
total_message_count = len(all_in_context_messages)
|
||||
total_message_count = len(in_context_messages)
|
||||
|
||||
# Starts at N% (eg 70%), and increments up until 100%
|
||||
message_count_cutoff_percent = max(
|
||||
1 - summarizer_config.sliding_window_percentage, 0.10
|
||||
) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage
|
||||
found_cutoff = False
|
||||
assert summarizer_config.sliding_window_percentage <= 1.0, "Sliding window percentage must be less than or equal to 1.0"
|
||||
assistant_message_index = None
|
||||
approx_token_count = llm_config.context_window
|
||||
|
||||
# Count tokens with system prompt, and message past cutoff point
|
||||
assistant_message_index = None # Initialize to track if we found an assistant message
|
||||
while not found_cutoff:
|
||||
# Mark the approximate cutoff
|
||||
message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages))
|
||||
while (
|
||||
approx_token_count >= summarizer_config.sliding_window_percentage * llm_config.context_window and message_count_cutoff_percent < 1.0
|
||||
):
|
||||
# calculate message_cutoff_index
|
||||
message_cutoff_index = round(message_count_cutoff_percent * total_message_count)
|
||||
|
||||
# we've reached the maximum message cutoff
|
||||
if message_cutoff_index >= total_message_count:
|
||||
# get index of first assistant message in range
|
||||
assistant_message_index = next(
|
||||
(i for i in range(message_cutoff_index, total_message_count) if in_context_messages[i].role == MessageRole.assistant), None
|
||||
)
|
||||
|
||||
# if no assistant message in tail, break out of loop (since future iterations will continue hitting this case)
|
||||
if assistant_message_index is None:
|
||||
break
|
||||
|
||||
# Walk up the list until we find the first assistant message
|
||||
for i in range(message_cutoff_index, total_message_count):
|
||||
if all_in_context_messages[i].role == MessageRole.assistant:
|
||||
assistant_message_index = i
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}")
|
||||
# update token count
|
||||
post_summarization_buffer = [system_prompt] + in_context_messages[assistant_message_index:]
|
||||
approx_token_count = await count_tokens(actor, llm_config, post_summarization_buffer)
|
||||
|
||||
# Count tokens of the hypothetical post-summarization buffer
|
||||
post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:]
|
||||
post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer)
|
||||
|
||||
# If hypothetical post-summarization count lower than the target remaining percentage?
|
||||
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
|
||||
found_cutoff = True
|
||||
else:
|
||||
message_count_cutoff_percent += 0.10
|
||||
if message_count_cutoff_percent >= 1.0:
|
||||
message_cutoff_index = total_message_count
|
||||
break
|
||||
|
||||
# If we found the cutoff, summarize and return
|
||||
# If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization
|
||||
# increment cutoff
|
||||
message_count_cutoff_percent += 0.10
|
||||
|
||||
if assistant_message_index is None:
|
||||
raise ValueError("No assistant message found for sliding window summarization") # fall back to complete summarization
|
||||
|
||||
messages_to_summarize = all_in_context_messages[1:message_cutoff_index]
|
||||
messages_to_summarize = in_context_messages[1:assistant_message_index]
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
@@ -128,5 +117,5 @@ async def summarize_via_sliding_window(
|
||||
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
|
||||
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
return summary_message_str, updated_in_context_messages
|
||||
updated_in_context_messages = in_context_messages[assistant_message_index:]
|
||||
return summary_message_str, [system_prompt] + updated_in_context_messages
|
||||
|
||||
Reference in New Issue
Block a user