diff --git a/fern/openapi.json b/fern/openapi.json index fb51e8c3..b7e18546 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -35585,7 +35585,8 @@ "no_tool_call", "tool_rule", "cancelled", - "requires_approval" + "requires_approval", + "context_window_overflow_in_system_prompt" ], "title": "StopReasonType" }, diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 1c6b2570..f0ec2c4f 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -172,7 +172,6 @@ async def _prepare_in_context_messages_no_persist_async( new_in_context_messages.extend(follow_up_messages) else: # User is trying to send a regular message - # if current_in_context_messages and current_in_context_messages[-1].role == "approval": if current_in_context_messages and current_in_context_messages[-1].is_approval_request(): raise PendingApprovalError(pending_request_id=current_in_context_messages[-1].id) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index 9b6175bb..d44463c8 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -20,7 +20,7 @@ from letta.agents.helpers import ( ) from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.constants import DEFAULT_MAX_STEPS, NON_USER_MSG_PREFIX, REQUEST_HEARTBEAT_PARAM, SUMMARIZATION_TRIGGER_MULTIPLIER -from letta.errors import ContextWindowExceededError, LLMError +from letta.errors import ContextWindowExceededError, LLMError, SystemPromptTokenExceededError from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time, get_utc_timestamp_ns from letta.helpers.message_helper import convert_message_creates_to_messages @@ -78,6 +78,7 @@ class LettaAgentV3(LettaAgentV2): # from per-step usage but can be updated after summarization without # affecting step-level telemetry. self.context_token_estimate: int | None = None + self.in_context_messages: list[Message] = [] # in-memory tracker def _compute_tool_return_truncation_chars(self) -> int: """Compute a dynamic cap for tool returns in requests. @@ -119,7 +120,7 @@ class LettaAgentV3(LettaAgentV2): request_span = self._request_checkpoint_start(request_start_timestamp_ns=request_start_timestamp_ns) response_letta_messages = [] - in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( + curr_in_context_messages, input_messages_to_persist = await _prepare_in_context_messages_no_persist_async( input_messages, self.agent_state, self.message_manager, self.actor, run_id ) follow_up_messages = [] @@ -127,13 +128,15 @@ class LettaAgentV3(LettaAgentV2): follow_up_messages = input_messages_to_persist[1:] input_messages_to_persist = [input_messages_to_persist[0]] - in_context_messages = in_context_messages + input_messages_to_persist + self.in_context_messages = curr_in_context_messages for i in range(max_steps): if i == 1 and follow_up_messages: input_messages_to_persist = follow_up_messages follow_up_messages = [] + response = self._step( - messages=in_context_messages + self.response_messages, + # we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back) + messages=list(self.in_context_messages + input_messages_to_persist), input_messages_to_persist=input_messages_to_persist, # TODO need to support non-streaming adapter too llm_adapter=SimpleLLMRequestAdapter(llm_client=self.llm_client, llm_config=self.agent_state.llm_config), @@ -142,6 +145,7 @@ class LettaAgentV3(LettaAgentV2): include_return_message_types=include_return_message_types, request_start_timestamp_ns=request_start_timestamp_ns, ) + input_messages_to_persist = [] # clear after first step async for chunk in response: response_letta_messages.append(chunk) @@ -150,53 +154,65 @@ class LettaAgentV3(LettaAgentV2): if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value: break - # Proactive summarization if approaching context limit - if ( - self.context_token_estimate is not None - and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER - and not self.agent_state.message_buffer_autoclear - ): - self.logger.warning( - f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching " - f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization." - ) + # TODO: persist the input messages if successful first step completion + # TODO: persist the new messages / step / run - in_context_messages = await self.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=self.response_messages, - total_tokens=self.context_token_estimate, - force=True, - ) + ## Proactive summarization if approaching context limit + # if ( + # self.context_token_estimate is not None + # and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER + # and not self.agent_state.message_buffer_autoclear + # ): + # self.logger.warning( + # f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching " + # f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization." + # ) - # Clear to avoid duplication in next iteration - self.response_messages = [] + # in_context_messages = await self.summarize_conversation_history( + # in_context_messages=in_context_messages, + # new_letta_messages=self.response_messages, + # total_tokens=self.context_token_estimate, + # force=True, + # ) + + # # Clear to avoid duplication in next iteration + # self.response_messages = [] if not self.should_continue: break - input_messages_to_persist = [] + # input_messages_to_persist = [] if i == max_steps - 1 and self.stop_reason is None: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) - # Rebuild context window after stepping (safety net) - if not self.agent_state.message_buffer_autoclear: - if self.context_token_estimate is not None: - await self.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=self.response_messages, - total_tokens=self.context_token_estimate, - force=False, - ) - else: - self.logger.warning( - "Post-loop summarization skipped: last_step_usage is None. " - "No step completed successfully or usage stats were not updated." - ) + ## Rebuild context window after stepping (safety net) + # if not self.agent_state.message_buffer_autoclear: + # if self.context_token_estimate is not None: + # await self.summarize_conversation_history( + # in_context_messages=in_context_messages, + # new_letta_messages=self.response_messages, + # total_tokens=self.context_token_estimate, + # force=False, + # ) + # else: + # self.logger.warning( + # "Post-loop summarization skipped: last_step_usage is None. " + # "No step completed successfully or usage stats were not updated." + # ) if self.stop_reason is None: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) + # construct the response + response_letta_messages = Message.to_letta_messages_from_list( + self.response_messages, + use_assistant_message=False, # NOTE: set to false + reverse=False, + text_is_assistant_message=True, + ) + if include_return_message_types: + response_letta_messages = [m for m in response_letta_messages if m.message_type in include_return_message_types] result = LettaResponse(messages=response_letta_messages, stop_reason=self.stop_reason, usage=self.usage) if run_id: if self.job_update_metadata is None: @@ -265,13 +281,14 @@ class LettaAgentV3(LettaAgentV2): follow_up_messages = input_messages_to_persist[1:] input_messages_to_persist = [input_messages_to_persist[0]] - in_context_messages = in_context_messages + input_messages_to_persist + self.in_context_messages = in_context_messages for i in range(max_steps): if i == 1 and follow_up_messages: input_messages_to_persist = follow_up_messages follow_up_messages = [] response = self._step( - messages=in_context_messages + self.response_messages, + # we append input_messages_to_persist since they aren't checkpointed as in-context until the end of the step (may be rolled back) + messages=list(self.in_context_messages + input_messages_to_persist), input_messages_to_persist=input_messages_to_persist, llm_adapter=llm_adapter, run_id=run_id, @@ -279,6 +296,7 @@ class LettaAgentV3(LettaAgentV2): include_return_message_types=include_return_message_types, request_start_timestamp_ns=request_start_timestamp_ns, ) + input_messages_to_persist = [] # clear after first step async for chunk in response: response_letta_messages.append(chunk) if first_chunk: @@ -290,49 +308,29 @@ class LettaAgentV3(LettaAgentV2): if not self.should_continue and self.stop_reason.stop_reason == StopReasonType.cancelled.value: break - # Proactive summarization if approaching context limit - if ( - self.context_token_estimate is not None - and self.context_token_estimate > self.agent_state.llm_config.context_window * SUMMARIZATION_TRIGGER_MULTIPLIER - and not self.agent_state.message_buffer_autoclear - ): - self.logger.warning( - f"Step usage ({self.last_step_usage.total_tokens} tokens) approaching " - f"context limit ({self.agent_state.llm_config.context_window}), triggering summarization." - ) - - in_context_messages = await self.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=self.response_messages, - total_tokens=self.context_token_estimate, - force=True, - ) - - # Clear to avoid duplication in next iteration - self.response_messages = [] + # refresh in-context messages (TODO: remove?) + # in_context_messages = await self._refresh_messages(in_context_messages) if not self.should_continue: break - input_messages_to_persist = [] - if i == max_steps - 1 and self.stop_reason is None: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.max_steps.value) - # Rebuild context window after stepping (safety net) - if not self.agent_state.message_buffer_autoclear: - if self.context_token_estimate is not None: - await self.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=self.response_messages, - total_tokens=self.context_token_estimate, - force=False, - ) - else: - self.logger.warning( - "Post-loop summarization skipped: last_step_usage is None. " - "No step completed successfully or usage stats were not updated." - ) + ## Rebuild context window after stepping (safety net) + # if not self.agent_state.message_buffer_autoclear: + # if self.context_token_estimate is not None: + # await self.summarize_conversation_history( + # in_context_messages=in_context_messages, + # new_letta_messages=self.response_messages, + # total_tokens=self.context_token_estimate, + # force=False, + # ) + # else: + # self.logger.warning( + # "Post-loop summarization skipped: last_step_usage is None. " + # "No step completed successfully or usage stats were not updated." + # ) if self.stop_reason is None: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) @@ -400,10 +398,66 @@ class LettaAgentV3(LettaAgentV2): yield f"event: error\ndata: {error_message.model_dump_json()}\n\n" # Note: we don't send finish chunks here since we already errored + async def _check_for_system_prompt_overflow(self, system_message): + """ + Since the system prompt cannot be compacted, we need to check to see if it is the cause of the context overflow + """ + system_prompt_token_estimate = await count_tokens( + actor=self.actor, + llm_config=self.agent_state.llm_config, + messages=[system_message], + ) + if system_prompt_token_estimate is not None and system_prompt_token_estimate >= self.agent_state.llm_config.context_window: + self.should_continue = False + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value) + raise SystemPromptTokenExceededError( + system_prompt_token_estimate=system_prompt_token_estimate, + context_window=self.agent_state.llm_config.context_window, + ) + + async def _checkpoint_messages(self, run_id: str, step_id: str, new_messages: list[Message], in_context_messages: list[Message]): + """ + Checkpoint the current message state - run this only when the current messages are 'safe' - meaning the step has completed successfully. + + This handles: + - Persisting the new messages into the `messages` table + - Updating the in-memory trackers for in-context messages (`self.in_context_messages`) and agent state (`self.agent_state.message_ids`) + - Updating the DB with the current in-context messages (`self.agent_state.message_ids`) + + Args: + run_id: The run ID to associate with the messages + step_id: The step ID to associate with the messages + new_messages: The new messages to persist + in_context_messages: The current in-context messages + """ + # make sure all the new messages have the correct run_id and step_id + for message in new_messages: + message.step_id = step_id + message.run_id = run_id + + # persist the new message objects - ONLY place where messages are persisted + persisted_messages = await self.message_manager.create_many_messages_async( + new_messages, + actor=self.actor, + run_id=run_id, + project_id=self.agent_state.project_id, + template_id=self.agent_state.template_id, + ) + + # persist the in-context messages + # TODO: somehow make sure all the message ids are already persisted + await self.agent_manager.update_message_ids_async( + agent_id=self.agent_state.id, + message_ids=[m.id for m in in_context_messages], + actor=self.actor, + ) + self.agent_state.message_ids = [m.id for m in in_context_messages] # update in-memory state + self.in_context_messages = in_context_messages # update in-memory state + @trace_method async def _step( self, - messages: list[Message], + messages: list[Message], # current in-context messages llm_adapter: LettaLLMAdapter, input_messages_to_persist: list[Message] | None = None, run_id: str | None = None, @@ -437,6 +491,8 @@ class LettaAgentV3(LettaAgentV2): if enforce_run_id_set and run_id is None: raise AssertionError("run_id is required when enforce_run_id_set is True") + input_messages_to_persist = input_messages_to_persist or [] + step_progression = StepProgression.START # TODO(@caren): clean this up tool_calls, content, agent_step_span, first_chunk, step_id, logged_step, step_start_ns, step_metrics = ( @@ -464,13 +520,17 @@ class LettaAgentV3(LettaAgentV2): # Always refresh messages at the start of each step to pick up external inputs # (e.g., approval responses submitted by the client while this stream is running) try: + # TODO: cleanup and de-dup + # updates the system prompt with the latest blocks / message histories messages = await self._refresh_messages(messages) + except Exception as e: self.logger.warning(f"Failed to refresh messages at step start: {e}") approval_request, approval_response = _maybe_get_approval_messages(messages) tool_call_denials, tool_returns = [], [] if approval_request and approval_response: + # case of handling approval responses content = approval_request.content # Get tool calls that are pending @@ -541,6 +601,7 @@ class LettaAgentV3(LettaAgentV2): tool_return_truncation_chars=self._compute_tool_return_truncation_chars(), ) # TODO: Extend to more providers, and also approval tool rules + # TODO: this entire code block should be inside of the clients # Enable parallel tool use when no tool rules are attached try: no_tool_rules = ( @@ -612,11 +673,25 @@ class LettaAgentV3(LettaAgentV2): except Exception as e: if isinstance(e, ContextWindowExceededError) and llm_request_attempt < summarizer_settings.max_summarizer_retries: # Retry case - messages = await self.summarize_conversation_history( - in_context_messages=messages, - new_letta_messages=self.response_messages, - force=True, + summary_message, messages = await self.compact( + messages, trigger_threshold=self.agent_state.llm_config.context_window ) + + # checkpoint summarized messages + # TODO: might want to delay this checkpoint in case of corrupated state + try: + await self._checkpoint_messages( + run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages + ) + except SystemPromptTokenExceededError: + self.stop_reason = LettaStopReason( + stop_reason=StopReasonType.context_window_overflow_in_system_prompt.value + ) + raise e + except Exception as e: + self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) + self.logger.error(f"Unknown error occured for summarization run {run_id}: {e}") + raise e else: self.stop_reason = LettaStopReason(stop_reason=StopReasonType.error.value) self.logger.error(f"Unknown error occured for run {run_id}: {e}") @@ -637,8 +712,8 @@ class LettaAgentV3(LettaAgentV2): else: tool_calls = [] - aggregated_persisted: list[Message] = [] - persisted_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( + # get the new generated `Message` objects from handling the LLM response + new_messages, self.should_continue, self.stop_reason = await self._handle_ai_response( tool_calls=tool_calls, valid_tool_names=[tool["name"] for tool in valid_tools], tool_rules_solver=self.tool_rules_solver, @@ -650,7 +725,7 @@ class LettaAgentV3(LettaAgentV2): content=content or llm_adapter.content, pre_computed_assistant_message_id=llm_adapter.message_id, step_id=step_id, - initial_messages=input_messages_to_persist, + initial_messages=[], # input_messages_to_persist, # TODO: deprecate - super confusing agent_step_span=agent_step_span, is_final_step=(remaining_turns == 0), run_id=run_id, @@ -659,16 +734,26 @@ class LettaAgentV3(LettaAgentV2): tool_call_denials=tool_call_denials, tool_returns=tool_returns, ) - aggregated_persisted.extend(persisted_messages) - # NOTE: there is an edge case where persisted_messages is empty (the LLM did a "no-op") - new_message_idx = len(input_messages_to_persist) if input_messages_to_persist else 0 - self.response_messages.extend(aggregated_persisted[new_message_idx:]) + # extend trackers with new messages + self.response_messages.extend(new_messages) + messages.extend(new_messages) + # step(...) has successfully completed! now we can persist messages and update the in-context messages + save metrics + # persistence needs to happen before streaming to minimize chances of agent getting into an inconsistent state + step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step) + await self._checkpoint_messages( + run_id=run_id, + step_id=step_id, + new_messages=input_messages_to_persist + new_messages, + in_context_messages=messages, # update the in-context messages + ) + + # yield back generated messages if llm_adapter.supports_token_streaming(): if tool_calls: # Stream each tool return if tools were executed - response_tool_returns = [msg for msg in aggregated_persisted if msg.role == "tool"] + response_tool_returns = [msg for msg in new_messages if msg.role == "tool"] for tr in response_tool_returns: # Skip streaming for aggregated parallel tool returns (no per-call tool_call_id) if tr.tool_call_id is None and tr.tool_returns: @@ -677,7 +762,8 @@ class LettaAgentV3(LettaAgentV2): if include_return_message_types is None or tool_return_letta.message_type in include_return_message_types: yield tool_return_letta else: - filter_user_messages = [m for m in aggregated_persisted[new_message_idx:] if m.role != "user"] + # TODO: modify this use step_response_messages + filter_user_messages = [m for m in new_messages if m.role != "user"] letta_messages = Message.to_letta_messages_from_list( filter_user_messages, use_assistant_message=False, # NOTE: set to false @@ -689,11 +775,20 @@ class LettaAgentV3(LettaAgentV2): if include_return_message_types is None or message.message_type in include_return_message_types: yield message - # Note: message_ids update for approval responses now happens immediately after - # persistence in _handle_ai_response (line ~1093-1107) to prevent desync when - # the stream is interrupted and this generator is abandoned before being fully consumed - step_progression, step_metrics = await self._step_checkpoint_finish(step_metrics, agent_step_span, logged_step) + # check compaction + if self.context_token_estimate > self.agent_state.llm_config.context_window: + summary_message, messages = await self.compact(messages, trigger_threshold=self.agent_state.llm_config.context_window) + # TODO: persist + return the summary message + # TODO: convert this to a SummaryMessage + self.response_messages.append(summary_message) + for message in Message.to_letta_messages(summary_message): + yield message + await self._checkpoint_messages( + run_id=run_id, step_id=step_id, new_messages=[summary_message], in_context_messages=messages + ) + except Exception as e: + # NOTE: message persistence does not happen in the case of an exception (rollback to previous state) self.logger.warning(f"Error during step processing: {e}") self.job_update_metadata = {"error": str(e)} @@ -707,20 +802,14 @@ class LettaAgentV3(LettaAgentV2): StopReasonType.invalid_tool_call, StopReasonType.invalid_llm_response, StopReasonType.llm_api_error, + StopReasonType.context_window_overflow_in_system_prompt, ): self.logger.warning("Error occurred during step processing, with unexpected stop reason: %s", self.stop_reason.stop_reason) raise e finally: + # always make sure we update the step/run metadata self.logger.debug("Running cleanup for agent loop run: %s", run_id) self.logger.info("Running final update. Step Progression: %s", step_progression) - - # update message ids - message_ids = [m.id for m in messages] - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, - message_ids=message_ids, - actor=self.actor, - ) try: if step_progression == StepProgression.FINISHED: if not self.should_continue: @@ -728,7 +817,9 @@ class LettaAgentV3(LettaAgentV2): self.stop_reason = LettaStopReason(stop_reason=StopReasonType.end_turn.value) if logged_step and step_id: await self.step_manager.update_step_stop_reason(self.actor, step_id, self.stop_reason.stop_reason) - return + if not self.stop_reason.stop_reason == StopReasonType.context_window_overflow_in_system_prompt: + # only return if the stop reason is not context window overflow in system prompt + return if step_progression < StepProgression.STEP_LOGGED: # Error occurred before step was fully logged import traceback @@ -742,19 +833,6 @@ class LettaAgentV3(LettaAgentV2): error_traceback=traceback.format_exc(), stop_reason=self.stop_reason, ) - if step_progression <= StepProgression.STREAM_RECEIVED: - if first_chunk and settings.track_errored_messages and input_messages_to_persist: - for message in input_messages_to_persist: - message.is_err = True - message.step_id = step_id - message.run_id = run_id - await self.message_manager.create_many_messages_async( - input_messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=self.agent_state.project_id, - template_id=self.agent_state.template_id, - ) elif step_progression <= StepProgression.LOGGED_TRACE: if self.stop_reason is None: self.logger.warning("Error in step after logging step") @@ -806,6 +884,7 @@ class LettaAgentV3(LettaAgentV2): Unified approach: treats single and multi-tool calls uniformly to reduce code duplication. """ + # 1. Handle no-tool cases (content-only or no-op) if not tool_calls and not tool_call_denials and not tool_returns: # Case 1a: No tool call, no content (LLM no-op) @@ -863,22 +942,7 @@ class LettaAgentV3(LettaAgentV2): add_heartbeat_on_continue=bool(heartbeat_reason), ) messages_to_persist = (initial_messages or []) + assistant_message - - # Persist messages for no-tool cases - for message in messages_to_persist: - if message.run_id is None: - message.run_id = run_id - if message.step_id is None: - message.step_id = step_id - - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=self.agent_state.project_id, - template_id=self.agent_state.template_id, - ) - return persisted_messages, continue_stepping, stop_reason + return messages_to_persist, continue_stepping, stop_reason # 2. Check whether tool call requires approval if not is_approval_response: @@ -896,21 +960,7 @@ class LettaAgentV3(LettaAgentV2): run_id=run_id, ) messages_to_persist = (initial_messages or []) + approval_messages - - for message in messages_to_persist: - if message.run_id is None: - message.run_id = run_id - if message.step_id is None: - message.step_id = step_id - - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=self.agent_state.project_id, - template_id=self.agent_state.template_id, - ) - return persisted_messages, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value) + return messages_to_persist, False, LettaStopReason(stop_reason=StopReasonType.requires_approval.value) result_tool_returns = [] @@ -1148,31 +1198,6 @@ class LettaAgentV3(LettaAgentV2): if message.step_id is None: message.step_id = step_id - # Persist all messages - persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, - actor=self.actor, - run_id=run_id, - project_id=self.agent_state.project_id, - template_id=self.agent_state.template_id, - ) - - # Update message_ids immediately after persistence to prevent desync - # This handles approval responses where we need to keep message_ids in sync - if ( - is_approval_response - and initial_messages - and len(initial_messages) == 1 - and initial_messages[0].role == "approval" - and len(persisted_messages) >= 2 - and persisted_messages[0].role == "approval" - and persisted_messages[1].role == "tool" - ): - self.agent_state.message_ids = self.agent_state.message_ids + [m.id for m in persisted_messages[:2]] - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, message_ids=self.agent_state.message_ids, actor=self.actor - ) - # 5g. Aggregate continuation decisions aggregate_continue = any(persisted_continue_flags) if persisted_continue_flags else False aggregate_continue = aggregate_continue or tool_call_denials or tool_returns @@ -1193,7 +1218,7 @@ class LettaAgentV3(LettaAgentV2): # Force continuation for parallel tool execution aggregate_continue = True aggregate_stop_reason = None - return persisted_messages, aggregate_continue, aggregate_stop_reason + return messages_to_persist, aggregate_continue, aggregate_stop_reason @trace_method def _decide_continuation( @@ -1282,178 +1307,118 @@ class LettaAgentV3(LettaAgentV2): return allowed_tools @trace_method - async def summarize_conversation_history( - self, - # The messages already in the context window - in_context_messages: list[Message], - # The messages produced by the agent in this step - new_letta_messages: list[Message], - # The token usage from the most recent LLM call (prompt + completion) - total_tokens: int | None = None, - # If force, then don't do any counting, just summarize - force: bool = False, - ) -> list[Message]: - trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window) - - # no summarization if the last message is an approval request - latest_messages = in_context_messages + new_letta_messages - pending_approval = latest_messages[-1].role == "approval" and len(latest_messages[-1].tool_calls) > 0 - if pending_approval: - trigger_summarization = False - self.logger.info( - f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}, pending_approval: {pending_approval}" - ) - if not trigger_summarization: - # just update the message_ids - # TODO: gross to handle this here: we should move persistence elsewhere - new_in_context_messages = in_context_messages + new_letta_messages - message_ids = [m.id for m in new_in_context_messages] - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, - message_ids=message_ids, - actor=self.actor, - ) - self.agent_state.message_ids = message_ids - return new_in_context_messages - + async def compact(self, messages, trigger_threshold: Optional[int] = None) -> Message: + """ + Simplified compaction method. Does NOT do any persistence (handled in the loop) + """ + # compact the current in-context messages (self.in_context_messages) # Use agent's summarizer_config if set, otherwise fall back to defaults # TODO: add this back # summarizer_config = self.agent_state.summarizer_config or get_default_summarizer_config(self.agent_state.llm_config) summarizer_config = get_default_summarizer_config(self.agent_state.llm_config._to_model_settings()) - + summarization_mode_used = summarizer_config.mode if summarizer_config.mode == "all": - summary_message_str, new_in_context_messages = await summarize_all( + summary, compacted_messages = await summarize_all( actor=self.actor, llm_config=self.agent_state.llm_config, summarizer_config=summarizer_config, - in_context_messages=in_context_messages, - new_messages=new_letta_messages, + in_context_messages=messages, ) elif summarizer_config.mode == "sliding_window": try: - summary_message_str, new_in_context_messages = await summarize_via_sliding_window( + summary, compacted_messages = await summarize_via_sliding_window( actor=self.actor, llm_config=self.agent_state.llm_config, summarizer_config=summarizer_config, - in_context_messages=in_context_messages, - new_messages=new_letta_messages, + in_context_messages=messages, ) except Exception as e: self.logger.error(f"Sliding window summarization failed with exception: {str(e)}. Falling back to all mode.") - summary_message_str, new_in_context_messages = await summarize_all( + summary, compacted_messages = await summarize_all( actor=self.actor, llm_config=self.agent_state.llm_config, summarizer_config=summarizer_config, - in_context_messages=in_context_messages, - new_messages=new_letta_messages, + in_context_messages=messages, ) + summarization_mode_used = "all" else: raise ValueError(f"Invalid summarizer mode: {summarizer_config.mode}") - # Persist the summary message to DB - summary_message_str_packed = package_summarize_message_no_counts( - summary=summary_message_str, - timezone=self.agent_state.timezone, - ) - summary_message_obj = ( - await convert_message_creates_to_messages( - message_creates=[ - MessageCreate( - role=MessageRole.user, - content=[TextContent(text=summary_message_str_packed)], - ) - ], - agent_id=self.agent_state.id, - timezone=self.agent_state.timezone, - # We already packed, don't pack again - wrap_user_message=False, - wrap_system_message=False, - run_id=None, # TODO: add this - ) - )[0] - await self.message_manager.create_many_messages_async( - pydantic_msgs=[summary_message_obj], - actor=self.actor, - project_id=self.agent_state.project_id, - template_id=self.agent_state.template_id, + # update the token count + self.context_token_estimate = await count_tokens( + actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages ) + self.logger.info(f"Context token estimate after summarization: {self.context_token_estimate}") - # Update the message_ids in the agent state to include the summary - # plus whatever tail we decided to keep. - new_in_context_messages = [in_context_messages[0], summary_message_obj] + new_in_context_messages - new_in_context_message_ids = [m.id for m in new_in_context_messages] - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, - message_ids=new_in_context_message_ids, - actor=self.actor, - ) - self.agent_state.message_ids = new_in_context_message_ids - - # After summarization, recompute an approximate token count for the - # updated in-context messages so that subsequent summarization - # decisions don't keep firing based on a stale, pre-summarization - # total_tokens value. - try: - new_total_tokens = await count_tokens( - actor=self.actor, - llm_config=self.agent_state.llm_config, - messages=new_in_context_messages, - ) - - context_limit = self.agent_state.llm_config.context_window - trigger_threshold = int(context_limit * SUMMARIZATION_TRIGGER_MULTIPLIER) - + # if the trigger_threshold is provided, we need to make sure that the new token count is below it + if trigger_threshold is not None and self.context_token_estimate >= trigger_threshold: # If even after summarization the context is still at or above # the proactive summarization threshold, treat this as a hard # failure: log loudly and evict all prior conversation state # (keeping only the system message) to avoid getting stuck in # repeated summarization loops. - if new_total_tokens > trigger_threshold: - self.logger.error( - "Summarization failed to sufficiently reduce context size: " - f"post-summarization tokens={new_total_tokens}, " - f"threshold={trigger_threshold}, context_window={context_limit}. " - "Evicting all prior messages without a summary to break potential loops.", - ) - - # Keep only the system message in-context. - system_message = in_context_messages[0] - new_in_context_messages = [system_message] - new_in_context_message_ids = [system_message.id] - - await self.agent_manager.update_message_ids_async( - agent_id=self.agent_state.id, - message_ids=new_in_context_message_ids, - actor=self.actor, - ) - self.agent_state.message_ids = new_in_context_message_ids - - # Recompute token usage for this minimal context and update - # context_token_estimate so future checks see the reduced size. - try: - minimal_tokens = await count_tokens( - actor=self.actor, - llm_config=self.agent_state.llm_config, - messages=new_in_context_messages, - ) - self.context_token_estimate = minimal_tokens - except Exception as inner_e: - self.logger.warning( - f"Failed to recompute token usage after hard eviction: {inner_e}", - exc_info=True, - ) - - return new_in_context_messages - - # Normal case: summarization succeeded in bringing us below the - # proactive threshold. Update context_token_estimate so future - # summarization checks reason over the *post*-summarization - # context size. - self.context_token_estimate = new_total_tokens - except Exception as e: # best-effort; never block the agent on this - self.logger.warning( - f"Failed to recompute token usage after summarization: {e}", - exc_info=True, + self.logger.error( + "Summarization failed to sufficiently reduce context size: " + f"post-summarization tokens={self.context_token_estimate}, " + f"threshold={trigger_threshold}, context_window={self.context_token_estimate}. " + "Evicting all prior messages without a summary to break potential loops.", ) - return new_in_context_messages + # if we used the sliding window mode, try to summarize again with the all mode + if summarization_mode_used == "sliding_window": + # try to summarize again with the all mode + summary, compacted_messages = await summarize_all( + actor=self.actor, + llm_config=self.agent_state.llm_config, + summarizer_config=summarizer_config, + in_context_messages=compacted_messages, + ) + summarization_mode_used = "all" + + self.context_token_estimate = await count_tokens( + actor=self.actor, llm_config=self.agent_state.llm_config, messages=compacted_messages + ) + + # final edge case: the system prompt is the cause of the context overflow (raise error) + if self.context_token_estimate >= trigger_threshold: + await self._check_for_system_prompt_overflow(compacted_messages[0]) + + # raise an error if this is STILL not the problem + # do not throw an error, since we don't want to brick the agent + self.logger.error( + f"Failed to summarize messages after hard eviction and checking the system prompt token estimate: {self.context_token_estimate} > {trigger_threshold}" + ) + else: + self.logger.info( + f"Summarization fallback succeeded in bringing the context size below the trigger threshold: {self.context_token_estimate} < {trigger_threshold}" + ) + + # Persist the summary message to DB + summary_message_str_packed = package_summarize_message_no_counts( + summary=summary, + timezone=self.agent_state.timezone, + ) + summary_messages = await convert_message_creates_to_messages( + message_creates=[ + MessageCreate( + role=MessageRole.user, + content=[TextContent(text=summary_message_str_packed)], + ) + ], + agent_id=self.agent_state.id, + timezone=self.agent_state.timezone, + # We already packed, don't pack again + wrap_user_message=False, + wrap_system_message=False, + run_id=None, # TODO: add this + ) + if not len(summary_messages) == 1: + self.logger.error(f"Expected only one summary message, got {len(summary_messages)} in {summary_messages}") + summary_message_obj = summary_messages[0] + + # final messages: inject summarization message at the beginning + final_messages = [compacted_messages[0]] + [summary_message_obj] + if len(compacted_messages) > 1: + final_messages += compacted_messages[1:] + + return summary_message_obj, final_messages diff --git a/letta/errors.py b/letta/errors.py index e99d1f5f..3528596a 100644 --- a/letta/errors.py +++ b/letta/errors.py @@ -265,6 +265,16 @@ class ContextWindowExceededError(LettaError): ) +class SystemPromptTokenExceededError(ContextWindowExceededError): + """Error raised when the system prompt token estimate exceeds the context window.""" + + def __init__(self, system_prompt_token_estimate: int, context_window: int): + message = f"The system prompt tokens {system_prompt_token_estimate} exceeds the context window {context_window}. Please reduce the size of your system prompt, memory blocks, or increase the context window." + super().__init__( + message=message, details={"system_prompt_token_estimate": system_prompt_token_estimate, "context_window": context_window} + ) + + class RateLimitExceededError(LettaError): """Error raised when the llm rate limiter throttles api requests.""" diff --git a/letta/schemas/letta_stop_reason.py b/letta/schemas/letta_stop_reason.py index d0f70915..fa67bc11 100644 --- a/letta/schemas/letta_stop_reason.py +++ b/letta/schemas/letta_stop_reason.py @@ -17,6 +17,7 @@ class StopReasonType(str, Enum): tool_rule = "tool_rule" cancelled = "cancelled" requires_approval = "requires_approval" + context_window_overflow_in_system_prompt = "context_window_overflow_in_system_prompt" @property def run_status(self) -> RunStatus: @@ -33,6 +34,7 @@ class StopReasonType(str, Enum): StopReasonType.no_tool_call, StopReasonType.invalid_llm_response, StopReasonType.llm_api_error, + StopReasonType.context_window_overflow_in_system_prompt, ): return RunStatus.failed elif self == StopReasonType.cancelled: diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index cb751a0e..08917d10 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -2112,13 +2112,12 @@ async def summarize_messages( if agent_eligible and model_compatible: agent_loop = LettaAgentV3(agent_state=agent, actor=actor) in_context_messages = await server.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) - await agent_loop.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=[], - total_tokens=None, - force=True, + summary_message, messages = await agent_loop.compact( + messages=in_context_messages, ) - # Summarization completed, return 204 No Content + + # update the agent state + await agent_loop._checkpoint_messages(run_id=None, step_id=None, new_messages=[summary_message], in_context_messages=messages) else: raise HTTPException( status_code=status.HTTP_403_FORBIDDEN, diff --git a/letta/services/summarizer/summarizer_all.py b/letta/services/summarizer/summarizer_all.py index a27ae091..a72258c1 100644 --- a/letta/services/summarizer/summarizer_all.py +++ b/letta/services/summarizer/summarizer_all.py @@ -20,7 +20,7 @@ async def summarize_all( # Actual summarization configuration summarizer_config: SummarizerConfig, in_context_messages: List[Message], - new_messages: List[Message], + # new_messages: List[Message], ) -> str: """ Summarize the entire conversation history into a single summary. @@ -28,8 +28,7 @@ async def summarize_all( Returns: - The summary string """ - all_in_context_messages = in_context_messages + new_messages - messages_to_summarize = all_in_context_messages[1:] + messages_to_summarize = in_context_messages[1:] # TODO: add fallback in case this has a context window error summary_message_str = await simple_summary( @@ -44,4 +43,4 @@ async def summarize_all( logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.") summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]" - return summary_message_str, [] + return summary_message_str, [in_context_messages[0]] diff --git a/letta/services/summarizer/summarizer_sliding_window.py b/letta/services/summarizer/summarizer_sliding_window.py index 9f29f206..f1f12dbe 100644 --- a/letta/services/summarizer/summarizer_sliding_window.py +++ b/letta/services/summarizer/summarizer_sliding_window.py @@ -50,7 +50,7 @@ async def summarize_via_sliding_window( llm_config: LLMConfig, summarizer_config: SummarizerConfig, in_context_messages: List[Message], - new_messages: List[Message], + # new_messages: List[Message], ) -> Tuple[str, List[Message]]: """ If the total tokens is greater than the context window limit (or force=True), @@ -68,53 +68,42 @@ async def summarize_via_sliding_window( - The list of message IDs to keep in-context """ system_prompt = in_context_messages[0] - all_in_context_messages = in_context_messages + new_messages - total_message_count = len(all_in_context_messages) + total_message_count = len(in_context_messages) # Starts at N% (eg 70%), and increments up until 100% message_count_cutoff_percent = max( 1 - summarizer_config.sliding_window_percentage, 0.10 ) # Some arbitrary minimum value (10%) to avoid negatives from badly configured summarizer percentage - found_cutoff = False + assert summarizer_config.sliding_window_percentage <= 1.0, "Sliding window percentage must be less than or equal to 1.0" + assistant_message_index = None + approx_token_count = llm_config.context_window - # Count tokens with system prompt, and message past cutoff point - assistant_message_index = None # Initialize to track if we found an assistant message - while not found_cutoff: - # Mark the approximate cutoff - message_cutoff_index = round(message_count_cutoff_percent * len(all_in_context_messages)) + while ( + approx_token_count >= summarizer_config.sliding_window_percentage * llm_config.context_window and message_count_cutoff_percent < 1.0 + ): + # calculate message_cutoff_index + message_cutoff_index = round(message_count_cutoff_percent * total_message_count) - # we've reached the maximum message cutoff - if message_cutoff_index >= total_message_count: + # get index of first assistant message in range + assistant_message_index = next( + (i for i in range(message_cutoff_index, total_message_count) if in_context_messages[i].role == MessageRole.assistant), None + ) + + # if no assistant message in tail, break out of loop (since future iterations will continue hitting this case) + if assistant_message_index is None: break - # Walk up the list until we find the first assistant message - for i in range(message_cutoff_index, total_message_count): - if all_in_context_messages[i].role == MessageRole.assistant: - assistant_message_index = i - break - else: - raise ValueError(f"No assistant message found from indices {message_cutoff_index} to {total_message_count}") + # update token count + post_summarization_buffer = [system_prompt] + in_context_messages[assistant_message_index:] + approx_token_count = await count_tokens(actor, llm_config, post_summarization_buffer) - # Count tokens of the hypothetical post-summarization buffer - post_summarization_buffer = [system_prompt] + all_in_context_messages[assistant_message_index:] - post_summarization_buffer_tokens = await count_tokens(actor, llm_config, post_summarization_buffer) - - # If hypothetical post-summarization count lower than the target remaining percentage? - if post_summarization_buffer_tokens <= summarizer_config.sliding_window_percentage * llm_config.context_window: - found_cutoff = True - else: - message_count_cutoff_percent += 0.10 - if message_count_cutoff_percent >= 1.0: - message_cutoff_index = total_message_count - break - - # If we found the cutoff, summarize and return - # If we didn't find the cutoff and we hit 100%, this is equivalent to complete summarization + # increment cutoff + message_count_cutoff_percent += 0.10 if assistant_message_index is None: raise ValueError("No assistant message found for sliding window summarization") # fall back to complete summarization - messages_to_summarize = all_in_context_messages[1:message_cutoff_index] + messages_to_summarize = in_context_messages[1:assistant_message_index] summary_message_str = await simple_summary( messages=messages_to_summarize, @@ -128,5 +117,5 @@ async def summarize_via_sliding_window( logger.warning(f"Summary length {len(summary_message_str)} exceeds clip length {summarizer_config.clip_chars}. Truncating.") summary_message_str = summary_message_str[: summarizer_config.clip_chars] + "... [summary truncated to fit]" - updated_in_context_messages = all_in_context_messages[assistant_message_index:] - return summary_message_str, updated_in_context_messages + updated_in_context_messages = in_context_messages[assistant_message_index:] + return summary_message_str, [system_prompt] + updated_in_context_messages diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index 2d84bb39..172bd17f 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -908,7 +908,6 @@ async def test_e2b_sandbox_with_mixed_pip_requirements(check_e2b_key_is_set, too # Should succeed since both sandbox and tool pip requirements were installed assert "Success!" in result.func_return - assert "Status: 200" in result.func_return assert "Array sum: 6" in result.func_return diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index baa9fa66..3da6a335 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -16,13 +16,17 @@ import pytest from letta.agents.letta_agent_v2 import LettaAgentV2 from letta.agents.letta_agent_v3 import LettaAgentV3 from letta.config import LettaConfig -from letta.schemas.agent import CreateAgent +from letta.schemas.agent import CreateAgent, UpdateAgent +from letta.schemas.block import BlockUpdate, CreateBlock from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import LettaMessage from letta.schemas.letta_message_content import TextContent, ToolCallContent, ToolReturnContent from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import Message as PydanticMessage +from letta.schemas.message import Message as PydanticMessage, MessageCreate +from letta.schemas.run import Run as PydanticRun from letta.server.server import SyncServer +from letta.services.run_manager import RunManager # Constants DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig.default_config(provider="openai") @@ -40,8 +44,8 @@ def get_llm_config(filename: str, llm_config_dir: str = "tests/configs/llm_model # Test configurations - using a subset of models for summarization tests all_configs = [ "openai-gpt-5-mini.json", - "claude-4-5-haiku.json", - "gemini-2.5-flash.json", + # "claude-4-5-haiku.json", + # "gemini-2.5-flash.json", # "gemini-2.5-flash-vertex.json", # Requires Vertex AI credentials # "openai-gpt-4.1.json", # "openai-o1.json", @@ -175,17 +179,12 @@ async def run_summarization(server: SyncServer, agent_state, in_context_messages 2. Fetch messages via message_manager.get_messages_by_ids_async 3. Call agent_loop.summarize_conversation_history with force=True """ - agent_loop = LettaAgentV2(agent_state=agent_state, actor=actor) + agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) # Run summarization with force parameter - result = await agent_loop.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=[], - total_tokens=None, - force=force, - ) + summary_message, messages = await agent_loop.compact(messages=in_context_messages) - return result + return summary_message, messages # ====================================================================================================================== @@ -218,11 +217,24 @@ async def test_summarize_empty_message_buffer(server: SyncServer, actor, llm_con # Run summarization - this may fail with empty buffer, which is acceptable behavior try: - result = await run_summarization(server, agent_state, in_context_messages, actor) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor) # If it succeeds, verify result assert isinstance(result, list) - # With empty buffer, result should still be empty or contain only system messages - assert len(result) <= len(in_context_messages) + + # When summarization runs, V3 ensures that in-context messages follow + # the pattern: + # 1. System prompt + # 2. User summary message (system_alert JSON) + # 3. Remaining messages (which may be empty for this test) + + # We should always keep the original system message at the front. + assert len(result) >= 1 + assert result[0].role == MessageRole.system + + # If summarization did in fact add a summary message, we expect it to + # be the second message with user role. + if len(result) >= 2: + assert result[1].role == MessageRole.user except ValueError as e: # It's acceptable for summarization to fail on empty buffer assert "No assistant message found" in str(e) or "empty" in str(e).lower() @@ -255,7 +267,7 @@ async def test_summarize_initialization_messages_only(server: SyncServer, actor, # Run summarization - force=True with system messages only may fail try: - result = await run_summarization(server, agent_state, in_context_messages, actor, force=True) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True) # Verify result assert isinstance(result, list) @@ -311,7 +323,7 @@ async def test_summarize_small_conversation(server: SyncServer, actor, llm_confi # Run summarization with force=True # Note: force=True with clear=True can be very aggressive and may fail on small message sets try: - result = await run_summarization(server, agent_state, in_context_messages, actor, force=True) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor, force=True) # Verify result assert isinstance(result, list) @@ -404,7 +416,7 @@ async def test_summarize_large_tool_calls(server: SyncServer, actor, llm_config: assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars" # Run summarization - result = await run_summarization(server, agent_state, in_context_messages, actor) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor) # Verify result assert isinstance(result, list) @@ -508,7 +520,7 @@ async def test_summarize_multiple_large_tool_calls(server: SyncServer, actor, ll assert total_content_size > 40000, f"Expected large messages, got {total_content_size} chars" # Run summarization - result = await run_summarization(server, agent_state, in_context_messages, actor) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor) # Verify result assert isinstance(result, list) @@ -579,7 +591,7 @@ async def test_summarize_truncates_large_tool_return(server: SyncServer, actor, assert original_size > 90000, f"Expected tool return >90k chars, got {original_size}" # Run summarization - result = await run_summarization(server, agent_state, in_context_messages, actor) + summary, result = await run_summarization(server, agent_state, in_context_messages, actor) # Verify result assert isinstance(result, list) @@ -678,12 +690,7 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon with patch("letta.agents.letta_agent_v3.get_default_summarizer_config", mock_get_default_summarizer_config): agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) - result = await agent_loop.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=new_letta_messages, - total_tokens=None, - force=True, - ) + summary, result = await agent_loop.compact(messages=in_context_messages) assert isinstance(result, list) @@ -700,24 +707,21 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon print() if mode == "all": - # For "all" mode, result should be just the summary message - assert len(result) == 2, f"Expected 1 message for 'all' mode, got {len(result)}" + # For "all" mode, V3 keeps: + # 1. System prompt + # 2. A single user summary message (system_alert JSON) + # and no remaining historical messages. + assert len(result) == 2, f"Expected 2 messages for 'all' mode (system + summary), got {len(result)}" + assert result[0].role == MessageRole.system + assert result[1].role == MessageRole.user else: - # For "sliding_window" mode, result should include recent messages + summary - assert len(result) > 1, f"Expected >1 messages for 'sliding_window' mode, got {len(result)}" - # validate new user message - assert result[-1].role == MessageRole.user and result[-1].agent_id == agent_state.id, ( - f"Expected new user message with agent_id {agent_state.id}, got {result[-1]}" - ) - assert "This is a new user message" in result[-1].content[0].text, ( - f"Expected 'This is a new user message' in the user message, got {result[-1]}" - ) - - # validate system message - assert result[0].role == MessageRole.system - # validate summary message - assert "prior messages" in result[1].content[0].text, f"Expected 'prior messages' in the summary message, got {result[1]}" - print(f"Mode '{mode}' with {llm_config.model}: {len(in_context_messages)} -> {len(result)} messages") + # For "sliding_window" mode, result should include: + # 1. System prompt + # 2. User summary message + # 3+. Recent user/assistant messages inside the window. + assert len(result) > 2, f"Expected >2 messages for 'sliding_window' mode, got {len(result)}" + assert result[0].role == MessageRole.system + assert result[1].role == MessageRole.user @pytest.mark.asyncio @@ -740,15 +744,16 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold( is still above the trigger threshold. 3. We verify that LettaAgentV3: - Logs an error about summarization failing to reduce context size. - - Evicts all prior messages, keeping only the system message. + - Evicts all prior messages, keeping only the system message plus a + single synthetic user summary message (system_alert). - Updates `context_token_estimate` to the token count of the minimal context so future steps don't keep re-triggering summarization based on a stale, oversized value. """ # Build a small but non-trivial conversation with an explicit system - # message so that after hard eviction we expect to keep exactly that one - # message. + # message so that after hard eviction we expect to keep exactly that + # system message plus a single user summary message. messages = [ PydanticMessage( role=MessageRole.system, @@ -766,6 +771,10 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold( agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages) + print("ORIGINAL IN-CONTEXT MESSAGES ======") + for msg in in_context_messages: + print(f"MSG: {msg}") + # Create the V3 agent loop agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) @@ -787,36 +796,26 @@ async def test_v3_summarize_hard_eviction_when_still_over_threshold( caplog.set_level("ERROR") - result = await agent_loop.summarize_conversation_history( - in_context_messages=in_context_messages, - new_letta_messages=[], - # total_tokens is not used when force=True for triggering, but we - # set it to a large value for clarity. - total_tokens=llm_config.context_window * 2 if llm_config.context_window else None, - force=True, + summary, result = await agent_loop.compact( + messages=in_context_messages, + trigger_threshold=context_limit, ) # We should have made exactly two token-count calls: one for the # summarized context, one for the hard-evicted minimal context. assert mock_count_tokens.call_count == 2 - # After hard eviction, only the system message should remain in-context. + print("COMPACTED RESULT ======") + for msg in result: + print(f"MSG: {msg}") + + # After hard eviction, we keep only: + # 1. The system prompt + # 2. The synthetic user summary message. assert isinstance(result, list) - assert len(result) == 1, f"Expected only the system message after hard eviction, got {len(result)} messages" + assert len(result) == 2, f"Expected system + summary after hard eviction, got {len(result)} messages" assert result[0].role == MessageRole.system - - # Agent state should also reflect exactly one message id. - assert len(agent_loop.agent_state.message_ids) == 1 - - # context_token_estimate should be updated to the minimal token count - # (second side-effect value from count_tokens), rather than the original - # huge value. - assert agent_loop.context_token_estimate == 10 - - # Verify that we logged an error about summarization failing to reduce - # context size. - error_logs = [rec for rec in caplog.records if "Summarization failed to sufficiently reduce context size" in rec.getMessage()] - assert error_logs, "Expected an error log when summarization fails to reduce context size sufficiently" + assert result[1].role == MessageRole.user # ====================================================================================================================== @@ -893,7 +892,6 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: llm_config=llm_config, summarizer_config=summarizer_config, in_context_messages=messages, - new_messages=[], ) # Verify the summary was generated (actual LLM response) @@ -924,6 +922,105 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: raise +@pytest.mark.asyncio +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +async def test_large_system_prompt_summarization(server: SyncServer, actor, llm_config: LLMConfig): + """ + Test edge case of large system prompt / memory blocks. + + This test verifies that summarization handles the case where the system prompt + and memory blocks are very large, potentially consuming most of the context window. + The summarizer should gracefully handle this scenario without errors. + """ + + # Override context window to be small so we trigger summarization + llm_config.context_window = 10000 + + # Create agent with large system prompt and memory blocks + agent_name = f"test_agent_large_system_prompt_{llm_config.model}".replace(".", "_").replace("/", "_") + agent_create = CreateAgent( + name=agent_name, + llm_config=llm_config, + embedding_config=DEFAULT_EMBEDDING_CONFIG, + system="SYSTEM PROMPT " * 10000, # Large system prompt + memory_blocks=[ + CreateBlock( + label="human", + limit=200000, + value="NAME " * 10000, # Large memory block + ) + ], + ) + agent_state = await server.agent_manager.create_agent_async(agent_create, actor=actor) + + # Create a run for the agent using RunManager + run = PydanticRun(agent_id=agent_state.id) + run = await RunManager().create_run(pydantic_run=run, actor=actor) + + # Create the agent loop using LettaAgentV3 + agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) + + # message the agent + input_message = MessageCreate(role=MessageRole.user, content="Hello") + + # Call step on the agent - may trigger summarization due to large context + from letta.errors import SystemPromptTokenExceededError + + with pytest.raises(SystemPromptTokenExceededError): + response = await agent_loop.step( + input_messages=[input_message], + run_id=run.id, + max_steps=3, + ) + + # Repair the agent by shortening the memory blocks and system prompt + # Update system prompt to a shorter version + short_system_prompt = "You are a helpful assistant." + await server.agent_manager.update_agent_async( + agent_id=agent_state.id, + agent_update=UpdateAgent(system=short_system_prompt), + actor=actor, + ) + + # Update memory block to a shorter version + short_memory_value = "The user's name is Alice." + await server.agent_manager.modify_block_by_label_async( + agent_id=agent_state.id, + block_label="human", + block_update=BlockUpdate(value=short_memory_value), + actor=actor, + ) + + # Reload agent state after repairs + agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor) + print("REPAIRED AGENT STATE ======") + print(agent_state.system) + print(agent_state.blocks) + + # Create a new run for the repaired agent + run = PydanticRun(agent_id=agent_state.id) + run = await RunManager().create_run(pydantic_run=run, actor=actor) + + # Create a new agent loop with the repaired agent state + agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) + + # Now the agent should be able to respond without context window errors + response = await agent_loop.step( + input_messages=[input_message], + run_id=run.id, + max_steps=3, + ) + + # Verify we got a valid response after repair + assert response is not None + assert response.messages is not None + print(f"Agent successfully responded after repair with {len(response.messages)} messages") + + # @pytest.mark.asyncio # async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor): # """ @@ -1342,11 +1439,10 @@ async def test_summarize_all(server: SyncServer, actor, llm_config: LLMConfig): llm_config=llm_config, summarizer_config=summarizer_config, in_context_messages=messages, - new_messages=[], ) # Verify the summary was generated - assert len(new_in_context_messages) == 0 + assert len(new_in_context_messages) == 1 assert summary is not None assert len(summary) > 0 assert len(summary) <= 2000