feat: fix new summarizer code and add more tests (#6461)

This commit is contained in:
Sarah Wooders
2025-11-30 00:49:38 -08:00
committed by Caren Thomas
parent 86023db9b1
commit 91e3dd8b3e
25 changed files with 728 additions and 358 deletions

View File

@@ -47,7 +47,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
# Convert response to chat completion format
self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config)
self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion(
self.response_data, messages, self.llm_config
)
# Extract reasoning content from the response
if self.chat_completions_response.choices[0].message.reasoning_content:

View File

@@ -47,7 +47,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
# Convert response to chat completion format
self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config)
self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion(
self.response_data, messages, self.llm_config
)
# Extract reasoning content from the response
if self.chat_completions_response.choices[0].message.reasoning_content:

View File

@@ -87,7 +87,7 @@ class EphemeralSummaryAgent(BaseAgent):
request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[])
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
summary = response.choices[0].message.content.strip()
await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor)

View File

@@ -337,7 +337,7 @@ class LettaAgent(BaseAgent):
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
try:
response = llm_client.convert_response_to_chat_completion(
response = await llm_client.convert_response_to_chat_completion(
response_data, in_context_messages, agent_state.llm_config
)
except ValueError as e:
@@ -681,7 +681,7 @@ class LettaAgent(BaseAgent):
log_event("agent.step.llm_response.received") # [3^]
try:
response = llm_client.convert_response_to_chat_completion(
response = await llm_client.convert_response_to_chat_completion(
response_data, in_context_messages, agent_state.llm_config
)
except ValueError as e:

View File

@@ -309,7 +309,7 @@ class LettaAgentBatch(BaseAgent):
agent_state_map = {agent.id: agent for agent in agent_states}
# Process each agent's results
tool_call_results = self._process_agent_results(
tool_call_results = await self._process_agent_results(
agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id
)
@@ -324,7 +324,7 @@ class LettaAgentBatch(BaseAgent):
request_status_updates=tool_call_results.status_updates,
)
def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
async def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
"""
Process the results for each agent, extracting tool calls and determining continuation status.
@@ -347,7 +347,7 @@ class LettaAgentBatch(BaseAgent):
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
# Process tool calls
name, args, cont = self._extract_tool_call_from_result(item, result)
name, args, cont = await self._extract_tool_call_from_result(item, result)
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
return ToolCallResults(name_map, args_map, cont_map, request_status_updates)
@@ -363,7 +363,7 @@ class LettaAgentBatch(BaseAgent):
else:
return JobStatus.expired
def _extract_tool_call_from_result(self, item, result):
async def _extract_tool_call_from_result(self, item, result):
"""Extract tool call information from a result"""
llm_client = LLMClient.create(
provider_type=item.llm_config.model_endpoint_type,
@@ -375,13 +375,10 @@ class LettaAgentBatch(BaseAgent):
if not isinstance(result, BetaMessageBatchSucceededResult):
return None, None, False
tool_call = (
llm_client.convert_response_to_chat_completion(
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
.choices[0]
.message.tool_calls[0]
response = await llm_client.convert_response_to_chat_completion(
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
)
tool_call = response.choices[0].message.tool_calls[0]
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)

View File

@@ -1280,6 +1280,9 @@ class LettaAgentV3(LettaAgentV2):
force: bool = False,
) -> list[Message]:
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
self.logger.info(
f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}"
)
if not trigger_summarization:
# just update the message_ids
# TODO: gross to handle this here: we should move persistence elsewhere
@@ -1301,6 +1304,7 @@ class LettaAgentV3(LettaAgentV2):
if summarizer_config.mode == "all":
summary_message_str = await summarize_all(
actor=self.actor,
llm_config=self.agent_state.llm_config,
summarizer_config=summarizer_config,
in_context_messages=in_context_messages,
new_messages=new_letta_messages,
@@ -1353,6 +1357,6 @@ class LettaAgentV3(LettaAgentV2):
message_ids=new_in_context_message_ids,
actor=self.actor,
)
self.agent_state.message_ids = new_in_context_messages
self.agent_state.message_ids = new_in_context_message_ids
return new_in_context_messages

View File

@@ -54,12 +54,12 @@ from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
from letta.server.rest_api.json_parser import OptimisticJSONParser
from letta.server.rest_api.utils import decrement_message_uuid
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.streaming_utils import (
FunctionArgumentsStreamHandler,
JSONInnerThoughtsExtractor,
sanitize_streamed_message_content,
)
from letta.utils import count_tokens
logger = get_logger(__name__)
@@ -83,6 +83,10 @@ class OpenAIStreamingInterface:
step_id: str | None = None,
):
self.use_assistant_message = use_assistant_message
# Create token counter for fallback token counting (when API doesn't return usage)
# Use openai endpoint type for approximate counting in streaming context
self._fallback_token_counter = create_token_counter(model_endpoint_type="openai")
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
@@ -301,7 +305,8 @@ class OpenAIStreamingInterface:
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
if self.is_openai_proxy:
self.fallback_output_tokens += count_tokens(tool_call.function.arguments)
# Use approximate counting for fallback (sync method)
self.fallback_output_tokens += self._fallback_token_counter._approx_token_count(tool_call.function.arguments)
# If we have inner thoughts, we should output them as a chunk
if updates_inner_thoughts:

View File

@@ -764,7 +764,7 @@ class AnthropicClient(LLMClientBase):
# TODO: Input messages doesn't get used here
# TODO: Clean up this interface
@trace_method
def convert_response_to_chat_completion(
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],

View File

@@ -401,7 +401,7 @@ class DeepseekClient(OpenAIClient):
return response_stream
@trace_method
def convert_response_to_chat_completion(
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
@@ -413,5 +413,5 @@ class DeepseekClient(OpenAIClient):
"""
response = ChatCompletionResponse(**response_data)
if response.choices[0].message.tool_calls:
return super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
return await super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
return convert_deepseek_response_to_chatcompletion(response)

View File

@@ -31,7 +31,6 @@ from letta.helpers.datetime_helpers import get_utc_time_int
from letta.helpers.json_helpers import json_dumps, json_loads
from letta.llm_api.llm_client_base import LLMClientBase
from letta.local_llm.json_parser import clean_json_string_extra_backslash
from letta.local_llm.utils import count_tokens
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.agent import AgentType
@@ -404,7 +403,7 @@ class GoogleVertexClient(LLMClientBase):
return request_data
@trace_method
def convert_response_to_chat_completion(
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage],
@@ -661,10 +660,13 @@ class GoogleVertexClient(LLMClientBase):
completion_tokens_details=completion_tokens_details,
)
else:
# Count it ourselves
# Count it ourselves using the Gemini token counting API
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model)
prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model)
# For completion tokens, wrap the response content in Google format
completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}]
completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model)
total_tokens = prompt_tokens + completion_tokens
usage = UsageStatistics(
prompt_tokens=prompt_tokens,

View File

@@ -19,7 +19,7 @@ from letta.schemas.response_format import (
TextResponseFormat,
)
from letta.settings import summarizer_settings
from letta.utils import count_tokens, printd
from letta.utils import printd
logger = get_logger(__name__)
@@ -394,97 +394,3 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -
logger.warning(f"Did not find tool call in message: {str(message)}")
return rewritten_choice
def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts: List[int], logger: "logging.Logger") -> int:
if len(in_context_messages) != len(token_counts):
raise ValueError(
f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}"
)
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
if summarizer_settings.evict_all_messages:
logger.info("Evicting all messages...")
return len(in_context_messages)
else:
# Start at index 1 (past the system message),
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
# We do the inverse of `desired_memory_token_pressure` to get what we need to remove
desired_token_count_to_summarize = int(sum(token_counts) * (1 - summarizer_settings.desired_memory_token_pressure))
logger.info(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
tokens_so_far = 0
cutoff = 0
for i, msg in enumerate(in_context_messages_openai):
# Skip system
if i == 0:
continue
cutoff = i
tokens_so_far += token_counts[i]
if msg["role"] not in ["user", "tool", "function"] and tokens_so_far >= desired_token_count_to_summarize:
# Break if the role is NOT a user or tool/function and tokens_so_far is enough
break
elif len(in_context_messages) - cutoff - 1 <= summarizer_settings.keep_last_n_messages:
# Also break if we reached the `keep_last_n_messages` threshold
# NOTE: This may be on a user, tool, or function in theory
logger.warning(
f"Breaking summary cutoff early on role={msg['role']} because we hit the `keep_last_n_messages`={summarizer_settings.keep_last_n_messages}"
)
break
# includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming.
if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool":
cutoff += 1
logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...")
return cutoff + 1
def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]:
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
return token_counts
def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool:
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
from letta.utils import printd
match_string = OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
if match_string in str(exception):
printd(f"Found '{match_string}' in str(exception)={(str(exception))}")
return True
# Based on python requests + OpenAI REST API (/v1)
elif isinstance(exception, requests.exceptions.HTTPError):
if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""):
try:
error_details = exception.response.json()
if "error" not in error_details:
printd(f"HTTPError occurred, but couldn't find error field: {error_details}")
return False
else:
error_details = error_details["error"]
# Check for the specific error code
if error_details.get("code") == "context_length_exceeded":
printd(f"HTTPError occurred, caught error code {error_details.get('code')}")
return True
# Soft-check for "maximum context length" inside of the message
elif error_details.get("message") and "maximum context length" in error_details.get("message"):
printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})")
return True
else:
printd(f"HTTPError occurred, but unknown error message: {error_details}")
return False
except ValueError:
# JSON decoding failed
printd(f"HTTPError occurred ({exception}), but no JSON error message.")
# Generic fail
else:
return False

