diff --git a/letta/adapters/letta_llm_request_adapter.py b/letta/adapters/letta_llm_request_adapter.py index e2166cec..eb7b606c 100644 --- a/letta/adapters/letta_llm_request_adapter.py +++ b/letta/adapters/letta_llm_request_adapter.py @@ -47,7 +47,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter): self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() # Convert response to chat completion format - self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config) + self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion( + self.response_data, messages, self.llm_config + ) # Extract reasoning content from the response if self.chat_completions_response.choices[0].message.reasoning_content: diff --git a/letta/adapters/simple_llm_request_adapter.py b/letta/adapters/simple_llm_request_adapter.py index 58ca1dff..e053c2a9 100644 --- a/letta/adapters/simple_llm_request_adapter.py +++ b/letta/adapters/simple_llm_request_adapter.py @@ -47,7 +47,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter): self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns() # Convert response to chat completion format - self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config) + self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion( + self.response_data, messages, self.llm_config + ) # Extract reasoning content from the response if self.chat_completions_response.choices[0].message.reasoning_content: diff --git a/letta/agents/ephemeral_summary_agent.py b/letta/agents/ephemeral_summary_agent.py index 1a5d7a77..fc97611d 100644 --- a/letta/agents/ephemeral_summary_agent.py +++ b/letta/agents/ephemeral_summary_agent.py @@ -87,7 +87,7 @@ class EphemeralSummaryAgent(BaseAgent): request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[]) response_data = await llm_client.request_async(request_data, agent_state.llm_config) - response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config) + response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config) summary = response.choices[0].message.content.strip() await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 17cab7b0..e8718651 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -337,7 +337,7 @@ class LettaAgent(BaseAgent): log_event("agent.stream_no_tokens.llm_response.received") # [3^] try: - response = llm_client.convert_response_to_chat_completion( + response = await llm_client.convert_response_to_chat_completion( response_data, in_context_messages, agent_state.llm_config ) except ValueError as e: @@ -681,7 +681,7 @@ class LettaAgent(BaseAgent): log_event("agent.step.llm_response.received") # [3^] try: - response = llm_client.convert_response_to_chat_completion( + response = await llm_client.convert_response_to_chat_completion( response_data, in_context_messages, agent_state.llm_config ) except ValueError as e: diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index d8526446..ca8fe4ed 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -309,7 +309,7 @@ class LettaAgentBatch(BaseAgent): agent_state_map = {agent.id: agent for agent in agent_states} # Process each agent's results - tool_call_results = self._process_agent_results( + tool_call_results = await self._process_agent_results( agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id ) @@ -324,7 +324,7 @@ class LettaAgentBatch(BaseAgent): request_status_updates=tool_call_results.status_updates, ) - def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id): + async def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id): """ Process the results for each agent, extracting tool calls and determining continuation status. @@ -347,7 +347,7 @@ class LettaAgentBatch(BaseAgent): request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status)) # Process tool calls - name, args, cont = self._extract_tool_call_from_result(item, result) + name, args, cont = await self._extract_tool_call_from_result(item, result) name_map[aid], args_map[aid], cont_map[aid] = name, args, cont return ToolCallResults(name_map, args_map, cont_map, request_status_updates) @@ -363,7 +363,7 @@ class LettaAgentBatch(BaseAgent): else: return JobStatus.expired - def _extract_tool_call_from_result(self, item, result): + async def _extract_tool_call_from_result(self, item, result): """Extract tool call information from a result""" llm_client = LLMClient.create( provider_type=item.llm_config.model_endpoint_type, @@ -375,13 +375,10 @@ class LettaAgentBatch(BaseAgent): if not isinstance(result, BetaMessageBatchSucceededResult): return None, None, False - tool_call = ( - llm_client.convert_response_to_chat_completion( - response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config - ) - .choices[0] - .message.tool_calls[0] + response = await llm_client.convert_response_to_chat_completion( + response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config ) + tool_call = response.choices[0].message.tool_calls[0] return self._extract_tool_call_and_decide_continue(tool_call, item.step_state) diff --git a/letta/agents/letta_agent_v3.py b/letta/agents/letta_agent_v3.py index ce691540..717e9f1a 100644 --- a/letta/agents/letta_agent_v3.py +++ b/letta/agents/letta_agent_v3.py @@ -1280,6 +1280,9 @@ class LettaAgentV3(LettaAgentV2): force: bool = False, ) -> list[Message]: trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window) + self.logger.info( + f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}" + ) if not trigger_summarization: # just update the message_ids # TODO: gross to handle this here: we should move persistence elsewhere @@ -1301,6 +1304,7 @@ class LettaAgentV3(LettaAgentV2): if summarizer_config.mode == "all": summary_message_str = await summarize_all( actor=self.actor, + llm_config=self.agent_state.llm_config, summarizer_config=summarizer_config, in_context_messages=in_context_messages, new_messages=new_letta_messages, @@ -1353,6 +1357,6 @@ class LettaAgentV3(LettaAgentV2): message_ids=new_in_context_message_ids, actor=self.actor, ) - self.agent_state.message_ids = new_in_context_messages + self.agent_state.message_ids = new_in_context_message_ids return new_in_context_messages diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py index 06677351..5b5580c9 100644 --- a/letta/interfaces/openai_streaming_interface.py +++ b/letta/interfaces/openai_streaming_interface.py @@ -54,12 +54,12 @@ from letta.schemas.message import Message from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall from letta.server.rest_api.json_parser import OptimisticJSONParser from letta.server.rest_api.utils import decrement_message_uuid +from letta.services.context_window_calculator.token_counter import create_token_counter from letta.streaming_utils import ( FunctionArgumentsStreamHandler, JSONInnerThoughtsExtractor, sanitize_streamed_message_content, ) -from letta.utils import count_tokens logger = get_logger(__name__) @@ -83,6 +83,10 @@ class OpenAIStreamingInterface: step_id: str | None = None, ): self.use_assistant_message = use_assistant_message + + # Create token counter for fallback token counting (when API doesn't return usage) + # Use openai endpoint type for approximate counting in streaming context + self._fallback_token_counter = create_token_counter(model_endpoint_type="openai") self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg @@ -301,7 +305,8 @@ class OpenAIStreamingInterface: updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) if self.is_openai_proxy: - self.fallback_output_tokens += count_tokens(tool_call.function.arguments) + # Use approximate counting for fallback (sync method) + self.fallback_output_tokens += self._fallback_token_counter._approx_token_count(tool_call.function.arguments) # If we have inner thoughts, we should output them as a chunk if updates_inner_thoughts: diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 0012d17d..91435638 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -764,7 +764,7 @@ class AnthropicClient(LLMClientBase): # TODO: Input messages doesn't get used here # TODO: Clean up this interface @trace_method - def convert_response_to_chat_completion( + async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], diff --git a/letta/llm_api/deepseek_client.py b/letta/llm_api/deepseek_client.py index e21d58b4..e5b2844e 100644 --- a/letta/llm_api/deepseek_client.py +++ b/letta/llm_api/deepseek_client.py @@ -401,7 +401,7 @@ class DeepseekClient(OpenAIClient): return response_stream @trace_method - def convert_response_to_chat_completion( + async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], # Included for consistency, maybe used later @@ -413,5 +413,5 @@ class DeepseekClient(OpenAIClient): """ response = ChatCompletionResponse(**response_data) if response.choices[0].message.tool_calls: - return super().convert_response_to_chat_completion(response_data, input_messages, llm_config) + return await super().convert_response_to_chat_completion(response_data, input_messages, llm_config) return convert_deepseek_response_to_chatcompletion(response) diff --git a/letta/llm_api/google_vertex_client.py b/letta/llm_api/google_vertex_client.py index 3ef2fa33..42d042b5 100644 --- a/letta/llm_api/google_vertex_client.py +++ b/letta/llm_api/google_vertex_client.py @@ -31,7 +31,6 @@ from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps, json_loads from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash -from letta.local_llm.utils import count_tokens from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.agent import AgentType @@ -404,7 +403,7 @@ class GoogleVertexClient(LLMClientBase): return request_data @trace_method - def convert_response_to_chat_completion( + async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], @@ -661,10 +660,13 @@ class GoogleVertexClient(LLMClientBase): completion_tokens_details=completion_tokens_details, ) else: - # Count it ourselves + # Count it ourselves using the Gemini token counting API assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required" - prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation - completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate + google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model) + prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model) + # For completion tokens, wrap the response content in Google format + completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}] + completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model) total_tokens = prompt_tokens + completion_tokens usage = UsageStatistics( prompt_tokens=prompt_tokens, diff --git a/letta/llm_api/helpers.py b/letta/llm_api/helpers.py index dbea6829..624c13d7 100644 --- a/letta/llm_api/helpers.py +++ b/letta/llm_api/helpers.py @@ -19,7 +19,7 @@ from letta.schemas.response_format import ( TextResponseFormat, ) from letta.settings import summarizer_settings -from letta.utils import count_tokens, printd +from letta.utils import printd logger = get_logger(__name__) @@ -394,97 +394,3 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) - logger.warning(f"Did not find tool call in message: {str(message)}") return rewritten_choice - - -def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts: List[int], logger: "logging.Logger") -> int: - if len(in_context_messages) != len(token_counts): - raise ValueError( - f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}" - ) - - in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) - - if summarizer_settings.evict_all_messages: - logger.info("Evicting all messages...") - return len(in_context_messages) - else: - # Start at index 1 (past the system message), - # and collect messages for summarization until we reach the desired truncation token fraction (eg 50%) - # We do the inverse of `desired_memory_token_pressure` to get what we need to remove - desired_token_count_to_summarize = int(sum(token_counts) * (1 - summarizer_settings.desired_memory_token_pressure)) - logger.info(f"desired_token_count_to_summarize={desired_token_count_to_summarize}") - - tokens_so_far = 0 - cutoff = 0 - for i, msg in enumerate(in_context_messages_openai): - # Skip system - if i == 0: - continue - cutoff = i - tokens_so_far += token_counts[i] - - if msg["role"] not in ["user", "tool", "function"] and tokens_so_far >= desired_token_count_to_summarize: - # Break if the role is NOT a user or tool/function and tokens_so_far is enough - break - elif len(in_context_messages) - cutoff - 1 <= summarizer_settings.keep_last_n_messages: - # Also break if we reached the `keep_last_n_messages` threshold - # NOTE: This may be on a user, tool, or function in theory - logger.warning( - f"Breaking summary cutoff early on role={msg['role']} because we hit the `keep_last_n_messages`={summarizer_settings.keep_last_n_messages}" - ) - break - - # includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming. - if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool": - cutoff += 1 - - logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...") - return cutoff + 1 - - -def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]: - in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages) - token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai] - return token_counts - - -def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool: - """Checks if an exception is due to context overflow (based on common OpenAI response messages)""" - from letta.utils import printd - - match_string = OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING - - # Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration) - if match_string in str(exception): - printd(f"Found '{match_string}' in str(exception)={(str(exception))}") - return True - - # Based on python requests + OpenAI REST API (/v1) - elif isinstance(exception, requests.exceptions.HTTPError): - if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""): - try: - error_details = exception.response.json() - if "error" not in error_details: - printd(f"HTTPError occurred, but couldn't find error field: {error_details}") - return False - else: - error_details = error_details["error"] - - # Check for the specific error code - if error_details.get("code") == "context_length_exceeded": - printd(f"HTTPError occurred, caught error code {error_details.get('code')}") - return True - # Soft-check for "maximum context length" inside of the message - elif error_details.get("message") and "maximum context length" in error_details.get("message"): - printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})") - return True - else: - printd(f"HTTPError occurred, but unknown error message: {error_details}") - return False - except ValueError: - # JSON decoding failed - printd(f"HTTPError occurred ({exception}), but no JSON error message.") - - # Generic fail - else: - return False diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index 881239d3..b25df76a 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -38,7 +38,7 @@ class LLMClientBase: self.use_tool_naming = use_tool_naming @trace_method - def send_llm_request( + async def send_llm_request( self, agent_type: AgentType, messages: List[Message], @@ -80,7 +80,7 @@ class LLMClientBase: except Exception as e: raise self.handle_llm_error(e) - return self.convert_response_to_chat_completion(response_data, messages, llm_config) + return await self.convert_response_to_chat_completion(response_data, messages, llm_config) @trace_method async def send_llm_request_async( @@ -114,7 +114,7 @@ class LLMClientBase: except Exception as e: raise self.handle_llm_error(e) - return self.convert_response_to_chat_completion(response_data, messages, llm_config) + return await self.convert_response_to_chat_completion(response_data, messages, llm_config) async def send_llm_batch_request_async( self, @@ -177,7 +177,7 @@ class LLMClientBase: raise NotImplementedError @abstractmethod - def convert_response_to_chat_completion( + async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[Message], diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index fa9d6e5d..bc8b014c 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -609,7 +609,7 @@ class OpenAIClient(LLMClientBase): return is_openai_reasoning_model(llm_config.model) @trace_method - def convert_response_to_chat_completion( + async def convert_response_to_chat_completion( self, response_data: dict, input_messages: List[PydanticMessage], # Included for consistency, maybe used later diff --git a/letta/local_llm/utils.py b/letta/local_llm/utils.py index be1e313e..0bbfcb10 100644 --- a/letta/local_llm/utils.py +++ b/letta/local_llm/utils.py @@ -58,11 +58,11 @@ def load_grammar_file(grammar): return grammar_str -# TODO: support tokenizers/tokenizer apis available in local models -def count_tokens(s: str, model: str = "gpt-4") -> int: - from letta.utils import count_tokens - - return count_tokens(s, model) +## TODO: support tokenizers/tokenizer apis available in local models +# def count_tokens(s: str, model: str = "gpt-4") -> int: +# from letta.utils import count_tokens +# +# return count_tokens(s, model) def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"): diff --git a/letta/memory.py b/letta/memory.py index cd9b4b87..1303ac55 100644 --- a/letta/memory.py +++ b/letta/memory.py @@ -50,59 +50,60 @@ def _format_summary_history(message_history: List[Message]): return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history]) -@trace_method -def summarize_messages( - agent_state: AgentState, - message_sequence_to_summarize: List[Message], - actor: "User", -): - """Summarize a message sequence using GPT""" - # we need the context_window - context_window = agent_state.llm_config.context_window - - summary_prompt = SUMMARY_PROMPT_SYSTEM - summary_input = _format_summary_history(message_sequence_to_summarize) - summary_input_tkns = count_tokens(summary_input) - if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window: - trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure... - cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) - summary_input = str( - [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)] - + message_sequence_to_summarize[cutoff:] - ) - - dummy_agent_id = agent_state.id - message_sequence = [ - Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]), - Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]), - Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]), - ] - - # TODO: We need to eventually have a separate LLM config for the summarizer LLM - llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True) - llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False - - llm_client = LLMClient.create( - provider_type=agent_state.llm_config.model_endpoint_type, - put_inner_thoughts_first=False, - actor=actor, - ) - # try to use new client, otherwise fallback to old flow - # TODO: we can just directly call the LLM here? - if llm_client: - response = llm_client.send_llm_request( - agent_type=agent_state.agent_type, - messages=message_sequence, - llm_config=llm_config_no_inner_thoughts, - ) - else: - response = create( - llm_config=llm_config_no_inner_thoughts, - user_id=agent_state.created_by_id, - messages=message_sequence, - stream=False, - ) - - printd(f"summarize_messages gpt reply: {response.choices[0]}") - reply = response.choices[0].message.content - return reply +# @trace_method +# def summarize_messages( +# agent_state: AgentState, +# message_sequence_to_summarize: List[Message], +# actor: "User", +# ): +# """Summarize a message sequence using GPT""" +# # we need the context_window +# context_window = agent_state.llm_config.context_window +# +# summary_prompt = SUMMARY_PROMPT_SYSTEM +# summary_input = _format_summary_history(message_sequence_to_summarize) +# summary_input_tkns = count_tokens(summary_input) +# if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window: +# trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure... +# cutoff = int(len(message_sequence_to_summarize) * trunc_ratio) +# summary_input = str( +# [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)] +# + message_sequence_to_summarize[cutoff:] +# ) +# +# dummy_agent_id = agent_state.id +# message_sequence = [ +# Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]), +# Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]), +# Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]), +# ] +# +# # TODO: We need to eventually have a separate LLM config for the summarizer LLM +# llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True) +# llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False +# +# llm_client = LLMClient.create( +# provider_type=agent_state.llm_config.model_endpoint_type, +# put_inner_thoughts_first=False, +# actor=actor, +# ) +# # try to use new client, otherwise fallback to old flow +# # TODO: we can just directly call the LLM here? +# if llm_client: +# response = llm_client.send_llm_request( +# agent_type=agent_state.agent_type, +# messages=message_sequence, +# llm_config=llm_config_no_inner_thoughts, +# ) +# else: +# response = create( +# llm_config=llm_config_no_inner_thoughts, +# user_id=agent_state.created_by_id, +# messages=message_sequence, +# stream=False, +# ) +# +# printd(f"summarize_messages gpt reply: {response.choices[0]}") +# reply = response.choices[0].message.content +# return reply +# diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index be9a209b..63227a6d 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -914,7 +914,7 @@ async def generate_tool_from_prompt( tools=[tool], ) response_data = await llm_client.request_async(request_data, llm_config) - response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config) + response = await llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config) output = json.loads(response.choices[0].message.tool_calls[0].function.arguments) pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()] diff --git a/letta/server/server.py b/letta/server/server.py index 8083199c..da85d905 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -670,20 +670,26 @@ class SyncServer(object): async def insert_archival_memory_async( self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime] ) -> List[Passage]: + from letta.services.context_window_calculator.token_counter import create_token_counter from letta.settings import settings - from letta.utils import count_tokens + + # Get the agent object (loaded in memory) + agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) # Check token count against limit - token_count = count_tokens(memory_contents) + token_counter = create_token_counter( + model_endpoint_type=agent_state.llm_config.model_endpoint_type, + model=agent_state.llm_config.model, + actor=actor, + agent_id=agent_id, + ) + token_count = await token_counter.count_text_tokens(memory_contents) if token_count > settings.archival_memory_token_limit: raise LettaInvalidArgumentError( message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)", argument_name="memory_contents", ) - # Get the agent object (loaded in memory) - agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) - # Use passage manager which handles dual-write to Turbopuffer if enabled passages = await self.passage_manager.insert_passage( agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 194afff5..023fc8ae 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -80,7 +80,7 @@ from letta.server.db import db_registry from letta.services.archive_manager import ArchiveManager from letta.services.block_manager import BlockManager, validate_block_limit_constraint from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator -from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, GeminiTokenCounter, TiktokenCounter +from letta.services.context_window_calculator.token_counter import create_token_counter from letta.services.file_processor.chunker.line_chunker import LineChunker from letta.services.files_agents_manager import FileAgentManager from letta.services.helpers.agent_manager_helper import ( @@ -3286,49 +3286,14 @@ class AgentManager: ) calculator = ContextWindowCalculator() - # Determine which token counter to use based on provider - model_endpoint_type = agent_state.llm_config.model_endpoint_type - - # Use Gemini token counter for Google Vertex and Google AI - use_gemini = model_endpoint_type in ("google_vertex", "google_ai") - - # Use Anthropic token counter if: - # 1. The model endpoint type is anthropic, OR - # 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini) - use_anthropic = model_endpoint_type == "anthropic" or ( - not use_gemini and settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None + # Create the appropriate token counter based on model configuration + token_counter = create_token_counter( + model_endpoint_type=agent_state.llm_config.model_endpoint_type, + model=agent_state.llm_config.model, + actor=actor, + agent_id=agent_id, ) - if use_gemini: - # Use native Gemini token counting API - - client = LLMClient.create(provider_type=agent_state.llm_config.model_endpoint_type, actor=actor) - model = agent_state.llm_config.model - - token_counter = GeminiTokenCounter(client, model) - logger.info( - f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, " - f"model_endpoint_type={model_endpoint_type}, " - f"environment={settings.environment}" - ) - elif use_anthropic: - anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) - model = agent_state.llm_config.model if model_endpoint_type == "anthropic" else None - - token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa - logger.info( - f"Using AnthropicTokenCounter for agent_id={agent_id}, model={model}, " - f"model_endpoint_type={model_endpoint_type}, " - f"environment={settings.environment}" - ) - else: - token_counter = TiktokenCounter(agent_state.llm_config.model) - logger.info( - f"Using TiktokenCounter for agent_id={agent_id}, model={agent_state.llm_config.model}, " - f"model_endpoint_type={model_endpoint_type}, " - f"environment={settings.environment}" - ) - try: result = await calculator.calculate_context_window( agent_state=agent_state, diff --git a/letta/services/context_window_calculator/token_counter.py b/letta/services/context_window_calculator/token_counter.py index eabc8a26..bb389467 100644 --- a/letta/services/context_window_calculator/token_counter.py +++ b/letta/services/context_window_calculator/token_counter.py @@ -1,15 +1,22 @@ import hashlib import json from abc import ABC, abstractmethod -from typing import Any, Dict, List +from typing import TYPE_CHECKING, Any, Dict, List, Optional from letta.helpers.decorators import async_redis_cache from letta.llm_api.anthropic_client import AnthropicClient from letta.llm_api.google_vertex_client import GoogleVertexClient +from letta.log import get_logger from letta.otel.tracing import trace_method +from letta.schemas.enums import ProviderType from letta.schemas.message import Message from letta.schemas.openai.chat_completion_request import Tool as OpenAITool -from letta.utils import count_tokens + +if TYPE_CHECKING: + from letta.schemas.llm_config import LLMConfig + from letta.schemas.user import User + +logger = get_logger(__name__) class TokenCounter(ABC): @@ -101,6 +108,22 @@ class ApproxTokenCounter(TokenCounter): byte_len = len(text.encode("utf-8")) return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN + async def count_text_tokens(self, text: str) -> int: + if not text: + return 0 + return self._approx_token_count(text) + + async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int: + if not messages: + return 0 + return self._approx_token_count(json.dumps(messages)) + + async def count_tool_tokens(self, tools: List[OpenAITool]) -> int: + if not tools: + return 0 + functions = [t.model_dump() for t in tools] + return self._approx_token_count(json.dumps(functions)) + def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: return Message.to_openai_dicts_from_list(messages) @@ -178,7 +201,14 @@ class TiktokenCounter(TokenCounter): logger.debug(f"TiktokenCounter.count_text_tokens: model={self.model}, text_length={text_length}, preview={repr(text_preview)}") try: - result = count_tokens(text) + import tiktoken + + try: + encoding = tiktoken.encoding_for_model(self.model) + except KeyError: + logger.debug(f"Model {self.model} not found in tiktoken. Using cl100k_base encoding.") + encoding = tiktoken.get_encoding("cl100k_base") + result = len(encoding.encode(text)) logger.debug(f"TiktokenCounter.count_text_tokens: completed successfully, tokens={result}") return result except Exception as e: @@ -234,3 +264,54 @@ class TiktokenCounter(TokenCounter): def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]: return Message.to_openai_dicts_from_list(messages) + + +def create_token_counter( + model_endpoint_type: ProviderType, + model: Optional[str] = None, + actor: "User" = None, + agent_id: Optional[str] = None, +) -> "TokenCounter": + """ + Factory function to create the appropriate token counter based on model configuration. + + Returns: + The appropriate TokenCounter instance + """ + from letta.llm_api.llm_client import LLMClient + from letta.settings import model_settings, settings + + # Use Gemini token counter for Google Vertex and Google AI + use_gemini = model_endpoint_type in ("google_vertex", "google_ai") + + # Use Anthropic token counter if: + # 1. The model endpoint type is anthropic, OR + # 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini) + use_anthropic = model_endpoint_type == "anthropic" + + if use_gemini: + client = LLMClient.create(provider_type=model_endpoint_type, actor=actor) + token_counter = GeminiTokenCounter(client, model) + logger.info( + f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, " + f"model_endpoint_type={model_endpoint_type}, " + f"environment={settings.environment}" + ) + elif use_anthropic: + anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) + counter_model = model if model_endpoint_type == "anthropic" else None + token_counter = AnthropicTokenCounter(anthropic_client, counter_model) + logger.info( + f"Using AnthropicTokenCounter for agent_id={agent_id}, model={counter_model}, " + f"model_endpoint_type={model_endpoint_type}, " + f"environment={settings.environment}" + ) + else: + token_counter = ApproxTokenCounter() + logger.info( + f"Using ApproxTokenCounter for agent_id={agent_id}, model={model}, " + f"model_endpoint_type={model_endpoint_type}, " + f"environment={settings.environment}" + ) + + return token_counter diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 7389090c..659c0ec7 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -447,9 +447,12 @@ async def simple_summary( ] input_messages_obj = [simple_message_wrapper(msg) for msg in input_messages] # Build a local LLMConfig for v1-style summarization which uses native content and must not - # include inner thoughts in kwargs to avoid conflicts in Anthropic formatting + # include inner thoughts in kwargs to avoid conflicts in Anthropic formatting. + # We also disable enable_reasoner to avoid extended thinking requirements (Anthropic requires + # assistant messages to start with thinking blocks when extended thinking is enabled). summarizer_llm_config = LLMConfig(**llm_config.model_dump()) summarizer_llm_config.put_inner_thoughts_in_kwargs = False + summarizer_llm_config.enable_reasoner = False request_data = llm_client.build_request_data(AgentType.letta_v1_agent, input_messages_obj, summarizer_llm_config, tools=[]) try: @@ -532,7 +535,7 @@ async def simple_summary( logger.info(f"Full fallback summarization payload: {request_data}") raise llm_client.handle_llm_error(fallback_error_b) - response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config) + response = await llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config) if response.choices[0].message.content is None: logger.warning("No content returned from summarizer") # TODO raise an error error instead? diff --git a/letta/services/summarizer/summarizer_all.py b/letta/services/summarizer/summarizer_all.py index 5fc833e3..dbdb7965 100644 --- a/letta/services/summarizer/summarizer_all.py +++ b/letta/services/summarizer/summarizer_all.py @@ -1,16 +1,11 @@ -from typing import List, Tuple +from typing import List -from letta.helpers.message_helper import convert_message_creates_to_messages from letta.log import get_logger -from letta.schemas.agent import AgentState -from letta.schemas.enums import MessageRole -from letta.schemas.letta_message_content import TextContent -from letta.schemas.message import Message, MessageCreate +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message from letta.schemas.user import User -from letta.services.message_manager import MessageManager from letta.services.summarizer.summarizer import simple_summary from letta.services.summarizer.summarizer_config import SummarizerConfig -from letta.system import package_summarize_message_no_counts logger = get_logger(__name__) @@ -18,6 +13,8 @@ logger = get_logger(__name__) async def summarize_all( # Required to tag LLM calls actor: User, + # LLM config for the summarizer model + llm_config: LLMConfig, # Actual summarization configuration summarizer_config: SummarizerConfig, in_context_messages: List[Message], @@ -33,9 +30,9 @@ async def summarize_all( summary_message_str = await simple_summary( messages=all_in_context_messages, - llm_config=summarizer_config.summarizer_model, + llm_config=llm_config, actor=actor, - include_ack=summarizer_config.prompt_acknowledgement, + include_ack=bool(summarizer_config.prompt_acknowledgement), prompt=summarizer_config.prompt, ) diff --git a/letta/services/summarizer/summarizer_sliding_window.py b/letta/services/summarizer/summarizer_sliding_window.py index dae9da03..29c530a9 100644 --- a/letta/services/summarizer/summarizer_sliding_window.py +++ b/letta/services/summarizer/summarizer_sliding_window.py @@ -1,19 +1,17 @@ from typing import List, Tuple from letta.helpers.message_helper import convert_message_creates_to_messages -from letta.llm_api.llm_client import LLMClient from letta.log import get_logger from letta.schemas.agent import AgentState -from letta.schemas.enums import MessageRole, ProviderType +from letta.schemas.enums import MessageRole from letta.schemas.letta_message_content import TextContent from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message, MessageCreate from letta.schemas.user import User -from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter +from letta.services.context_window_calculator.token_counter import create_token_counter from letta.services.message_manager import MessageManager from letta.services.summarizer.summarizer import simple_summary from letta.services.summarizer.summarizer_config import SummarizerConfig -from letta.settings import model_settings, settings from letta.system import package_summarize_message_no_counts logger = get_logger(__name__) @@ -26,21 +24,21 @@ APPROX_TOKEN_SAFETY_MARGIN = 1.3 async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int: - # If the model is an Anthropic model, use the Anthropic token counter (accurate) - if llm_config.model_endpoint_type == "anthropic": - anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor) - token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model) - converted_messages = token_counter.convert_messages(messages) - return await token_counter.count_message_tokens(converted_messages) + """Count tokens in messages using the appropriate token counter for the model configuration.""" + token_counter = create_token_counter( + model_endpoint_type=llm_config.model_endpoint_type, + model=llm_config.model, + actor=actor, + ) + converted_messages = token_counter.convert_messages(messages) + tokens = await token_counter.count_message_tokens(converted_messages) - else: - # Otherwise, use approximate count (bytes / 4) with safety margin - # This is much faster than tiktoken and doesn't require loading tokenizer models - token_counter = ApproxTokenCounter(llm_config.model) - converted_messages = token_counter.convert_messages(messages) - tokens = await token_counter.count_message_tokens(converted_messages) - # Apply safety margin to avoid underestimating and keeping too many messages + # Apply safety margin for approximate counting to avoid underestimating + from letta.services.context_window_calculator.token_counter import ApproxTokenCounter + + if isinstance(token_counter, ApproxTokenCounter): return int(tokens * APPROX_TOKEN_SAFETY_MARGIN) + return tokens async def summarize_via_sliding_window( @@ -110,9 +108,9 @@ async def summarize_via_sliding_window( summary_message_str = await simple_summary( messages=messages_to_summarize, - llm_config=summarizer_config.summarizer_model, + llm_config=llm_config, actor=actor, - include_ack=summarizer_config.prompt_acknowledgement, + include_ack=bool(summarizer_config.prompt_acknowledgement), prompt=summarizer_config.prompt, ) diff --git a/letta/utils.py b/letta/utils.py index 3401207e..02c0ff23 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -812,13 +812,13 @@ class OpenAIBackcompatUnpickler(pickle.Unpickler): return super().find_class(module, name) -def count_tokens(s: str, model: str = "gpt-4") -> int: - try: - encoding = tiktoken.encoding_for_model(model) - except KeyError: - print("Falling back to cl100k base for token counting.") - encoding = tiktoken.get_encoding("cl100k_base") - return len(encoding.encode(s)) +# def count_tokens(s: str, model: str = "gpt-4") -> int: +# try: +# encoding = tiktoken.encoding_for_model(model) +# except KeyError: +# print("Falling back to cl100k base for token counting.") +# encoding = tiktoken.get_encoding("cl100k_base") +# return len(encoding.encode(s)) def printd(*args, **kwargs): diff --git a/tests/integration_test_summarizer.py b/tests/integration_test_summarizer.py index a4d9e181..25b6d029 100644 --- a/tests/integration_test_summarizer.py +++ b/tests/integration_test_summarizer.py @@ -14,6 +14,7 @@ from typing import List import pytest from letta.agents.letta_agent_v2 import LettaAgentV2 +from letta.agents.letta_agent_v3 import LettaAgentV3 from letta.config import LettaConfig from letta.schemas.agent import CreateAgent from letta.schemas.embedding_config import EmbeddingConfig @@ -671,7 +672,12 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon @pytest.mark.asyncio -async def test_sliding_window_cutoff_index_does_not_exceed_message_count(): +@pytest.mark.parametrize( + "llm_config", + TESTED_LLM_CONFIGS, + ids=[c.model for c in TESTED_LLM_CONFIGS], +) +async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig): """ Test that the sliding window summarizer correctly calculates cutoff indices. @@ -685,35 +691,19 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(): - max(..., 10) -> max(..., 0.10) - += 10 -> += 0.10 - >= 100 -> >= 1.0 - """ - from unittest.mock import MagicMock, patch - from letta.schemas.enums import MessageRole - from letta.schemas.letta_message_content import TextContent - from letta.schemas.llm_config import LLMConfig - from letta.schemas.message import Message as PydanticMessage - from letta.schemas.user import User + This test uses the real token counter (via create_token_counter) to verify + the sliding window logic works with actual token counting. + """ + from letta.schemas.model import ModelSettings + from letta.services.summarizer.summarizer_config import get_default_summarizer_config from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window - # Create a mock user (using proper ID format pattern) - mock_actor = User( - id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000" - ) - - # Create a mock LLM config - mock_llm_config = LLMConfig( - model="gpt-4", - model_endpoint_type="openai", - context_window=128000, - ) - - # Create a mock summarizer config with sliding_window_percentage = 0.3 - mock_summarizer_config = MagicMock() - mock_summarizer_config.sliding_window_percentage = 0.3 - mock_summarizer_config.summarizer_model = mock_llm_config - mock_summarizer_config.prompt = "Summarize the conversation." - mock_summarizer_config.prompt_acknowledgement = True - mock_summarizer_config.clip_chars = 2000 + # Create a real summarizer config using the default factory + # Override sliding_window_percentage to 0.3 for this test + model_settings = ModelSettings() # Use defaults + summarizer_config = get_default_summarizer_config(model_settings) + summarizer_config.sliding_window_percentage = 0.3 # Create 65 messages (similar to the failing case in the bug report) # Pattern: system + alternating user/assistant messages @@ -741,59 +731,470 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count(): assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}" - # Mock count_tokens to return a value that would trigger summarization - # Return a high token count so that the while loop continues - async def mock_count_tokens(actor, llm_config, messages): - # Return tokens that decrease as we cut off more messages - # This simulates the token count decreasing as we evict messages - return len(messages) * 100 # 100 tokens per message + # This should NOT raise "No assistant message found from indices 650 to 65" + # With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7 + # So message_cutoff_index = round(0.7 * 65) = 46, which is valid + try: + summary, remaining_messages = await summarize_via_sliding_window( + actor=actor, + llm_config=llm_config, + summarizer_config=summarizer_config, + in_context_messages=messages, + new_messages=[], + ) - # Mock simple_summary to return a fake summary - async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt): - return "This is a mock summary of the conversation." + # Verify the summary was generated (actual LLM response) + assert summary is not None + assert len(summary) > 0 - with ( - patch( - "letta.services.summarizer.summarizer_sliding_window.count_tokens", - side_effect=mock_count_tokens, - ), - patch( - "letta.services.summarizer.summarizer_sliding_window.simple_summary", - side_effect=mock_simple_summary, - ), - ): - # This should NOT raise "No assistant message found from indices 650 to 65" - # With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7 - # So message_cutoff_index = round(0.7 * 65) = 46, which is valid - try: - summary, remaining_messages = await summarize_via_sliding_window( - actor=mock_actor, - llm_config=mock_llm_config, - summarizer_config=mock_summarizer_config, - in_context_messages=messages, - new_messages=[], - ) + # Verify remaining messages is a valid subset + assert len(remaining_messages) < len(messages) + assert len(remaining_messages) > 0 - # Verify the summary was generated - assert summary == "This is a mock summary of the conversation." + print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining") + print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}") + print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}") - # Verify remaining messages is a valid subset - assert len(remaining_messages) < len(messages) - assert len(remaining_messages) > 0 + except ValueError as e: + if "No assistant message found from indices" in str(e): + # Extract the indices from the error message + import re - print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining") + match = re.search(r"from indices (\d+) to (\d+)", str(e)) + if match: + start_idx, end_idx = int(match.group(1)), int(match.group(2)) + pytest.fail( + f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). " + f"This indicates the percentage calculation bug where 10 was used instead of 0.10. " + f"Error: {e}" + ) + raise - except ValueError as e: - if "No assistant message found from indices" in str(e): - # Extract the indices from the error message - import re - match = re.search(r"from indices (\d+) to (\d+)", str(e)) - if match: - start_idx, end_idx = int(match.group(1)), int(match.group(2)) - pytest.fail( - f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). " - f"This indicates the percentage calculation bug where 10 was used instead of 0.10. " - f"Error: {e}" - ) - raise +# @pytest.mark.asyncio +# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor): +# """ +# Test that a ContextWindowExceededError during a streaming LLM request +# properly triggers the summarizer and compacts the in-context messages. +# +# This test simulates: +# 1. An LLM streaming request that fails with ContextWindowExceededError +# 2. The summarizer being invoked to reduce context size +# 3. Verification that messages are compacted and summary message exists +# +# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling. +# """ +# import uuid +# from unittest.mock import patch +# +# import openai +# +# from letta.schemas.message import MessageCreate +# from letta.schemas.run import Run +# from letta.services.run_manager import RunManager +# +# # Use OpenAI config for this test (since we're using OpenAI-specific error handling) +# llm_config = get_llm_config("openai-gpt-4o-mini.json") +# +# # Create test messages - enough to have something to summarize +# messages = [] +# for i in range(15): +# messages.append( +# PydanticMessage( +# role=MessageRole.user, +# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")], +# ) +# ) +# messages.append( +# PydanticMessage( +# role=MessageRole.assistant, +# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")], +# ) +# ) +# +# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages) +# original_message_count = len(agent_state.message_ids) +# +# # Create an input message to trigger the agent +# input_message = MessageCreate( +# role=MessageRole.user, +# content=[TextContent(type="text", text="Hello, please respond.")], +# ) +# +# # Create a proper run record in the database +# run_manager = RunManager() +# test_run_id = f"run-{uuid.uuid4()}" +# test_run = Run( +# id=test_run_id, +# agent_id=agent_state.id, +# ) +# await run_manager.create_run(test_run, actor) +# +# # Create the agent loop using LettaAgentV3 +# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) +# +# # Track how many times stream_async is called +# call_count = 0 +# +# # Store original stream_async method +# original_stream_async = agent_loop.llm_client.stream_async +# +# async def mock_stream_async_with_error(request_data, llm_config): +# nonlocal call_count +# call_count += 1 +# if call_count == 1: +# # First call raises OpenAI BadRequestError with context_length_exceeded error code +# # This will be properly converted to ContextWindowExceededError by handle_llm_error +# from unittest.mock import MagicMock +# +# import httpx +# +# # Create a mock response with the required structure +# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") +# mock_response = httpx.Response( +# status_code=400, +# request=mock_request, +# json={ +# "error": { +# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# "type": "invalid_request_error", +# "code": "context_length_exceeded", +# } +# }, +# ) +# +# raise openai.BadRequestError( +# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# response=mock_response, +# body={ +# "error": { +# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# "type": "invalid_request_error", +# "code": "context_length_exceeded", +# } +# }, +# ) +# # Subsequent calls use the real implementation +# return await original_stream_async(request_data, llm_config) +# +# # Patch the llm_client's stream_async to raise ContextWindowExceededError on first call +# with patch.object(agent_loop.llm_client, "stream_async", side_effect=mock_stream_async_with_error): +# # Execute a streaming step +# try: +# result_chunks = [] +# async for chunk in agent_loop.stream( +# input_messages=[input_message], +# max_steps=1, +# stream_tokens=True, +# run_id=test_run_id, +# ): +# result_chunks.append(chunk) +# except Exception as e: +# # Some errors might happen due to real LLM calls after retry +# print(f"Exception during stream: {e}") +# +# # Reload agent state to get updated message_ids after summarization +# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor) +# updated_message_count = len(updated_agent_state.message_ids) +# +# # Fetch the updated in-context messages +# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async( +# message_ids=updated_agent_state.message_ids, actor=actor +# ) +# +# # Convert to LettaMessage format for easier content inspection +# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages) +# +# # Verify a summary message exists with the correct format +# # The summary message has content with type="system_alert" and message containing: +# # "prior messages ... have been hidden" and "summary of the previous" +# import json +# +# summary_message_found = False +# summary_message_text = None +# for msg in letta_messages: +# # Not all message types have a content attribute (e.g., ReasoningMessage) +# if not hasattr(msg, "content"): +# continue +# +# content = msg.content +# # Content can be a string (JSON) or an object with type/message fields +# if isinstance(content, str): +# # Try to parse as JSON +# try: +# parsed = json.loads(content) +# if isinstance(parsed, dict) and parsed.get("type") == "system_alert": +# text_to_check = parsed.get("message", "").lower() +# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check: +# summary_message_found = True +# summary_message_text = parsed.get("message") +# break +# except (json.JSONDecodeError, TypeError): +# pass +# # Check if content has system_alert type with the summary message (object form) +# elif hasattr(content, "type") and content.type == "system_alert": +# if hasattr(content, "message") and content.message: +# text_to_check = content.message.lower() +# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check: +# summary_message_found = True +# summary_message_text = content.message +# break +# +# assert summary_message_found, ( +# "A summary message should exist in the in-context messages after summarization. " +# "Expected format containing 'prior messages...hidden' and 'summary of the previous'" +# ) +# +# # Verify we attempted multiple invocations (the failing one + retry after summarization) +# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}" +# +# # The original messages should have been compacted - the updated count should be less than +# # original + the new messages added (input + assistant response + tool results) +# # Since summarization should have removed most of the original 30 messages +# print("Test passed: Summary message found in context") +# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}") +# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...") +# print(f"Total LLM invocations: {call_count}") +# +# +# @pytest.mark.asyncio +# async def test_context_window_overflow_triggers_summarization_in_blocking(server: SyncServer, actor): +# """ +# Test that a ContextWindowExceededError during a blocking (non-streaming) LLM request +# properly triggers the summarizer and compacts the in-context messages. +# +# This test is similar to the streaming test but uses the blocking step() method. +# +# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling. +# """ +# import uuid +# from unittest.mock import patch +# +# import openai +# +# from letta.schemas.message import MessageCreate +# from letta.schemas.run import Run +# from letta.services.run_manager import RunManager +# +# # Use OpenAI config for this test (since we're using OpenAI-specific error handling) +# llm_config = get_llm_config("openai-gpt-4o-mini.json") +# +# # Create test messages +# messages = [] +# for i in range(15): +# messages.append( +# PydanticMessage( +# role=MessageRole.user, +# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")], +# ) +# ) +# messages.append( +# PydanticMessage( +# role=MessageRole.assistant, +# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")], +# ) +# ) +# +# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages) +# original_message_count = len(agent_state.message_ids) +# +# # Create an input message to trigger the agent +# input_message = MessageCreate( +# role=MessageRole.user, +# content=[TextContent(type="text", text="Hello, please respond.")], +# ) +# +# # Create a proper run record in the database +# run_manager = RunManager() +# test_run_id = f"run-{uuid.uuid4()}" +# test_run = Run( +# id=test_run_id, +# agent_id=agent_state.id, +# ) +# await run_manager.create_run(test_run, actor) +# +# # Create the agent loop using LettaAgentV3 +# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor) +# +# # Track how many times request_async is called +# call_count = 0 +# +# # Store original request_async method +# original_request_async = agent_loop.llm_client.request_async +# +# async def mock_request_async_with_error(request_data, llm_config): +# nonlocal call_count +# call_count += 1 +# if call_count == 1: +# # First call raises OpenAI BadRequestError with context_length_exceeded error code +# # This will be properly converted to ContextWindowExceededError by handle_llm_error +# import httpx +# +# # Create a mock response with the required structure +# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions") +# mock_response = httpx.Response( +# status_code=400, +# request=mock_request, +# json={ +# "error": { +# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# "type": "invalid_request_error", +# "code": "context_length_exceeded", +# } +# }, +# ) +# +# raise openai.BadRequestError( +# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# response=mock_response, +# body={ +# "error": { +# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.", +# "type": "invalid_request_error", +# "code": "context_length_exceeded", +# } +# }, +# ) +# # Subsequent calls use the real implementation +# return await original_request_async(request_data, llm_config) +# +# # Patch the llm_client's request_async to raise ContextWindowExceededError on first call +# with patch.object(agent_loop.llm_client, "request_async", side_effect=mock_request_async_with_error): +# # Execute a blocking step +# try: +# result = await agent_loop.step( +# input_messages=[input_message], +# max_steps=1, +# run_id=test_run_id, +# ) +# except Exception as e: +# # Some errors might happen due to real LLM calls after retry +# print(f"Exception during step: {e}") +# +# # Reload agent state to get updated message_ids after summarization +# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor) +# updated_message_count = len(updated_agent_state.message_ids) +# +# # Fetch the updated in-context messages +# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async( +# message_ids=updated_agent_state.message_ids, actor=actor +# ) +# +# # Convert to LettaMessage format for easier content inspection +# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages) +# +# # Verify a summary message exists with the correct format +# # The summary message has content with type="system_alert" and message containing: +# # "prior messages ... have been hidden" and "summary of the previous" +# import json +# +# summary_message_found = False +# summary_message_text = None +# for msg in letta_messages: +# # Not all message types have a content attribute (e.g., ReasoningMessage) +# if not hasattr(msg, "content"): +# continue +# +# content = msg.content +# # Content can be a string (JSON) or an object with type/message fields +# if isinstance(content, str): +# # Try to parse as JSON +# try: +# parsed = json.loads(content) +# if isinstance(parsed, dict) and parsed.get("type") == "system_alert": +# text_to_check = parsed.get("message", "").lower() +# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check: +# summary_message_found = True +# summary_message_text = parsed.get("message") +# break +# except (json.JSONDecodeError, TypeError): +# pass +# # Check if content has system_alert type with the summary message (object form) +# elif hasattr(content, "type") and content.type == "system_alert": +# if hasattr(content, "message") and content.message: +# text_to_check = content.message.lower() +# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check: +# summary_message_found = True +# summary_message_text = content.message +# break +# +# assert summary_message_found, ( +# "A summary message should exist in the in-context messages after summarization. " +# "Expected format containing 'prior messages...hidden' and 'summary of the previous'" +# ) +# +# # Verify we attempted multiple invocations (the failing one + retry after summarization) +# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}" +# +# # The original messages should have been compacted - the updated count should be less than +# # original + the new messages added (input + assistant response + tool results) +# print("Test passed: Summary message found in context (blocking mode)") +# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}") +# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...") +# print(f"Total LLM invocations: {call_count}") +# +# +# @pytest.mark.asyncio +# @pytest.mark.parametrize( +# "llm_config", +# TESTED_LLM_CONFIGS, +# ids=[c.model for c in TESTED_LLM_CONFIGS], +# ) +# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig): +# """ +# Test the summarize_all function with real LLM calls. +# +# This test verifies that the 'all' summarization mode works correctly, +# summarizing the entire conversation into a single summary string. +# """ +# from letta.schemas.model import ModelSettings +# from letta.services.summarizer.summarizer_all import summarize_all +# from letta.services.summarizer.summarizer_config import get_default_summarizer_config +# +# # Create a summarizer config with "all" mode +# model_settings = ModelSettings() +# summarizer_config = get_default_summarizer_config(model_settings) +# summarizer_config.mode = "all" +# +# # Create test messages - a simple conversation +# messages = [ +# PydanticMessage( +# role=MessageRole.system, +# content=[TextContent(type="text", text="You are a helpful assistant.")], +# ) +# ] +# +# # Add 10 user-assistant pairs +# for i in range(10): +# messages.append( +# PydanticMessage( +# role=MessageRole.user, +# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")], +# ) +# ) +# messages.append( +# PydanticMessage( +# role=MessageRole.assistant, +# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")], +# ) +# ) +# +# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}" +# +# # Call summarize_all with real LLM +# summary = await summarize_all( +# actor=actor, +# llm_config=llm_config, +# summarizer_config=summarizer_config, +# in_context_messages=messages, +# new_messages=[], +# ) +# +# # Verify the summary was generated +# assert summary is not None +# assert len(summary) > 0 +# +# print(f"Successfully summarized {len(messages)} messages using 'all' mode") +# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}") +# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}") +# diff --git a/tests/integration_test_token_counters.py b/tests/integration_test_token_counters.py index f105ad02..cb05edbf 100644 --- a/tests/integration_test_token_counters.py +++ b/tests/integration_test_token_counters.py @@ -169,8 +169,8 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig): from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.services.context_window_calculator.token_counter import ( AnthropicTokenCounter, + ApproxTokenCounter, GeminiTokenCounter, - TiktokenCounter, ) if llm_config.model_endpoint_type == "anthropic": @@ -179,7 +179,7 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig): client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() token_counter = GeminiTokenCounter(client, llm_config.model) else: - token_counter = TiktokenCounter(llm_config.model) + token_counter = ApproxTokenCounter() token_count = await token_counter.count_text_tokens("") assert token_count == 0 @@ -194,8 +194,8 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig): from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.services.context_window_calculator.token_counter import ( AnthropicTokenCounter, + ApproxTokenCounter, GeminiTokenCounter, - TiktokenCounter, ) if llm_config.model_endpoint_type == "anthropic": @@ -204,7 +204,7 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig): client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() token_counter = GeminiTokenCounter(client, llm_config.model) else: - token_counter = TiktokenCounter(llm_config.model) + token_counter = ApproxTokenCounter() token_count = await token_counter.count_message_tokens([]) assert token_count == 0 @@ -219,8 +219,8 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig): from letta.llm_api.google_vertex_client import GoogleVertexClient from letta.services.context_window_calculator.token_counter import ( AnthropicTokenCounter, + ApproxTokenCounter, GeminiTokenCounter, - TiktokenCounter, ) if llm_config.model_endpoint_type == "anthropic": @@ -229,7 +229,7 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig): client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient() token_counter = GeminiTokenCounter(client, llm_config.model) else: - token_counter = TiktokenCounter(llm_config.model) + token_counter = ApproxTokenCounter() token_count = await token_counter.count_tool_tokens([]) assert token_count == 0