feat: refactor summarization and message persistence code [LET-6464] (#6561)

This commit is contained in:
Sarah Wooders
2025-12-09 16:34:06 -08:00
committed by Caren Thomas
parent b23722e4a1
commit bbd52e291c
10 changed files with 493 additions and 434 deletions

View File

@@ -50,7 +50,7 @@ async def summarize_via_sliding_window(
llm_config: LLMConfig,
summarizer_config: SummarizerConfig,
in_context_messages: List[Message],
new_messages: List[Message],
# new_messages: List[Message],
) -> Tuple[str, List[Message]]:
"""
If the total tokens is greater than the context window limit (or force=True),
@@ -68,53 +68,42 @@ async def summarize_via_sliding_window(
- The list of message IDs to keep in-context
"""
system_prompt = in_context_messages[0]
all_in_context_messages = in_context_messages + new_messages
total_message_count = len(all_in_context_messages)
total_message_count = len(in_context_messages)
# Starts at N% (eg 70%), and increments up until 100%
message_count_cutoff_percent = max(
1 - summarizer_config.sliding_window_percentage, 0.10
) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage
found_cutoff = False
assert summarizer_config.sliding_window_percentage <= 1.0, "Sliding window percentage must be less than or equal to 1.0"
assistant_message_index = None
approx_token_count = llm_config.context_window
# Count tokens with system prompt, and message past cutoff point
assistant_message_index = None # Initialize to track if we found an assistant message
while not found_cutoff:
# Mark the approximate cutoff
message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages))
while (
approx_token_count >= summarizer_config.sliding_window_percentage * llm_config.context_window and message_count_cutoff_percent < 1.0
):
# calculate message_cutoff_index
message_cutoff_index = round(message_count_cutoff_percent * total_message_count)
# we've reached the maximum message cutoff
if message_cutoff_index >= total_message_count:
# get index of first assistant message in range
assistant_message_index = next(
(i for i in range(message_cutoff_index, total_message_count) if in_context_messages[i].role == MessageRole.assistant), None
)
# if no assistant message in tail, break out of loop (since future iterations will continue hitting this case)
if assistant_message_index is None:
break
# Walk up the list until we find the first assistant message
for i in range(message_cutoff_index, total_message_count):
if all_in_context_messages[i].role == MessageRole.assistant:
assistant_message_index = i
break
else:
raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}")
# update token count
post_summarization_buffer = [system_prompt] + in_context_messages[assistant_message_index:]
approx_token_count = await count_tokens(actor, llm_config, post_summarization_buffer)
# Count tokens of the hypothetical post-summarization buffer
post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:]
post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer)
# If hypothetical post-summarization count lower than the target remaining percentage?
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
found_cutoff = True
else:
message_count_cutoff_percent += 0.10
if message_count_cutoff_percent >= 1.0:
message_cutoff_index = total_message_count
break
# If we found the cutoff, summarize and return
# If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization
# increment cutoff
message_count_cutoff_percent += 0.10
if assistant_message_index is None:
raise ValueError("No assistant message found for sliding window summarization") # fall back to complete summarization
messages_to_summarize = all_in_context_messages[1:message_cutoff_index]
messages_to_summarize = in_context_messages[1:assistant_message_index]
summary_message_str = await simple_summary(
messages=messages_to_summarize,
@@ -128,5 +117,5 @@ async def summarize_via_sliding_window(
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
return summary_message_str, updated_in_context_messages
updated_in_context_messages = in_context_messages[assistant_message_index:]
return summary_message_str, [system_prompt] + updated_in_context_messages