View File

@@ -38,7 +38,7 @@ class LLMClientBase:
self.use_tool_naming = use_tool_naming
@trace_method
def send_llm_request(
async def send_llm_request(
self,
agent_type: AgentType,
messages: List[Message],
@@ -80,7 +80,7 @@ class LLMClientBase:
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
@trace_method
async def send_llm_request_async(
@@ -114,7 +114,7 @@ class LLMClientBase:
except Exception as e:
raise self.handle_llm_error(e)
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
async def send_llm_batch_request_async(
self,
@@ -177,7 +177,7 @@ class LLMClientBase:
raise NotImplementedError
@abstractmethod
def convert_response_to_chat_completion(
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[Message],

View File

@@ -609,7 +609,7 @@ class OpenAIClient(LLMClientBase):
return is_openai_reasoning_model(llm_config.model)
@trace_method
def convert_response_to_chat_completion(
async def convert_response_to_chat_completion(
self,
response_data: dict,
input_messages: List[PydanticMessage], # Included for consistency, maybe used later

View File

@@ -58,11 +58,11 @@ def load_grammar_file(grammar):
return grammar_str
# TODO: support tokenizers/tokenizer apis available in local models
def count_tokens(s: str, model: str = "gpt-4") -> int:
from letta.utils import count_tokens
return count_tokens(s, model)
## TODO: support tokenizers/tokenizer apis available in local models
# def count_tokens(s: str, model: str = "gpt-4") -> int:
# from letta.utils import count_tokens
#
# return count_tokens(s, model)
def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):

View File

@@ -50,59 +50,60 @@ def _format_summary_history(message_history: List[Message]):
return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history])
@trace_method
def summarize_messages(
agent_state: AgentState,
message_sequence_to_summarize: List[Message],
actor: "User",
):
"""Summarize a message sequence using GPT"""
# we need the context_window
context_window = agent_state.llm_config.context_window
summary_prompt = SUMMARY_PROMPT_SYSTEM
summary_input = _format_summary_history(message_sequence_to_summarize)
summary_input_tkns = count_tokens(summary_input)
if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window:
trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
summary_input = str(
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
+ message_sequence_to_summarize[cutoff:]
)
dummy_agent_id = agent_state.id
message_sequence = [
Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
]
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
llm_client = LLMClient.create(
provider_type=agent_state.llm_config.model_endpoint_type,
put_inner_thoughts_first=False,
actor=actor,
)
# try to use new client, otherwise fallback to old flow
# TODO: we can just directly call the LLM here?
if llm_client:
response = llm_client.send_llm_request(
agent_type=agent_state.agent_type,
messages=message_sequence,
llm_config=llm_config_no_inner_thoughts,
)
else:
response = create(
llm_config=llm_config_no_inner_thoughts,
user_id=agent_state.created_by_id,
messages=message_sequence,
stream=False,
)
printd(f"summarize_messages gpt reply: {response.choices[0]}")
reply = response.choices[0].message.content
return reply
# @trace_method
# def summarize_messages(
# agent_state: AgentState,
# message_sequence_to_summarize: List[Message],
# actor: "User",
# ):
# """Summarize a message sequence using GPT"""
# # we need the context_window
# context_window = agent_state.llm_config.context_window
#
# summary_prompt = SUMMARY_PROMPT_SYSTEM
# summary_input = _format_summary_history(message_sequence_to_summarize)
# summary_input_tkns = count_tokens(summary_input)
# if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window:
# trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
# cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
# summary_input = str(
# [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
# + message_sequence_to_summarize[cutoff:]
# )
#
# dummy_agent_id = agent_state.id
# message_sequence = [
# Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
# Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
# Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
# ]
#
# # TODO: We need to eventually have a separate LLM config for the summarizer LLM
# llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
# llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
#
# llm_client = LLMClient.create(
# provider_type=agent_state.llm_config.model_endpoint_type,
# put_inner_thoughts_first=False,
# actor=actor,
# )
# # try to use new client, otherwise fallback to old flow
# # TODO: we can just directly call the LLM here?
# if llm_client:
# response = llm_client.send_llm_request(
# agent_type=agent_state.agent_type,
# messages=message_sequence,
# llm_config=llm_config_no_inner_thoughts,
# )
# else:
# response = create(
# llm_config=llm_config_no_inner_thoughts,
# user_id=agent_state.created_by_id,
# messages=message_sequence,
# stream=False,
# )
#
# printd(f"summarize_messages gpt reply: {response.choices[0]}")
# reply = response.choices[0].message.content
# return reply
#

View File

@@ -914,7 +914,7 @@ async def generate_tool_from_prompt(
tools=[tool],
)
response_data = await llm_client.request_async(request_data, llm_config)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]

View File

@@ -670,20 +670,26 @@ class SyncServer(object):
async def insert_archival_memory_async(
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
) -> List[Passage]:
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.settings import settings
from letta.utils import count_tokens
# Get the agent object (loaded in memory)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
# Check token count against limit
token_count = count_tokens(memory_contents)
token_counter = create_token_counter(
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
model=agent_state.llm_config.model,
actor=actor,
agent_id=agent_id,
)
token_count = await token_counter.count_text_tokens(memory_contents)
if token_count > settings.archival_memory_token_limit:
raise LettaInvalidArgumentError(
message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)",
argument_name="memory_contents",
)
# Get the agent object (loaded in memory)
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
# Use passage manager which handles dual-write to Turbopuffer if enabled
passages = await self.passage_manager.insert_passage(
agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at

View File

@@ -80,7 +80,7 @@ from letta.server.db import db_registry
from letta.services.archive_manager import ArchiveManager
from letta.services.block_manager import BlockManager, validate_block_limit_constraint
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, GeminiTokenCounter, TiktokenCounter
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.services.file_processor.chunker.line_chunker import LineChunker
from letta.services.files_agents_manager import FileAgentManager
from letta.services.helpers.agent_manager_helper import (
@@ -3286,49 +3286,14 @@ class AgentManager:
)
calculator = ContextWindowCalculator()
# Determine which token counter to use based on provider
model_endpoint_type = agent_state.llm_config.model_endpoint_type
# Use Gemini token counter for Google Vertex and Google AI
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
# Use Anthropic token counter if:
# 1. The model endpoint type is anthropic, OR
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
use_anthropic = model_endpoint_type == "anthropic" or (
not use_gemini and settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None
# Create the appropriate token counter based on model configuration
token_counter = create_token_counter(
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
model=agent_state.llm_config.model,
actor=actor,
agent_id=agent_id,
)
if use_gemini:
# Use native Gemini token counting API
client = LLMClient.create(provider_type=agent_state.llm_config.model_endpoint_type, actor=actor)
model = agent_state.llm_config.model
token_counter = GeminiTokenCounter(client, model)
logger.info(
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
elif use_anthropic:
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
model = agent_state.llm_config.model if model_endpoint_type == "anthropic" else None
token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa
logger.info(
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
else:
token_counter = TiktokenCounter(agent_state.llm_config.model)
logger.info(
f"Using TiktokenCounter for agent_id={agent_id}, model={agent_state.llm_config.model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
try:
result = await calculator.calculate_context_window(
agent_state=agent_state,

View File

@@ -1,15 +1,22 @@
import hashlib
import json
from abc import ABC, abstractmethod
from typing import Any, Dict, List
from typing import TYPE_CHECKING, Any, Dict, List, Optional
from letta.helpers.decorators import async_redis_cache
from letta.llm_api.anthropic_client import AnthropicClient
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.enums import ProviderType
from letta.schemas.message import Message
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
from letta.utils import count_tokens
if TYPE_CHECKING:
from letta.schemas.llm_config import LLMConfig
from letta.schemas.user import User
logger = get_logger(__name__)
class TokenCounter(ABC):
@@ -101,6 +108,22 @@ class ApproxTokenCounter(TokenCounter):
byte_len = len(text.encode("utf-8"))
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
async def count_text_tokens(self, text: str) -> int:
if not text:
return 0
return self._approx_token_count(text)
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
if not messages:
return 0
return self._approx_token_count(json.dumps(messages))
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
if not tools:
return 0
functions = [t.model_dump() for t in tools]
return self._approx_token_count(json.dumps(functions))
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_openai_dicts_from_list(messages)
@@ -178,7 +201,14 @@ class TiktokenCounter(TokenCounter):
logger.debug(f"TiktokenCounter.count_text_tokens: model={self.model}, text_length={text_length}, preview={repr(text_preview)}")
try:
result = count_tokens(text)
import tiktoken
try:
encoding = tiktoken.encoding_for_model(self.model)
except KeyError:
logger.debug(f"Model {self.model} not found in tiktoken. Using cl100k_base encoding.")
encoding = tiktoken.get_encoding("cl100k_base")
result = len(encoding.encode(text))
logger.debug(f"TiktokenCounter.count_text_tokens: completed successfully, tokens={result}")
return result
except Exception as e:
@@ -234,3 +264,54 @@ class TiktokenCounter(TokenCounter):
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
return Message.to_openai_dicts_from_list(messages)
def create_token_counter(
model_endpoint_type: ProviderType,
model: Optional[str] = None,
actor: "User" = None,
agent_id: Optional[str] = None,
) -> "TokenCounter":
"""
Factory function to create the appropriate token counter based on model configuration.
Returns:
The appropriate TokenCounter instance
"""
from letta.llm_api.llm_client import LLMClient
from letta.settings import model_settings, settings
# Use Gemini token counter for Google Vertex and Google AI
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
# Use Anthropic token counter if:
# 1. The model endpoint type is anthropic, OR
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
use_anthropic = model_endpoint_type == "anthropic"
if use_gemini:
client = LLMClient.create(provider_type=model_endpoint_type, actor=actor)
token_counter = GeminiTokenCounter(client, model)
logger.info(
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
elif use_anthropic:
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
counter_model = model if model_endpoint_type == "anthropic" else None
token_counter = AnthropicTokenCounter(anthropic_client, counter_model)
logger.info(
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={counter_model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
else:
token_counter = ApproxTokenCounter()
logger.info(
f"Using ApproxTokenCounter for agent_id={agent_id}, model={model}, "
f"model_endpoint_type={model_endpoint_type}, "
f"environment={settings.environment}"
)
return token_counter

View File

@@ -447,9 +447,12 @@ async def simple_summary(
]
input_messages_obj = [simple_message_wrapper(msg) for msg in input_messages]
# Build a local LLMConfig for v1-style summarization which uses native content and must not
# include inner thoughts in kwargs to avoid conflicts in Anthropic formatting
# include inner thoughts in kwargs to avoid conflicts in Anthropic formatting.
# We also disable enable_reasoner to avoid extended thinking requirements (Anthropic requires
# assistant messages to start with thinking blocks when extended thinking is enabled).
summarizer_llm_config = LLMConfig(**llm_config.model_dump())
summarizer_llm_config.put_inner_thoughts_in_kwargs = False
summarizer_llm_config.enable_reasoner = False
request_data = llm_client.build_request_data(AgentType.letta_v1_agent, input_messages_obj, summarizer_llm_config, tools=[])
try:
@@ -532,7 +535,7 @@ async def simple_summary(
logger.info(f"Full fallback summarization payload: {request_data}")
raise llm_client.handle_llm_error(fallback_error_b)
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
if response.choices[0].message.content is None:
logger.warning("No content returned from summarizer")
# TODO raise an error error instead?

View File

@@ -1,16 +1,11 @@
from typing import List, Tuple
from typing import List
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.message import Message, MessageCreate
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.user import User
from letta.services.message_manager import MessageManager
from letta.services.summarizer.summarizer import simple_summary
from letta.services.summarizer.summarizer_config import SummarizerConfig
from letta.system import package_summarize_message_no_counts
logger = get_logger(__name__)
@@ -18,6 +13,8 @@ logger = get_logger(__name__)
async def summarize_all(
# Required to tag LLM calls
actor: User,
# LLM config for the summarizer model
llm_config: LLMConfig,
# Actual summarization configuration
summarizer_config: SummarizerConfig,
in_context_messages: List[Message],
@@ -33,9 +30,9 @@ async def summarize_all(
summary_message_str = await simple_summary(
messages=all_in_context_messages,
llm_config=summarizer_config.summarizer_model,
llm_config=llm_config,
actor=actor,
include_ack=summarizer_config.prompt_acknowledgement,
include_ack=bool(summarizer_config.prompt_acknowledgement),
prompt=summarizer_config.prompt,
)

View File

@@ -1,19 +1,17 @@
from typing import List, Tuple
from letta.helpers.message_helper import convert_message_creates_to_messages
from letta.llm_api.llm_client import LLMClient
from letta.log import get_logger
from letta.schemas.agent import AgentState
from letta.schemas.enums import MessageRole, ProviderType
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message, MessageCreate
from letta.schemas.user import User
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter
from letta.services.context_window_calculator.token_counter import create_token_counter
from letta.services.message_manager import MessageManager
from letta.services.summarizer.summarizer import simple_summary
from letta.services.summarizer.summarizer_config import SummarizerConfig
from letta.settings import model_settings, settings
from letta.system import package_summarize_message_no_counts
logger = get_logger(__name__)
@@ -26,21 +24,21 @@ APPROX_TOKEN_SAFETY_MARGIN = 1.3
async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int:
# If the model is an Anthropic model, use the Anthropic token counter (accurate)
if llm_config.model_endpoint_type == "anthropic":
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model)
converted_messages = token_counter.convert_messages(messages)
return await token_counter.count_message_tokens(converted_messages)
"""Count tokens in messages using the appropriate token counter for the model configuration."""
token_counter = create_token_counter(
model_endpoint_type=llm_config.model_endpoint_type,
model=llm_config.model,
actor=actor,
)
converted_messages = token_counter.convert_messages(messages)
tokens = await token_counter.count_message_tokens(converted_messages)
else:
# Otherwise, use approximate count (bytes / 4) with safety margin
# This is much faster than tiktoken and doesn't require loading tokenizer models
token_counter = ApproxTokenCounter(llm_config.model)
converted_messages = token_counter.convert_messages(messages)
tokens = await token_counter.count_message_tokens(converted_messages)
# Apply safety margin to avoid underestimating and keeping too many messages
# Apply safety margin for approximate counting to avoid underestimating
from letta.services.context_window_calculator.token_counter import ApproxTokenCounter
if isinstance(token_counter, ApproxTokenCounter):
return int(tokens * APPROX_TOKEN_SAFETY_MARGIN)
return tokens
async def summarize_via_sliding_window(
@@ -110,9 +108,9 @@ async def summarize_via_sliding_window(
summary_message_str = await simple_summary(
messages=messages_to_summarize,
llm_config=summarizer_config.summarizer_model,
llm_config=llm_config,
actor=actor,
include_ack=summarizer_config.prompt_acknowledgement,
include_ack=bool(summarizer_config.prompt_acknowledgement),
prompt=summarizer_config.prompt,
)

View File

@@ -812,13 +812,13 @@ class OpenAIBackcompatUnpickler(pickle.Unpickler):
return super().find_class(module, name)
def count_tokens(s: str, model: str = "gpt-4") -> int:
try:
encoding = tiktoken.encoding_for_model(model)
except KeyError:
print("Falling back to cl100k base for token counting.")
encoding = tiktoken.get_encoding("cl100k_base")
return len(encoding.encode(s))
# def count_tokens(s: str, model: str = "gpt-4") -> int:
# try:
# encoding = tiktoken.encoding_for_model(model)
# except KeyError:
# print("Falling back to cl100k base for token counting.")
# encoding = tiktoken.get_encoding("cl100k_base")
# return len(encoding.encode(s))
def printd(*args, **kwargs):

View File

@@ -14,6 +14,7 @@ from typing import List
import pytest
from letta.agents.letta_agent_v2 import LettaAgentV2
from letta.agents.letta_agent_v3 import LettaAgentV3
from letta.config import LettaConfig
from letta.schemas.agent import CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
@@ -671,7 +672,12 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
@pytest.mark.asyncio
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
@pytest.mark.parametrize(
"llm_config",
TESTED_LLM_CONFIGS,
ids=[c.model for c in TESTED_LLM_CONFIGS],
)
async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig):
"""
Test that the sliding window summarizer correctly calculates cutoff indices.
@@ -685,35 +691,19 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
- max(..., 10) -> max(..., 0.10)
- += 10 -> += 0.10
- >= 100 -> >= 1.0
"""
from unittest.mock import MagicMock, patch
from letta.schemas.enums import MessageRole
from letta.schemas.letta_message_content import TextContent
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.schemas.user import User
This test uses the real token counter (via create_token_counter) to verify
the sliding window logic works with actual token counting.
"""
from letta.schemas.model import ModelSettings
from letta.services.summarizer.summarizer_config import get_default_summarizer_config
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
# Create a mock user (using proper ID format pattern)
mock_actor = User(
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
)
# Create a mock LLM config
mock_llm_config = LLMConfig(
model="gpt-4",
model_endpoint_type="openai",
context_window=128000,
)
# Create a mock summarizer config with sliding_window_percentage = 0.3
mock_summarizer_config = MagicMock()
mock_summarizer_config.sliding_window_percentage = 0.3
mock_summarizer_config.summarizer_model = mock_llm_config
mock_summarizer_config.prompt = "Summarize the conversation."
mock_summarizer_config.prompt_acknowledgement = True
mock_summarizer_config.clip_chars = 2000
# Create a real summarizer config using the default factory
# Override sliding_window_percentage to 0.3 for this test
model_settings = ModelSettings() # Use defaults
summarizer_config = get_default_summarizer_config(model_settings)
summarizer_config.sliding_window_percentage = 0.3
# Create 65 messages (similar to the failing case in the bug report)
# Pattern: system + alternating user/assistant messages
@@ -741,59 +731,470 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
# Mock count_tokens to return a value that would trigger summarization
# Return a high token count so that the while loop continues
async def mock_count_tokens(actor, llm_config, messages):
# Return tokens that decrease as we cut off more messages
# This simulates the token count decreasing as we evict messages
return len(messages) * 100 # 100 tokens per message
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await summarize_via_sliding_window(
actor=actor,
llm_config=llm_config,
summarizer_config=summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Mock simple_summary to return a fake summary
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
return "This is a mock summary of the conversation."
# Verify the summary was generated (actual LLM response)
assert summary is not None
assert len(summary) > 0
with (
patch(
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
side_effect=mock_count_tokens,
),
patch(
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
side_effect=mock_simple_summary,
),
):
# This should NOT raise "No assistant message found from indices 650 to 65"
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
try:
summary, remaining_messages = await summarize_via_sliding_window(
actor=mock_actor,
llm_config=mock_llm_config,
summarizer_config=mock_summarizer_config,
in_context_messages=messages,
new_messages=[],
)
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
# Verify the summary was generated
assert summary == "This is a mock summary of the conversation."
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}")
# Verify remaining messages is a valid subset
assert len(remaining_messages) < len(messages)
assert len(remaining_messages) > 0
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise
except ValueError as e:
if "No assistant message found from indices" in str(e):
# Extract the indices from the error message
import re
match = re.search(r"from indices (\d+) to (\d+)", str(e))
if match:
start_idx, end_idx = int(match.group(1)), int(match.group(2))
pytest.fail(
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
f"Error: {e}"
)
raise
# @pytest.mark.asyncio
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
# """
# Test that a ContextWindowExceededError during a streaming LLM request
# properly triggers the summarizer and compacts the in-context messages.
#
# This test simulates:
# 1. An LLM streaming request that fails with ContextWindowExceededError
# 2. The summarizer being invoked to reduce context size
# 3. Verification that messages are compacted and summary message exists
#
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
# """
# import uuid
# from unittest.mock import patch
#
# import openai
#
# from letta.schemas.message import MessageCreate
# from letta.schemas.run import Run
# from letta.services.run_manager import RunManager
#
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
#
# # Create test messages - enough to have something to summarize
# messages = []
# for i in range(15):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
# )
# )
#
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
# original_message_count = len(agent_state.message_ids)
#
# # Create an input message to trigger the agent
# input_message = MessageCreate(
# role=MessageRole.user,
# content=[TextContent(type="text", text="Hello, please respond.")],
# )
#
# # Create a proper run record in the database
# run_manager = RunManager()
# test_run_id = f"run-{uuid.uuid4()}"
# test_run = Run(
# id=test_run_id,
# agent_id=agent_state.id,
# )
# await run_manager.create_run(test_run, actor)
#
# # Create the agent loop using LettaAgentV3
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
#
# # Track how many times stream_async is called
# call_count = 0
#
# # Store original stream_async method
# original_stream_async = agent_loop.llm_client.stream_async
#
# async def mock_stream_async_with_error(request_data, llm_config):
# nonlocal call_count
# call_count += 1
# if call_count == 1:
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
# from unittest.mock import MagicMock
#
# import httpx
#
# # Create a mock response with the required structure
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
# mock_response = httpx.Response(
# status_code=400,
# request=mock_request,
# json={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
#
# raise openai.BadRequestError(
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# response=mock_response,
# body={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
# # Subsequent calls use the real implementation
# return await original_stream_async(request_data, llm_config)
#
# # Patch the llm_client's stream_async to raise ContextWindowExceededError on first call
# with patch.object(agent_loop.llm_client, "stream_async", side_effect=mock_stream_async_with_error):
# # Execute a streaming step
# try:
# result_chunks = []
# async for chunk in agent_loop.stream(
# input_messages=[input_message],
# max_steps=1,
# stream_tokens=True,
# run_id=test_run_id,
# ):
# result_chunks.append(chunk)
# except Exception as e:
# # Some errors might happen due to real LLM calls after retry
# print(f"Exception during stream: {e}")
#
# # Reload agent state to get updated message_ids after summarization
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
# updated_message_count = len(updated_agent_state.message_ids)
#
# # Fetch the updated in-context messages
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
# message_ids=updated_agent_state.message_ids, actor=actor
# )
#
# # Convert to LettaMessage format for easier content inspection
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
#
# # Verify a summary message exists with the correct format
# # The summary message has content with type="system_alert" and message containing:
# # "prior messages ... have been hidden" and "summary of the previous"
# import json
#
# summary_message_found = False
# summary_message_text = None
# for msg in letta_messages:
# # Not all message types have a content attribute (e.g., ReasoningMessage)
# if not hasattr(msg, "content"):
# continue
#
# content = msg.content
# # Content can be a string (JSON) or an object with type/message fields
# if isinstance(content, str):
# # Try to parse as JSON
# try:
# parsed = json.loads(content)
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
# text_to_check = parsed.get("message", "").lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = parsed.get("message")
# break
# except (json.JSONDecodeError, TypeError):
# pass
# # Check if content has system_alert type with the summary message (object form)
# elif hasattr(content, "type") and content.type == "system_alert":
# if hasattr(content, "message") and content.message:
# text_to_check = content.message.lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = content.message
# break
#
# assert summary_message_found, (
# "A summary message should exist in the in-context messages after summarization. "
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
# )
#
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
#
# # The original messages should have been compacted - the updated count should be less than
# # original + the new messages added (input + assistant response + tool results)
# # Since summarization should have removed most of the original 30 messages
# print("Test passed: Summary message found in context")
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
# print(f"Total LLM invocations: {call_count}")
#
#
# @pytest.mark.asyncio
# async def test_context_window_overflow_triggers_summarization_in_blocking(server: SyncServer, actor):
# """
# Test that a ContextWindowExceededError during a blocking (non-streaming) LLM request
# properly triggers the summarizer and compacts the in-context messages.
#
# This test is similar to the streaming test but uses the blocking step() method.
#
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
# """
# import uuid
# from unittest.mock import patch
#
# import openai
#
# from letta.schemas.message import MessageCreate
# from letta.schemas.run import Run
# from letta.services.run_manager import RunManager
#
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
#
# # Create test messages
# messages = []
# for i in range(15):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
# )
# )
#
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
# original_message_count = len(agent_state.message_ids)
#
# # Create an input message to trigger the agent
# input_message = MessageCreate(
# role=MessageRole.user,
# content=[TextContent(type="text", text="Hello, please respond.")],
# )
#
# # Create a proper run record in the database
# run_manager = RunManager()
# test_run_id = f"run-{uuid.uuid4()}"
# test_run = Run(
# id=test_run_id,
# agent_id=agent_state.id,
# )
# await run_manager.create_run(test_run, actor)
#
# # Create the agent loop using LettaAgentV3
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
#
# # Track how many times request_async is called
# call_count = 0
#
# # Store original request_async method
# original_request_async = agent_loop.llm_client.request_async
#
# async def mock_request_async_with_error(request_data, llm_config):
# nonlocal call_count
# call_count += 1
# if call_count == 1:
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
# import httpx
#
# # Create a mock response with the required structure
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
# mock_response = httpx.Response(
# status_code=400,
# request=mock_request,
# json={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
#
# raise openai.BadRequestError(
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# response=mock_response,
# body={
# "error": {
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
# "type": "invalid_request_error",
# "code": "context_length_exceeded",
# }
# },
# )
# # Subsequent calls use the real implementation
# return await original_request_async(request_data, llm_config)
#
# # Patch the llm_client's request_async to raise ContextWindowExceededError on first call
# with patch.object(agent_loop.llm_client, "request_async", side_effect=mock_request_async_with_error):
# # Execute a blocking step
# try:
# result = await agent_loop.step(
# input_messages=[input_message],
# max_steps=1,
# run_id=test_run_id,
# )
# except Exception as e:
# # Some errors might happen due to real LLM calls after retry
# print(f"Exception during step: {e}")
#
# # Reload agent state to get updated message_ids after summarization
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
# updated_message_count = len(updated_agent_state.message_ids)
#
# # Fetch the updated in-context messages
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
# message_ids=updated_agent_state.message_ids, actor=actor
# )
#
# # Convert to LettaMessage format for easier content inspection
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
#
# # Verify a summary message exists with the correct format
# # The summary message has content with type="system_alert" and message containing:
# # "prior messages ... have been hidden" and "summary of the previous"
# import json
#
# summary_message_found = False
# summary_message_text = None
# for msg in letta_messages:
# # Not all message types have a content attribute (e.g., ReasoningMessage)
# if not hasattr(msg, "content"):
# continue
#
# content = msg.content
# # Content can be a string (JSON) or an object with type/message fields
# if isinstance(content, str):
# # Try to parse as JSON
# try:
# parsed = json.loads(content)
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
# text_to_check = parsed.get("message", "").lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = parsed.get("message")
# break
# except (json.JSONDecodeError, TypeError):
# pass
# # Check if content has system_alert type with the summary message (object form)
# elif hasattr(content, "type") and content.type == "system_alert":
# if hasattr(content, "message") and content.message:
# text_to_check = content.message.lower()
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
# summary_message_found = True
# summary_message_text = content.message
# break
#
# assert summary_message_found, (
# "A summary message should exist in the in-context messages after summarization. "
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
# )
#
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
#
# # The original messages should have been compacted - the updated count should be less than
# # original + the new messages added (input + assistant response + tool results)
# print("Test passed: Summary message found in context (blocking mode)")
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
# print(f"Total LLM invocations: {call_count}")
#
#
# @pytest.mark.asyncio
# @pytest.mark.parametrize(
# "llm_config",
# TESTED_LLM_CONFIGS,
# ids=[c.model for c in TESTED_LLM_CONFIGS],
# )
# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig):
# """
# Test the summarize_all function with real LLM calls.
#
# This test verifies that the 'all' summarization mode works correctly,
# summarizing the entire conversation into a single summary string.
# """
# from letta.schemas.model import ModelSettings
# from letta.services.summarizer.summarizer_all import summarize_all
# from letta.services.summarizer.summarizer_config import get_default_summarizer_config
#
# # Create a summarizer config with "all" mode
# model_settings = ModelSettings()
# summarizer_config = get_default_summarizer_config(model_settings)
# summarizer_config.mode = "all"
#
# # Create test messages - a simple conversation
# messages = [
# PydanticMessage(
# role=MessageRole.system,
# content=[TextContent(type="text", text="You are a helpful assistant.")],
# )
# ]
#
# # Add 10 user-assistant pairs
# for i in range(10):
# messages.append(
# PydanticMessage(
# role=MessageRole.user,
# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
# )
# )
# messages.append(
# PydanticMessage(
# role=MessageRole.assistant,
# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
# )
# )
#
# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
#
# # Call summarize_all with real LLM
# summary = await summarize_all(
# actor=actor,
# llm_config=llm_config,
# summarizer_config=summarizer_config,
# in_context_messages=messages,
# new_messages=[],
# )
#
# # Verify the summary was generated
# assert summary is not None
# assert len(summary) > 0
#
# print(f"Successfully summarized {len(messages)} messages using 'all' mode")
# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
#

View File

@@ -169,8 +169,8 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig):
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
@@ -179,7 +179,7 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_text_tokens("")
assert token_count == 0
@@ -194,8 +194,8 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig):
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
@@ -204,7 +204,7 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_message_tokens([])
assert token_count == 0
@@ -219,8 +219,8 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig):
from letta.llm_api.google_vertex_client import GoogleVertexClient
from letta.services.context_window_calculator.token_counter import (
AnthropicTokenCounter,
ApproxTokenCounter,
GeminiTokenCounter,
TiktokenCounter,
)
if llm_config.model_endpoint_type == "anthropic":
@@ -229,7 +229,7 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig):
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
token_counter = GeminiTokenCounter(client, llm_config.model)
else:
token_counter = TiktokenCounter(llm_config.model)
token_counter = ApproxTokenCounter()
token_count = await token_counter.count_tool_tokens([])
assert token_count == 0