feat: patch summarizer without changes to AgentState (#6450)

This commit is contained in:
Sarah Wooders
2025-11-29 15:29:50 -08:00
committed by Caren Thomas
parent 4af6465226
commit 1939a9d185
9 changed files with 424 additions and 4 deletions

View File

@@ -1,7 +1,7 @@
import asyncio
import json
import uuid
from typing import Any, AsyncGenerator, Dict, Optional
from typing import Any, AsyncGenerator, Dict, Literal, Optional
from opentelemetry.trace import Span
@@ -23,6 +23,7 @@ from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEAR
from letta.errors import ContextWindowExceededError, LLMError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.helpers.tool_execution_helper import enable_strict_mode
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.otel.tracing import trace_method
@@ -32,6 +33,7 @@ from letta.schemas.letta_message import ApprovalReturn, LettaErrorMessage, Letta
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
from letta.schemas.letta_response import LettaResponse
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate, ToolReturn
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
from letta.schemas.step import StepProgression
@@ -44,8 +46,11 @@ from letta.server.rest_api.utils import (
create_parallel_tool_messages_from_llm_response,
)
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
from letta.services.summarizer.summarizer_all import summarize_all
from letta.services.summarizer.summarizer_config import SummarizerConfig, get_default_summarizer_config
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
from letta.settings import settings, summarizer_settings
from letta.system import package_function_response
from letta.system import package_function_response, package_summarize_message_no_counts
from letta.utils import log_telemetry, validate_function_response
@@ -1262,3 +1267,93 @@ class LettaAgentV3(LettaAgentV2):
terminal_tools=terminal_tool_names,
)
return allowed_tools
@trace_method
async def summarize_conversation_history(
self,
# The messages already in the context window
in_context_messages: list[Message],
# The messages produced by the agent in this step
new_letta_messages: list[Message],
# The token usage from the most recent LLM call (prompt + completion)
total_tokens: int | None = None,
# If force, then don't do any counting, just summarize
force: bool = False,
) -> list[Message]:
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
if not trigger_summarization:
# just update the message_ids
# TODO: gross to handle this here: we should move persistence elsewhere
new_in_context_messages = in_context_messages + new_letta_messages
message_ids = [m.id for m in new_in_context_messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=message_ids,
actor=self.actor,
)
self.agent_state.message_ids = message_ids
return new_in_context_messages
# Use agent's summarizer_config if set, otherwise fall back to defaults
# TODO: add this back
# summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config)
summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings())
if summarizer_config.mode == "all":
summary_message_str = await summarize_all(
actor=self.actor,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
)
new_in_context_messages = []
elif summarizer_config.mode == "sliding_window":
summary_message_str, new_in_context_messages = await summarize_via_sliding_window(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
)
else:
raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}")
# Persist the summary message to DB
summary_message_str_packed = package_summarize_message_no_counts(
summary=summary_message_str,
timezone=self.agent_state.timezone,
)
summary_message_obj = (
await convert_message_creates_to_messages(
message_creates=[
MessageCreate(
role=MessageRole.user,
content=[TextContent(text=summary_message_str_packed)],
)
],
agent_id=self.agent_state.id,
timezone=self.agent_state.timezone,
# We already packed, don't pack again
wrap_user_message=False,
wrap_system_message=False,
run_id=None, # TODO: add this
)
)[0]
await self.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.actor,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
# Update the message_ids in the agent state
new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages
new_in_context_message_ids = [m.id for m in new_in_context_messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=new_in_context_message_ids,
actor=self.actor,
)
self.agent_state.message_ids = new_in_context_messages
return new_in_context_messages

View File

