feat: refactor summarization and message persistence code [LET-6464] (#6561)

This commit is contained in:
Sarah Wooders
2025-12-09 16:34:06 -08:00
committed by Caren Thomas
parent b23722e4a1
commit bbd52e291c
10 changed files with 493 additions and 434 deletions

View File

@@ -20,7 +20,7 @@ from letta.agents.helpers import (
)
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM, SUMMARIZATION_TRIGGER_MULTIPLIER
from letta.errors import ContextWindowExceededError, LLMError
from letta.errors import ContextWindowExceededError, LLMError, SystemPromptTokenExceededError
from letta.helpers import ToolRulesSolver
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
from letta.helpers.message_helper import convert_message_creates_to_messages
@@ -78,6 +78,7 @@ class LettaAgentV3(LettaAgentV2):
# from per-step usage but can be updated after summarization without
# affecting step-level telemetry.
self.context_token_estimate: int | None = None
self.in_context_messages: list[Message] = [] # in-memory tracker
def _compute_tool_return_truncation_chars(self) -> int:
"""Compute a dynamic cap for tool returns in requests.
@@ -119,7 +120,7 @@ class LettaAgentV3(LettaAgentV2):
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
response_letta_messages = []
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
input_messages, self.agent_state, self.message_manager, self.actor, run_id
)
follow_up_messages = []
@@ -127,13 +128,15 @@ class LettaAgentV3(LettaAgentV2):
follow_up_messages = input_messages_to_persist[1:]
input_messages_to_persist = [input_messages_to_persist[0]]
in_context_messages = in_context_messages + input_messages_to_persist
self.in_context_messages = curr_in_context_messages
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
follow_up_messages = []
response = self._step(
messages=in_context_messages + self.response_messages,
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
messages=list(self.in_context_messages + input_messages_to_persist),
input_messages_to_persist=input_messages_to_persist,
# TODO need to support non-streaming adapter too
llm_adapter=SimpleLLMRequestAdapter(llm_client=self.llm_client, llm_config=self.agent_state.llm_config),
@@ -142,6 +145,7 @@ class LettaAgentV3(LettaAgentV2):
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
)
input_messages_to_persist = [] # clear after first step
async for chunk in response:
response_letta_messages.append(chunk)
@@ -150,53 +154,65 @@ class LettaAgentV3(LettaAgentV2):
if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value:
break
# Proactive summarization if approaching context limit
if (
self.context_token_estimate is not None
and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
and not self.agent_state.message_buffer_autoclear
):
self.logger.warning(
f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
)
# TODO: persist the input messages if successful first step completion
# TODO: persist the new messages / step / run
in_context_messages = await self.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=self.response_messages,
total_tokens=self.context_token_estimate,
force=True,
)
## Proactive summarization if approaching context limit
# if (
# self.context_token_estimate is not None
# and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
# and not self.agent_state.message_buffer_autoclear
# ):
# self.logger.warning(
# f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
# f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
# )
# Clear to avoid duplication in next iteration
self.response_messages = []
# in_context_messages = await self.summarize_conversation_history(
# in_context_messages=in_context_messages,
# new_letta_messages=self.response_messages,
# total_tokens=self.context_token_estimate,
# force=True,
# )
# # Clear to avoid duplication in next iteration
# self.response_messages = []
if not self.should_continue:
break
input_messages_to_persist = []
# input_messages_to_persist = []
if i == max_steps - 1 and self.stop_reason is None:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
# Rebuild context window after stepping (safety net)
if not self.agent_state.message_buffer_autoclear:
if self.context_token_estimate is not None:
await self.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=self.response_messages,
total_tokens=self.context_token_estimate,
force=False,
)
else:
self.logger.warning(
"Post-loop summarization skipped: last_step_usage is None. "
"No step completed successfully or usage stats were not updated."
)
## Rebuild context window after stepping (safety net)
# if not self.agent_state.message_buffer_autoclear:
# if self.context_token_estimate is not None:
# await self.summarize_conversation_history(
# in_context_messages=in_context_messages,
# new_letta_messages=self.response_messages,
# total_tokens=self.context_token_estimate,
# force=False,
# )
# else:
# self.logger.warning(
# "Post-loop summarization skipped: last_step_usage is None. "
# "No step completed successfully or usage stats were not updated."
# )
if self.stop_reason is None:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
# construct the response
response_letta_messages = Message.to_letta_messages_from_list(
self.response_messages,
use_assistant_message=False, # NOTE: set to false
reverse=False,
text_is_assistant_message=True,
)
if include_return_message_types:
response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types]
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
if run_id:
if self.job_update_metadata is None:
@@ -265,13 +281,14 @@ class LettaAgentV3(LettaAgentV2):
follow_up_messages = input_messages_to_persist[1:]
input_messages_to_persist = [input_messages_to_persist[0]]
in_context_messages = in_context_messages + input_messages_to_persist
self.in_context_messages = in_context_messages
for i in range(max_steps):
if i == 1 and follow_up_messages:
input_messages_to_persist = follow_up_messages
follow_up_messages = []
response = self._step(
messages=in_context_messages + self.response_messages,
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
messages=list(self.in_context_messages + input_messages_to_persist),
input_messages_to_persist=input_messages_to_persist,
llm_adapter=llm_adapter,
run_id=run_id,
@@ -279,6 +296,7 @@ class LettaAgentV3(LettaAgentV2):
include_return_message_types=include_return_message_types,
request_start_timestamp_ns=request_start_timestamp_ns,
)
input_messages_to_persist = [] # clear after first step
async for chunk in response:
response_letta_messages.append(chunk)
if first_chunk:
@@ -290,49 +308,29 @@ class LettaAgentV3(LettaAgentV2):
if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value:
break
# Proactive summarization if approaching context limit
if (
self.context_token_estimate is not None
and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
and not self.agent_state.message_buffer_autoclear
):
self.logger.warning(
f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
)
in_context_messages = await self.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=self.response_messages,
total_tokens=self.context_token_estimate,
force=True,
)
# Clear to avoid duplication in next iteration
self.response_messages = []
# refresh in-context messages (TODO: remove?)
# in_context_messages = await self._refresh_messages(in_context_messages)
if not self.should_continue:
break
input_messages_to_persist = []
if i == max_steps - 1 and self.stop_reason is None:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
# Rebuild context window after stepping (safety net)
if not self.agent_state.message_buffer_autoclear:
if self.context_token_estimate is not None:
await self.summarize_conversation_history(
in_context_messages=in_context_messages,
new_letta_messages=self.response_messages,
total_tokens=self.context_token_estimate,
force=False,
)
else:
self.logger.warning(
"Post-loop summarization skipped: last_step_usage is None. "
"No step completed successfully or usage stats were not updated."
)
## Rebuild context window after stepping (safety net)
# if not self.agent_state.message_buffer_autoclear:
# if self.context_token_estimate is not None:
# await self.summarize_conversation_history(
# in_context_messages=in_context_messages,
# new_letta_messages=self.response_messages,
# total_tokens=self.context_token_estimate,
# force=False,
# )
# else:
# self.logger.warning(
# "Post-loop summarization skipped: last_step_usage is None. "
# "No step completed successfully or usage stats were not updated."
# )
if self.stop_reason is None:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
@@ -400,10 +398,66 @@ class LettaAgentV3(LettaAgentV2):
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
# Note: we don't send finish chunks here since we already errored
async def _check_for_system_prompt_overflow(self, system_message):
"""
Since the system prompt cannot be compacted, we need to check to see if it is the cause of the context overflow
"""
system_prompt_token_estimate = await count_tokens(
actor=self.actor,
llm_config=self.agent_state.llm_config,
messages=[system_message],
)
if system_prompt_token_estimate is not None and system_prompt_token_estimate >= self.agent_state.llm_config.context_window:
self.should_continue = False
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value)
raise SystemPromptTokenExceededError(
system_prompt_token_estimate=system_prompt_token_estimate,
context_window=self.agent_state.llm_config.context_window,
)
async def _checkpoint_messages(self, run_id: str, step_id: str, new_messages: list[Message], in_context_messages: list[Message]):
"""
Checkpoint the current message state - run this only when the current messages are 'safe' - meaning the step has completed successfully.
This handles:
- Persisting the new messages into the `messages` table
- Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`)
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`)
Args:
run_id: The run ID to associate with the messages
step_id: The step ID to associate with the messages
new_messages: The new messages to persist
in_context_messages: The current in-context messages
"""
# make sure all the new messages have the correct run_id and step_id
for message in new_messages:
message.step_id = step_id
message.run_id = run_id
# persist the new message objects - ONLY place where messages are persisted
persisted_messages = await self.message_manager.create_many_messages_async(
new_messages,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
# persist the in-context messages
# TODO: somehow make sure all the message ids are already persisted
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=[m.id for m in in_context_messages],
actor=self.actor,
)
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
self.in_context_messages = in_context_messages # update in-memory state
@trace_method
async def _step(
self,
messages: list[Message],
messages: list[Message], # current in-context messages
llm_adapter: LettaLLMAdapter,
input_messages_to_persist: list[Message] | None = None,
run_id: str | None = None,
@@ -437,6 +491,8 @@ class LettaAgentV3(LettaAgentV2):
if enforce_run_id_set and run_id is None:
raise AssertionError("run_id is required when enforce_run_id_set is True")
input_messages_to_persist = input_messages_to_persist or []
step_progression = StepProgression.START
# TODO(@caren): clean this up
tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
@@ -464,13 +520,17 @@ class LettaAgentV3(LettaAgentV2):
# Always refresh messages at the start of each step to pick up external inputs
# (e.g., approval responses submitted by the client while this stream is running)
try:
# TODO: cleanup and de-dup
# updates the system prompt with the latest blocks / message histories
messages = await self._refresh_messages(messages)
except Exception as e:
self.logger.warning(f"Failed to refresh messages at step start: {e}")
approval_request, approval_response = _maybe_get_approval_messages(messages)
tool_call_denials, tool_returns = [], []
if approval_request and approval_response:
# case of handling approval responses
content = approval_request.content
# Get tool calls that are pending
@@ -541,6 +601,7 @@ class LettaAgentV3(LettaAgentV2):
tool_return_truncation_chars=self._compute_tool_return_truncation_chars(),
)
# TODO: Extend to more providers, and also approval tool rules
# TODO: this entire code block should be inside of the clients
# Enable parallel tool use when no tool rules are attached
try:
no_tool_rules = (
@@ -612,11 +673,25 @@ class LettaAgentV3(LettaAgentV2):
except Exception as e:
if isinstance(e, ContextWindowExceededError) and llm_request_attempt < summarizer_settings.max_summarizer_retries:
# Retry case
messages = await self.summarize_conversation_history(
in_context_messages=messages,
new_letta_messages=self.response_messages,
force=True,
summary_message, messages = await self.compact(
messages, trigger_threshold=self.agent_state.llm_config.context_window
)
# checkpoint summarized messages
# TODO: might want to delay this checkpoint in case of corrupated state
try:
await self._checkpoint_messages(
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
)
except SystemPromptTokenExceededError:
self.stop_reason = LettaStopReason(
stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value
)
raise e
except Exception as e:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
self.logger.error(f"Unknown error occured for summarization run {run_id}: {e}")
raise e
else:
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
self.logger.error(f"Unknown error occured for run {run_id}: {e}")
@@ -637,8 +712,8 @@ class LettaAgentV3(LettaAgentV2):
else:
tool_calls = []
aggregated_persisted: list[Message] = []
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
# get the new generated `Message` objects from handling the LLM response
new_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
tool_calls=tool_calls,
valid_tool_names=[tool["name"] for tool in valid_tools],
tool_rules_solver=self.tool_rules_solver,
@@ -650,7 +725,7 @@ class LettaAgentV3(LettaAgentV2):
content=content or llm_adapter.content,
pre_computed_assistant_message_id=llm_adapter.message_id,
step_id=step_id,
initial_messages=input_messages_to_persist,
initial_messages=[], # input_messages_to_persist, # TODO: deprecate - super confusing
agent_step_span=agent_step_span,
is_final_step=(remaining_turns == 0),
run_id=run_id,
@@ -659,16 +734,26 @@ class LettaAgentV3(LettaAgentV2):
tool_call_denials=tool_call_denials,
tool_returns=tool_returns,
)
aggregated_persisted.extend(persisted_messages)
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
new_message_idx = len(input_messages_to_persist) if input_messages_to_persist else 0
self.response_messages.extend(aggregated_persisted[new_message_idx:])
# extend trackers with new messages
self.response_messages.extend(new_messages)
messages.extend(new_messages)
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
await self._checkpoint_messages(
run_id=run_id,
step_id=step_id,
new_messages=input_messages_to_persist + new_messages,
in_context_messages=messages, # update the in-context messages
)
# yield back generated messages
if llm_adapter.supports_token_streaming():
if tool_calls:
# Stream each tool return if tools were executed
response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
response_tool_returns = [msg for msg in new_messages if msg.role == "tool"]
for tr in response_tool_returns:
# Skip streaming for aggregated parallel tool returns (no per-call tool_call_id)
if tr.tool_call_id is None and tr.tool_returns:
@@ -677,7 +762,8 @@ class LettaAgentV3(LettaAgentV2):
if include_return_message_types is None or tool_return_letta.message_type in include_return_message_types:
yield tool_return_letta
else:
filter_user_messages = [m for m in aggregated_persisted[new_message_idx:] if m.role != "user"]
# TODO: modify this use step_response_messages
filter_user_messages = [m for m in new_messages if m.role != "user"]
letta_messages = Message.to_letta_messages_from_list(
filter_user_messages,
use_assistant_message=False, # NOTE: set to false
@@ -689,11 +775,20 @@ class LettaAgentV3(LettaAgentV2):
if include_return_message_types is None or message.message_type in include_return_message_types:
yield message
# Note: message_ids update for approval responses now happens immediately after
# persistence in _handle_ai_response (line ~1093-1107) to prevent desync when
# the stream is interrupted and this generator is abandoned before being fully consumed
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
# check compaction
if self.context_token_estimate > self.agent_state.llm_config.context_window:
summary_message, messages = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
# TODO: persist + return the summary message
# TODO: convert this to a SummaryMessage
self.response_messages.append(summary_message)
for message in Message.to_letta_messages(summary_message):
yield message
await self._checkpoint_messages(
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
)
except Exception as e:
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
self.logger.warning(f"Error during step processing: {e}")
self.job_update_metadata = {"error": str(e)}
@@ -707,20 +802,14 @@ class LettaAgentV3(LettaAgentV2):
StopReasonType.invalid_tool_call,
StopReasonType.invalid_llm_response,
StopReasonType.llm_api_error,
StopReasonType.context_window_overflow_in_system_prompt,
):
self.logger.warning("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason)
raise e
finally:
# always make sure we update the step/run metadata
self.logger.debug("Running cleanup for agent loop run: %s", run_id)
self.logger.info("Running final update. Step Progression: %s", step_progression)
# update message ids
message_ids = [m.id for m in messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=message_ids,
actor=self.actor,
)
try:
if step_progression == StepProgression.FINISHED:
if not self.should_continue:
@@ -728,7 +817,9 @@ class LettaAgentV3(LettaAgentV2):
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
if logged_step and step_id:
await self.step_manager.update_step_stop_reason(self.actor, step_id, self.stop_reason.stop_reason)
return
if not self.stop_reason.stop_reason == StopReasonType.context_window_overflow_in_system_prompt:
# only return if the stop reason is not context window overflow in system prompt
return
if step_progression < StepProgression.STEP_LOGGED:
# Error occurred before step was fully logged
import traceback
@@ -742,19 +833,6 @@ class LettaAgentV3(LettaAgentV2):
error_traceback=traceback.format_exc(),
stop_reason=self.stop_reason,
)
if step_progression <= StepProgression.STREAM_RECEIVED:
if first_chunk and settings.track_errored_messages and input_messages_to_persist:
for message in input_messages_to_persist:
message.is_err = True
message.step_id = step_id
message.run_id = run_id
await self.message_manager.create_many_messages_async(
input_messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if self.stop_reason is None:
self.logger.warning("Error in step after logging step")
@@ -806,6 +884,7 @@ class LettaAgentV3(LettaAgentV2):
Unified approach: treats single and multi-tool calls uniformly to reduce code duplication.
"""
# 1. Handle no-tool cases (content-only or no-op)
if not tool_calls and not tool_call_denials and not tool_returns:
# Case 1a: No tool call, no content (LLM no-op)
@@ -863,22 +942,7 @@ class LettaAgentV3(LettaAgentV2):
add_heartbeat_on_continue=bool(heartbeat_reason),
)
messages_to_persist = (initial_messages or []) + assistant_message
# Persist messages for no-tool cases
for message in messages_to_persist:
if message.run_id is None:
message.run_id = run_id
if message.step_id is None:
message.step_id = step_id
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
return persisted_messages, continue_stepping, stop_reason
return messages_to_persist, continue_stepping, stop_reason
# 2. Check whether tool call requires approval
if not is_approval_response:
@@ -896,21 +960,7 @@ class LettaAgentV3(LettaAgentV2):
run_id=run_id,
)
messages_to_persist = (initial_messages or []) + approval_messages
for message in messages_to_persist:
if message.run_id is None:
message.run_id = run_id
if message.step_id is None:
message.step_id = step_id
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
return messages_to_persist, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
result_tool_returns = []
@@ -1148,31 +1198,6 @@ class LettaAgentV3(LettaAgentV2):
if message.step_id is None:
message.step_id = step_id
# Persist all messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist,
actor=self.actor,
run_id=run_id,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
)
# Update message_ids immediately after persistence to prevent desync
# This handles approval responses where we need to keep message_ids in sync
if (
is_approval_response
and initial_messages
and len(initial_messages) == 1
and initial_messages[0].role == "approval"
and len(persisted_messages) >= 2
and persisted_messages[0].role == "approval"
and persisted_messages[1].role == "tool"
):
self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
)
# 5g. Aggregate continuation decisions
aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns
@@ -1193,7 +1218,7 @@ class LettaAgentV3(LettaAgentV2):
# Force continuation for parallel tool execution
aggregate_continue = True
aggregate_stop_reason = None
return persisted_messages, aggregate_continue, aggregate_stop_reason
return messages_to_persist, aggregate_continue, aggregate_stop_reason
@trace_method
def _decide_continuation(
@@ -1282,178 +1307,118 @@ class LettaAgentV3(LettaAgentV2):
return allowed_tools
@trace_method
async def summarize_conversation_history(
self,
# The messages already in the context window
in_context_messages: list[Message],
# The messages produced by the agent in this step
new_letta_messages: list[Message],
# The token usage from the most recent LLM call (prompt + completion)
total_tokens: int | None = None,
# If force, then don't do any counting, just summarize
force: bool = False,
) -> list[Message]:
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
# no summarization if the last message is an approval request
latest_messages = in_context_messages + new_letta_messages
pending_approval = latest_messages[-1].role == "approval" and len(latest_messages[-1].tool_calls) > 0
if pending_approval:
trigger_summarization = False
self.logger.info(
f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}, pending_approval: {pending_approval}"
)
if not trigger_summarization:
# just update the message_ids
# TODO: gross to handle this here: we should move persistence elsewhere
new_in_context_messages = in_context_messages + new_letta_messages
message_ids = [m.id for m in new_in_context_messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=message_ids,
actor=self.actor,
)
self.agent_state.message_ids = message_ids
return new_in_context_messages
async def compact(self, messages, trigger_threshold: Optional[int] = None) -> Message:
"""
Simplified compaction method. Does NOT do any persistence (handled in the loop)
"""
# compact the current in-context messages (self.in_context_messages)
# Use agent's summarizer_config if set, otherwise fall back to defaults
# TODO: add this back
# summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config)
summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings())
summarization_mode_used = summarizer_config.mode
if summarizer_config.mode == "all":
summary_message_str, new_in_context_messages = await summarize_all(
summary, compacted_messages = await summarize_all(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
in_context_messages=messages,
)
elif summarizer_config.mode == "sliding_window":
try:
summary_message_str, new_in_context_messages = await summarize_via_sliding_window(
summary, compacted_messages = await summarize_via_sliding_window(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
in_context_messages=messages,
)
except Exception as e:
self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.")
summary_message_str, new_in_context_messages = await summarize_all(
summary, compacted_messages = await summarize_all(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
in_context_messages=messages,
)
summarization_mode_used = "all"
else:
raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}")
# Persist the summary message to DB
summary_message_str_packed = package_summarize_message_no_counts(
summary=summary_message_str,
timezone=self.agent_state.timezone,
)
summary_message_obj = (
await convert_message_creates_to_messages(
message_creates=[
MessageCreate(
role=MessageRole.user,
content=[TextContent(text=summary_message_str_packed)],
)
],
agent_id=self.agent_state.id,
timezone=self.agent_state.timezone,
# We already packed, don't pack again
wrap_user_message=False,
wrap_system_message=False,
run_id=None, # TODO: add this
)
)[0]
await self.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.actor,
project_id=self.agent_state.project_id,
template_id=self.agent_state.template_id,
# update the token count
self.context_token_estimate = await count_tokens(
actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages
)
self.logger.info(f"Context token estimate after summarization: {self.context_token_estimate}")
# Update the message_ids in the agent state to include the summary
# plus whatever tail we decided to keep.
new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages
new_in_context_message_ids = [m.id for m in new_in_context_messages]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=new_in_context_message_ids,
actor=self.actor,
)
self.agent_state.message_ids = new_in_context_message_ids
# After summarization, recompute an approximate token count for the
# updated in-context messages so that subsequent summarization
# decisions don't keep firing based on a stale, pre-summarization
# total_tokens value.
try:
new_total_tokens = await count_tokens(
actor=self.actor,
llm_config=self.agent_state.llm_config,
messages=new_in_context_messages,
)
context_limit = self.agent_state.llm_config.context_window
trigger_threshold = int(context_limit * SUMMARIZATION_TRIGGER_MULTIPLIER)
# if the trigger_threshold is provided, we need to make sure that the new token count is below it
if trigger_threshold is not None and self.context_token_estimate >= trigger_threshold:
# If even after summarization the context is still at or above
# the proactive summarization threshold, treat this as a hard
# failure: log loudly and evict all prior conversation state
# (keeping only the system message) to avoid getting stuck in
# repeated summarization loops.
if new_total_tokens > trigger_threshold:
self.logger.error(
"Summarization failed to sufficiently reduce context size: "
f"post-summarization tokens={new_total_tokens}, "
f"threshold={trigger_threshold}, context_window={context_limit}. "
"Evicting all prior messages without a summary to break potential loops.",
)
# Keep only the system message in-context.
system_message = in_context_messages[0]
new_in_context_messages = [system_message]
new_in_context_message_ids = [system_message.id]
await self.agent_manager.update_message_ids_async(
agent_id=self.agent_state.id,
message_ids=new_in_context_message_ids,
actor=self.actor,
)
self.agent_state.message_ids = new_in_context_message_ids
# Recompute token usage for this minimal context and update
# context_token_estimate so future checks see the reduced size.
try:
minimal_tokens = await count_tokens(
actor=self.actor,
llm_config=self.agent_state.llm_config,
messages=new_in_context_messages,
)
self.context_token_estimate = minimal_tokens
except Exception as inner_e:
self.logger.warning(
f"Failed to recompute token usage after hard eviction: {inner_e}",
exc_info=True,
)
return new_in_context_messages
# Normal case: summarization succeeded in bringing us below the
# proactive threshold. Update context_token_estimate so future
# summarization checks reason over the *post*-summarization
# context size.
self.context_token_estimate = new_total_tokens
except Exception as e: # best-effort; never block the agent on this
self.logger.warning(
f"Failed to recompute token usage after summarization: {e}",
exc_info=True,
self.logger.error(
"Summarization failed to sufficiently reduce context size: "
f"post-summarization tokens={self.context_token_estimate}, "
f"threshold={trigger_threshold}, context_window={self.context_token_estimate}. "
"Evicting all prior messages without a summary to break potential loops.",
)
return new_in_context_messages
# if we used the sliding window mode, try to summarize again with the all mode
if summarization_mode_used == "sliding_window":
# try to summarize again with the all mode
summary, compacted_messages = await summarize_all(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=compacted_messages,
)
summarization_mode_used = "all"
self.context_token_estimate = await count_tokens(
actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages
)
# final edge case: the system prompt is the cause of the context overflow (raise error)
if self.context_token_estimate >= trigger_threshold:
await self._check_for_system_prompt_overflow(compacted_messages[0])
# raise an error if this is STILL not the problem
# do not throw an error, since we don't want to brick the agent
self.logger.error(
f"Failed to summarize messages after hard eviction and checking the system prompt token estimate: {self.context_token_estimate} > {trigger_threshold}"
)
else:
self.logger.info(
f"Summarization fallback succeeded in bringing the context size below the trigger threshold: {self.context_token_estimate} < {trigger_threshold}"
)
# Persist the summary message to DB
summary_message_str_packed = package_summarize_message_no_counts(
summary=summary,
timezone=self.agent_state.timezone,
)
summary_messages = await convert_message_creates_to_messages(
message_creates=[
MessageCreate(
role=MessageRole.user,
content=[TextContent(text=summary_message_str_packed)],
)
],
agent_id=self.agent_state.id,
timezone=self.agent_state.timezone,
# We already packed, don't pack again
wrap_user_message=False,
wrap_system_message=False,
run_id=None, # TODO: add this
)
if not len(summary_messages) == 1:
self.logger.error(f"Expected only one summary message, got {len(summary_messages)} in {summary_messages}")
summary_message_obj = summary_messages[0]
# final messages: inject summarization message at the beginning
final_messages = [compacted_messages[0]] + [summary_message_obj]
if len(compacted_messages) > 1:
final_messages += compacted_messages[1:]
return summary_message_obj, final_messages