* fix: update ContextWindowCalculator to parse new system message sections The context window calculator was using outdated position-based parsing that only handled 3 sections (base_instructions, memory_blocks, memory_metadata). The actual system message now includes additional sections that were not being tracked: - <memory_filesystem> (git-enabled agents) - <tool_usage_rules> (when tool rules configured) - <directories> (when sources attached) Changes: - Add _extract_tag_content() helper for proper XML tag extraction - Rewrite extract_system_components() to return a Dict with all 6 sections - Update calculate_context_window() to count tokens for new sections - Add new fields to ContextWindowOverview schema with backward-compatible defaults - Add unit tests for the extraction logic * update * generate * fix: check attached file in directories section instead of core_memory Files are rendered inside <directories> tags, not <memory_blocks>. Update validate_context_window_overview assertions accordingly. * fix: address review feedback for context window parser - Fix git-enabled agents regression: capture bare file blocks (e.g. <system/human.md>) rendered after </memory_filesystem> as core_memory via new _extract_git_core_memory() method - Make _extract_top_level_tag robust: scan all occurrences to find tag outside container, handling nested-first + top-level-later case - Document system_prompt tag inconsistency in docstring - Add TODO to base_agent.py extract_dynamic_section linking to ContextWindowCalculator to flag parallel parser tech debt - Add tests: git-enabled agent parsing, dual-occurrence tag extraction, pure text system prompt, git-enabled integration test
371 lines
17 KiB
Python
371 lines
17 KiB
Python
import functools
|
|
import os
|
|
import time
|
|
from typing import Any, Optional, Union
|
|
|
|
from letta_client import AsyncLetta, Letta
|
|
|
|
from letta.functions.functions import parse_source_code
|
|
from letta.functions.schema_generator import generate_schema
|
|
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
|
|
from letta.schemas.enums import MessageRole
|
|
from letta.schemas.file import FileAgent
|
|
from letta.schemas.memory import ContextWindowOverview
|
|
from letta.schemas.tool import Tool
|
|
from letta.schemas.user import User, User as PydanticUser
|
|
from letta.server.rest_api.routers.v1.agents import ImportedAgentsResponse
|
|
from letta.server.server import SyncServer
|
|
|
|
|
|
def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
|
|
"""
|
|
Decorator to retry a test until a failure threshold is crossed.
|
|
|
|
:param threshold: Expected passing rate (e.g., 0.5 means 50% success rate expected).
|
|
:param max_attempts: Maximum number of attempts to retry the test.
|
|
"""
|
|
|
|
def decorator_retry(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
success_count = 0
|
|
failure_count = 0
|
|
|
|
for attempt in range(max_attempts):
|
|
try:
|
|
func(*args, **kwargs)
|
|
success_count += 1
|
|
except Exception as e:
|
|
failure_count += 1
|
|
print(f"\033[93mAn attempt failed with error:\n{e}\033[0m")
|
|
|
|
time.sleep(sleep_time_seconds)
|
|
|
|
rate = success_count / max_attempts
|
|
if rate >= threshold:
|
|
print(f"Test met expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}")
|
|
else:
|
|
raise AssertionError(
|
|
f"Test did not meet expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}"
|
|
)
|
|
|
|
return wrapper
|
|
|
|
return decorator_retry
|
|
|
|
|
|
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
|
"""
|
|
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
|
|
|
|
:param max_attempts: Maximum number of attempts to retry the function.
|
|
:param sleep_time_seconds: Time to wait between attempts, in seconds.
|
|
"""
|
|
|
|
def decorator_retry(func):
|
|
@functools.wraps(func)
|
|
def wrapper(*args, **kwargs):
|
|
for attempt in range(1, max_attempts + 1):
|
|
try:
|
|
return func(*args, **kwargs)
|
|
except Exception as e:
|
|
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
|
|
|
|
if attempt == max_attempts:
|
|
raise
|
|
|
|
time.sleep(sleep_time_seconds)
|
|
|
|
return wrapper
|
|
|
|
return decorator_retry
|
|
|
|
|
|
async def cleanup_async(server: SyncServer, agent_uuid: str, actor: User):
|
|
# Clear all agents
|
|
agent_states = await server.agent_manager.list_agents_async(name=agent_uuid, actor=actor)
|
|
|
|
for agent_state in agent_states:
|
|
await server.agent_manager.delete_agent_async(agent_id=agent_state.id, actor=actor)
|
|
|
|
|
|
# Utility functions
|
|
def create_tool_from_func(func: callable):
|
|
return Tool(
|
|
name=func.__name__,
|
|
description="",
|
|
source_type="python",
|
|
tags=[],
|
|
source_code=parse_source_code(func),
|
|
json_schema=generate_schema(func, None),
|
|
)
|
|
|
|
|
|
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser):
|
|
# Assert scalar fields
|
|
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
|
|
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
|
|
assert agent.metadata == request.metadata, f"Metadata mismatch: {agent.metadata} != {request.metadata}"
|
|
|
|
# Assert agent env vars
|
|
if hasattr(request, "tool_exec_environment_variables") and request.tool_exec_environment_variables:
|
|
for agent_env_var in agent.tool_exec_environment_variables:
|
|
assert agent_env_var.key in request.tool_exec_environment_variables
|
|
assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value
|
|
assert agent_env_var.organization_id == actor.organization_id
|
|
if hasattr(request, "secrets") and request.secrets:
|
|
for agent_env_var in agent.secrets:
|
|
assert agent_env_var.key in request.secrets
|
|
assert request.secrets[agent_env_var.key] == agent_env_var.value
|
|
assert agent_env_var.organization_id == actor.organization_id
|
|
|
|
# Assert agent type
|
|
if hasattr(request, "agent_type"):
|
|
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
|
|
|
|
# Assert LLM configuration
|
|
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
|
|
|
|
# Assert embedding configuration
|
|
assert agent.embedding_config == request.embedding_config, (
|
|
f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
|
|
)
|
|
|
|
# Assert memory blocks
|
|
if hasattr(request, "memory_blocks"):
|
|
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(request.block_ids), (
|
|
f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
|
|
)
|
|
memory_block_values = {block.value for block in agent.memory.blocks}
|
|
expected_block_values = {block.value for block in request.memory_blocks}
|
|
assert expected_block_values.issubset(memory_block_values), (
|
|
f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
|
|
)
|
|
|
|
# Assert tools
|
|
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
|
|
assert {tool.id for tool in agent.tools} == set(request.tool_ids), (
|
|
f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
|
|
)
|
|
|
|
# Assert sources
|
|
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
|
|
assert {source.id for source in agent.sources} == set(request.source_ids), (
|
|
f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
|
|
)
|
|
|
|
# Assert tags
|
|
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
|
|
|
|
# Assert tool rules
|
|
print("TOOLRULES", request.tool_rules)
|
|
print("AGENTTOOLRULES", agent.tool_rules)
|
|
if request.tool_rules:
|
|
assert len(agent.tool_rules) == len(request.tool_rules), (
|
|
f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
|
|
)
|
|
assert all(any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules), (
|
|
f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
|
|
)
|
|
|
|
# Assert message_buffer_autoclear
|
|
if request.message_buffer_autoclear is not None:
|
|
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
|
|
|
|
|
|
def validate_context_window_overview(
|
|
agent_state: AgentState, overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None
|
|
) -> None:
|
|
"""Validate common sense assertions for ContextWindowOverview"""
|
|
|
|
# 1. Current context size should not exceed maximum
|
|
assert overview.context_window_size_current <= overview.context_window_size_max, (
|
|
f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
|
|
)
|
|
|
|
# 2. All token counts should be non-negative
|
|
assert overview.num_tokens_system >= 0, "System token count cannot be negative"
|
|
assert overview.num_tokens_core_memory >= 0, "Core memory token count cannot be negative"
|
|
assert overview.num_tokens_memory_filesystem >= 0, "Memory filesystem token count cannot be negative"
|
|
assert overview.num_tokens_tool_usage_rules >= 0, "Tool usage rules token count cannot be negative"
|
|
assert overview.num_tokens_directories >= 0, "Directories token count cannot be negative"
|
|
assert overview.num_tokens_external_memory_summary >= 0, "External memory summary token count cannot be negative"
|
|
assert overview.num_tokens_summary_memory >= 0, "Summary memory token count cannot be negative"
|
|
assert overview.num_tokens_messages >= 0, "Messages token count cannot be negative"
|
|
assert overview.num_tokens_functions_definitions >= 0, "Functions definitions token count cannot be negative"
|
|
|
|
# 3. Token components should sum to total
|
|
expected_total = (
|
|
overview.num_tokens_system
|
|
+ overview.num_tokens_core_memory
|
|
+ overview.num_tokens_memory_filesystem
|
|
+ overview.num_tokens_tool_usage_rules
|
|
+ overview.num_tokens_directories
|
|
+ overview.num_tokens_external_memory_summary
|
|
+ overview.num_tokens_summary_memory
|
|
+ overview.num_tokens_messages
|
|
+ overview.num_tokens_functions_definitions
|
|
)
|
|
assert overview.context_window_size_current == expected_total, (
|
|
f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
|
|
)
|
|
|
|
# 4. Message count should match messages list length
|
|
assert len(overview.messages) == overview.num_messages, (
|
|
f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
|
|
)
|
|
|
|
# 5. If summary_memory is None, its token count should be 0
|
|
if overview.summary_memory is None:
|
|
assert overview.num_tokens_summary_memory == 0, "Summary memory is None but has non-zero token count"
|
|
|
|
# 7. External memory summary consistency
|
|
assert overview.num_tokens_external_memory_summary > 0, "External memory summary exists but has zero token count"
|
|
|
|
# 8. System prompt consistency
|
|
assert overview.num_tokens_system > 0, "System prompt exists but has zero token count"
|
|
|
|
# 9. Core memory consistency
|
|
assert overview.num_tokens_core_memory > 0, "Core memory exists but has zero token count"
|
|
|
|
# 10. Functions definitions consistency
|
|
assert overview.num_tokens_functions_definitions > 0, "Functions definitions exist but have zero token count"
|
|
assert len(overview.functions_definitions) > 0, "Functions definitions list should not be empty"
|
|
|
|
# 11. Memory counts should be non-negative
|
|
assert overview.num_archival_memory >= 0, "Archival memory count cannot be negative"
|
|
assert overview.num_recall_memory >= 0, "Recall memory count cannot be negative"
|
|
|
|
# 12. Context window max should be positive
|
|
assert overview.context_window_size_max > 0, "Maximum context window size must be positive"
|
|
|
|
# 13. If there are messages, check basic structure
|
|
# At least one message should be system message (typical pattern)
|
|
has_system_message = any(msg.role == MessageRole.system for msg in overview.messages)
|
|
# This is a soft assertion - log warning instead of failing
|
|
if not has_system_message:
|
|
print("Warning: No system message found in messages list")
|
|
|
|
# Average tokens per message should be reasonable (typically > 0)
|
|
avg_tokens_per_message = overview.num_tokens_messages / overview.num_messages
|
|
assert avg_tokens_per_message >= 0, "Average tokens per message should be non-negative"
|
|
|
|
# 16. Check attached file is visible in the directories section
|
|
if attached_file:
|
|
assert overview.directories is not None, "Directories section must exist when files are attached"
|
|
assert attached_file.visible_content in overview.directories, "File must be attached in directories"
|
|
assert '<file status="open"' in overview.directories
|
|
assert "</file>" in overview.directories
|
|
assert "max_files_open" in overview.directories, "Max files should be set in directories"
|
|
assert "current_files_open" in overview.directories, "Current files should be set in directories"
|
|
|
|
# Check for tools
|
|
assert overview.num_tokens_functions_definitions > 0
|
|
assert len(overview.functions_definitions) > 0
|
|
|
|
|
|
# Changed this from server_url to client since client may be authenticated or not
|
|
def upload_test_agentfile_from_disk(client: Letta, filename: str) -> ImportedAgentsResponse:
|
|
"""
|
|
Upload a given .af file to live FastAPI server.
|
|
"""
|
|
path_to_current_file = os.path.dirname(__file__)
|
|
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
|
|
file_path = os.path.join(path_to_test_agent_files, filename)
|
|
|
|
with open(file_path, "rb") as f:
|
|
return client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
|
|
|
|
|
|
async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: str) -> ImportedAgentsResponse:
|
|
"""
|
|
Upload a given .af file to live FastAPI server.
|
|
"""
|
|
path_to_current_file = os.path.dirname(__file__)
|
|
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
|
|
file_path = os.path.join(path_to_test_agent_files, filename)
|
|
|
|
with open(file_path, "rb") as f:
|
|
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
|
|
return uploaded
|
|
|
|
|
|
def upload_file_and_wait(
|
|
client: Letta,
|
|
source_id: str,
|
|
file_path: str,
|
|
name: Optional[str] = None,
|
|
max_wait: int = 60,
|
|
duplicate_handling: Optional[str] = None,
|
|
):
|
|
"""Helper function to upload a file and wait for processing to complete"""
|
|
with open(file_path, "rb") as f:
|
|
if duplicate_handling:
|
|
file_metadata = client.folders.files.upload(folder_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
|
else:
|
|
file_metadata = client.folders.files.upload(folder_id=source_id, file=f, name=name)
|
|
|
|
# wait for the file to be processed
|
|
start_time = time.time()
|
|
file_metadata_id = file_metadata.id
|
|
processing_status = file_metadata.processing_status
|
|
while processing_status != "completed" and processing_status != "error":
|
|
if time.time() - start_time > max_wait:
|
|
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
|
|
time.sleep(1)
|
|
file_metadata = client.get(
|
|
path=f"/v1/sources/{source_id}/files/{file_metadata_id}",
|
|
cast_to=dict[str, Any],
|
|
)
|
|
print("Waiting for file processing to complete...", file_metadata["processing_status"])
|
|
processing_status = file_metadata["processing_status"]
|
|
|
|
if isinstance(file_metadata, dict) and file_metadata["processing_status"] == "error":
|
|
raise RuntimeError(f"File processing failed: {file_metadata['error_message']}")
|
|
elif hasattr(file_metadata, "processing_status") and file_metadata.processing_status == "error":
|
|
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
|
|
|
|
if not isinstance(file_metadata, dict):
|
|
file_metadata = file_metadata.model_dump()
|
|
return file_metadata
|
|
|
|
|
|
def upload_file_and_wait_list_files(
|
|
client: Letta,
|
|
source_id: str,
|
|
file_path: str,
|
|
name: Optional[str] = None,
|
|
max_wait: int = 60,
|
|
duplicate_handling: Optional[str] = None,
|
|
):
|
|
"""Helper function to upload a file and wait for processing using list_files instead of get_file_metadata"""
|
|
with open(file_path, "rb") as f:
|
|
if duplicate_handling:
|
|
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
|
|
else:
|
|
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
|
|
|
|
# wait for the file to be processed using list_files
|
|
start_time = time.time()
|
|
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
|
|
if time.time() - start_time > max_wait:
|
|
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
|
|
time.sleep(1)
|
|
|
|
# use list_files to get all files and find our specific file
|
|
files = client.sources.files.list(source_id=source_id, limit=100)
|
|
# find the file with matching id
|
|
for file in files:
|
|
if file.id == file_metadata.id:
|
|
file_metadata = file
|
|
break
|
|
else:
|
|
raise RuntimeError(f"File {file_metadata.id} not found in source files list")
|
|
|
|
print("Waiting for file processing to complete (via list_files)...", file_metadata.processing_status)
|
|
|
|
if file_metadata.processing_status == "error":
|
|
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
|
|
|
|
return file_metadata
|