@@ -24,6 +24,7 @@ from letta.schemas.response_format import ResponseFormatUnion
from letta.schemas.source import Source
from letta.schemas.tool import Tool
from letta.schemas.tool_rule import ToolRule
from letta.services.summarizer.summarizer_config import SummarizerConfig
from letta.utils import calculate_file_defaults_based_on_context_window, create_random_username
@@ -87,6 +88,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings used by the agent.")
# TODO: add this back
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
response_format: Optional[ResponseFormatUnion] = Field(
None,
description="The response format used by the agent",
@@ -242,6 +246,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.")
# TODO: add this back
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
embedding_chunk_size: Optional[int] = Field(
DEFAULT_EMBEDDING_CHUNK_SIZE, description="Deprecated: No longer used. The embedding chunk size used by the agent.", deprecated=True
@@ -434,6 +441,10 @@ class UpdateAgent(BaseModel):
)
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.")
# TODO: add this back
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
reasoning: Optional[bool] = Field(
None,

View File

@@ -78,6 +78,33 @@ class AnthropicTokenCounter(TokenCounter):
return Message.to_anthropic_dicts_from_list(messages, current_model=self.model)
class ApproxTokenCounter(TokenCounter):
"""Fast approximate token counter using byte-based heuristic (bytes / 4).
This is the same approach codex-cli uses - a simple approximation that assumes
~4 bytes per token on average for English text. Much faster than tiktoken
and doesn't require loading tokenizer models into memory.
Just serializes the input to JSON and divides byte length by 4.
"""
APPROX_BYTES_PER_TOKEN = 4
def __init__(self, model: str | None = None):
# Model is optional since we don't actually use a tokenizer
self.model = model
def _approx_token_count(self, text: str) -> int:
"""Approximate token count: ceil(byte_len / 4)"""
if not text:
return 0
byte_len = len(text.encode("utf-8"))
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_openai_dicts_from_list(messages)
class GeminiTokenCounter(TokenCounter):
"""Token counter using Google's Gemini token counting API"""

View File

@@ -30,6 +30,7 @@ from letta.utils import safe_create_task
logger = get_logger(__name__)
# NOTE: legacy, new version is functional
class Summarizer:
"""
Handles summarization or trimming of conversation messages based on
@@ -407,7 +408,13 @@ def simple_message_wrapper(openai_msg: dict) -> Message:
raise ValueError(f"Unknown role: {openai_msg['role']}")
async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor: User, include_ack: bool = True) -> str:
async def simple_summary(
messages: List[Message],
llm_config: LLMConfig,
actor: User,
include_ack: bool = True,
prompt: str | None = None,
) -> str:
"""Generate a simple summary from a list of messages.
Intentionally kept functional due to the simplicity of the prompt.
@@ -422,7 +429,7 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
assert llm_client is not None
# Prepare the messages payload to send to the LLM
system_prompt = gpt_summarize.SYSTEM
system_prompt = prompt or gpt_summarize.SYSTEM
# Build the initial transcript without clamping to preserve fidelity
# TODO proactively clip here?
summary_transcript = simple_formatter(messages)

View File

@@ -0,0 +1,46 @@
from typing import List, Tuple
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.services.message_manager import MessageManager
from letta.services.summarizer.summarizer import simple_summary
from letta.services.summarizer.summarizer_config import SummarizerConfig
from letta.system import package_summarize_message_no_counts
logger = get_logger(__name__)
async def summarize_all(
# Required to tag LLM calls
actor: User,
# Actual summarization configuration
summarizer_config: SummarizerConfig,
in_context_messages: List[Message],
new_messages: List[Message],
) -> str:
"""
Summarize the entire conversation history into a single summary.
Returns:
- The summary string
"""
all_in_context_messages = in_context_messages + new_messages
summary_message_str = await simple_summary(
messages=all_in_context_messages,
llm_config=summarizer_config.summarizer_model,
actor=actor,
include_ack=summarizer_config.prompt_acknowledgement,
prompt=summarizer_config.prompt,
)
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
return summary_message_str

View File

@@ -0,0 +1,46 @@
from typing import Literal
from pydantic import BaseModel, Field
from letta.schemas.llm_config import LLMConfig
from letta.schemas.model import ModelSettings
class SummarizerConfig(BaseModel):
# summarizer_model: LLMConfig = Field(default=..., description="The model to use for summarization.")
model_settings: ModelSettings = Field(default=..., description="The model settings to use for summarization.")
prompt: str = Field(default=..., description="The prompt to use for summarization.")
prompt_acknowledgement: str = Field(
default=..., description="Whether to include an acknowledgement post-prompt (helps prevent non-summary outputs)."
)
clip_chars: int | None = Field(
default=2000, description="The maximum length of the summary in characters. If none, no clipping is performed."
)
mode: Literal["all", "sliding_window"] = Field(default="sliding_window", description="The type of summarization technique use.")
sliding_window_percentage: float = Field(
default=0.3, description="The percentage of the context window to keep post-summarization (only used in sliding window mode)."
)
def get_default_summarizer_config(model_settings: ModelSettings) -> SummarizerConfig:
"""Build a default SummarizerConfig from global settings for backward compatibility.
Args:
llm_config: The LLMConfig to use for the summarizer model (typically the agent's llm_config).
Returns:
A SummarizerConfig with default values from global settings.
"""
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
from letta.prompts import gpt_summarize
from letta.settings import summarizer_settings
return SummarizerConfig(
mode="sliding_window",
model_settings=model_settings,
prompt=gpt_summarize.SYSTEM,
prompt_acknowledgement=MESSAGE_SUMMARY_REQUEST_ACK,
clip_chars=2000,
sliding_window_percentage=summarizer_settings.partial_evict_summarizer_percentage,
)

View File

@@ -0,0 +1,124 @@
from typing import List, Tuple
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter
from letta.services.message_manager import MessageManager
from letta.services.summarizer.summarizer import simple_summary
from letta.services.summarizer.summarizer_config import SummarizerConfig
from letta.settings import model_settings, settings
from letta.system import package_summarize_message_no_counts
logger = get_logger(__name__)
# Safety margin for approximate token counting.
# The bytes/4 heuristic underestimates by ~25-35% for JSON-serialized messages
# due to structural overhead (brackets, quotes, colons) each becoming tokens.
APPROX_TOKEN_SAFETY_MARGIN = 1.3
async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int:
# If the model is an Anthropic model, use the Anthropic token counter (accurate)
if llm_config.model_endpoint_type == "anthropic":
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model)
converted_messages = token_counter.convert_messages(messages)
return await token_counter.count_message_tokens(converted_messages)
else:
# Otherwise, use approximate count (bytes / 4) with safety margin
# This is much faster than tiktoken and doesn't require loading tokenizer models
token_counter = ApproxTokenCounter(llm_config.model)
converted_messages = token_counter.convert_messages(messages)
tokens = await token_counter.count_message_tokens(converted_messages)
# Apply safety margin to avoid underestimating and keeping too many messages
return int(tokens * APPROX_TOKEN_SAFETY_MARGIN)
async def summarize_via_sliding_window(
# Required to tag LLM calls
actor: User,
# Actual summarization configuration
llm_config: LLMConfig,
summarizer_config: SummarizerConfig,
in_context_messages: List[Message],
new_messages: List[Message],
) -> Tuple[str, List[Message]]:
"""
If the total tokens is greater than the context window limit (or force=True),
then summarize and rearrange the in-context messages (with the summary in front).
Finding the summarization cutoff point (target of final post-summarize count is N% of configured context window):
1. Start at a message index cutoff (1-N%)
2. Count tokens with system prompt, prior summary (if it exists), and messages past cutoff point (messages[0] + messages[cutoff:])
3. Is count(post_sum_messages) <= N% of configured context window?
3a. Yes -> create new summary with [prior summary, cutoff:], and safety truncate summary with char count
3b. No -> increment cutoff by 10%, and repeat
Returns:
- The summary string
- The list of message IDs to keep in-context
"""
system_prompt = in_context_messages[0]
all_in_context_messages = in_context_messages + new_messages
total_message_count = len(all_in_context_messages)
# Starts at N% (eg 70%), and increments up until 100%
message_count_cutoff_percent = max(
1 - summarizer_config.sliding_window_percentage, 10
) # Some arbitrary minimum value to avoid negatives from badly configured summarizer percentage
found_cutoff = False
# Count tokens with system prompt, and message past cutoff point
while not found_cutoff:
# Mark the approximate cutoff
message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages))
# Walk up the list until we find the first assistant message
for i in range(message_cutoff_index, total_message_count):
if all_in_context_messages[i].role == MessageRole.assistant:
assistant_message_index = i
break
else:
raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}")
# Count tokens of the hypothetical post-summarization buffer
post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:]
post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer)
# If hypothetical post-summarization count lower than the target remaining percentage?
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
found_cutoff = True
else:
message_count_cutoff_percent += 10
if message_count_cutoff_percent >= 100:
message_cutoff_index = total_message_count
break
# If we found the cutoff, summarize and return
# If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization
messages_to_summarize = all_in_context_messages[1:message_cutoff_index]
summary_message_str = await simple_summary(
messages=messages_to_summarize,
llm_config=summarizer_config.summarizer_model,
actor=actor,
include_ack=summarizer_config.prompt_acknowledgement,
prompt=summarizer_config.prompt,
)
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
return summary_message_str, updated_in_context_messages

