feat: refactor summarization and message persistence code [LET-6464] (#6561)
This commit is contained in:
committed by
Caren Thomas
parent
b23722e4a1
commit
bbd52e291c
@@ -35585,7 +35585,8 @@
|
||||
"no_tool_call",
|
||||
"tool_rule",
|
||||
"cancelled",
|
||||
"requires_approval"
|
||||
"requires_approval",
|
||||
"context_window_overflow_in_system_prompt"
|
||||
],
|
||||
"title": "StopReasonType"
|
||||
},
|
||||
|
||||
@@ -172,7 +172,6 @@ async def _prepare_in_context_messages_no_persist_async(
|
||||
new_in_context_messages.extend(follow_up_messages)
|
||||
else:
|
||||
# User is trying to send a regular message
|
||||
# if current_in_context_messages and current_in_context_messages[-1].role == "approval":
|
||||
if current_in_context_messages and current_in_context_messages[-1].is_approval_request():
|
||||
raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id)
|
||||
|
||||
|
||||
@@ -20,7 +20,7 @@ from letta.agents.helpers import (
|
||||
)
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM, SUMMARIZATION_TRIGGER_MULTIPLIER
|
||||
from letta.errors import ContextWindowExceededError, LLMError
|
||||
from letta.errors import ContextWindowExceededError, LLMError, SystemPromptTokenExceededError
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
@@ -78,6 +78,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# from per-step usage but can be updated after summarization without
|
||||
# affecting step-level telemetry.
|
||||
self.context_token_estimate: int | None = None
|
||||
self.in_context_messages: list[Message] = [] # in-memory tracker
|
||||
|
||||
def _compute_tool_return_truncation_chars(self) -> int:
|
||||
"""Compute a dynamic cap for tool returns in requests.
|
||||
@@ -119,7 +120,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns)
|
||||
response_letta_messages = []
|
||||
|
||||
in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async(
|
||||
input_messages, self.agent_state, self.message_manager, self.actor, run_id
|
||||
)
|
||||
follow_up_messages = []
|
||||
@@ -127,13 +128,15 @@ class LettaAgentV3(LettaAgentV2):
|
||||
follow_up_messages = input_messages_to_persist[1:]
|
||||
input_messages_to_persist = [input_messages_to_persist[0]]
|
||||
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
self.in_context_messages = curr_in_context_messages
|
||||
for i in range(max_steps):
|
||||
if i == 1 and follow_up_messages:
|
||||
input_messages_to_persist = follow_up_messages
|
||||
follow_up_messages = []
|
||||
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
|
||||
messages=list(self.in_context_messages + input_messages_to_persist),
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
# TODO need to support non-streaming adapter too
|
||||
llm_adapter=SimpleLLMRequestAdapter(llm_client=self.llm_client, llm_config=self.agent_state.llm_config),
|
||||
@@ -142,6 +145,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
input_messages_to_persist = [] # clear after first step
|
||||
|
||||
async for chunk in response:
|
||||
response_letta_messages.append(chunk)
|
||||
@@ -150,53 +154,65 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value:
|
||||
break
|
||||
|
||||
# Proactive summarization if approaching context limit
|
||||
if (
|
||||
self.context_token_estimate is not None
|
||||
and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
|
||||
and not self.agent_state.message_buffer_autoclear
|
||||
):
|
||||
self.logger.warning(
|
||||
f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
|
||||
f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
|
||||
)
|
||||
# TODO: persist the input messages if successful first step completion
|
||||
# TODO: persist the new messages / step / run
|
||||
|
||||
in_context_messages = await self.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.context_token_estimate,
|
||||
force=True,
|
||||
)
|
||||
## Proactive summarization if approaching context limit
|
||||
# if (
|
||||
# self.context_token_estimate is not None
|
||||
# and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
|
||||
# and not self.agent_state.message_buffer_autoclear
|
||||
# ):
|
||||
# self.logger.warning(
|
||||
# f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
|
||||
# f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
|
||||
# )
|
||||
|
||||
# Clear to avoid duplication in next iteration
|
||||
self.response_messages = []
|
||||
# in_context_messages = await self.summarize_conversation_history(
|
||||
# in_context_messages=in_context_messages,
|
||||
# new_letta_messages=self.response_messages,
|
||||
# total_tokens=self.context_token_estimate,
|
||||
# force=True,
|
||||
# )
|
||||
|
||||
# # Clear to avoid duplication in next iteration
|
||||
# self.response_messages = []
|
||||
|
||||
if not self.should_continue:
|
||||
break
|
||||
|
||||
input_messages_to_persist = []
|
||||
# input_messages_to_persist = []
|
||||
|
||||
if i == max_steps - 1 and self.stop_reason is None:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
|
||||
# Rebuild context window after stepping (safety net)
|
||||
if not self.agent_state.message_buffer_autoclear:
|
||||
if self.context_token_estimate is not None:
|
||||
await self.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.context_token_estimate,
|
||||
force=False,
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Post-loop summarization skipped: last_step_usage is None. "
|
||||
"No step completed successfully or usage stats were not updated."
|
||||
)
|
||||
## Rebuild context window after stepping (safety net)
|
||||
# if not self.agent_state.message_buffer_autoclear:
|
||||
# if self.context_token_estimate is not None:
|
||||
# await self.summarize_conversation_history(
|
||||
# in_context_messages=in_context_messages,
|
||||
# new_letta_messages=self.response_messages,
|
||||
# total_tokens=self.context_token_estimate,
|
||||
# force=False,
|
||||
# )
|
||||
# else:
|
||||
# self.logger.warning(
|
||||
# "Post-loop summarization skipped: last_step_usage is None. "
|
||||
# "No step completed successfully or usage stats were not updated."
|
||||
# )
|
||||
|
||||
if self.stop_reason is None:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
|
||||
# construct the response
|
||||
response_letta_messages = Message.to_letta_messages_from_list(
|
||||
self.response_messages,
|
||||
use_assistant_message=False, # NOTE: set to false
|
||||
reverse=False,
|
||||
text_is_assistant_message=True,
|
||||
)
|
||||
if include_return_message_types:
|
||||
response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types]
|
||||
result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage)
|
||||
if run_id:
|
||||
if self.job_update_metadata is None:
|
||||
@@ -265,13 +281,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
follow_up_messages = input_messages_to_persist[1:]
|
||||
input_messages_to_persist = [input_messages_to_persist[0]]
|
||||
|
||||
in_context_messages = in_context_messages + input_messages_to_persist
|
||||
self.in_context_messages = in_context_messages
|
||||
for i in range(max_steps):
|
||||
if i == 1 and follow_up_messages:
|
||||
input_messages_to_persist = follow_up_messages
|
||||
follow_up_messages = []
|
||||
response = self._step(
|
||||
messages=in_context_messages + self.response_messages,
|
||||
# we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back)
|
||||
messages=list(self.in_context_messages + input_messages_to_persist),
|
||||
input_messages_to_persist=input_messages_to_persist,
|
||||
llm_adapter=llm_adapter,
|
||||
run_id=run_id,
|
||||
@@ -279,6 +296,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
include_return_message_types=include_return_message_types,
|
||||
request_start_timestamp_ns=request_start_timestamp_ns,
|
||||
)
|
||||
input_messages_to_persist = [] # clear after first step
|
||||
async for chunk in response:
|
||||
response_letta_messages.append(chunk)
|
||||
if first_chunk:
|
||||
@@ -290,49 +308,29 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value:
|
||||
break
|
||||
|
||||
# Proactive summarization if approaching context limit
|
||||
if (
|
||||
self.context_token_estimate is not None
|
||||
and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER
|
||||
and not self.agent_state.message_buffer_autoclear
|
||||
):
|
||||
self.logger.warning(
|
||||
f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching "
|
||||
f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization."
|
||||
)
|
||||
|
||||
in_context_messages = await self.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.context_token_estimate,
|
||||
force=True,
|
||||
)
|
||||
|
||||
# Clear to avoid duplication in next iteration
|
||||
self.response_messages = []
|
||||
# refresh in-context messages (TODO: remove?)
|
||||
# in_context_messages = await self._refresh_messages(in_context_messages)
|
||||
|
||||
if not self.should_continue:
|
||||
break
|
||||
|
||||
input_messages_to_persist = []
|
||||
|
||||
if i == max_steps - 1 and self.stop_reason is None:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value)
|
||||
|
||||
# Rebuild context window after stepping (safety net)
|
||||
if not self.agent_state.message_buffer_autoclear:
|
||||
if self.context_token_estimate is not None:
|
||||
await self.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
total_tokens=self.context_token_estimate,
|
||||
force=False,
|
||||
)
|
||||
else:
|
||||
self.logger.warning(
|
||||
"Post-loop summarization skipped: last_step_usage is None. "
|
||||
"No step completed successfully or usage stats were not updated."
|
||||
)
|
||||
## Rebuild context window after stepping (safety net)
|
||||
# if not self.agent_state.message_buffer_autoclear:
|
||||
# if self.context_token_estimate is not None:
|
||||
# await self.summarize_conversation_history(
|
||||
# in_context_messages=in_context_messages,
|
||||
# new_letta_messages=self.response_messages,
|
||||
# total_tokens=self.context_token_estimate,
|
||||
# force=False,
|
||||
# )
|
||||
# else:
|
||||
# self.logger.warning(
|
||||
# "Post-loop summarization skipped: last_step_usage is None. "
|
||||
# "No step completed successfully or usage stats were not updated."
|
||||
# )
|
||||
|
||||
if self.stop_reason is None:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
@@ -400,10 +398,66 @@ class LettaAgentV3(LettaAgentV2):
|
||||
yield f"event: error\ndata: {error_message.model_dump_json()}\n\n"
|
||||
# Note: we don't send finish chunks here since we already errored
|
||||
|
||||
async def _check_for_system_prompt_overflow(self, system_message):
|
||||
"""
|
||||
Since the system prompt cannot be compacted, we need to check to see if it is the cause of the context overflow
|
||||
"""
|
||||
system_prompt_token_estimate = await count_tokens(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=[system_message],
|
||||
)
|
||||
if system_prompt_token_estimate is not None and system_prompt_token_estimate >= self.agent_state.llm_config.context_window:
|
||||
self.should_continue = False
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value)
|
||||
raise SystemPromptTokenExceededError(
|
||||
system_prompt_token_estimate=system_prompt_token_estimate,
|
||||
context_window=self.agent_state.llm_config.context_window,
|
||||
)
|
||||
|
||||
async def _checkpoint_messages(self, run_id: str, step_id: str, new_messages: list[Message], in_context_messages: list[Message]):
|
||||
"""
|
||||
Checkpoint the current message state - run this only when the current messages are 'safe' - meaning the step has completed successfully.
|
||||
|
||||
This handles:
|
||||
- Persisting the new messages into the `messages` table
|
||||
- Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`)
|
||||
- Updating the DB with the current in-context messages (`self.agent_state.message_ids`)
|
||||
|
||||
Args:
|
||||
run_id: The run ID to associate with the messages
|
||||
step_id: The step ID to associate with the messages
|
||||
new_messages: The new messages to persist
|
||||
in_context_messages: The current in-context messages
|
||||
"""
|
||||
# make sure all the new messages have the correct run_id and step_id
|
||||
for message in new_messages:
|
||||
message.step_id = step_id
|
||||
message.run_id = run_id
|
||||
|
||||
# persist the new message objects - ONLY place where messages are persisted
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
new_messages,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# persist the in-context messages
|
||||
# TODO: somehow make sure all the message ids are already persisted
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=[m.id for m in in_context_messages],
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state
|
||||
self.in_context_messages = in_context_messages # update in-memory state
|
||||
|
||||
@trace_method
|
||||
async def _step(
|
||||
self,
|
||||
messages: list[Message],
|
||||
messages: list[Message], # current in-context messages
|
||||
llm_adapter: LettaLLMAdapter,
|
||||
input_messages_to_persist: list[Message] | None = None,
|
||||
run_id: str | None = None,
|
||||
@@ -437,6 +491,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if enforce_run_id_set and run_id is None:
|
||||
raise AssertionError("run_id is required when enforce_run_id_set is True")
|
||||
|
||||
input_messages_to_persist = input_messages_to_persist or []
|
||||
|
||||
step_progression = StepProgression.START
|
||||
# TODO(@caren): clean this up
|
||||
tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = (
|
||||
@@ -464,13 +520,17 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Always refresh messages at the start of each step to pick up external inputs
|
||||
# (e.g., approval responses submitted by the client while this stream is running)
|
||||
try:
|
||||
# TODO: cleanup and de-dup
|
||||
# updates the system prompt with the latest blocks / message histories
|
||||
messages = await self._refresh_messages(messages)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.warning(f"Failed to refresh messages at step start: {e}")
|
||||
|
||||
approval_request, approval_response = _maybe_get_approval_messages(messages)
|
||||
tool_call_denials, tool_returns = [], []
|
||||
if approval_request and approval_response:
|
||||
# case of handling approval responses
|
||||
content = approval_request.content
|
||||
|
||||
# Get tool calls that are pending
|
||||
@@ -541,6 +601,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
tool_return_truncation_chars=self._compute_tool_return_truncation_chars(),
|
||||
)
|
||||
# TODO: Extend to more providers, and also approval tool rules
|
||||
# TODO: this entire code block should be inside of the clients
|
||||
# Enable parallel tool use when no tool rules are attached
|
||||
try:
|
||||
no_tool_rules = (
|
||||
@@ -612,11 +673,25 @@ class LettaAgentV3(LettaAgentV2):
|
||||
except Exception as e:
|
||||
if isinstance(e, ContextWindowExceededError) and llm_request_attempt < summarizer_settings.max_summarizer_retries:
|
||||
# Retry case
|
||||
messages = await self.summarize_conversation_history(
|
||||
in_context_messages=messages,
|
||||
new_letta_messages=self.response_messages,
|
||||
force=True,
|
||||
summary_message, messages = await self.compact(
|
||||
messages, trigger_threshold=self.agent_state.llm_config.context_window
|
||||
)
|
||||
|
||||
# checkpoint summarized messages
|
||||
# TODO: might want to delay this checkpoint in case of corrupated state
|
||||
try:
|
||||
await self._checkpoint_messages(
|
||||
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
|
||||
)
|
||||
except SystemPromptTokenExceededError:
|
||||
self.stop_reason = LettaStopReason(
|
||||
stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value
|
||||
)
|
||||
raise e
|
||||
except Exception as e:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
self.logger.error(f"Unknown error occured for summarization run {run_id}: {e}")
|
||||
raise e
|
||||
else:
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value)
|
||||
self.logger.error(f"Unknown error occured for run {run_id}: {e}")
|
||||
@@ -637,8 +712,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
else:
|
||||
tool_calls = []
|
||||
|
||||
aggregated_persisted: list[Message] = []
|
||||
persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
# get the new generated `Message` objects from handling the LLM response
|
||||
new_messages, self.should_continue, self.stop_reason = await self._handle_ai_response(
|
||||
tool_calls=tool_calls,
|
||||
valid_tool_names=[tool["name"] for tool in valid_tools],
|
||||
tool_rules_solver=self.tool_rules_solver,
|
||||
@@ -650,7 +725,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
content=content or llm_adapter.content,
|
||||
pre_computed_assistant_message_id=llm_adapter.message_id,
|
||||
step_id=step_id,
|
||||
initial_messages=input_messages_to_persist,
|
||||
initial_messages=[], # input_messages_to_persist, # TODO: deprecate - super confusing
|
||||
agent_step_span=agent_step_span,
|
||||
is_final_step=(remaining_turns == 0),
|
||||
run_id=run_id,
|
||||
@@ -659,16 +734,26 @@ class LettaAgentV3(LettaAgentV2):
|
||||
tool_call_denials=tool_call_denials,
|
||||
tool_returns=tool_returns,
|
||||
)
|
||||
aggregated_persisted.extend(persisted_messages)
|
||||
# NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op")
|
||||
|
||||
new_message_idx = len(input_messages_to_persist) if input_messages_to_persist else 0
|
||||
self.response_messages.extend(aggregated_persisted[new_message_idx:])
|
||||
# extend trackers with new messages
|
||||
self.response_messages.extend(new_messages)
|
||||
messages.extend(new_messages)
|
||||
|
||||
# step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics
|
||||
# persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state
|
||||
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
||||
await self._checkpoint_messages(
|
||||
run_id=run_id,
|
||||
step_id=step_id,
|
||||
new_messages=input_messages_to_persist + new_messages,
|
||||
in_context_messages=messages, # update the in-context messages
|
||||
)
|
||||
|
||||
# yield back generated messages
|
||||
if llm_adapter.supports_token_streaming():
|
||||
if tool_calls:
|
||||
# Stream each tool return if tools were executed
|
||||
response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"]
|
||||
response_tool_returns = [msg for msg in new_messages if msg.role == "tool"]
|
||||
for tr in response_tool_returns:
|
||||
# Skip streaming for aggregated parallel tool returns (no per-call tool_call_id)
|
||||
if tr.tool_call_id is None and tr.tool_returns:
|
||||
@@ -677,7 +762,8 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if include_return_message_types is None or tool_return_letta.message_type in include_return_message_types:
|
||||
yield tool_return_letta
|
||||
else:
|
||||
filter_user_messages = [m for m in aggregated_persisted[new_message_idx:] if m.role != "user"]
|
||||
# TODO: modify this use step_response_messages
|
||||
filter_user_messages = [m for m in new_messages if m.role != "user"]
|
||||
letta_messages = Message.to_letta_messages_from_list(
|
||||
filter_user_messages,
|
||||
use_assistant_message=False, # NOTE: set to false
|
||||
@@ -689,11 +775,20 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if include_return_message_types is None or message.message_type in include_return_message_types:
|
||||
yield message
|
||||
|
||||
# Note: message_ids update for approval responses now happens immediately after
|
||||
# persistence in _handle_ai_response (line ~1093-1107) to prevent desync when
|
||||
# the stream is interrupted and this generator is abandoned before being fully consumed
|
||||
step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step)
|
||||
# check compaction
|
||||
if self.context_token_estimate > self.agent_state.llm_config.context_window:
|
||||
summary_message, messages = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window)
|
||||
# TODO: persist + return the summary message
|
||||
# TODO: convert this to a SummaryMessage
|
||||
self.response_messages.append(summary_message)
|
||||
for message in Message.to_letta_messages(summary_message):
|
||||
yield message
|
||||
await self._checkpoint_messages(
|
||||
run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages
|
||||
)
|
||||
|
||||
except Exception as e:
|
||||
# NOTE: message persistence does not happen in the case of an exception (rollback to previous state)
|
||||
self.logger.warning(f"Error during step processing: {e}")
|
||||
self.job_update_metadata = {"error": str(e)}
|
||||
|
||||
@@ -707,20 +802,14 @@ class LettaAgentV3(LettaAgentV2):
|
||||
StopReasonType.invalid_tool_call,
|
||||
StopReasonType.invalid_llm_response,
|
||||
StopReasonType.llm_api_error,
|
||||
StopReasonType.context_window_overflow_in_system_prompt,
|
||||
):
|
||||
self.logger.warning("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason)
|
||||
raise e
|
||||
finally:
|
||||
# always make sure we update the step/run metadata
|
||||
self.logger.debug("Running cleanup for agent loop run: %s", run_id)
|
||||
self.logger.info("Running final update. Step Progression: %s", step_progression)
|
||||
|
||||
# update message ids
|
||||
message_ids = [m.id for m in messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
try:
|
||||
if step_progression == StepProgression.FINISHED:
|
||||
if not self.should_continue:
|
||||
@@ -728,7 +817,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value)
|
||||
if logged_step and step_id:
|
||||
await self.step_manager.update_step_stop_reason(self.actor, step_id, self.stop_reason.stop_reason)
|
||||
return
|
||||
if not self.stop_reason.stop_reason == StopReasonType.context_window_overflow_in_system_prompt:
|
||||
# only return if the stop reason is not context window overflow in system prompt
|
||||
return
|
||||
if step_progression < StepProgression.STEP_LOGGED:
|
||||
# Error occurred before step was fully logged
|
||||
import traceback
|
||||
@@ -742,19 +833,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
error_traceback=traceback.format_exc(),
|
||||
stop_reason=self.stop_reason,
|
||||
)
|
||||
if step_progression <= StepProgression.STREAM_RECEIVED:
|
||||
if first_chunk and settings.track_errored_messages and input_messages_to_persist:
|
||||
for message in input_messages_to_persist:
|
||||
message.is_err = True
|
||||
message.step_id = step_id
|
||||
message.run_id = run_id
|
||||
await self.message_manager.create_many_messages_async(
|
||||
input_messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if self.stop_reason is None:
|
||||
self.logger.warning("Error in step after logging step")
|
||||
@@ -806,6 +884,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
|
||||
Unified approach: treats single and multi-tool calls uniformly to reduce code duplication.
|
||||
"""
|
||||
|
||||
# 1. Handle no-tool cases (content-only or no-op)
|
||||
if not tool_calls and not tool_call_denials and not tool_returns:
|
||||
# Case 1a: No tool call, no content (LLM no-op)
|
||||
@@ -863,22 +942,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
add_heartbeat_on_continue=bool(heartbeat_reason),
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + assistant_message
|
||||
|
||||
# Persist messages for no-tool cases
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
message.run_id = run_id
|
||||
if message.step_id is None:
|
||||
message.step_id = step_id
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
return messages_to_persist, continue_stepping, stop_reason
|
||||
|
||||
# 2. Check whether tool call requires approval
|
||||
if not is_approval_response:
|
||||
@@ -896,21 +960,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
run_id=run_id,
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + approval_messages
|
||||
|
||||
for message in messages_to_persist:
|
||||
if message.run_id is None:
|
||||
message.run_id = run_id
|
||||
if message.step_id is None:
|
||||
message.step_id = step_id
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
return messages_to_persist, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value)
|
||||
|
||||
result_tool_returns = []
|
||||
|
||||
@@ -1148,31 +1198,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if message.step_id is None:
|
||||
message.step_id = step_id
|
||||
|
||||
# Persist all messages
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist,
|
||||
actor=self.actor,
|
||||
run_id=run_id,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
)
|
||||
|
||||
# Update message_ids immediately after persistence to prevent desync
|
||||
# This handles approval responses where we need to keep message_ids in sync
|
||||
if (
|
||||
is_approval_response
|
||||
and initial_messages
|
||||
and len(initial_messages) == 1
|
||||
and initial_messages[0].role == "approval"
|
||||
and len(persisted_messages) >= 2
|
||||
and persisted_messages[0].role == "approval"
|
||||
and persisted_messages[1].role == "tool"
|
||||
):
|
||||
self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor
|
||||
)
|
||||
|
||||
# 5g. Aggregate continuation decisions
|
||||
aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False
|
||||
aggregate_continue = aggregate_continue or tool_call_denials or tool_returns
|
||||
@@ -1193,7 +1218,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
# Force continuation for parallel tool execution
|
||||
aggregate_continue = True
|
||||
aggregate_stop_reason = None
|
||||
return persisted_messages, aggregate_continue, aggregate_stop_reason
|
||||
return messages_to_persist, aggregate_continue, aggregate_stop_reason
|
||||
|
||||
@trace_method
|
||||
def _decide_continuation(
|
||||
@@ -1282,178 +1307,118 @@ class LettaAgentV3(LettaAgentV2):
|
||||
return allowed_tools
|
||||
|
||||
@trace_method
|
||||
async def summarize_conversation_history(
|
||||
self,
|
||||
# The messages already in the context window
|
||||
in_context_messages: list[Message],
|
||||
# The messages produced by the agent in this step
|
||||
new_letta_messages: list[Message],
|
||||
# The token usage from the most recent LLM call (prompt + completion)
|
||||
total_tokens: int | None = None,
|
||||
# If force, then don't do any counting, just summarize
|
||||
force: bool = False,
|
||||
) -> list[Message]:
|
||||
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
|
||||
|
||||
# no summarization if the last message is an approval request
|
||||
latest_messages = in_context_messages + new_letta_messages
|
||||
pending_approval = latest_messages[-1].role == "approval" and len(latest_messages[-1].tool_calls) > 0
|
||||
if pending_approval:
|
||||
trigger_summarization = False
|
||||
self.logger.info(
|
||||
f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}, pending_approval: {pending_approval}"
|
||||
)
|
||||
if not trigger_summarization:
|
||||
# just update the message_ids
|
||||
# TODO: gross to handle this here: we should move persistence elsewhere
|
||||
new_in_context_messages = in_context_messages + new_letta_messages
|
||||
message_ids = [m.id for m in new_in_context_messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = message_ids
|
||||
return new_in_context_messages
|
||||
|
||||
async def compact(self, messages, trigger_threshold: Optional[int] = None) -> Message:
|
||||
"""
|
||||
Simplified compaction method. Does NOT do any persistence (handled in the loop)
|
||||
"""
|
||||
# compact the current in-context messages (self.in_context_messages)
|
||||
# Use agent's summarizer_config if set, otherwise fall back to defaults
|
||||
# TODO: add this back
|
||||
# summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config)
|
||||
summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings())
|
||||
|
||||
summarization_mode_used = summarizer_config.mode
|
||||
if summarizer_config.mode == "all":
|
||||
summary_message_str, new_in_context_messages = await summarize_all(
|
||||
summary, compacted_messages = await summarize_all(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
in_context_messages=messages,
|
||||
)
|
||||
elif summarizer_config.mode == "sliding_window":
|
||||
try:
|
||||
summary_message_str, new_in_context_messages = await summarize_via_sliding_window(
|
||||
summary, compacted_messages = await summarize_via_sliding_window(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
in_context_messages=messages,
|
||||
)
|
||||
except Exception as e:
|
||||
self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.")
|
||||
summary_message_str, new_in_context_messages = await summarize_all(
|
||||
summary, compacted_messages = await summarize_all(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
in_context_messages=messages,
|
||||
)
|
||||
summarization_mode_used = "all"
|
||||
else:
|
||||
raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}")
|
||||
|
||||
# Persist the summary message to DB
|
||||
summary_message_str_packed = package_summarize_message_no_counts(
|
||||
summary=summary_message_str,
|
||||
timezone=self.agent_state.timezone,
|
||||
)
|
||||
summary_message_obj = (
|
||||
await convert_message_creates_to_messages(
|
||||
message_creates=[
|
||||
MessageCreate(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=summary_message_str_packed)],
|
||||
)
|
||||
],
|
||||
agent_id=self.agent_state.id,
|
||||
timezone=self.agent_state.timezone,
|
||||
# We already packed, don't pack again
|
||||
wrap_user_message=False,
|
||||
wrap_system_message=False,
|
||||
run_id=None, # TODO: add this
|
||||
)
|
||||
)[0]
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.actor,
|
||||
project_id=self.agent_state.project_id,
|
||||
template_id=self.agent_state.template_id,
|
||||
# update the token count
|
||||
self.context_token_estimate = await count_tokens(
|
||||
actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages
|
||||
)
|
||||
self.logger.info(f"Context token estimate after summarization: {self.context_token_estimate}")
|
||||
|
||||
# Update the message_ids in the agent state to include the summary
|
||||
# plus whatever tail we decided to keep.
|
||||
new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages
|
||||
new_in_context_message_ids = [m.id for m in new_in_context_messages]
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=new_in_context_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = new_in_context_message_ids
|
||||
|
||||
# After summarization, recompute an approximate token count for the
|
||||
# updated in-context messages so that subsequent summarization
|
||||
# decisions don't keep firing based on a stale, pre-summarization
|
||||
# total_tokens value.
|
||||
try:
|
||||
new_total_tokens = await count_tokens(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=new_in_context_messages,
|
||||
)
|
||||
|
||||
context_limit = self.agent_state.llm_config.context_window
|
||||
trigger_threshold = int(context_limit * SUMMARIZATION_TRIGGER_MULTIPLIER)
|
||||
|
||||
# if the trigger_threshold is provided, we need to make sure that the new token count is below it
|
||||
if trigger_threshold is not None and self.context_token_estimate >= trigger_threshold:
|
||||
# If even after summarization the context is still at or above
|
||||
# the proactive summarization threshold, treat this as a hard
|
||||
# failure: log loudly and evict all prior conversation state
|
||||
# (keeping only the system message) to avoid getting stuck in
|
||||
# repeated summarization loops.
|
||||
if new_total_tokens > trigger_threshold:
|
||||
self.logger.error(
|
||||
"Summarization failed to sufficiently reduce context size: "
|
||||
f"post-summarization tokens={new_total_tokens}, "
|
||||
f"threshold={trigger_threshold}, context_window={context_limit}. "
|
||||
"Evicting all prior messages without a summary to break potential loops.",
|
||||
)
|
||||
|
||||
# Keep only the system message in-context.
|
||||
system_message = in_context_messages[0]
|
||||
new_in_context_messages = [system_message]
|
||||
new_in_context_message_ids = [system_message.id]
|
||||
|
||||
await self.agent_manager.update_message_ids_async(
|
||||
agent_id=self.agent_state.id,
|
||||
message_ids=new_in_context_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = new_in_context_message_ids
|
||||
|
||||
# Recompute token usage for this minimal context and update
|
||||
# context_token_estimate so future checks see the reduced size.
|
||||
try:
|
||||
minimal_tokens = await count_tokens(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
messages=new_in_context_messages,
|
||||
)
|
||||
self.context_token_estimate = minimal_tokens
|
||||
except Exception as inner_e:
|
||||
self.logger.warning(
|
||||
f"Failed to recompute token usage after hard eviction: {inner_e}",
|
||||
exc_info=True,
|
||||
)
|
||||
|
||||
return new_in_context_messages
|
||||
|
||||
# Normal case: summarization succeeded in bringing us below the
|
||||
# proactive threshold. Update context_token_estimate so future
|
||||
# summarization checks reason over the *post*-summarization
|
||||
# context size.
|
||||
self.context_token_estimate = new_total_tokens
|
||||
except Exception as e: # best-effort; never block the agent on this
|
||||
self.logger.warning(
|
||||
f"Failed to recompute token usage after summarization: {e}",
|
||||
exc_info=True,
|
||||
self.logger.error(
|
||||
"Summarization failed to sufficiently reduce context size: "
|
||||
f"post-summarization tokens={self.context_token_estimate}, "
|
||||
f"threshold={trigger_threshold}, context_window={self.context_token_estimate}. "
|
||||
"Evicting all prior messages without a summary to break potential loops.",
|
||||
)
|
||||
|
||||
return new_in_context_messages
|
||||
# if we used the sliding window mode, try to summarize again with the all mode
|
||||
if summarization_mode_used == "sliding_window":
|
||||
# try to summarize again with the all mode
|
||||
summary, compacted_messages = await summarize_all(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=compacted_messages,
|
||||
)
|
||||
summarization_mode_used = "all"
|
||||
|
||||
self.context_token_estimate = await count_tokens(
|
||||
actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages
|
||||
)
|
||||
|
||||
# final edge case: the system prompt is the cause of the context overflow (raise error)
|
||||
if self.context_token_estimate >= trigger_threshold:
|
||||
await self._check_for_system_prompt_overflow(compacted_messages[0])
|
||||
|
||||
# raise an error if this is STILL not the problem
|
||||
# do not throw an error, since we don't want to brick the agent
|
||||
self.logger.error(
|
||||
f"Failed to summarize messages after hard eviction and checking the system prompt token estimate: {self.context_token_estimate} > {trigger_threshold}"
|
||||
)
|
||||
else:
|
||||
self.logger.info(
|
||||
f"Summarization fallback succeeded in bringing the context size below the trigger threshold: {self.context_token_estimate} < {trigger_threshold}"
|
||||
)
|
||||
|
||||
# Persist the summary message to DB
|
||||
summary_message_str_packed = package_summarize_message_no_counts(
|
||||
summary=summary,
|
||||
timezone=self.agent_state.timezone,
|
||||
)
|
||||
summary_messages = await convert_message_creates_to_messages(
|
||||
message_creates=[
|
||||
MessageCreate(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=summary_message_str_packed)],
|
||||
)
|
||||
],
|
||||
agent_id=self.agent_state.id,
|
||||
timezone=self.agent_state.timezone,
|
||||
# We already packed, don't pack again
|
||||
wrap_user_message=False,
|
||||
wrap_system_message=False,
|
||||
run_id=None, # TODO: add this
|
||||
)
|
||||
if not len(summary_messages) == 1:
|
||||
self.logger.error(f"Expected only one summary message, got {len(summary_messages)} in {summary_messages}")
|
||||
summary_message_obj = summary_messages[0]
|
||||
|
||||
# final messages: inject summarization message at the beginning
|
||||
final_messages = [compacted_messages[0]] + [summary_message_obj]
|
||||
if len(compacted_messages) > 1:
|
||||
final_messages += compacted_messages[1:]
|
||||
|
||||
return summary_message_obj, final_messages
|
||||
|
||||
@@ -265,6 +265,16 @@ class ContextWindowExceededError(LettaError):
|
||||
)
|
||||
|
||||
|
||||
class SystemPromptTokenExceededError(ContextWindowExceededError):
|
||||
"""Error raised when the system prompt token estimate exceeds the context window."""
|
||||
|
||||
def __init__(self, system_prompt_token_estimate: int, context_window: int):
|
||||
message = f"The system prompt tokens {system_prompt_token_estimate} exceeds the context window {context_window}. Please reduce the size of your system prompt, memory blocks, or increase the context window."
|
||||
super().__init__(
|
||||
message=message, details={"system_prompt_token_estimate": system_prompt_token_estimate, "context_window": context_window}
|
||||
)
|
||||
|
||||
|
||||
class RateLimitExceededError(LettaError):
|
||||
"""Error raised when the llm rate limiter throttles api requests."""
|
||||
|
||||
|
||||
@@ -17,6 +17,7 @@ class StopReasonType(str, Enum):
|
||||
tool_rule = "tool_rule"
|
||||
cancelled = "cancelled"
|
||||
requires_approval = "requires_approval"
|
||||
context_window_overflow_in_system_prompt = "context_window_overflow_in_system_prompt"
|
||||
|
||||
@property
|
||||
def run_status(self) -> RunStatus:
|
||||
@@ -33,6 +34,7 @@ class StopReasonType(str, Enum):
|
||||
StopReasonType.no_tool_call,
|
||||
StopReasonType.invalid_llm_response,
|
||||
StopReasonType.llm_api_error,
|
||||
StopReasonType.context_window_overflow_in_system_prompt,
|
||||
):
|
||||
return RunStatus.failed
|
||||
elif self == StopReasonType.cancelled:
|
||||
|
||||
@@ -2112,13 +2112,12 @@ async def summarize_messages(
|
||||
if agent_eligible and model_compatible:
|
||||
agent_loop = LettaAgentV3(agent_state=agent, actor=actor)
|
||||
in_context_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor)
|
||||
await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
total_tokens=None,
|
||||
force=True,
|
||||
summary_message, messages = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
)
|
||||
# Summarization completed, return 204 No Content
|
||||
|
||||
# update the agent state
|
||||
await agent_loop._checkpoint_messages(run_id=None, step_id=None, new_messages=[summary_message], in_context_messages=messages)
|
||||
else:
|
||||
raise HTTPException(
|
||||
status_code=status.HTTP_403_FORBIDDEN,
|
||||
|
||||
@@ -20,7 +20,7 @@ async def summarize_all(
|
||||
# Actual summarization configuration
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
new_messages: List[Message],
|
||||
# new_messages: List[Message],
|
||||
) -> str:
|
||||
"""
|
||||
Summarize the entire conversation history into a single summary.
|
||||
@@ -28,8 +28,7 @@ async def summarize_all(
|
||||
Returns:
|
||||
- The summary string
|
||||
"""
|
||||
all_in_context_messages = in_context_messages + new_messages
|
||||
messages_to_summarize = all_in_context_messages[1:]
|
||||
messages_to_summarize = in_context_messages[1:]
|
||||
|
||||
# TODO: add fallback in case this has a context window error
|
||||
summary_message_str = await simple_summary(
|
||||
@@ -44,4 +43,4 @@ async def summarize_all(
|
||||
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
|
||||
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
|
||||
|
||||
return summary_message_str, []
|
||||
return summary_message_str, [in_context_messages[0]]
|
||||
|
||||
@@ -50,7 +50,7 @@ async def summarize_via_sliding_window(
|
||||
llm_config: LLMConfig,
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
new_messages: List[Message],
|
||||
# new_messages: List[Message],
|
||||
) -> Tuple[str, List[Message]]:
|
||||
"""
|
||||
If the total tokens is greater than the context window limit (or force=True),
|
||||
@@ -68,53 +68,42 @@ async def summarize_via_sliding_window(
|
||||
- The list of message IDs to keep in-context
|
||||
"""
|
||||
system_prompt = in_context_messages[0]
|
||||
all_in_context_messages = in_context_messages + new_messages
|
||||
total_message_count = len(all_in_context_messages)
|
||||
total_message_count = len(in_context_messages)
|
||||
|
||||
# Starts at N% (eg 70%), and increments up until 100%
|
||||
message_count_cutoff_percent = max(
|
||||
1 - summarizer_config.sliding_window_percentage, 0.10
|
||||
) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage
|
||||
found_cutoff = False
|
||||
assert summarizer_config.sliding_window_percentage <= 1.0, "Sliding window percentage must be less than or equal to 1.0"
|
||||
assistant_message_index = None
|
||||
approx_token_count = llm_config.context_window
|
||||
|
||||
# Count tokens with system prompt, and message past cutoff point
|
||||
assistant_message_index = None # Initialize to track if we found an assistant message
|
||||
while not found_cutoff:
|
||||
# Mark the approximate cutoff
|
||||
message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages))
|
||||
while (
|
||||
approx_token_count >= summarizer_config.sliding_window_percentage * llm_config.context_window and message_count_cutoff_percent < 1.0
|
||||
):
|
||||
# calculate message_cutoff_index
|
||||
message_cutoff_index = round(message_count_cutoff_percent * total_message_count)
|
||||
|
||||
# we've reached the maximum message cutoff
|
||||
if message_cutoff_index >= total_message_count:
|
||||
# get index of first assistant message in range
|
||||
assistant_message_index = next(
|
||||
(i for i in range(message_cutoff_index, total_message_count) if in_context_messages[i].role == MessageRole.assistant), None
|
||||
)
|
||||
|
||||
# if no assistant message in tail, break out of loop (since future iterations will continue hitting this case)
|
||||
if assistant_message_index is None:
|
||||
break
|
||||
|
||||
# Walk up the list until we find the first assistant message
|
||||
for i in range(message_cutoff_index, total_message_count):
|
||||
if all_in_context_messages[i].role == MessageRole.assistant:
|
||||
assistant_message_index = i
|
||||
break
|
||||
else:
|
||||
raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}")
|
||||
# update token count
|
||||
post_summarization_buffer = [system_prompt] + in_context_messages[assistant_message_index:]
|
||||
approx_token_count = await count_tokens(actor, llm_config, post_summarization_buffer)
|
||||
|
||||
# Count tokens of the hypothetical post-summarization buffer
|
||||
post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:]
|
||||
post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer)
|
||||
|
||||
# If hypothetical post-summarization count lower than the target remaining percentage?
|
||||
if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window:
|
||||
found_cutoff = True
|
||||
else:
|
||||
message_count_cutoff_percent += 0.10
|
||||
if message_count_cutoff_percent >= 1.0:
|
||||
message_cutoff_index = total_message_count
|
||||
break
|
||||
|
||||
# If we found the cutoff, summarize and return
|
||||
# If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization
|
||||
# increment cutoff
|
||||
message_count_cutoff_percent += 0.10
|
||||
|
||||
if assistant_message_index is None:
|
||||
raise ValueError("No assistant message found for sliding window summarization") # fall back to complete summarization
|
||||
|
||||
messages_to_summarize = all_in_context_messages[1:message_cutoff_index]
|
||||
messages_to_summarize = in_context_messages[1:assistant_message_index]
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
@@ -128,5 +117,5 @@ async def summarize_via_sliding_window(
|
||||
logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.")
|
||||
summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]"
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
return summary_message_str, updated_in_context_messages
|
||||
updated_in_context_messages = in_context_messages[assistant_message_index:]
|
||||
return summary_message_str, [system_prompt] + updated_in_context_messages
|
||||
|
||||
@@ -908,7 +908,6 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too
|
||||
|
||||
# Should succeed since both sandbox and tool pip requirements were installed
|
||||
assert "Success!" in result.func_return
|
||||
assert "Status: 200" in result.func_return
|
||||
assert "Array sum: 6" in result.func_return
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,17 @@ import pytest
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.agent import CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessage
|
||||
from letta.schemas.letta_message_content import TextContent, ToolCallContent, ToolReturnContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.message import Message as PydanticMessage, MessageCreate
|
||||
from letta.schemas.run import Run as PydanticRun
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.run_manager import RunManager
|
||||
|
||||
# Constants
|
||||
DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai")
|
||||
@@ -40,8 +44,8 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model
|
||||
# Test configurations - using a subset of models for summarization tests
|
||||
all_configs = [
|
||||
"openai-gpt-5-mini.json",
|
||||
"claude-4-5-haiku.json",
|
||||
"gemini-2.5-flash.json",
|
||||
# "claude-4-5-haiku.json",
|
||||
# "gemini-2.5-flash.json",
|
||||
# "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials
|
||||
# "openai-gpt-4.1.json",
|
||||
# "openai-o1.json",
|
||||
@@ -175,17 +179,12 @@ async def run_summarization(server: SyncServer, agent_state, in_context_messages
|
||||
2. Fetch messages via message_manager.get_messages_by_ids_async
|
||||
3. Call agent_loop.summarize_conversation_history with force=True
|
||||
"""
|
||||
agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor)
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# Run summarization with force parameter
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
total_tokens=None,
|
||||
force=force,
|
||||
)
|
||||
summary_message, messages = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
return result
|
||||
return summary_message, messages
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -218,11 +217,24 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con
|
||||
|
||||
# Run summarization - this may fail with empty buffer, which is acceptable behavior
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
# If it succeeds, verify result
|
||||
assert isinstance(result, list)
|
||||
# With empty buffer, result should still be empty or contain only system messages
|
||||
assert len(result) <= len(in_context_messages)
|
||||
|
||||
# When summarization runs, V3 ensures that in-context messages follow
|
||||
# the pattern:
|
||||
# 1. System prompt
|
||||
# 2. User summary message (system_alert JSON)
|
||||
# 3. Remaining messages (which may be empty for this test)
|
||||
|
||||
# We should always keep the original system message at the front.
|
||||
assert len(result) >= 1
|
||||
assert result[0].role == MessageRole.system
|
||||
|
||||
# If summarization did in fact add a summary message, we expect it to
|
||||
# be the second message with user role.
|
||||
if len(result) >= 2:
|
||||
assert result[1].role == MessageRole.user
|
||||
except ValueError as e:
|
||||
# It's acceptable for summarization to fail on empty buffer
|
||||
assert "No assistant message found" in str(e) or "empty" in str(e).lower()
|
||||
@@ -255,7 +267,7 @@ async def test_summarize_initialization_messages_only(server: SyncServer, actor,
|
||||
|
||||
# Run summarization - force=True with system messages only may fail
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -311,7 +323,7 @@ async def test_summarize_small_conversation(server: SyncServer, actor, llm_confi
|
||||
# Run summarization with force=True
|
||||
# Note: force=True with clear=True can be very aggressive and may fail on small message sets
|
||||
try:
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -404,7 +416,7 @@ async def test_summarize_large_tool_calls(server: SyncServer, actor, llm_config:
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -508,7 +520,7 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll
|
||||
assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -579,7 +591,7 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor,
|
||||
assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}"
|
||||
|
||||
# Run summarization
|
||||
result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
summary, result = await run_summarization(server, agent_state, in_context_messages, actor)
|
||||
|
||||
# Verify result
|
||||
assert isinstance(result, list)
|
||||
@@ -678,12 +690,7 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
with patch("letta.agents.letta_agent_v3.get_default_summarizer_config", mock_get_default_summarizer_config):
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=new_letta_messages,
|
||||
total_tokens=None,
|
||||
force=True,
|
||||
)
|
||||
summary, result = await agent_loop.compact(messages=in_context_messages)
|
||||
|
||||
assert isinstance(result, list)
|
||||
|
||||
@@ -700,24 +707,21 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
print()
|
||||
|
||||
if mode == "all":
|
||||
# For "all" mode, result should be just the summary message
|
||||
assert len(result) == 2, f"Expected 1 message for 'all' mode, got {len(result)}"
|
||||
# For "all" mode, V3 keeps:
|
||||
# 1. System prompt
|
||||
# 2. A single user summary message (system_alert JSON)
|
||||
# and no remaining historical messages.
|
||||
assert len(result) == 2, f"Expected 2 messages for 'all' mode (system + summary), got {len(result)}"
|
||||
assert result[0].role == MessageRole.system
|
||||
assert result[1].role == MessageRole.user
|
||||
else:
|
||||
# For "sliding_window" mode, result should include recent messages + summary
|
||||
assert len(result) > 1, f"Expected >1 messages for 'sliding_window' mode, got {len(result)}"
|
||||
# validate new user message
|
||||
assert result[-1].role == MessageRole.user and result[-1].agent_id == agent_state.id, (
|
||||
f"Expected new user message with agent_id {agent_state.id}, got {result[-1]}"
|
||||
)
|
||||
assert "This is a new user message" in result[-1].content[0].text, (
|
||||
f"Expected 'This is a new user message' in the user message, got {result[-1]}"
|
||||
)
|
||||
|
||||
# validate system message
|
||||
assert result[0].role == MessageRole.system
|
||||
# validate summary message
|
||||
assert "prior messages" in result[1].content[0].text, f"Expected 'prior messages' in the summary message, got {result[1]}"
|
||||
print(f"Mode '{mode}' with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages")
|
||||
# For "sliding_window" mode, result should include:
|
||||
# 1. System prompt
|
||||
# 2. User summary message
|
||||
# 3+. Recent user/assistant messages inside the window.
|
||||
assert len(result) > 2, f"Expected >2 messages for 'sliding_window' mode, got {len(result)}"
|
||||
assert result[0].role == MessageRole.system
|
||||
assert result[1].role == MessageRole.user
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -740,15 +744,16 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
is still above the trigger threshold.
|
||||
3. We verify that LettaAgentV3:
|
||||
- Logs an error about summarization failing to reduce context size.
|
||||
- Evicts all prior messages, keeping only the system message.
|
||||
- Evicts all prior messages, keeping only the system message plus a
|
||||
single synthetic user summary message (system_alert).
|
||||
- Updates `context_token_estimate` to the token count of the minimal
|
||||
context so future steps don't keep re-triggering summarization based
|
||||
on a stale, oversized value.
|
||||
"""
|
||||
|
||||
# Build a small but non-trivial conversation with an explicit system
|
||||
# message so that after hard eviction we expect to keep exactly that one
|
||||
# message.
|
||||
# message so that after hard eviction we expect to keep exactly that
|
||||
# system message plus a single user summary message.
|
||||
messages = [
|
||||
PydanticMessage(
|
||||
role=MessageRole.system,
|
||||
@@ -766,6 +771,10 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
|
||||
agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
|
||||
print("ORIGINAL IN-CONTEXT MESSAGES ======")
|
||||
for msg in in_context_messages:
|
||||
print(f"MSG: {msg}")
|
||||
|
||||
# Create the V3 agent loop
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
@@ -787,36 +796,26 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold(
|
||||
|
||||
caplog.set_level("ERROR")
|
||||
|
||||
result = await agent_loop.summarize_conversation_history(
|
||||
in_context_messages=in_context_messages,
|
||||
new_letta_messages=[],
|
||||
# total_tokens is not used when force=True for triggering, but we
|
||||
# set it to a large value for clarity.
|
||||
total_tokens=llm_config.context_window * 2 if llm_config.context_window else None,
|
||||
force=True,
|
||||
summary, result = await agent_loop.compact(
|
||||
messages=in_context_messages,
|
||||
trigger_threshold=context_limit,
|
||||
)
|
||||
|
||||
# We should have made exactly two token-count calls: one for the
|
||||
# summarized context, one for the hard-evicted minimal context.
|
||||
assert mock_count_tokens.call_count == 2
|
||||
|
||||
# After hard eviction, only the system message should remain in-context.
|
||||
print("COMPACTED RESULT ======")
|
||||
for msg in result:
|
||||
print(f"MSG: {msg}")
|
||||
|
||||
# After hard eviction, we keep only:
|
||||
# 1. The system prompt
|
||||
# 2. The synthetic user summary message.
|
||||
assert isinstance(result, list)
|
||||
assert len(result) == 1, f"Expected only the system message after hard eviction, got {len(result)} messages"
|
||||
assert len(result) == 2, f"Expected system + summary after hard eviction, got {len(result)} messages"
|
||||
assert result[0].role == MessageRole.system
|
||||
|
||||
# Agent state should also reflect exactly one message id.
|
||||
assert len(agent_loop.agent_state.message_ids) == 1
|
||||
|
||||
# context_token_estimate should be updated to the minimal token count
|
||||
# (second side-effect value from count_tokens), rather than the original
|
||||
# huge value.
|
||||
assert agent_loop.context_token_estimate == 10
|
||||
|
||||
# Verify that we logged an error about summarization failing to reduce
|
||||
# context size.
|
||||
error_logs = [rec for rec in caplog.records if "Summarization failed to sufficiently reduce context size" in rec.getMessage()]
|
||||
assert error_logs, "Expected an error log when summarization fails to reduce context size sufficiently"
|
||||
assert result[1].role == MessageRole.user
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
@@ -893,7 +892,6 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated (actual LLM response)
|
||||
@@ -924,6 +922,105 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server:
|
||||
raise
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test edge case of large system prompt / memory blocks.
|
||||
|
||||
This test verifies that summarization handles the case where the system prompt
|
||||
and memory blocks are very large, potentially consuming most of the context window.
|
||||
The summarizer should gracefully handle this scenario without errors.
|
||||
"""
|
||||
|
||||
# Override context window to be small so we trigger summarization
|
||||
llm_config.context_window = 10000
|
||||
|
||||
# Create agent with large system prompt and memory blocks
|
||||
agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_")
|
||||
agent_create = CreateAgent(
|
||||
name=agent_name,
|
||||
llm_config=llm_config,
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
system="SYSTEM PROMPT " * 10000, # Large system prompt
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="human",
|
||||
limit=200000,
|
||||
value="NAME " * 10000, # Large memory block
|
||||
)
|
||||
],
|
||||
)
|
||||
agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor)
|
||||
|
||||
# Create a run for the agent using RunManager
|
||||
run = PydanticRun(agent_id=agent_state.id)
|
||||
run = await RunManager().create_run(pydantic_run=run, actor=actor)
|
||||
|
||||
# Create the agent loop using LettaAgentV3
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# message the agent
|
||||
input_message = MessageCreate(role=MessageRole.user, content="Hello")
|
||||
|
||||
# Call step on the agent - may trigger summarization due to large context
|
||||
from letta.errors import SystemPromptTokenExceededError
|
||||
|
||||
with pytest.raises(SystemPromptTokenExceededError):
|
||||
response = await agent_loop.step(
|
||||
input_messages=[input_message],
|
||||
run_id=run.id,
|
||||
max_steps=3,
|
||||
)
|
||||
|
||||
# Repair the agent by shortening the memory blocks and system prompt
|
||||
# Update system prompt to a shorter version
|
||||
short_system_prompt = "You are a helpful assistant."
|
||||
await server.agent_manager.update_agent_async(
|
||||
agent_id=agent_state.id,
|
||||
agent_update=UpdateAgent(system=short_system_prompt),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Update memory block to a shorter version
|
||||
short_memory_value = "The user's name is Alice."
|
||||
await server.agent_manager.modify_block_by_label_async(
|
||||
agent_id=agent_state.id,
|
||||
block_label="human",
|
||||
block_update=BlockUpdate(value=short_memory_value),
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# Reload agent state after repairs
|
||||
agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
print("REPAIRED AGENT STATE ======")
|
||||
print(agent_state.system)
|
||||
print(agent_state.blocks)
|
||||
|
||||
# Create a new run for the repaired agent
|
||||
run = PydanticRun(agent_id=agent_state.id)
|
||||
run = await RunManager().create_run(pydantic_run=run, actor=actor)
|
||||
|
||||
# Create a new agent loop with the repaired agent state
|
||||
agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
|
||||
# Now the agent should be able to respond without context window errors
|
||||
response = await agent_loop.step(
|
||||
input_messages=[input_message],
|
||||
run_id=run.id,
|
||||
max_steps=3,
|
||||
)
|
||||
|
||||
# Verify we got a valid response after repair
|
||||
assert response is not None
|
||||
assert response.messages is not None
|
||||
print(f"Agent successfully responded after repair with {len(response.messages)} messages")
|
||||
|
||||
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
|
||||
# """
|
||||
@@ -1342,11 +1439,10 @@ async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Verify the summary was generated
|
||||
assert len(new_in_context_messages) == 0
|
||||
assert len(new_in_context_messages) == 1
|
||||
assert summary is not None
|
||||
assert len(summary) > 0
|
||||
assert len(summary) <= 2000
|
||||
|
||||
Reference in New Issue
Block a user