feat: fix new summarizer code and add more tests (#6461)
This commit is contained in:
committed by
Caren Thomas
parent
86023db9b1
commit
91e3dd8b3e
@@ -47,7 +47,9 @@ class LettaLLMRequestAdapter(LettaLLMAdapter):
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Convert response to chat completion format
|
||||
self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config)
|
||||
self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion(
|
||||
self.response_data, messages, self.llm_config
|
||||
)
|
||||
|
||||
# Extract reasoning content from the response
|
||||
if self.chat_completions_response.choices[0].message.reasoning_content:
|
||||
|
||||
@@ -47,7 +47,9 @@ class SimpleLLMRequestAdapter(LettaLLMRequestAdapter):
|
||||
self.llm_request_finish_timestamp_ns = get_utc_timestamp_ns()
|
||||
|
||||
# Convert response to chat completion format
|
||||
self.chat_completions_response = self.llm_client.convert_response_to_chat_completion(self.response_data, messages, self.llm_config)
|
||||
self.chat_completions_response = await self.llm_client.convert_response_to_chat_completion(
|
||||
self.response_data, messages, self.llm_config
|
||||
)
|
||||
|
||||
# Extract reasoning content from the response
|
||||
if self.chat_completions_response.choices[0].message.reasoning_content:
|
||||
|
||||
@@ -87,7 +87,7 @@ class EphemeralSummaryAgent(BaseAgent):
|
||||
|
||||
request_data = llm_client.build_request_data(agent_state.agent_type, messages, agent_state.llm_config, tools=[])
|
||||
response_data = await llm_client.request_async(request_data, agent_state.llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, messages, agent_state.llm_config)
|
||||
summary = response.choices[0].message.content.strip()
|
||||
|
||||
await self.block_manager.update_block_async(block_id=block.id, block_update=BlockUpdate(value=summary), actor=self.actor)
|
||||
|
||||
@@ -337,7 +337,7 @@ class LettaAgent(BaseAgent):
|
||||
log_event("agent.stream_no_tokens.llm_response.received") # [3^]
|
||||
|
||||
try:
|
||||
response = llm_client.convert_response_to_chat_completion(
|
||||
response = await llm_client.convert_response_to_chat_completion(
|
||||
response_data, in_context_messages, agent_state.llm_config
|
||||
)
|
||||
except ValueError as e:
|
||||
@@ -681,7 +681,7 @@ class LettaAgent(BaseAgent):
|
||||
log_event("agent.step.llm_response.received") # [3^]
|
||||
|
||||
try:
|
||||
response = llm_client.convert_response_to_chat_completion(
|
||||
response = await llm_client.convert_response_to_chat_completion(
|
||||
response_data, in_context_messages, agent_state.llm_config
|
||||
)
|
||||
except ValueError as e:
|
||||
|
||||
@@ -309,7 +309,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
agent_state_map = {agent.id: agent for agent in agent_states}
|
||||
|
||||
# Process each agent's results
|
||||
tool_call_results = self._process_agent_results(
|
||||
tool_call_results = await self._process_agent_results(
|
||||
agent_ids=agent_ids, batch_item_map=batch_item_map, provider_results=provider_results, llm_batch_id=llm_batch_id
|
||||
)
|
||||
|
||||
@@ -324,7 +324,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
request_status_updates=tool_call_results.status_updates,
|
||||
)
|
||||
|
||||
def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
|
||||
async def _process_agent_results(self, agent_ids, batch_item_map, provider_results, llm_batch_id):
|
||||
"""
|
||||
Process the results for each agent, extracting tool calls and determining continuation status.
|
||||
|
||||
@@ -347,7 +347,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
request_status_updates.append(RequestStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, request_status=status))
|
||||
|
||||
# Process tool calls
|
||||
name, args, cont = self._extract_tool_call_from_result(item, result)
|
||||
name, args, cont = await self._extract_tool_call_from_result(item, result)
|
||||
name_map[aid], args_map[aid], cont_map[aid] = name, args, cont
|
||||
|
||||
return ToolCallResults(name_map, args_map, cont_map, request_status_updates)
|
||||
@@ -363,7 +363,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
else:
|
||||
return JobStatus.expired
|
||||
|
||||
def _extract_tool_call_from_result(self, item, result):
|
||||
async def _extract_tool_call_from_result(self, item, result):
|
||||
"""Extract tool call information from a result"""
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=item.llm_config.model_endpoint_type,
|
||||
@@ -375,13 +375,10 @@ class LettaAgentBatch(BaseAgent):
|
||||
if not isinstance(result, BetaMessageBatchSucceededResult):
|
||||
return None, None, False
|
||||
|
||||
tool_call = (
|
||||
llm_client.convert_response_to_chat_completion(
|
||||
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
|
||||
)
|
||||
.choices[0]
|
||||
.message.tool_calls[0]
|
||||
response = await llm_client.convert_response_to_chat_completion(
|
||||
response_data=result.message.model_dump(), input_messages=[], llm_config=item.llm_config
|
||||
)
|
||||
tool_call = response.choices[0].message.tool_calls[0]
|
||||
|
||||
return self._extract_tool_call_and_decide_continue(tool_call, item.step_state)
|
||||
|
||||
|
||||
@@ -1280,6 +1280,9 @@ class LettaAgentV3(LettaAgentV2):
|
||||
force: bool = False,
|
||||
) -> list[Message]:
|
||||
trigger_summarization = force or (total_tokens and total_tokens > self.agent_state.llm_config.context_window)
|
||||
self.logger.info(
|
||||
f"trigger_summarization: {trigger_summarization}, total_tokens: {total_tokens}, context_window: {self.agent_state.llm_config.context_window}"
|
||||
)
|
||||
if not trigger_summarization:
|
||||
# just update the message_ids
|
||||
# TODO: gross to handle this here: we should move persistence elsewhere
|
||||
@@ -1301,6 +1304,7 @@ class LettaAgentV3(LettaAgentV2):
|
||||
if summarizer_config.mode == "all":
|
||||
summary_message_str = await summarize_all(
|
||||
actor=self.actor,
|
||||
llm_config=self.agent_state.llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=in_context_messages,
|
||||
new_messages=new_letta_messages,
|
||||
@@ -1353,6 +1357,6 @@ class LettaAgentV3(LettaAgentV2):
|
||||
message_ids=new_in_context_message_ids,
|
||||
actor=self.actor,
|
||||
)
|
||||
self.agent_state.message_ids = new_in_context_messages
|
||||
self.agent_state.message_ids = new_in_context_message_ids
|
||||
|
||||
return new_in_context_messages
|
||||
|
||||
@@ -54,12 +54,12 @@ from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.server.rest_api.utils import decrement_message_uuid
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.streaming_utils import (
|
||||
FunctionArgumentsStreamHandler,
|
||||
JSONInnerThoughtsExtractor,
|
||||
sanitize_streamed_message_content,
|
||||
)
|
||||
from letta.utils import count_tokens
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -83,6 +83,10 @@ class OpenAIStreamingInterface:
|
||||
step_id: str | None = None,
|
||||
):
|
||||
self.use_assistant_message = use_assistant_message
|
||||
|
||||
# Create token counter for fallback token counting (when API doesn't return usage)
|
||||
# Use openai endpoint type for approximate counting in streaming context
|
||||
self._fallback_token_counter = create_token_counter(model_endpoint_type="openai")
|
||||
self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL
|
||||
self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG
|
||||
self.put_inner_thoughts_in_kwarg = put_inner_thoughts_in_kwarg
|
||||
@@ -301,7 +305,8 @@ class OpenAIStreamingInterface:
|
||||
updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments)
|
||||
|
||||
if self.is_openai_proxy:
|
||||
self.fallback_output_tokens += count_tokens(tool_call.function.arguments)
|
||||
# Use approximate counting for fallback (sync method)
|
||||
self.fallback_output_tokens += self._fallback_token_counter._approx_token_count(tool_call.function.arguments)
|
||||
|
||||
# If we have inner thoughts, we should output them as a chunk
|
||||
if updates_inner_thoughts:
|
||||
|
||||
@@ -764,7 +764,7 @@ class AnthropicClient(LLMClientBase):
|
||||
# TODO: Input messages doesn't get used here
|
||||
# TODO: Clean up this interface
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
|
||||
@@ -401,7 +401,7 @@ class DeepseekClient(OpenAIClient):
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
||||
@@ -413,5 +413,5 @@ class DeepseekClient(OpenAIClient):
|
||||
"""
|
||||
response = ChatCompletionResponse(**response_data)
|
||||
if response.choices[0].message.tool_calls:
|
||||
return super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
return await super().convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
return convert_deepseek_response_to_chatcompletion(response)
|
||||
|
||||
@@ -31,7 +31,6 @@ from letta.helpers.datetime_helpers import get_utc_time_int
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.llm_api.llm_client_base import LLMClientBase
|
||||
from letta.local_llm.json_parser import clean_json_string_extra_backslash
|
||||
from letta.local_llm.utils import count_tokens
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.agent import AgentType
|
||||
@@ -404,7 +403,7 @@ class GoogleVertexClient(LLMClientBase):
|
||||
return request_data
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage],
|
||||
@@ -661,10 +660,13 @@ class GoogleVertexClient(LLMClientBase):
|
||||
completion_tokens_details=completion_tokens_details,
|
||||
)
|
||||
else:
|
||||
# Count it ourselves
|
||||
# Count it ourselves using the Gemini token counting API
|
||||
assert input_messages is not None, "Didn't get UsageMetadata from the API response, so input_messages is required"
|
||||
prompt_tokens = count_tokens(json_dumps(input_messages)) # NOTE: this is a very rough approximation
|
||||
completion_tokens = count_tokens(json_dumps(openai_response_message.model_dump())) # NOTE: this is also approximate
|
||||
google_messages = PydanticMessage.to_google_dicts_from_list(input_messages, current_model=llm_config.model)
|
||||
prompt_tokens = await self.count_tokens(messages=google_messages, model=llm_config.model)
|
||||
# For completion tokens, wrap the response content in Google format
|
||||
completion_content = [{"role": "model", "parts": [{"text": json_dumps(openai_response_message.model_dump())}]}]
|
||||
completion_tokens = await self.count_tokens(messages=completion_content, model=llm_config.model)
|
||||
total_tokens = prompt_tokens + completion_tokens
|
||||
usage = UsageStatistics(
|
||||
prompt_tokens=prompt_tokens,
|
||||
|
||||
@@ -19,7 +19,7 @@ from letta.schemas.response_format import (
|
||||
TextResponseFormat,
|
||||
)
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.utils import count_tokens, printd
|
||||
from letta.utils import printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -394,97 +394,3 @@ def unpack_inner_thoughts_from_kwargs(choice: Choice, inner_thoughts_key: str) -
|
||||
logger.warning(f"Did not find tool call in message: {str(message)}")
|
||||
|
||||
return rewritten_choice
|
||||
|
||||
|
||||
def calculate_summarizer_cutoff(in_context_messages: List[Message], token_counts: List[int], logger: "logging.Logger") -> int:
|
||||
if len(in_context_messages) != len(token_counts):
|
||||
raise ValueError(
|
||||
f"Given in_context_messages has different length from given token_counts: {len(in_context_messages)} != {len(token_counts)}"
|
||||
)
|
||||
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
|
||||
if summarizer_settings.evict_all_messages:
|
||||
logger.info("Evicting all messages...")
|
||||
return len(in_context_messages)
|
||||
else:
|
||||
# Start at index 1 (past the system message),
|
||||
# and collect messages for summarization until we reach the desired truncation token fraction (eg 50%)
|
||||
# We do the inverse of `desired_memory_token_pressure` to get what we need to remove
|
||||
desired_token_count_to_summarize = int(sum(token_counts) * (1 - summarizer_settings.desired_memory_token_pressure))
|
||||
logger.info(f"desired_token_count_to_summarize={desired_token_count_to_summarize}")
|
||||
|
||||
tokens_so_far = 0
|
||||
cutoff = 0
|
||||
for i, msg in enumerate(in_context_messages_openai):
|
||||
# Skip system
|
||||
if i == 0:
|
||||
continue
|
||||
cutoff = i
|
||||
tokens_so_far += token_counts[i]
|
||||
|
||||
if msg["role"] not in ["user", "tool", "function"] and tokens_so_far >= desired_token_count_to_summarize:
|
||||
# Break if the role is NOT a user or tool/function and tokens_so_far is enough
|
||||
break
|
||||
elif len(in_context_messages) - cutoff - 1 <= summarizer_settings.keep_last_n_messages:
|
||||
# Also break if we reached the `keep_last_n_messages` threshold
|
||||
# NOTE: This may be on a user, tool, or function in theory
|
||||
logger.warning(
|
||||
f"Breaking summary cutoff early on role={msg['role']} because we hit the `keep_last_n_messages`={summarizer_settings.keep_last_n_messages}"
|
||||
)
|
||||
break
|
||||
|
||||
# includes the tool response to be summarized after a tool call so we don't have any hanging tool calls after trimming.
|
||||
if i + 1 < len(in_context_messages_openai) and in_context_messages_openai[i + 1]["role"] == "tool":
|
||||
cutoff += 1
|
||||
|
||||
logger.info(f"Evicting {cutoff}/{len(in_context_messages)} messages...")
|
||||
return cutoff + 1
|
||||
|
||||
|
||||
def get_token_counts_for_messages(in_context_messages: List[Message]) -> List[int]:
|
||||
in_context_messages_openai = Message.to_openai_dicts_from_list(in_context_messages)
|
||||
token_counts = [count_tokens(str(msg)) for msg in in_context_messages_openai]
|
||||
return token_counts
|
||||
|
||||
|
||||
def is_context_overflow_error(exception: Union[requests.exceptions.RequestException, Exception]) -> bool:
|
||||
"""Checks if an exception is due to context overflow (based on common OpenAI response messages)"""
|
||||
from letta.utils import printd
|
||||
|
||||
match_string = OPENAI_CONTEXT_WINDOW_ERROR_SUBSTRING
|
||||
|
||||
# Backwards compatibility with openai python package/client v0.28 (pre-v1 client migration)
|
||||
if match_string in str(exception):
|
||||
printd(f"Found '{match_string}' in str(exception)={(str(exception))}")
|
||||
return True
|
||||
|
||||
# Based on python requests + OpenAI REST API (/v1)
|
||||
elif isinstance(exception, requests.exceptions.HTTPError):
|
||||
if exception.response is not None and "application/json" in exception.response.headers.get("Content-Type", ""):
|
||||
try:
|
||||
error_details = exception.response.json()
|
||||
if "error" not in error_details:
|
||||
printd(f"HTTPError occurred, but couldn't find error field: {error_details}")
|
||||
return False
|
||||
else:
|
||||
error_details = error_details["error"]
|
||||
|
||||
# Check for the specific error code
|
||||
if error_details.get("code") == "context_length_exceeded":
|
||||
printd(f"HTTPError occurred, caught error code {error_details.get('code')}")
|
||||
return True
|
||||
# Soft-check for "maximum context length" inside of the message
|
||||
elif error_details.get("message") and "maximum context length" in error_details.get("message"):
|
||||
printd(f"HTTPError occurred, found '{match_string}' in error message contents ({error_details})")
|
||||
return True
|
||||
else:
|
||||
printd(f"HTTPError occurred, but unknown error message: {error_details}")
|
||||
return False
|
||||
except ValueError:
|
||||
# JSON decoding failed
|
||||
printd(f"HTTPError occurred ({exception}), but no JSON error message.")
|
||||
|
||||
# Generic fail
|
||||
else:
|
||||
return False
|
||||
|
||||
@@ -38,7 +38,7 @@ class LLMClientBase:
|
||||
self.use_tool_naming = use_tool_naming
|
||||
|
||||
@trace_method
|
||||
def send_llm_request(
|
||||
async def send_llm_request(
|
||||
self,
|
||||
agent_type: AgentType,
|
||||
messages: List[Message],
|
||||
@@ -80,7 +80,7 @@ class LLMClientBase:
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
@trace_method
|
||||
async def send_llm_request_async(
|
||||
@@ -114,7 +114,7 @@ class LLMClientBase:
|
||||
except Exception as e:
|
||||
raise self.handle_llm_error(e)
|
||||
|
||||
return self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
return await self.convert_response_to_chat_completion(response_data, messages, llm_config)
|
||||
|
||||
async def send_llm_batch_request_async(
|
||||
self,
|
||||
@@ -177,7 +177,7 @@ class LLMClientBase:
|
||||
raise NotImplementedError
|
||||
|
||||
@abstractmethod
|
||||
def convert_response_to_chat_completion(
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[Message],
|
||||
|
||||
@@ -609,7 +609,7 @@ class OpenAIClient(LLMClientBase):
|
||||
return is_openai_reasoning_model(llm_config.model)
|
||||
|
||||
@trace_method
|
||||
def convert_response_to_chat_completion(
|
||||
async def convert_response_to_chat_completion(
|
||||
self,
|
||||
response_data: dict,
|
||||
input_messages: List[PydanticMessage], # Included for consistency, maybe used later
|
||||
|
||||
@@ -58,11 +58,11 @@ def load_grammar_file(grammar):
|
||||
return grammar_str
|
||||
|
||||
|
||||
# TODO: support tokenizers/tokenizer apis available in local models
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
from letta.utils import count_tokens
|
||||
|
||||
return count_tokens(s, model)
|
||||
## TODO: support tokenizers/tokenizer apis available in local models
|
||||
# def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
# from letta.utils import count_tokens
|
||||
#
|
||||
# return count_tokens(s, model)
|
||||
|
||||
|
||||
def num_tokens_from_functions(functions: List[dict], model: str = "gpt-4"):
|
||||
|
||||
113
letta/memory.py
113
letta/memory.py
@@ -50,59 +50,60 @@ def _format_summary_history(message_history: List[Message]):
|
||||
return "\n".join([f"{m.role}: {get_message_text(m.content)}" for m in message_history])
|
||||
|
||||
|
||||
@trace_method
|
||||
def summarize_messages(
|
||||
agent_state: AgentState,
|
||||
message_sequence_to_summarize: List[Message],
|
||||
actor: "User",
|
||||
):
|
||||
"""Summarize a message sequence using GPT"""
|
||||
# we need the context_window
|
||||
context_window = agent_state.llm_config.context_window
|
||||
|
||||
summary_prompt = SUMMARY_PROMPT_SYSTEM
|
||||
summary_input = _format_summary_history(message_sequence_to_summarize)
|
||||
summary_input_tkns = count_tokens(summary_input)
|
||||
if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window:
|
||||
trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
|
||||
cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
summary_input = str(
|
||||
[summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
|
||||
+ message_sequence_to_summarize[cutoff:]
|
||||
)
|
||||
|
||||
dummy_agent_id = agent_state.id
|
||||
message_sequence = [
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
|
||||
Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
|
||||
]
|
||||
|
||||
# TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
||||
llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
put_inner_thoughts_first=False,
|
||||
actor=actor,
|
||||
)
|
||||
# try to use new client, otherwise fallback to old flow
|
||||
# TODO: we can just directly call the LLM here?
|
||||
if llm_client:
|
||||
response = llm_client.send_llm_request(
|
||||
agent_type=agent_state.agent_type,
|
||||
messages=message_sequence,
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
)
|
||||
else:
|
||||
response = create(
|
||||
llm_config=llm_config_no_inner_thoughts,
|
||||
user_id=agent_state.created_by_id,
|
||||
messages=message_sequence,
|
||||
stream=False,
|
||||
)
|
||||
|
||||
printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
||||
reply = response.choices[0].message.content
|
||||
return reply
|
||||
# @trace_method
|
||||
# def summarize_messages(
|
||||
# agent_state: AgentState,
|
||||
# message_sequence_to_summarize: List[Message],
|
||||
# actor: "User",
|
||||
# ):
|
||||
# """Summarize a message sequence using GPT"""
|
||||
# # we need the context_window
|
||||
# context_window = agent_state.llm_config.context_window
|
||||
#
|
||||
# summary_prompt = SUMMARY_PROMPT_SYSTEM
|
||||
# summary_input = _format_summary_history(message_sequence_to_summarize)
|
||||
# summary_input_tkns = count_tokens(summary_input)
|
||||
# if summary_input_tkns > summarizer_settings.memory_warning_threshold * context_window:
|
||||
# trunc_ratio = (summarizer_settings.memory_warning_threshold * context_window / summary_input_tkns) * 0.8 # For good measure...
|
||||
# cutoff = int(len(message_sequence_to_summarize) * trunc_ratio)
|
||||
# summary_input = str(
|
||||
# [summarize_messages(agent_state, message_sequence_to_summarize=message_sequence_to_summarize[:cutoff], actor=actor)]
|
||||
# + message_sequence_to_summarize[cutoff:]
|
||||
# )
|
||||
#
|
||||
# dummy_agent_id = agent_state.id
|
||||
# message_sequence = [
|
||||
# Message(agent_id=dummy_agent_id, role=MessageRole.system, content=[TextContent(text=summary_prompt)]),
|
||||
# Message(agent_id=dummy_agent_id, role=MessageRole.assistant, content=[TextContent(text=MESSAGE_SUMMARY_REQUEST_ACK)]),
|
||||
# Message(agent_id=dummy_agent_id, role=MessageRole.user, content=[TextContent(text=summary_input)]),
|
||||
# ]
|
||||
#
|
||||
# # TODO: We need to eventually have a separate LLM config for the summarizer LLM
|
||||
# llm_config_no_inner_thoughts = agent_state.llm_config.model_copy(deep=True)
|
||||
# llm_config_no_inner_thoughts.put_inner_thoughts_in_kwargs = False
|
||||
#
|
||||
# llm_client = LLMClient.create(
|
||||
# provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
# put_inner_thoughts_first=False,
|
||||
# actor=actor,
|
||||
# )
|
||||
# # try to use new client, otherwise fallback to old flow
|
||||
# # TODO: we can just directly call the LLM here?
|
||||
# if llm_client:
|
||||
# response = llm_client.send_llm_request(
|
||||
# agent_type=agent_state.agent_type,
|
||||
# messages=message_sequence,
|
||||
# llm_config=llm_config_no_inner_thoughts,
|
||||
# )
|
||||
# else:
|
||||
# response = create(
|
||||
# llm_config=llm_config_no_inner_thoughts,
|
||||
# user_id=agent_state.created_by_id,
|
||||
# messages=message_sequence,
|
||||
# stream=False,
|
||||
# )
|
||||
#
|
||||
# printd(f"summarize_messages gpt reply: {response.choices[0]}")
|
||||
# reply = response.choices[0].message.content
|
||||
# return reply
|
||||
#
|
||||
|
||||
@@ -914,7 +914,7 @@ async def generate_tool_from_prompt(
|
||||
tools=[tool],
|
||||
)
|
||||
response_data = await llm_client.request_async(request_data, llm_config)
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages, llm_config)
|
||||
output = json.loads(response.choices[0].message.tool_calls[0].function.arguments)
|
||||
pip_requirements = [PipRequirement(name=k, version=v or None) for k, v in json.loads(output["pip_requirements_json"]).items()]
|
||||
|
||||
|
||||
@@ -670,20 +670,26 @@ class SyncServer(object):
|
||||
async def insert_archival_memory_async(
|
||||
self, agent_id: str, memory_contents: str, actor: User, tags: Optional[List[str]], created_at: Optional[datetime]
|
||||
) -> List[Passage]:
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.settings import settings
|
||||
from letta.utils import count_tokens
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Check token count against limit
|
||||
token_count = count_tokens(memory_contents)
|
||||
token_counter = create_token_counter(
|
||||
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
|
||||
model=agent_state.llm_config.model,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
token_count = await token_counter.count_text_tokens(memory_contents)
|
||||
if token_count > settings.archival_memory_token_limit:
|
||||
raise LettaInvalidArgumentError(
|
||||
message=f"Archival memory content exceeds token limit of {settings.archival_memory_token_limit} tokens (found {token_count} tokens)",
|
||||
argument_name="memory_contents",
|
||||
)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Use passage manager which handles dual-write to Turbopuffer if enabled
|
||||
passages = await self.passage_manager.insert_passage(
|
||||
agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at
|
||||
|
||||
@@ -80,7 +80,7 @@ from letta.server.db import db_registry
|
||||
from letta.services.archive_manager import ArchiveManager
|
||||
from letta.services.block_manager import BlockManager, validate_block_limit_constraint
|
||||
from letta.services.context_window_calculator.context_window_calculator import ContextWindowCalculator
|
||||
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, GeminiTokenCounter, TiktokenCounter
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.services.file_processor.chunker.line_chunker import LineChunker
|
||||
from letta.services.files_agents_manager import FileAgentManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
@@ -3286,49 +3286,14 @@ class AgentManager:
|
||||
)
|
||||
calculator = ContextWindowCalculator()
|
||||
|
||||
# Determine which token counter to use based on provider
|
||||
model_endpoint_type = agent_state.llm_config.model_endpoint_type
|
||||
|
||||
# Use Gemini token counter for Google Vertex and Google AI
|
||||
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
|
||||
|
||||
# Use Anthropic token counter if:
|
||||
# 1. The model endpoint type is anthropic, OR
|
||||
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
|
||||
use_anthropic = model_endpoint_type == "anthropic" or (
|
||||
not use_gemini and settings.environment == "PRODUCTION" and model_settings.anthropic_api_key is not None
|
||||
# Create the appropriate token counter based on model configuration
|
||||
token_counter = create_token_counter(
|
||||
model_endpoint_type=agent_state.llm_config.model_endpoint_type,
|
||||
model=agent_state.llm_config.model,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
if use_gemini:
|
||||
# Use native Gemini token counting API
|
||||
|
||||
client = LLMClient.create(provider_type=agent_state.llm_config.model_endpoint_type, actor=actor)
|
||||
model = agent_state.llm_config.model
|
||||
|
||||
token_counter = GeminiTokenCounter(client, model)
|
||||
logger.info(
|
||||
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
elif use_anthropic:
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
||||
model = agent_state.llm_config.model if model_endpoint_type == "anthropic" else None
|
||||
|
||||
token_counter = AnthropicTokenCounter(anthropic_client, model) # noqa
|
||||
logger.info(
|
||||
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
else:
|
||||
token_counter = TiktokenCounter(agent_state.llm_config.model)
|
||||
logger.info(
|
||||
f"Using TiktokenCounter for agent_id={agent_id}, model={agent_state.llm_config.model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
|
||||
try:
|
||||
result = await calculator.calculate_context_window(
|
||||
agent_state=agent_state,
|
||||
|
||||
@@ -1,15 +1,22 @@
|
||||
import hashlib
|
||||
import json
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, List
|
||||
from typing import TYPE_CHECKING, Any, Dict, List, Optional
|
||||
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.anthropic_client import AnthropicClient
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.enums import ProviderType
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_request import Tool as OpenAITool
|
||||
from letta.utils import count_tokens
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.user import User
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class TokenCounter(ABC):
|
||||
@@ -101,6 +108,22 @@ class ApproxTokenCounter(TokenCounter):
|
||||
byte_len = len(text.encode("utf-8"))
|
||||
return (byte_len + self.APPROX_BYTES_PER_TOKEN - 1) // self.APPROX_BYTES_PER_TOKEN
|
||||
|
||||
async def count_text_tokens(self, text: str) -> int:
|
||||
if not text:
|
||||
return 0
|
||||
return self._approx_token_count(text)
|
||||
|
||||
async def count_message_tokens(self, messages: List[Dict[str, Any]]) -> int:
|
||||
if not messages:
|
||||
return 0
|
||||
return self._approx_token_count(json.dumps(messages))
|
||||
|
||||
async def count_tool_tokens(self, tools: List[OpenAITool]) -> int:
|
||||
if not tools:
|
||||
return 0
|
||||
functions = [t.model_dump() for t in tools]
|
||||
return self._approx_token_count(json.dumps(functions))
|
||||
|
||||
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return Message.to_openai_dicts_from_list(messages)
|
||||
|
||||
@@ -178,7 +201,14 @@ class TiktokenCounter(TokenCounter):
|
||||
logger.debug(f"TiktokenCounter.count_text_tokens: model={self.model}, text_length={text_length}, preview={repr(text_preview)}")
|
||||
|
||||
try:
|
||||
result = count_tokens(text)
|
||||
import tiktoken
|
||||
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(self.model)
|
||||
except KeyError:
|
||||
logger.debug(f"Model {self.model} not found in tiktoken. Using cl100k_base encoding.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
result = len(encoding.encode(text))
|
||||
logger.debug(f"TiktokenCounter.count_text_tokens: completed successfully, tokens={result}")
|
||||
return result
|
||||
except Exception as e:
|
||||
@@ -234,3 +264,54 @@ class TiktokenCounter(TokenCounter):
|
||||
|
||||
def convert_messages(self, messages: List[Any]) -> List[Dict[str, Any]]:
|
||||
return Message.to_openai_dicts_from_list(messages)
|
||||
|
||||
|
||||
def create_token_counter(
|
||||
model_endpoint_type: ProviderType,
|
||||
model: Optional[str] = None,
|
||||
actor: "User" = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> "TokenCounter":
|
||||
"""
|
||||
Factory function to create the appropriate token counter based on model configuration.
|
||||
|
||||
Returns:
|
||||
The appropriate TokenCounter instance
|
||||
"""
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.settings import model_settings, settings
|
||||
|
||||
# Use Gemini token counter for Google Vertex and Google AI
|
||||
use_gemini = model_endpoint_type in ("google_vertex", "google_ai")
|
||||
|
||||
# Use Anthropic token counter if:
|
||||
# 1. The model endpoint type is anthropic, OR
|
||||
# 2. We're in PRODUCTION and anthropic_api_key is available (and not using Gemini)
|
||||
use_anthropic = model_endpoint_type == "anthropic"
|
||||
|
||||
if use_gemini:
|
||||
client = LLMClient.create(provider_type=model_endpoint_type, actor=actor)
|
||||
token_counter = GeminiTokenCounter(client, model)
|
||||
logger.info(
|
||||
f"Using GeminiTokenCounter for agent_id={agent_id}, model={model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
elif use_anthropic:
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
||||
counter_model = model if model_endpoint_type == "anthropic" else None
|
||||
token_counter = AnthropicTokenCounter(anthropic_client, counter_model)
|
||||
logger.info(
|
||||
f"Using AnthropicTokenCounter for agent_id={agent_id}, model={counter_model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
else:
|
||||
token_counter = ApproxTokenCounter()
|
||||
logger.info(
|
||||
f"Using ApproxTokenCounter for agent_id={agent_id}, model={model}, "
|
||||
f"model_endpoint_type={model_endpoint_type}, "
|
||||
f"environment={settings.environment}"
|
||||
)
|
||||
|
||||
return token_counter
|
||||
|
||||
@@ -447,9 +447,12 @@ async def simple_summary(
|
||||
]
|
||||
input_messages_obj = [simple_message_wrapper(msg) for msg in input_messages]
|
||||
# Build a local LLMConfig for v1-style summarization which uses native content and must not
|
||||
# include inner thoughts in kwargs to avoid conflicts in Anthropic formatting
|
||||
# include inner thoughts in kwargs to avoid conflicts in Anthropic formatting.
|
||||
# We also disable enable_reasoner to avoid extended thinking requirements (Anthropic requires
|
||||
# assistant messages to start with thinking blocks when extended thinking is enabled).
|
||||
summarizer_llm_config = LLMConfig(**llm_config.model_dump())
|
||||
summarizer_llm_config.put_inner_thoughts_in_kwargs = False
|
||||
summarizer_llm_config.enable_reasoner = False
|
||||
|
||||
request_data = llm_client.build_request_data(AgentType.letta_v1_agent, input_messages_obj, summarizer_llm_config, tools=[])
|
||||
try:
|
||||
@@ -532,7 +535,7 @@ async def simple_summary(
|
||||
logger.info(f"Full fallback summarization payload: {request_data}")
|
||||
raise llm_client.handle_llm_error(fallback_error_b)
|
||||
|
||||
response = llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
|
||||
response = await llm_client.convert_response_to_chat_completion(response_data, input_messages_obj, summarizer_llm_config)
|
||||
if response.choices[0].message.content is None:
|
||||
logger.warning("No content returned from summarizer")
|
||||
# TODO raise an error error instead?
|
||||
|
||||
@@ -1,16 +1,11 @@
|
||||
from typing import List, Tuple
|
||||
from typing import List
|
||||
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.user import User
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.summarizer import simple_summary
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -18,6 +13,8 @@ logger = get_logger(__name__)
|
||||
async def summarize_all(
|
||||
# Required to tag LLM calls
|
||||
actor: User,
|
||||
# LLM config for the summarizer model
|
||||
llm_config: LLMConfig,
|
||||
# Actual summarization configuration
|
||||
summarizer_config: SummarizerConfig,
|
||||
in_context_messages: List[Message],
|
||||
@@ -33,9 +30,9 @@ async def summarize_all(
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=all_in_context_messages,
|
||||
llm_config=summarizer_config.summarizer_model,
|
||||
llm_config=llm_config,
|
||||
actor=actor,
|
||||
include_ack=summarizer_config.prompt_acknowledgement,
|
||||
include_ack=bool(summarizer_config.prompt_acknowledgement),
|
||||
prompt=summarizer_config.prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -1,19 +1,17 @@
|
||||
from typing import List, Tuple
|
||||
|
||||
from letta.helpers.message_helper import convert_message_creates_to_messages
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.enums import MessageRole, ProviderType
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.services.context_window_calculator.token_counter import AnthropicTokenCounter, ApproxTokenCounter
|
||||
from letta.services.context_window_calculator.token_counter import create_token_counter
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.summarizer.summarizer import simple_summary
|
||||
from letta.services.summarizer.summarizer_config import SummarizerConfig
|
||||
from letta.settings import model_settings, settings
|
||||
from letta.system import package_summarize_message_no_counts
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -26,21 +24,21 @@ APPROX_TOKEN_SAFETY_MARGIN = 1.3
|
||||
|
||||
|
||||
async def count_tokens(actor: User, llm_config: LLMConfig, messages: List[Message]) -> int:
|
||||
# If the model is an Anthropic model, use the Anthropic token counter (accurate)
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
anthropic_client = LLMClient.create(provider_type=ProviderType.anthropic, actor=actor)
|
||||
token_counter = AnthropicTokenCounter(anthropic_client, llm_config.model)
|
||||
converted_messages = token_counter.convert_messages(messages)
|
||||
return await token_counter.count_message_tokens(converted_messages)
|
||||
"""Count tokens in messages using the appropriate token counter for the model configuration."""
|
||||
token_counter = create_token_counter(
|
||||
model_endpoint_type=llm_config.model_endpoint_type,
|
||||
model=llm_config.model,
|
||||
actor=actor,
|
||||
)
|
||||
converted_messages = token_counter.convert_messages(messages)
|
||||
tokens = await token_counter.count_message_tokens(converted_messages)
|
||||
|
||||
else:
|
||||
# Otherwise, use approximate count (bytes / 4) with safety margin
|
||||
# This is much faster than tiktoken and doesn't require loading tokenizer models
|
||||
token_counter = ApproxTokenCounter(llm_config.model)
|
||||
converted_messages = token_counter.convert_messages(messages)
|
||||
tokens = await token_counter.count_message_tokens(converted_messages)
|
||||
# Apply safety margin to avoid underestimating and keeping too many messages
|
||||
# Apply safety margin for approximate counting to avoid underestimating
|
||||
from letta.services.context_window_calculator.token_counter import ApproxTokenCounter
|
||||
|
||||
if isinstance(token_counter, ApproxTokenCounter):
|
||||
return int(tokens * APPROX_TOKEN_SAFETY_MARGIN)
|
||||
return tokens
|
||||
|
||||
|
||||
async def summarize_via_sliding_window(
|
||||
@@ -110,9 +108,9 @@ async def summarize_via_sliding_window(
|
||||
|
||||
summary_message_str = await simple_summary(
|
||||
messages=messages_to_summarize,
|
||||
llm_config=summarizer_config.summarizer_model,
|
||||
llm_config=llm_config,
|
||||
actor=actor,
|
||||
include_ack=summarizer_config.prompt_acknowledgement,
|
||||
include_ack=bool(summarizer_config.prompt_acknowledgement),
|
||||
prompt=summarizer_config.prompt,
|
||||
)
|
||||
|
||||
|
||||
@@ -812,13 +812,13 @@ class OpenAIBackcompatUnpickler(pickle.Unpickler):
|
||||
return super().find_class(module, name)
|
||||
|
||||
|
||||
def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
try:
|
||||
encoding = tiktoken.encoding_for_model(model)
|
||||
except KeyError:
|
||||
print("Falling back to cl100k base for token counting.")
|
||||
encoding = tiktoken.get_encoding("cl100k_base")
|
||||
return len(encoding.encode(s))
|
||||
# def count_tokens(s: str, model: str = "gpt-4") -> int:
|
||||
# try:
|
||||
# encoding = tiktoken.encoding_for_model(model)
|
||||
# except KeyError:
|
||||
# print("Falling back to cl100k base for token counting.")
|
||||
# encoding = tiktoken.get_encoding("cl100k_base")
|
||||
# return len(encoding.encode(s))
|
||||
|
||||
|
||||
def printd(*args, **kwargs):
|
||||
|
||||
@@ -14,6 +14,7 @@ from typing import List
|
||||
import pytest
|
||||
|
||||
from letta.agents.letta_agent_v2 import LettaAgentV2
|
||||
from letta.agents.letta_agent_v3 import LettaAgentV3
|
||||
from letta.config import LettaConfig
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -671,7 +672,12 @@ async def test_summarize_with_mode(server: SyncServer, actor, llm_config: LLMCon
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
@pytest.mark.parametrize(
|
||||
"llm_config",
|
||||
TESTED_LLM_CONFIGS,
|
||||
ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
)
|
||||
async def test_sliding_window_cutoff_index_does_not_exceed_message_count(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
"""
|
||||
Test that the sliding window summarizer correctly calculates cutoff indices.
|
||||
|
||||
@@ -685,35 +691,19 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
- max(..., 10) -> max(..., 0.10)
|
||||
- += 10 -> += 0.10
|
||||
- >= 100 -> >= 1.0
|
||||
"""
|
||||
from unittest.mock import MagicMock, patch
|
||||
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.user import User
|
||||
This test uses the real token counter (via create_token_counter) to verify
|
||||
the sliding window logic works with actual token counting.
|
||||
"""
|
||||
from letta.schemas.model import ModelSettings
|
||||
from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
from letta.services.summarizer.summarizer_sliding_window import summarize_via_sliding_window
|
||||
|
||||
# Create a mock user (using proper ID format pattern)
|
||||
mock_actor = User(
|
||||
id="user-00000000-0000-0000-0000-000000000000", name="Test User", organization_id="org-00000000-0000-0000-0000-000000000000"
|
||||
)
|
||||
|
||||
# Create a mock LLM config
|
||||
mock_llm_config = LLMConfig(
|
||||
model="gpt-4",
|
||||
model_endpoint_type="openai",
|
||||
context_window=128000,
|
||||
)
|
||||
|
||||
# Create a mock summarizer config with sliding_window_percentage = 0.3
|
||||
mock_summarizer_config = MagicMock()
|
||||
mock_summarizer_config.sliding_window_percentage = 0.3
|
||||
mock_summarizer_config.summarizer_model = mock_llm_config
|
||||
mock_summarizer_config.prompt = "Summarize the conversation."
|
||||
mock_summarizer_config.prompt_acknowledgement = True
|
||||
mock_summarizer_config.clip_chars = 2000
|
||||
# Create a real summarizer config using the default factory
|
||||
# Override sliding_window_percentage to 0.3 for this test
|
||||
model_settings = ModelSettings() # Use defaults
|
||||
summarizer_config = get_default_summarizer_config(model_settings)
|
||||
summarizer_config.sliding_window_percentage = 0.3
|
||||
|
||||
# Create 65 messages (similar to the failing case in the bug report)
|
||||
# Pattern: system + alternating user/assistant messages
|
||||
@@ -741,59 +731,470 @@ async def test_sliding_window_cutoff_index_does_not_exceed_message_count():
|
||||
|
||||
assert len(messages) == 65, f"Expected 65 messages, got {len(messages)}"
|
||||
|
||||
# Mock count_tokens to return a value that would trigger summarization
|
||||
# Return a high token count so that the while loop continues
|
||||
async def mock_count_tokens(actor, llm_config, messages):
|
||||
# Return tokens that decrease as we cut off more messages
|
||||
# This simulates the token count decreasing as we evict messages
|
||||
return len(messages) * 100 # 100 tokens per message
|
||||
# This should NOT raise "No assistant message found from indices 650 to 65"
|
||||
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
|
||||
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
|
||||
try:
|
||||
summary, remaining_messages = await summarize_via_sliding_window(
|
||||
actor=actor,
|
||||
llm_config=llm_config,
|
||||
summarizer_config=summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
|
||||
# Mock simple_summary to return a fake summary
|
||||
async def mock_simple_summary(messages, llm_config, actor, include_ack, prompt):
|
||||
return "This is a mock summary of the conversation."
|
||||
# Verify the summary was generated (actual LLM response)
|
||||
assert summary is not None
|
||||
assert len(summary) > 0
|
||||
|
||||
with (
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.count_tokens",
|
||||
side_effect=mock_count_tokens,
|
||||
),
|
||||
patch(
|
||||
"letta.services.summarizer.summarizer_sliding_window.simple_summary",
|
||||
side_effect=mock_simple_summary,
|
||||
),
|
||||
):
|
||||
# This should NOT raise "No assistant message found from indices 650 to 65"
|
||||
# With the fix, message_count_cutoff_percent starts at max(0.7, 0.10) = 0.7
|
||||
# So message_cutoff_index = round(0.7 * 65) = 46, which is valid
|
||||
try:
|
||||
summary, remaining_messages = await summarize_via_sliding_window(
|
||||
actor=mock_actor,
|
||||
llm_config=mock_llm_config,
|
||||
summarizer_config=mock_summarizer_config,
|
||||
in_context_messages=messages,
|
||||
new_messages=[],
|
||||
)
|
||||
# Verify remaining messages is a valid subset
|
||||
assert len(remaining_messages) < len(messages)
|
||||
assert len(remaining_messages) > 0
|
||||
|
||||
# Verify the summary was generated
|
||||
assert summary == "This is a mock summary of the conversation."
|
||||
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
|
||||
print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
print(f"Using {llm_config.model_endpoint_type} token counter for model {llm_config.model}")
|
||||
|
||||
# Verify remaining messages is a valid subset
|
||||
assert len(remaining_messages) < len(messages)
|
||||
assert len(remaining_messages) > 0
|
||||
except ValueError as e:
|
||||
if "No assistant message found from indices" in str(e):
|
||||
# Extract the indices from the error message
|
||||
import re
|
||||
|
||||
print(f"Successfully summarized {len(messages)} messages to {len(remaining_messages)} remaining")
|
||||
match = re.search(r"from indices (\d+) to (\d+)", str(e))
|
||||
if match:
|
||||
start_idx, end_idx = int(match.group(1)), int(match.group(2))
|
||||
pytest.fail(
|
||||
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
|
||||
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
|
||||
except ValueError as e:
|
||||
if "No assistant message found from indices" in str(e):
|
||||
# Extract the indices from the error message
|
||||
import re
|
||||
|
||||
match = re.search(r"from indices (\d+) to (\d+)", str(e))
|
||||
if match:
|
||||
start_idx, end_idx = int(match.group(1)), int(match.group(2))
|
||||
pytest.fail(
|
||||
f"Bug detected: cutoff index ({start_idx}) exceeds message count ({end_idx}). "
|
||||
f"This indicates the percentage calculation bug where 10 was used instead of 0.10. "
|
||||
f"Error: {e}"
|
||||
)
|
||||
raise
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_streaming(server: SyncServer, actor):
|
||||
# """
|
||||
# Test that a ContextWindowExceededError during a streaming LLM request
|
||||
# properly triggers the summarizer and compacts the in-context messages.
|
||||
#
|
||||
# This test simulates:
|
||||
# 1. An LLM streaming request that fails with ContextWindowExceededError
|
||||
# 2. The summarizer being invoked to reduce context size
|
||||
# 3. Verification that messages are compacted and summary message exists
|
||||
#
|
||||
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
|
||||
# """
|
||||
# import uuid
|
||||
# from unittest.mock import patch
|
||||
#
|
||||
# import openai
|
||||
#
|
||||
# from letta.schemas.message import MessageCreate
|
||||
# from letta.schemas.run import Run
|
||||
# from letta.services.run_manager import RunManager
|
||||
#
|
||||
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
|
||||
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
|
||||
#
|
||||
# # Create test messages - enough to have something to summarize
|
||||
# messages = []
|
||||
# for i in range(15):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
# original_message_count = len(agent_state.message_ids)
|
||||
#
|
||||
# # Create an input message to trigger the agent
|
||||
# input_message = MessageCreate(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text="Hello, please respond.")],
|
||||
# )
|
||||
#
|
||||
# # Create a proper run record in the database
|
||||
# run_manager = RunManager()
|
||||
# test_run_id = f"run-{uuid.uuid4()}"
|
||||
# test_run = Run(
|
||||
# id=test_run_id,
|
||||
# agent_id=agent_state.id,
|
||||
# )
|
||||
# await run_manager.create_run(test_run, actor)
|
||||
#
|
||||
# # Create the agent loop using LettaAgentV3
|
||||
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
#
|
||||
# # Track how many times stream_async is called
|
||||
# call_count = 0
|
||||
#
|
||||
# # Store original stream_async method
|
||||
# original_stream_async = agent_loop.llm_client.stream_async
|
||||
#
|
||||
# async def mock_stream_async_with_error(request_data, llm_config):
|
||||
# nonlocal call_count
|
||||
# call_count += 1
|
||||
# if call_count == 1:
|
||||
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
|
||||
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
|
||||
# from unittest.mock import MagicMock
|
||||
#
|
||||
# import httpx
|
||||
#
|
||||
# # Create a mock response with the required structure
|
||||
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
||||
# mock_response = httpx.Response(
|
||||
# status_code=400,
|
||||
# request=mock_request,
|
||||
# json={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# raise openai.BadRequestError(
|
||||
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# response=mock_response,
|
||||
# body={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# # Subsequent calls use the real implementation
|
||||
# return await original_stream_async(request_data, llm_config)
|
||||
#
|
||||
# # Patch the llm_client's stream_async to raise ContextWindowExceededError on first call
|
||||
# with patch.object(agent_loop.llm_client, "stream_async", side_effect=mock_stream_async_with_error):
|
||||
# # Execute a streaming step
|
||||
# try:
|
||||
# result_chunks = []
|
||||
# async for chunk in agent_loop.stream(
|
||||
# input_messages=[input_message],
|
||||
# max_steps=1,
|
||||
# stream_tokens=True,
|
||||
# run_id=test_run_id,
|
||||
# ):
|
||||
# result_chunks.append(chunk)
|
||||
# except Exception as e:
|
||||
# # Some errors might happen due to real LLM calls after retry
|
||||
# print(f"Exception during stream: {e}")
|
||||
#
|
||||
# # Reload agent state to get updated message_ids after summarization
|
||||
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
# updated_message_count = len(updated_agent_state.message_ids)
|
||||
#
|
||||
# # Fetch the updated in-context messages
|
||||
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
|
||||
# message_ids=updated_agent_state.message_ids, actor=actor
|
||||
# )
|
||||
#
|
||||
# # Convert to LettaMessage format for easier content inspection
|
||||
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
|
||||
#
|
||||
# # Verify a summary message exists with the correct format
|
||||
# # The summary message has content with type="system_alert" and message containing:
|
||||
# # "prior messages ... have been hidden" and "summary of the previous"
|
||||
# import json
|
||||
#
|
||||
# summary_message_found = False
|
||||
# summary_message_text = None
|
||||
# for msg in letta_messages:
|
||||
# # Not all message types have a content attribute (e.g., ReasoningMessage)
|
||||
# if not hasattr(msg, "content"):
|
||||
# continue
|
||||
#
|
||||
# content = msg.content
|
||||
# # Content can be a string (JSON) or an object with type/message fields
|
||||
# if isinstance(content, str):
|
||||
# # Try to parse as JSON
|
||||
# try:
|
||||
# parsed = json.loads(content)
|
||||
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
|
||||
# text_to_check = parsed.get("message", "").lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = parsed.get("message")
|
||||
# break
|
||||
# except (json.JSONDecodeError, TypeError):
|
||||
# pass
|
||||
# # Check if content has system_alert type with the summary message (object form)
|
||||
# elif hasattr(content, "type") and content.type == "system_alert":
|
||||
# if hasattr(content, "message") and content.message:
|
||||
# text_to_check = content.message.lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = content.message
|
||||
# break
|
||||
#
|
||||
# assert summary_message_found, (
|
||||
# "A summary message should exist in the in-context messages after summarization. "
|
||||
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
|
||||
# )
|
||||
#
|
||||
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
|
||||
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
|
||||
#
|
||||
# # The original messages should have been compacted - the updated count should be less than
|
||||
# # original + the new messages added (input + assistant response + tool results)
|
||||
# # Since summarization should have removed most of the original 30 messages
|
||||
# print("Test passed: Summary message found in context")
|
||||
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
|
||||
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
|
||||
# print(f"Total LLM invocations: {call_count}")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_context_window_overflow_triggers_summarization_in_blocking(server: SyncServer, actor):
|
||||
# """
|
||||
# Test that a ContextWindowExceededError during a blocking (non-streaming) LLM request
|
||||
# properly triggers the summarizer and compacts the in-context messages.
|
||||
#
|
||||
# This test is similar to the streaming test but uses the blocking step() method.
|
||||
#
|
||||
# Note: This test only runs with OpenAI since it uses OpenAI-specific error handling.
|
||||
# """
|
||||
# import uuid
|
||||
# from unittest.mock import patch
|
||||
#
|
||||
# import openai
|
||||
#
|
||||
# from letta.schemas.message import MessageCreate
|
||||
# from letta.schemas.run import Run
|
||||
# from letta.services.run_manager import RunManager
|
||||
#
|
||||
# # Use OpenAI config for this test (since we're using OpenAI-specific error handling)
|
||||
# llm_config = get_llm_config("openai-gpt-4o-mini.json")
|
||||
#
|
||||
# # Create test messages
|
||||
# messages = []
|
||||
# for i in range(15):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: This is test message number {i}.")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: I acknowledge message {i}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# agent_state, in_context_messages = await create_agent_with_messages(server, actor, llm_config, messages)
|
||||
# original_message_count = len(agent_state.message_ids)
|
||||
#
|
||||
# # Create an input message to trigger the agent
|
||||
# input_message = MessageCreate(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text="Hello, please respond.")],
|
||||
# )
|
||||
#
|
||||
# # Create a proper run record in the database
|
||||
# run_manager = RunManager()
|
||||
# test_run_id = f"run-{uuid.uuid4()}"
|
||||
# test_run = Run(
|
||||
# id=test_run_id,
|
||||
# agent_id=agent_state.id,
|
||||
# )
|
||||
# await run_manager.create_run(test_run, actor)
|
||||
#
|
||||
# # Create the agent loop using LettaAgentV3
|
||||
# agent_loop = LettaAgentV3(agent_state=agent_state, actor=actor)
|
||||
#
|
||||
# # Track how many times request_async is called
|
||||
# call_count = 0
|
||||
#
|
||||
# # Store original request_async method
|
||||
# original_request_async = agent_loop.llm_client.request_async
|
||||
#
|
||||
# async def mock_request_async_with_error(request_data, llm_config):
|
||||
# nonlocal call_count
|
||||
# call_count += 1
|
||||
# if call_count == 1:
|
||||
# # First call raises OpenAI BadRequestError with context_length_exceeded error code
|
||||
# # This will be properly converted to ContextWindowExceededError by handle_llm_error
|
||||
# import httpx
|
||||
#
|
||||
# # Create a mock response with the required structure
|
||||
# mock_request = httpx.Request("POST", "https://api.openai.com/v1/chat/completions")
|
||||
# mock_response = httpx.Response(
|
||||
# status_code=400,
|
||||
# request=mock_request,
|
||||
# json={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
#
|
||||
# raise openai.BadRequestError(
|
||||
# message="This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# response=mock_response,
|
||||
# body={
|
||||
# "error": {
|
||||
# "message": "This model's maximum context length is 8000 tokens. However, your messages resulted in 12000 tokens.",
|
||||
# "type": "invalid_request_error",
|
||||
# "code": "context_length_exceeded",
|
||||
# }
|
||||
# },
|
||||
# )
|
||||
# # Subsequent calls use the real implementation
|
||||
# return await original_request_async(request_data, llm_config)
|
||||
#
|
||||
# # Patch the llm_client's request_async to raise ContextWindowExceededError on first call
|
||||
# with patch.object(agent_loop.llm_client, "request_async", side_effect=mock_request_async_with_error):
|
||||
# # Execute a blocking step
|
||||
# try:
|
||||
# result = await agent_loop.step(
|
||||
# input_messages=[input_message],
|
||||
# max_steps=1,
|
||||
# run_id=test_run_id,
|
||||
# )
|
||||
# except Exception as e:
|
||||
# # Some errors might happen due to real LLM calls after retry
|
||||
# print(f"Exception during step: {e}")
|
||||
#
|
||||
# # Reload agent state to get updated message_ids after summarization
|
||||
# updated_agent_state = await server.agent_manager.get_agent_by_id_async(agent_id=agent_state.id, actor=actor)
|
||||
# updated_message_count = len(updated_agent_state.message_ids)
|
||||
#
|
||||
# # Fetch the updated in-context messages
|
||||
# updated_in_context_messages = await server.message_manager.get_messages_by_ids_async(
|
||||
# message_ids=updated_agent_state.message_ids, actor=actor
|
||||
# )
|
||||
#
|
||||
# # Convert to LettaMessage format for easier content inspection
|
||||
# letta_messages = PydanticMessage.to_letta_messages_from_list(updated_in_context_messages)
|
||||
#
|
||||
# # Verify a summary message exists with the correct format
|
||||
# # The summary message has content with type="system_alert" and message containing:
|
||||
# # "prior messages ... have been hidden" and "summary of the previous"
|
||||
# import json
|
||||
#
|
||||
# summary_message_found = False
|
||||
# summary_message_text = None
|
||||
# for msg in letta_messages:
|
||||
# # Not all message types have a content attribute (e.g., ReasoningMessage)
|
||||
# if not hasattr(msg, "content"):
|
||||
# continue
|
||||
#
|
||||
# content = msg.content
|
||||
# # Content can be a string (JSON) or an object with type/message fields
|
||||
# if isinstance(content, str):
|
||||
# # Try to parse as JSON
|
||||
# try:
|
||||
# parsed = json.loads(content)
|
||||
# if isinstance(parsed, dict) and parsed.get("type") == "system_alert":
|
||||
# text_to_check = parsed.get("message", "").lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = parsed.get("message")
|
||||
# break
|
||||
# except (json.JSONDecodeError, TypeError):
|
||||
# pass
|
||||
# # Check if content has system_alert type with the summary message (object form)
|
||||
# elif hasattr(content, "type") and content.type == "system_alert":
|
||||
# if hasattr(content, "message") and content.message:
|
||||
# text_to_check = content.message.lower()
|
||||
# if "prior messages" in text_to_check and "hidden" in text_to_check and "summary of the previous" in text_to_check:
|
||||
# summary_message_found = True
|
||||
# summary_message_text = content.message
|
||||
# break
|
||||
#
|
||||
# assert summary_message_found, (
|
||||
# "A summary message should exist in the in-context messages after summarization. "
|
||||
# "Expected format containing 'prior messages...hidden' and 'summary of the previous'"
|
||||
# )
|
||||
#
|
||||
# # Verify we attempted multiple invocations (the failing one + retry after summarization)
|
||||
# assert call_count >= 2, f"Expected at least 2 LLM invocations (initial + retry), got {call_count}"
|
||||
#
|
||||
# # The original messages should have been compacted - the updated count should be less than
|
||||
# # original + the new messages added (input + assistant response + tool results)
|
||||
# print("Test passed: Summary message found in context (blocking mode)")
|
||||
# print(f"Original message count: {original_message_count}, Updated: {updated_message_count}")
|
||||
# print(f"Summary message: {summary_message_text[:200] if summary_message_text else 'N/A'}...")
|
||||
# print(f"Total LLM invocations: {call_count}")
|
||||
#
|
||||
#
|
||||
# @pytest.mark.asyncio
|
||||
# @pytest.mark.parametrize(
|
||||
# "llm_config",
|
||||
# TESTED_LLM_CONFIGS,
|
||||
# ids=[c.model for c in TESTED_LLM_CONFIGS],
|
||||
# )
|
||||
# async def test_summarize_all_with_real_llm(server: SyncServer, actor, llm_config: LLMConfig):
|
||||
# """
|
||||
# Test the summarize_all function with real LLM calls.
|
||||
#
|
||||
# This test verifies that the 'all' summarization mode works correctly,
|
||||
# summarizing the entire conversation into a single summary string.
|
||||
# """
|
||||
# from letta.schemas.model import ModelSettings
|
||||
# from letta.services.summarizer.summarizer_all import summarize_all
|
||||
# from letta.services.summarizer.summarizer_config import get_default_summarizer_config
|
||||
#
|
||||
# # Create a summarizer config with "all" mode
|
||||
# model_settings = ModelSettings()
|
||||
# summarizer_config = get_default_summarizer_config(model_settings)
|
||||
# summarizer_config.mode = "all"
|
||||
#
|
||||
# # Create test messages - a simple conversation
|
||||
# messages = [
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.system,
|
||||
# content=[TextContent(type="text", text="You are a helpful assistant.")],
|
||||
# )
|
||||
# ]
|
||||
#
|
||||
# # Add 10 user-assistant pairs
|
||||
# for i in range(10):
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.user,
|
||||
# content=[TextContent(type="text", text=f"User message {i}: What is {i} + {i}?")],
|
||||
# )
|
||||
# )
|
||||
# messages.append(
|
||||
# PydanticMessage(
|
||||
# role=MessageRole.assistant,
|
||||
# content=[TextContent(type="text", text=f"Assistant response {i}: {i} + {i} = {i * 2}.")],
|
||||
# )
|
||||
# )
|
||||
#
|
||||
# assert len(messages) == 21, f"Expected 21 messages, got {len(messages)}"
|
||||
#
|
||||
# # Call summarize_all with real LLM
|
||||
# summary = await summarize_all(
|
||||
# actor=actor,
|
||||
# llm_config=llm_config,
|
||||
# summarizer_config=summarizer_config,
|
||||
# in_context_messages=messages,
|
||||
# new_messages=[],
|
||||
# )
|
||||
#
|
||||
# # Verify the summary was generated
|
||||
# assert summary is not None
|
||||
# assert len(summary) > 0
|
||||
#
|
||||
# print(f"Successfully summarized {len(messages)} messages using 'all' mode")
|
||||
# print(f"Summary: {summary[:200]}..." if len(summary) > 200 else f"Summary: {summary}")
|
||||
# print(f"Using {llm_config.model_endpoint_type} for model {llm_config.model}")
|
||||
#
|
||||
|
||||
@@ -169,8 +169,8 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig):
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
ApproxTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
@@ -179,7 +179,7 @@ async def test_count_empty_text_tokens(llm_config: LLMConfig):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
token_counter = ApproxTokenCounter()
|
||||
|
||||
token_count = await token_counter.count_text_tokens("")
|
||||
assert token_count == 0
|
||||
@@ -194,8 +194,8 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig):
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
ApproxTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
@@ -204,7 +204,7 @@ async def test_count_empty_messages_tokens(llm_config: LLMConfig):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
token_counter = ApproxTokenCounter()
|
||||
|
||||
token_count = await token_counter.count_message_tokens([])
|
||||
assert token_count == 0
|
||||
@@ -219,8 +219,8 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig):
|
||||
from letta.llm_api.google_vertex_client import GoogleVertexClient
|
||||
from letta.services.context_window_calculator.token_counter import (
|
||||
AnthropicTokenCounter,
|
||||
ApproxTokenCounter,
|
||||
GeminiTokenCounter,
|
||||
TiktokenCounter,
|
||||
)
|
||||
|
||||
if llm_config.model_endpoint_type == "anthropic":
|
||||
@@ -229,7 +229,7 @@ async def test_count_empty_tools_tokens(llm_config: LLMConfig):
|
||||
client = GoogleAIClient() if llm_config.model_endpoint_type == "google_ai" else GoogleVertexClient()
|
||||
token_counter = GeminiTokenCounter(client, llm_config.model)
|
||||
else:
|
||||
token_counter = TiktokenCounter(llm_config.model)
|
||||
token_counter = ApproxTokenCounter()
|
||||
|
||||
token_count = await token_counter.count_tool_tokens([])
|
||||
assert token_count == 0
|
||||
|
||||
Reference in New Issue
Block a user