View File

@@ -205,6 +205,9 @@ def test_send_user_message_with_pending_request(client, agent):
agent_id=agent.id,
messages=USER_MESSAGE_TEST_APPROVAL,
)
print("RESPONSE", response)
for message in response.messages:
print("MESSAGE", message)
with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"):
client.agents.messages.create(

View File

@@ -602,3 +602,64 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
# (they may have been completely removed during aggressive summarization)
if not tool_returns_found:
print("Tool returns were completely removed during summarization")
# ======================================================================================================================
# SummarizerConfig Mode Tests (with pytest.patch)
# ======================================================================================================================
from letta.services.summarizer.enums import SummarizationMode
SUMMARIZATION_MODES = [
SummarizationMode.STATIC_MESSAGE_BUFFER,
SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
]
@pytest.mark.asyncio
@pytest.mark.parametrize("mode", SUMMARIZATION_MODES, ids=[m.value for m in SUMMARIZATION_MODES])
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: SummarizationMode):
"""
Test summarization with different modes and LLM configurations.
"""
from unittest.mock import patch
# Create a conversation with enough messages to trigger summarization
messages = []
for i in range(10):
messages.append(
PydanticMessage(
role=MessageRole.user,
content=[TextContent(type="text", text=f"User message {i}: Test message {i}.")],
)
)
messages.append(
PydanticMessage(
role=MessageRole.assistant,
content=[TextContent(type="text", text=f"Assistant response {i}: Acknowledged message {i}.")],
)
)
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
with patch("letta.agents.letta_agent_v2.summarizer_settings") as mock_settings:
mock_settings.mode = mode
mock_settings.message_buffer_limit = 10
mock_settings.message_buffer_min = 3
mock_settings.partial_evict_summarizer_percentage = 0.30
mock_settings.max_summarizer_retries = 3
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
assert agent_loop.summarizer.mode == mode
result = await agent_loop.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=[],
total_tokens=None,
force=True,
)
assert isinstance(result, list)
assert len(result) >= 1
print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")