feat: patch summarizer without changes to AgentState (#6450)
This commit is contained in:
committed by
Caren Thomas
parent
4af6465226
commit
1939a9d185
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import json
|
||||
import uuid
|
||||
from typing import Any, AsyncGenerator, Dict, Optional
|
||||
from typing import Any, AsyncGenerator, Dict, Literal, Optional
|
||||
|
||||
from opentelemetry.trace import Span
|
||||
|
||||
@@ -23,6 +23,7 @@ from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEAR
|
||||
from letta.errors import ContextWindowExceededError, LLMError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.helpers.tool_execution_helper import enable_strict_mode
|
||||
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
|
||||
from letta.otel.tracing import trace_method
|
||||
@@ -32,6 +33,7 @@ from letta.schemas.letta_message import ApprovalReturn, LettaErrorMessage, Letta
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.letta_stop_reason import LettaStopReason, StopReasonType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall, ToolCallDenial, UsageStatistics
|
||||
from letta.schemas.step import StepProgression
|
||||
@@ -44,8 +46,11 @@ from letta.server.rest_api.utils import (
|
||||
create_parallel_tool_messages_from_llm_response,
|
||||
)
|
||||
from letta.services.helpers.tool_parser_helper import runtime_override_tool_json_schema
|
||||
from letta.services.summarizer.summarizer_all import summarize_all
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig, get_default_summarizer_config
|
||||
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
|
||||
from letta.settings import settings, summarizer_settings
|
||||
from letta.system import package_function_response
|
||||
from letta.system import package_function_response, package_summarize_message_no_counts
|
||||
from letta.utils import log_telemetry, validate_function_response
|
||||
|
||||
|
||||
@@ -1262,3 +1267,93 @@ class LettaAgentV3(LettaAgentV2):
|
||||
terminal_tools=terminal_tool_names,
|
||||
)
|
||||
return allowed_tools
|
||||
|
||||
@trace_method
|
||||
async def summarize_conversation_history(
|
||||
self,
|
||||
# The messages already in the context window
|
||||
in_context_messages: list[Message],
|
||||
# The messages produced by the agent in this step
|
||||
new_letta_messages: list[Message],
|
||||
# The token usage from the most recent LLM call (prompt + completion)
|
||||
total_tokens: int | None = None,
|
||||
# If force, then don't do any counting, just summarize
|
||||
force: bool = False,
|
||||
) -> list[Message]:
|
||||
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
|
||||
if not trigger_summarization:
|
||||
# just update the message_ids
|
||||
# TODO: gross to handle this here: we should move persistence elsewhere
|
||||
new_in_context_messages = in_context_messages + new_letta_messages
|
||||
message_ids = [m.id for m in new_in_context_messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = message_ids
|
||||
return new_in_context_messages
|
||||
|
||||
# Use agent's summarizer_config if set, otherwise fall back to defaults
|
||||
# TODO: add this back
|
||||
# summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config)
|
||||
summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings())
|
||||
|
||||
if summarizer_config.mode == "all":
|
||||
summary_message_str = await summarize_all(
|
||||
actor=self.actor,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
)
|
||||
new_in_context_messages = []
|
||||
elif summarizer_config.mode == "sliding_window":
|
||||
summary_message_str, new_in_context_messages = await summarize_via_sliding_window(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}")
|
||||
|
||||
# Persist the summary message to DB
|
||||
summary_message_str_packed = package_summarize_message_no_counts(
|
||||
summary=summary_message_str,
|
||||
timezone=self.agent_state.timezone,
|
||||
)
|
||||
summary_message_obj = (
|
||||
await convert_message_creates_to_messages(
|
||||
message_creates=[
|
||||
MessageCreate(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=summary_message_str_packed)],
|
||||
)
|
||||
],
|
||||
agent_id=self.agent_state.id,
|
||||
timezone=self.agent_state.timezone,
|
||||
# We already packed, don't pack again
|
||||
wrap_user_message=False,
|
||||
wrap_system_message=False,
|
||||
run_id=None, # TODO: add this
|
||||
)
|
||||
)[0]
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.actor,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# Update the message_ids in the agent state
|
||||
new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages
|
||||
new_in_context_message_ids = [m.id for m in new_in_context_messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=new_in_context_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = new_in_context_messages
|
||||
|
||||
return new_in_context_messages
|
||||
|
||||
@@ -24,6 +24,7 @@ from letta.schemas.response_format import ResponseFormatUnion
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import ToolRule
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig
|
||||
from letta.utils import calculate_file_defaults_based_on_context_window, create_random_username
|
||||
|
||||
|
||||
@@ -87,6 +88,9 @@ class AgentState(OrmMetadataBase, validate_assignment=True):
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings used by the agent.")
|
||||
|
||||
# TODO: add this back
|
||||
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
|
||||
|
||||
response_format: Optional[ResponseFormatUnion] = Field(
|
||||
None,
|
||||
description="The response format used by the agent",
|
||||
@@ -242,6 +246,9 @@ class CreateAgent(BaseModel, validate_assignment=True): #
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.")
|
||||
|
||||
# TODO: add this back
|
||||
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
|
||||
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||
embedding_chunk_size: Optional[int] = Field(
|
||||
DEFAULT_EMBEDDING_CHUNK_SIZE, description="Deprecated: No longer used. The embedding chunk size used by the agent.", deprecated=True
|
||||
@@ -434,6 +441,10 @@ class UpdateAgent(BaseModel):
|
||||
)
|
||||
embedding: Optional[str] = Field(None, description="The embedding model handle used by the agent (format: provider/model-name).")
|
||||
model_settings: Optional[ModelSettingsUnion] = Field(None, description="The model settings for the agent.")
|
||||
|
||||
# TODO: add this back
|
||||
# summarizer_config: Optional[SummarizerConfig] = Field(None, description="The summarizer configuration used by the agent.")
|
||||
|
||||
context_window_limit: Optional[int] = Field(None, description="The context window limit used by the agent.")
|
||||
reasoning: Optional[bool] = Field(
|
||||
None,
|
||||
|
||||
@@ -78,6 +78,33 @@ class AnthropicTokenCounter(TokenCounter):
|
||||
return Message.to_anthropic_dicts_from_list(messages, current_model=self.model)
|
||||
|
||||
|
||||
class ApproxTokenCounter(TokenCounter):
|
||||
"""Fast approximate token counter using byte-based heuristic (bytes / 4).
|
||||
|
||||
This is the same approach codex-cli uses - a simple approximation that assumes
|
||||
~4 bytes per token on average for English text. Much faster than tiktoken
|
||||
and doesn't require loading tokenizer models into memory.
|
||||
|
||||
Just serializes the input to JSON and divides byte length by 4.
|
||||
"""
|
||||
|
||||
APPROX_BYTES_PER_TOKEN = 4
|
||||
|
||||
def __init__(self, model: str | None = None):
|
||||
# Model is optional since we don't actually use a tokenizer
|
||||
self.model = model
|
||||
|
||||
def _approx_token_count(self, text: str) -> int:
|
||||
"""Approximate token count: ceil(byte_len / 4)"""
|
||||
if not text:
|
||||
return 0
|
||||
byte_len = len(text.encode("utf-8"))
|
||||
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
|
||||
|
||||
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return Message.to_openai_dicts_from_list(messages)
|
||||
|
||||
|
||||
class GeminiTokenCounter(TokenCounter):
|
||||
"""Token counter using Google's Gemini token counting API"""
|
||||
|
||||
|
||||
@@ -30,6 +30,7 @@ from letta.utils import safe_create_task
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# NOTE: legacy, new version is functional
|
||||
class Summarizer:
|
||||
"""
|
||||
Handles summarization or trimming of conversation messages based on
|
||||
@@ -407,7 +408,13 @@ def simple_message_wrapper(openai_msg: dict) -> Message:
|
||||
raise ValueError(f"Unknown role: {openai_msg['role']}")
|
||||
|
||||
|
||||
async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor: User, include_ack: bool = True) -> str:
|
||||
async def simple_summary(
|
||||
messages: List[Message],
|
||||
llm_config: LLMConfig,
|
||||
actor: User,
|
||||
include_ack: bool = True,
|
||||
prompt: str | None = None,
|
||||
) -> str:
|
||||
"""Generate a simple summary from a list of messages.
|
||||
|
||||
Intentionally kept functional due to the simplicity of the prompt.
|
||||
@@ -422,7 +429,7 @@ async def simple_summary(messages: List[Message], llm_config: LLMConfig, actor:
|
||||
assert llm_client is not None
|
||||
|
||||
# Prepare the messages payload to send to the LLM
|
||||
system_prompt = gpt_summarize.SYSTEM
|
||||
system_prompt = prompt or gpt_summarize.SYSTEM
|
||||
# Build the initial transcript without clamping to preserve fidelity
|
||||
# TODO proactively clip here?
|
||||
summary_transcript = simple_formatter(messages)
|
||||
|
||||
46
letta/services/summarizer/summarizer_all.py
Normal file
46
letta/services/summarizer/summarizer_all.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.summarizer import simple_summary
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
async def summarize_all(
|
||||
# Required to tag LLM calls
|
||||
actor: User,
|
||||
# Actual summarization configuration
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
new_messages: List[Message],
|
||||
) -> str:
|
||||
"""
|
||||
Summarize the entire conversation history into a single summary.
|
||||
|
||||
Returns:
|
||||
- The summary string
|
||||
"""
|
||||
all_in_context_messages = in_context_messages + new_messages
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=all_in_context_messages,
|
||||
llm_config=summarizer_config.summarizer_model,
|
||||
actor=actor,
|
||||
include_ack=summarizer_config.prompt_acknowledgement,
|
||||
prompt=summarizer_config.prompt,
|
||||
)
|
||||
|
||||
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
|
||||
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
|
||||
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
|
||||
|
||||
return summary_message_str
|
||||
46
letta/services/summarizer/summarizer_config.py
Normal file
46
letta/services/summarizer/summarizer_config.py
Normal file
@@ -0,0 +1,46 @@
|
||||
from typing import Literal
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.model import ModelSettings
|
||||
|
||||
|
||||
class SummarizerConfig(BaseModel):
|
||||
# summarizer_model: LLMConfig = Field(default=..., description="The model to use for summarization.")
|
||||
model_settings: ModelSettings = Field(default=..., description="The model settings to use for summarization.")
|
||||
prompt: str = Field(default=..., description="The prompt to use for summarization.")
|
||||
prompt_acknowledgement: str = Field(
|
||||
default=..., description="Whether to include an acknowledgement post-prompt (helps prevent non-summary outputs)."
|
||||
)
|
||||
clip_chars: int | None = Field(
|
||||
default=2000, description="The maximum length of the summary in characters. If none, no clipping is performed."
|
||||
)
|
||||
|
||||
mode: Literal["all", "sliding_window"] = Field(default="sliding_window", description="The type of summarization technique use.")
|
||||
sliding_window_percentage: float = Field(
|
||||
default=0.3, description="The percentage of the context window to keep post-summarization (only used in sliding window mode)."
|
||||
)
|
||||
|
||||
|
||||
def get_default_summarizer_config(model_settings: ModelSettings) -> SummarizerConfig:
|
||||
"""Build a default SummarizerConfig from global settings for backward compatibility.
|
||||
|
||||
Args:
|
||||
llm_config: The LLMConfig to use for the summarizer model (typically the agent's llm_config).
|
||||
|
||||
Returns:
|
||||
A SummarizerConfig with default values from global settings.
|
||||
"""
|
||||
from letta.constants import MESSAGE_SUMMARY_REQUEST_ACK
|
||||
from letta.prompts import gpt_summarize
|
||||
from letta.settings import summarizer_settings
|
||||
|
||||
return SummarizerConfig(
|
||||
mode="sliding_window",
|
||||
model_settings=model_settings,
|
||||
prompt=gpt_summarize.SYSTEM,
|
||||
prompt_acknowledgement=MESSAGE_SUMMARY_REQUEST_ACK,
|
||||
clip_chars=2000,
|
||||
sliding_window_percentage=summarizer_settings.partial_evict_summarizer_percentage,
|
||||
)
|
||||
124
letta/services/summarizer/summarizer_sliding_window.py
Normal file
124
letta/services/summarizer/summarizer_sliding_window.py
Normal file
@@ -0,0 +1,124 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.summarizer import simple_summary
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# Safety margin for approximate token counting.
|
||||
# The bytes/4 heuristic underestimates by ~25-35% for JSON-serialized messages
|
||||
# due to structural overhead (brackets, quotes, colons) each becoming tokens.
|
||||
APPROX_TOKEN_SAFETY_MARGIN = 1.3
|
||||
|
||||
|
||||
async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int:
|
||||
# If the model is an Anthropic model, use the Anthropic token counter (accurate)
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
||||
token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model)
|
||||
converted_messages = token_counter.convert_messages(messages)
|
||||
return await token_counter.count_message_tokens(converted_messages)
|
||||
|
||||
else:
|
||||
# Otherwise, use approximate count (bytes / 4) with safety margin
|
||||
# This is much faster than tiktoken and doesn't require loading tokenizer models
|
||||
token_counter = ApproxTokenCounter(llm_config.model)
|
||||
converted_messages = token_counter.convert_messages(messages)
|
||||
tokens = await token_counter.count_message_tokens(converted_messages)
|
||||
# Apply safety margin to avoid underestimating and keeping too many messages
|
||||
return int(tokens * APPROX_TOKEN_SAFETY_MARGIN)
|
||||
|
||||
|
||||
async def summarize_via_sliding_window(
|
||||
# Required to tag LLM calls
|
||||
actor: User,
|
||||
# Actual summarization configuration
|
||||
llm_config: LLMConfig,
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
new_messages: List[Message],
|
||||
) -> Tuple[str, List[Message]]:
|
||||
"""
|
||||
If the total tokens is greater than the context window limit (or force=True),
|
||||
then summarize and rearrange the in-context messages (with the summary in front).
|
||||
|
||||
Finding the summarization cutoff point (target of final post-summarize count is N% of configured context window):
|
||||
1. Start at a message index cutoff (1-N%)
|
||||
2. Count tokens with system prompt, prior summary (if it exists), and messages past cutoff point (messages[0] + messages[cutoff:])
|
||||
3. Is count(post_sum_messages) <= N% of configured context window?
|
||||
3a. Yes -> create new summary with [prior summary, cutoff:], and safety truncate summary with char count
|
||||
3b. No -> increment cutoff by 10%, and repeat
|
||||
|
||||
Returns:
|
||||
- The summary string
|
||||
- The list of message IDs to keep in-context
|
||||
"""
|
||||
system_prompt = in_context_messages[0]
|
||||
all_in_context_messages = in_context_messages + new_messages
|
||||
total_message_count = len(all_in_context_messages)
|
||||
|
||||
# Starts at N% (eg 70%), and increments up until 100%
|
||||
message_count_cutoff_percent = max(
|
||||
1 - summarizer_config.sliding_window_percentage, 10
|
||||
) # Some arbitrary minimum value to avoid negatives from badly configured summarizer percentage
|
||||
found_cutoff = False
|
||||
|
||||
# Count tokens with system prompt, and message past cutoff point
|
||||
while not found_cutoff:
|
||||
# Mark the approximate cutoff
|
||||
message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages))
|
||||
|
||||
# Walk up the list until we find the first assistant message
|
||||
for i in range(message_cutoff_index, total_message_count):
|
||||
if all_in_context_messages[i].role == MessageRole.assistant:
|
||||
assistant_message_index = i
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}")
|
||||
|
||||
# Count tokens of the hypothetical post-summarization buffer
|
||||
post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:]
|
||||
post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer)
|
||||
|
||||
# If hypothetical post-summarization count lower than the target remaining percentage?
|
||||
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
|
||||
found_cutoff = True
|
||||
else:
|
||||
message_count_cutoff_percent += 10
|
||||
if message_count_cutoff_percent >= 100:
|
||||
message_cutoff_index = total_message_count
|
||||
break
|
||||
|
||||
# If we found the cutoff, summarize and return
|
||||
# If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization
|
||||
|
||||
messages_to_summarize = all_in_context_messages[1:message_cutoff_index]
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
llm_config=summarizer_config.summarizer_model,
|
||||
actor=actor,
|
||||
include_ack=summarizer_config.prompt_acknowledgement,
|
||||
prompt=summarizer_config.prompt,
|
||||
)
|
||||
|
||||
if summarizer_config.clip_chars is not None and len(summary_message_str) > summarizer_config.clip_chars:
|
||||
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
|
||||
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
return summary_message_str, updated_in_context_messages
|
||||
@@ -205,6 +205,9 @@ def test_send_user_message_with_pending_request(client, agent):
|
||||
agent_id=agent.id,
|
||||
messages=USER_MESSAGE_TEST_APPROVAL,
|
||||
)
|
||||
print("RESPONSE", response)
|
||||
for message in response.messages:
|
||||
print("MESSAGE", message)
|
||||
|
||||
with pytest.raises(APIError, match="Please approve or deny the pending request before continuing"):
|
||||
client.agents.messages.create(
|
||||
|
||||
@@ -602,3 +602,64 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
|
||||
# (they may have been completely removed during aggressive summarization)
|
||||
if not tool_returns_found:
|
||||
print("Tool returns were completely removed during summarization")
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# SummarizerConfig Mode Tests (with pytest.patch)
|
||||
# ======================================================================================================================
|
||||
|
||||
from letta.services.summarizer.enums import SummarizationMode
|
||||
|
||||
SUMMARIZATION_MODES = [
|
||||
SummarizationMode.STATIC_MESSAGE_BUFFER,
|
||||
SummarizationMode.PARTIAL_EVICT_MESSAGE_BUFFER,
|
||||
]
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize("mode", SUMMARIZATION_MODES, ids=[m.value for m in SUMMARIZATION_MODES])
|
||||
@pytest.mark.parametrize("llm_config", TESTED_LLM_CONFIGS, ids=[c.model for c in TESTED_LLM_CONFIGS])
|
||||
async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMConfig, mode: SummarizationMode):
|
||||
"""
|
||||
Test summarization with different modes and LLM configurations.
|
||||
"""
|
||||
from unittest.mock import patch
|
||||
|
||||
# Create a conversation with enough messages to trigger summarization
|
||||
messages = []
|
||||
for i in range(10):
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(type="text", text=f"User message {i}: Test message {i}.")],
|
||||
)
|
||||
)
|
||||
messages.append(
|
||||
PydanticMessage(
|
||||
role=MessageRole.assistant,
|
||||
content=[TextContent(type="text", text=f"Assistant response {i}: Acknowledged message {i}.")],
|
||||
)
|
||||
)
|
||||
|
||||
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
|
||||
with patch("letta.agents.letta_agent_v2.summarizer_settings") as mock_settings:
|
||||
mock_settings.mode = mode
|
||||
mock_settings.message_buffer_limit = 10
|
||||
mock_settings.message_buffer_min = 3
|
||||
mock_settings.partial_evict_summarizer_percentage = 0.30
|
||||
mock_settings.max_summarizer_retries = 3
|
||||
|
||||
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
|
||||
assert agent_loop.summarizer.mode == mode
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
total_tokens=None,
|
||||
force=True,
|
||||
)
|
||||
|
||||
assert isinstance(result, list)
|
||||
assert len(result) >= 1
|
||||
print(f"{mode.value} with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
|
||||
Reference in New Issue
Block a user