diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 09b35fa5..dc1bf19b 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -1,7 +1,7 @@ import asyncio import json import uuid -from typing import Any, AsyncGenerator, Dict, Optional +from typing import Any, AsyncGenerator, Dict, Literal, Optional from opentelemetry.trace import Span @@ -23,6 +23,7 @@ from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEAR from letta.errors import ContextWindowExceededError, LLMError from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns +from letta.helpers.message_helper import convert_message_creates_to_messages from letta.helpers.tool_execution_helper import enable_strict_mode from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.otel.tracing import trace_method @@ -32,6 +33,7 @@ from letta.schemas.letta_message import ApprovalReturn, LettaErrorMessage, Letta from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_response import LettaResponse from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType +from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate, ToolReturn from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics from letta.schemas.step import StepProgression @@ -44,8 +46,11 @@ from letta.server.rest_api.utils import ( create_parallel_tool_messages_from_llm_response, ) from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema +from letta.services.summarizer.summarizer_all import summarize_all +from letta.services.summarizer.summarizer_config import SummarizerConfig, get_default_summarizer_config +from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window from letta.settings import settings, summarizer_settings -from letta.system import package_function_response +from letta.system import package_function_response, package_summarize_message_no_counts from letta.utils import log_telemetry, validate_function_response @@ -1262,3 +1267,93 @@ class LettaAgentV3(LettaAgentV2): terminal_tools=terminal_tool_names, ) return allowed_tools + + @trace_method + async def summarize_conversation_history( + self, + # The messages already in the context window + in_context_messages: list[Message], + # The messages produced by the agent in this step + new_letta_messages: list[Message], + # The token usage from the most recent LLM call (prompt + completion) + total_tokens: int | None = None, + # If force, then don't do any counting, just summarize + force: bool = False, + ) -> list[Message]: + trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window) + if not trigger_summarization: + # just update the message_ids + # TODO: gross to handle this here: we should move persistence elsewhere + new_in_context_messages = in_context_messages + new_letta_messages + message_ids = [m.id for m in new_in_context_messages] + await self.agent_manager.update_message_ids_async( + agent_id=self.agent_state.id, + message_ids=message_ids, + actor=self.actor, + ) + self.agent_state.message_ids = message_ids + return new_in_context_messages + + # Use agent's summarizer_config if set, otherwise fall back to defaults + # TODO: add this back + # summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config) + summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings()) + + if summarizer_config.mode == "all": + summary_message_str = await summarize_all( + actor=self.actor, + summarizer_config=summarizer_config, + in_context_messages=in_context_messages, + new_messages=new_letta_messages, + ) + new_in_context_messages = [] + elif summarizer_config.mode == "sliding_window": + summary_message_str, new_in_context_messages = await summarize_via_sliding_window( + actor=self.actor, + llm_config=self.agent_state.llm_config, + summarizer_config=summarizer_config, + in_context_messages=in_context_messages, + new_messages=new_letta_messages, + ) + else: + raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}") + + # Persist the summary message to DB + summary_message_str_packed = package_summarize_message_no_counts( + summary=summary_message_str, + timezone=self.agent_state.timezone, + ) + summary_message_obj = ( + await convert_message_creates_to_messages( + message_creates=[ + MessageCreate( + role=MessageRole.user, + content=[TextContent(text=summary_message_str_packed)], + ) + ], + agent_id=self.agent_state.id, + timezone=self.agent_state.timezone, + # We already packed, don't pack again + wrap_user_message=False, + wrap_system_message=False, + run_id=None, # TODO: add this + ) + )[0] + await self.message_manager.create_many_messages_async( + pydantic_msgs=[summary_message_obj], + actor=self.actor, + project_id=self.agent_state.project_id, + template_id=self.agent_state.template_id, + ) + + # Update the message_ids in the agent state + new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages + new_in_context_message_ids = [m.id for m in new_in_context_messages] + await self.agent_manager.update_message_ids_async( + agent_id=self.agent_state.id, + message_ids=new_in_context_message_ids, + actor=self.actor, + ) + self.agent_state.message_ids = new_in_context_messages + + return new_in_context_messages diff --git a/letta/schemas/agent.py b/letta/schemas/agent.py index 596729b6..561d4680 100644 --- a/letta/schemas/agent.py +++ b/letta/schemas/agent.py @@ -24,6 +24,7 @@ from letta.schemas.response_format import ResponseFormatUnion from letta.schemas.source import Source from letta.schemas.tool import Tool from letta.schemas.tool_rule import ToolRule +from letta.services.summarizer.summarizer_config import SummarizerConfig from letta.utils import calculate_file_defaults_based_on_context_window, create_random_username @@ -87,6 +88,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True): embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).") model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings used by the agent.") + # TODO: add this back + # summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.") + response_format: Optional[ResponseFormatUnion] = Field( None, description="The response format used by the agent", @@ -242,6 +246,9 @@ class CreateAgent(BaseModel, validate_assignment=True): # embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).") model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.") + # TODO: add this back + # summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.") + context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.") embedding_chunk_size: Optional[int] = Field( DEFAULT_EMBEDDING_CHUNK_SIZE, description="Deprecated: No longer used. The embedding chunk size used by the agent.", deprecated=True @@ -434,6 +441,10 @@ class UpdateAgent(BaseModel): ) embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).") model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.") + + # TODO: add this back + # summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.") + context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.") reasoning: Optional[bool] = Field( None, diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index 33c9a70f..eabc8a26 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -78,6 +78,33 @@ class AnthropicTokenCounter(TokenCounter): return Message.to_anthropic_dicts_from_list(messages, current_model=self.model) +class ApproxTokenCounter(TokenCounter): + """Fast approximate token counter using byte-based heuristic (bytes / 4). + + This is the same approach codex-cli uses - a simple approximation that assumes + ~4 bytes per token on average for English text. Much faster than tiktoken + and doesn't require loading tokenizer models into memory. + + Just serializes the input to JSON and divides byte length by 4. + """ + + APPROX_BYTES_PER_TOKEN = 4 + + def __init__(self, model: str | None = None): + # Model is optional since we don't actually use a tokenizer + self.model = model + + def _approx_token_count(self, text: str) -> int: + """Approximate token count: ceil(byte_len / 4)""" + if not text: + return 0 + byte_len = len(text.encode("utf-8")) + return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN + + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: + return Message.to_openai_dicts_from_list(messages) + + class GeminiTokenCounter(TokenCounter): """Token counter using Google's Gemini token counting API""" diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 055f9f96..7389090c 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -30,6 +30,7 @@ from letta.utils import safe_create_task logger = get_logger(__name__) +# NOTE: legacy, new version is functional class Summarizer: """ Handles summarization or trimming of conversation messages based on @@ -407,7 +408,13 @@ def simple_message_wrapper(openai_msg: dict) -> Message: raise ValueError(f"Unknown role: {openai_msg['role']}") -async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor: User, include_ack: bool = True) -> str: +async def simple_summary( + messages: List[Message], + llm_config: LLMConfig, + actor: User, + include_ack: bool = True, + prompt: str | None = None, +) -> str: """Generate a simple summary from a list of messages. Intentionally kept functional due to the simplicity of the prompt. @@ -422,7 +429,7 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor: assert llm_client is not None # Prepare the messages payload to send to the LLM - system_prompt = gpt_summarize.SYSTEM + system_prompt = prompt or gpt_summarize.SYSTEM # Build the initial transcript without clamping to preserve fidelity # TODO proactively clip here? summary_transcript = simple_formatter(messages) diff --git a/letta/services/summarizer/summarizer_all.py b/letta/services/summarizer/summarizer_all.py new file mode 100644 index 00000000..5fc833e3 --- /dev/null +++ b/letta/services/summarizer/summarizer_all.py @@ -0,0 +1,46 @@ +from typing import List, Tuple + +from letta.helpers.message_helper import convert_message_creates_to_messages +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import Message, MessageCreate +from letta.schemas.user import User +from letta.services.message_manager import MessageManager +from letta.services.summarizer.summarizer import simple_summary +from letta.services.summarizer.summarizer_config import SummarizerConfig +from letta.system import package_summarize_message_no_counts + +logger = get_logger(__name__) + + +async def summarize_all( + # Required to tag LLM calls + actor: User, + # Actual summarization configuration + summarizer_config: SummarizerConfig, + in_context_messages: List[Message], + new_messages: List[Message], +) -> str: + """ + Summarize the entire conversation history into a single summary. + + Returns: + - The summary string + """ + all_in_context_messages = in_context_messages + new_messages + + summary_message_str = await simple_summary( + messages=all_in_context_messages, + llm_config=summarizer_config.summarizer_model, + actor=actor, + include_ack=summarizer_config.prompt_acknowledgement, + prompt=summarizer_config.prompt, + ) + + if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars: + logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.") + summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]" + + return summary_message_str diff --git a/letta/services/summarizer/summarizer_config.py b/letta/services/summarizer/summarizer_config.py new file mode 100644 index 00000000..258de9f3 --- /dev/null +++ b/letta/services/summarizer/summarizer_config.py @@ -0,0 +1,46 @@ +from typing import Literal + +from pydantic import BaseModel, Field + +from letta.schemas.llm_config import LLMConfig +from letta.schemas.model import ModelSettings + + +class SummarizerConfig(BaseModel): + # summarizer_model: LLMConfig = Field(default=..., description="The model to use for summarization.") + model_settings: ModelSettings = Field(default=..., description="The model settings to use for summarization.") + prompt: str = Field(default=..., description="The prompt to use for summarization.") + prompt_acknowledgement: str = Field( + default=..., description="Whether to include an acknowledgement post-prompt (helps prevent non-summary outputs)." + ) + clip_chars: int | None = Field( + default=2000, description="The maximum length of the summary in characters. If none, no clipping is performed." + ) + + mode: Literal["all", "sliding_window"] = Field(default="sliding_window", description="The type of summarization technique use.") + sliding_window_percentage: float = Field( + default=0.3, description="The percentage of the context window to keep post-summarization (only used in sliding window mode)." + ) + + +def get_default_summarizer_config(model_settings: ModelSettings) -> SummarizerConfig: + """Build a default SummarizerConfig from global settings for backward compatibility. + + Args: + llm_config: The LLMConfig to use for the summarizer model (typically the agent's llm_config). + + Returns: + A SummarizerConfig with default values from global settings. + """ + from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK + from letta.prompts import gpt_summarize + from letta.settings import summarizer_settings + + return SummarizerConfig( + mode="sliding_window", + model_settings=model_settings, + prompt=gpt_summarize.SYSTEM, + prompt_acknowledgement=MESSAGE_SUMMARY_REQUEST_ACK, + clip_chars=2000, + sliding_window_percentage=summarizer_settings.partial_evict_summarizer_percentage, + ) diff --git a/letta/services/summarizer/summarizer_sliding_window.py b/letta/services/summarizer/summarizer_sliding_window.py new file mode 100644 index 00000000..b1fb68cb --- /dev/null +++ b/letta/services/summarizer/summarizer_sliding_window.py @@ -0,0 +1,124 @@ +from typing import List, Tuple + +from letta.helpers.message_helper import convert_message_creates_to_messages +from letta.llm_api.llm_client import LLMClient +from letta.log import get_logger +from letta.schemas.agent import AgentState +from letta.schemas.enums import MessageRole, ProviderType +from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message, MessageCreate +from letta.schemas.user import User +from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter +from letta.services.message_manager import MessageManager +from letta.services.summarizer.summarizer import simple_summary +from letta.services.summarizer.summarizer_config import SummarizerConfig +from letta.settings import model_settings, settings +from letta.system import package_summarize_message_no_counts + +logger = get_logger(__name__) + + +# Safety margin for approximate token counting. +# The bytes/4 heuristic underestimates by ~25-35% for JSON-serialized messages +# due to structural overhead (brackets, quotes, colons) each becoming tokens. +APPROX_TOKEN_SAFETY_MARGIN = 1.3 + + +async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int: + # If the model is an Anthropic model, use the Anthropic token counter (accurate) + if llm_config.model_endpoint_type == "anthropic": + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) + token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model) + converted_messages = token_counter.convert_messages(messages) + return await token_counter.count_message_tokens(converted_messages) + + else: + # Otherwise, use approximate count (bytes / 4) with safety margin + # This is much faster than tiktoken and doesn't require loading tokenizer models + token_counter = ApproxTokenCounter(llm_config.model) + converted_messages = token_counter.convert_messages(messages) + tokens = await token_counter.count_message_tokens(converted_messages) + # Apply safety margin to avoid underestimating and keeping too many messages + return int(tokens * APPROX_TOKEN_SAFETY_MARGIN) + + +async def summarize_via_sliding_window( + # Required to tag LLM calls + actor: User, + # Actual summarization configuration + llm_config: LLMConfig, + summarizer_config: SummarizerConfig, + in_context_messages: List[Message], + new_messages: List[Message], +) -> Tuple[str, List[Message]]: + """ + If the total tokens is greater than the context window limit (or force=True), + then summarize and rearrange the in-context messages (with the summary in front). + + Finding the summarization cutoff point (target of final post-summarize count is N% of configured context window): + 1. Start at a message index cutoff (1-N%) + 2. Count tokens with system prompt, prior summary (if it exists), and messages past cutoff point (messages[0] + messages[cutoff:]) + 3. Is count(post_sum_messages) <= N% of configured context window? + 3a. Yes -> create new summary with [prior summary, cutoff:], and safety truncate summary with char count + 3b. No -> increment cutoff by 10%, and repeat + + Returns: + - The summary string + - The list of message IDs to keep in-context + """ + system_prompt = in_context_messages[0] + all_in_context_messages = in_context_messages + new_messages + total_message_count = len(all_in_context_messages) + + # Starts at N% (eg 70%), and increments up until 100% + message_count_cutoff_percent = max( + 1 - summarizer_config.sliding_window_percentage, 10 + ) # Some arbitrary minimum value to avoid negatives from badly configured summarizer percentage + found_cutoff = False + + # Count tokens with system prompt, and message past cutoff point + while not found_cutoff: + # Mark the approximate cutoff + message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages)) + + # Walk up the list until we find the first assistant message + for i in range(message_cutoff_index, total_message_count): + if all_in_context_messages[i].role == MessageRole.assistant: + assistant_message_index = i + break + else: + raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}") + + # Count tokens of the hypothetical post-summarization buffer + post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:] + post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer) + + # If hypothetical post-summarization count lower than the target remaining percentage? + if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window: + found_cutoff = True + else: + message_count_cutoff_percent += 10 + if message_count_cutoff_percent >= 100: + message_cutoff_index = total_message_count + break + + # If we found the cutoff, summarize and return + # If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization + + messages_to_summarize = all_in_context_messages[1:message_cutoff_index] + + summary_message_str = await simple_summary( + messages=messages_to_summarize, + llm_config=summarizer_config.summarizer_model, + actor=actor, + include_ack=summarizer_config.prompt_acknowledgement, + prompt=summarizer_config.prompt, + ) + + if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars: + logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.") + summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]" + + updated_in_context_messages = all_in_context_messages[assistant_message_index:] + return summary_message_str, updated_in_context_messages diff --git a/tests/integration_test_human_in_the_loop.py b/tests/integration_test_human_in_the_loop.py index 0f02b9ef..af1a9486 100644 --- a/tests/integration_test_human_in_the_loop.py +++ b/tests/integration_test_human_in_the_loop.py @@ -205,6 +205,9 @@ def test_send_user_message_with_pending_request(client, agent): agent_id=agent.id, messages=USER_MESSAGE_TEST_APPROVAL, ) + print("RESPONSE", response) + for message in response.messages: + print("MESSAGE", message) with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"): client.agents.messages.create( diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index b4b45fff..a7e74ddf 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -602,3 +602,64 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor, # (they may have been completely removed during aggressive summarization) if not tool_returns_found: print("Tool returns were completely removed during summarization") + + +# ====================================================================================================================== +# SummarizerConfig Mode Tests (with pytest.patch) +# ====================================================================================================================== + +from letta.services.summarizer.enums import SummarizationMode + +SUMMARIZATION_MODES = [ + SummarizationMode.STATIC_MESSAGE_BUFFER, + SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER, +] + + +@pytest.mark.asyncio +@pytest.mark.parametrize("mode", SUMMARIZATION_MODES, ids=[m.value for m in SUMMARIZATION_MODES]) +@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS]) +async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: SummarizationMode): + """ + Test summarization with different modes and LLM configurations. + """ + from unittest.mock import patch + + # Create a conversation with enough messages to trigger summarization + messages = [] + for i in range(10): + messages.append( + PydanticMessage( + role=MessageRole.user, + content=[TextContent(type="text", text=f"User message {i}: Test message {i}.")], + ) + ) + messages.append( + PydanticMessage( + role=MessageRole.assistant, + content=[TextContent(type="text", text=f"Assistant response {i}: Acknowledged message {i}.")], + ) + ) + + agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages) + + with patch("letta.agents.letta_agent_v2.summarizer_settings") as mock_settings: + mock_settings.mode = mode + mock_settings.message_buffer_limit = 10 + mock_settings.message_buffer_min = 3 + mock_settings.partial_evict_summarizer_percentage = 0.30 + mock_settings.max_summarizer_retries = 3 + + agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor) + assert agent_loop.summarizer.mode == mode + + result = await agent_loop.summarize_conversation_history( + in_context_messages=in_context_messages, + new_letta_messages=[], + total_tokens=None, + force=True, + ) + + assert isinstance(result, list) + assert len(result) >= 1 + print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")