merge this (#4759)

* wait I forgot to comit locally

* cp the entire core directory and then rm the .git subdir
This commit is contained in:
Kian Jones
2025-09-17 15:47:40 -07:00
committed by GitHub
parent 22f70ca07c
commit b8e9a80d93
1240 changed files with 235556 additions and 0 deletions

View File

@@ -0,0 +1,33 @@
import time
from letta import RESTClient
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job
from letta.schemas.source import Source
def upload_file_using_client(client: RESTClient, source: Source, filename: str) -> Job:
# load a file into a source (non-blocking job)
upload_job = client.load_file_to_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 240
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
print(f"\r{status}", end="", flush=True)
if status == JobStatus.completed:
break
time.sleep(1)
if time.time() - start_time > timeout:
raise ValueError("Job did not finish in time")
return upload_job

View File

@@ -0,0 +1,245 @@
import json
import logging
import uuid
from typing import Callable, List, Optional, Sequence
from letta.llm_api.helpers import unpack_inner_thoughts_from_kwargs
from letta.schemas.block import CreateBlock
from letta.schemas.tool_rule import BaseToolRule
from letta.server.server import SyncServer
logging.basicConfig(level=logging.DEBUG)
logger = logging.getLogger(__name__)
from letta.config import LettaConfig
from letta.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from letta.errors import InvalidInnerMonologueError, InvalidToolCallError, MissingInnerMonologueError, MissingToolCallError
from letta.llm_api.llm_client import LLMClient
from letta.local_llm.constants import INNER_THOUGHTS_KWARG
from letta.schemas.agent import AgentState, CreateAgent
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.letta_message import LettaMessage, ReasoningMessage, ToolCallMessage
from letta.schemas.letta_response import LettaResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.openai.chat_completion_response import Choice, FunctionCall, Message
from letta.utils import get_human_text, get_persona_text
# Generate uuid for agent name for this example
namespace = uuid.NAMESPACE_DNS
agent_uuid = str(uuid.uuid5(namespace, "test-endpoints-agent"))
# defaults (letta hosted)
EMBEDDING_CONFIG_PATH = "tests/configs/embedding_model_configs/letta-hosted.json"
LLM_CONFIG_PATH = "tests/configs/llm_model_configs/letta-hosted.json"
# ======================================================================================================================
# Section: Test Setup
# These functions help setup the test
# ======================================================================================================================
def setup_agent(
server: SyncServer,
filename: str,
memory_human_str: str = get_human_text(DEFAULT_HUMAN),
memory_persona_str: str = get_persona_text(DEFAULT_PERSONA),
tool_ids: Optional[List[str]] = None,
tool_rules: Optional[List[BaseToolRule]] = None,
agent_uuid: str = agent_uuid,
include_base_tools: bool = True,
include_base_tool_rules: bool = True,
) -> AgentState:
with open(filename, "r") as f:
config_data = json.load(f)
llm_config = LLMConfig(**config_data)
with open(EMBEDDING_CONFIG_PATH, "r") as f:
embedding_config = EmbeddingConfig(**json.load(f))
# setup config
config = LettaConfig()
config.default_llm_config = llm_config
config.default_embedding_config = embedding_config
config.save()
request = CreateAgent(
name=agent_uuid,
llm_config=llm_config,
embedding_config=embedding_config,
memory_blocks=[
CreateBlock(
label="human",
value=memory_human_str,
),
CreateBlock(
label="persona",
value=memory_persona_str,
),
],
tool_ids=tool_ids,
tool_rules=tool_rules,
include_base_tools=include_base_tools,
include_base_tool_rules=include_base_tool_rules,
)
actor = server.user_manager.get_user_or_default()
agent_state = server.create_agent(request=request, actor=actor)
return agent_state
# ======================================================================================================================
# Section: Complex E2E Tests
# These functions describe individual testing scenarios.
# ======================================================================================================================
async def run_embedding_endpoint(filename, actor=None):
# load JSON file
with open(filename, "r") as f:
config_data = json.load(f)
print(config_data)
embedding_config = EmbeddingConfig(**config_data)
# Use the new LLMClient for embeddings
client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
query_text = "hello"
query_vecs = await client.request_embeddings([query_text], embedding_config)
query_vec = query_vecs[0]
print("vector dim", len(query_vec))
assert query_vec is not None
# ======================================================================================================================
# Section: Letta Message Assertions
# These functions are validating elements of parsed Letta Messsage
# ======================================================================================================================
def assert_sanity_checks(response: LettaResponse):
assert response is not None, response
assert response.messages is not None, response
assert len(response.messages) > 0, response
def assert_invoked_send_message_with_keyword(messages: Sequence[LettaMessage], keyword: str, case_sensitive: bool = False) -> None:
# Find first instance of send_message
target_message = None
for message in messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == "send_message":
target_message = message
break
# No messages found with `send_messages`
if target_message is None:
raise MissingToolCallError(messages=messages, explanation="Missing `send_message` function call")
send_message_function_call = target_message.tool_call
try:
arguments = json.loads(send_message_function_call.arguments)
except:
raise InvalidToolCallError(messages=[target_message], explanation="Function call arguments could not be loaded into JSON")
# Message field not in send_message
if "message" not in arguments:
raise InvalidToolCallError(
messages=[target_message], explanation="send_message function call does not have required field `message`"
)
# Check that the keyword is in the message arguments
if not case_sensitive:
keyword = keyword.lower()
arguments["message"] = arguments["message"].lower()
if keyword not in arguments["message"]:
raise InvalidToolCallError(messages=[target_message], explanation=f"Message argument did not contain keyword={keyword}")
def assert_invoked_function_call(messages: Sequence[LettaMessage], function_name: str) -> None:
for message in messages:
if isinstance(message, ToolCallMessage) and message.tool_call.name == function_name:
# Found it, do nothing
return
raise MissingToolCallError(messages=messages, explanation=f"No messages were found invoking function call with name: {function_name}")
def assert_inner_monologue_is_present_and_valid(messages: List[LettaMessage]) -> None:
for message in messages:
if isinstance(message, ReasoningMessage):
# Found it, do nothing
return
raise MissingInnerMonologueError(messages=messages)
# ======================================================================================================================
# Section: Raw API Assertions
# These functions are validating elements of the (close to) raw LLM API's response
# ======================================================================================================================
def assert_contains_valid_function_call(
message: Message,
function_call_validator: Optional[Callable[[FunctionCall], bool]] = None,
validation_failure_summary: Optional[str] = None,
) -> None:
"""
Helper function to check that a message contains a valid function call.
There is an Optional parameter `function_call_validator` that specifies a validator function.
This function gets called on the resulting function_call to validate the function is what we expect.
"""
if (hasattr(message, "function_call") and message.function_call is not None) and (
hasattr(message, "tool_calls") and message.tool_calls is not None
):
raise InvalidToolCallError(messages=[message], explanation="Both function_call and tool_calls is present in the message")
elif hasattr(message, "function_call") and message.function_call is not None:
function_call = message.function_call
elif hasattr(message, "tool_calls") and message.tool_calls is not None:
# Note: We only take the first one for now. Is this a problem? @charles
# This seems to be standard across the repo
function_call = message.tool_calls[0].function
else:
# Throw a missing function call error
raise MissingToolCallError(messages=[message])
if function_call_validator and not function_call_validator(function_call):
raise InvalidToolCallError(messages=[message], explanation=validation_failure_summary)
def assert_inner_monologue_is_valid(message: Message) -> None:
"""
Helper function to check that the inner monologue is valid.
"""
# Sometimes the syntax won't be correct and internal syntax will leak into message
invalid_phrases = ["functions", "send_message", "arguments"]
monologue = message.content
for phrase in invalid_phrases:
if phrase in monologue:
raise InvalidInnerMonologueError(messages=[message], explanation=f"{phrase} is in monologue")
def assert_contains_correct_inner_monologue(
choice: Choice,
inner_thoughts_in_kwargs: bool,
validate_inner_monologue_contents: bool = True,
) -> None:
"""
Helper function to check that the inner monologue exists and is valid.
"""
# Unpack inner thoughts out of function kwargs, and repackage into choice
if inner_thoughts_in_kwargs:
choice = unpack_inner_thoughts_from_kwargs(choice, INNER_THOUGHTS_KWARG)
message = choice.message
monologue = message.content
if not monologue or monologue is None or monologue == "":
raise MissingInnerMonologueError(messages=[message])
if validate_inner_monologue_contents:
assert_inner_monologue_is_valid(message)

View File

@@ -0,0 +1,20 @@
from letta.data_sources.redis_client import get_redis_client
from letta.services.agent_manager import AgentManager
async def is_experimental_okay(feature_name: str, **kwargs) -> bool:
print(feature_name, kwargs)
if feature_name == "test_pass_with_kwarg":
return isinstance(kwargs["agent_manager"], AgentManager)
if feature_name == "test_just_pass":
return True
if feature_name == "test_fail":
return False
if feature_name == "test_override_kwarg":
return kwargs["bool_val"]
if feature_name == "test_redis_flag":
client = await get_redis_client()
user_id = kwargs["user_id"]
return await client.check_inclusion_and_exclusion(member=user_id, group="TEST_GROUP")
# Err on safety here, disabling experimental if not handled here.
return False

353
tests/helpers/utils.py Normal file
View File

@@ -0,0 +1,353 @@
import functools
import os
import time
from typing import Optional, Union
from letta_client import AsyncLetta, Letta
from letta.functions.functions import parse_source_code
from letta.functions.schema_generator import generate_schema
from letta.schemas.agent import AgentState, CreateAgent, UpdateAgent
from letta.schemas.enums import MessageRole
from letta.schemas.file import FileAgent
from letta.schemas.memory import ContextWindowOverview
from letta.schemas.tool import Tool
from letta.schemas.user import User, User as PydanticUser
from letta.server.rest_api.routers.v1.agents import ImportedAgentsResponse
from letta.server.server import SyncServer
def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
"""
Decorator to retry a test until a failure threshold is crossed.
:param threshold: Expected passing rate (e.g., 0.5 means 50% success rate expected).
:param max_attempts: Maximum number of attempts to retry the test.
"""
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
success_count = 0
failure_count = 0
for attempt in range(max_attempts):
try:
func(*args, **kwargs)
success_count += 1
except Exception as e:
failure_count += 1
print(f"\033[93mAn attempt failed with error:\n{e}\033[0m")
time.sleep(sleep_time_seconds)
rate = success_count / max_attempts
if rate >= threshold:
print(f"Test met expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}")
else:
raise AssertionError(
f"Test did not meet expected passing rate of {threshold:.2f}. Actual rate: {success_count}/{max_attempts}"
)
return wrapper
return decorator_retry
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
"""
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
:param max_attempts: Maximum number of attempts to retry the function.
:param sleep_time_seconds: Time to wait between attempts, in seconds.
"""
def decorator_retry(func):
@functools.wraps(func)
def wrapper(*args, **kwargs):
for attempt in range(1, max_attempts + 1):
try:
return func(*args, **kwargs)
except Exception as e:
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
if attempt == max_attempts:
raise
time.sleep(sleep_time_seconds)
return wrapper
return decorator_retry
def cleanup(server: SyncServer, agent_uuid: str, actor: User):
# Clear all agents
agent_states = server.agent_manager.list_agents(name=agent_uuid, actor=actor)
for agent_state in agent_states:
server.agent_manager.delete_agent(agent_id=agent_state.id, actor=actor)
# Utility functions
def create_tool_from_func(func: callable):
return Tool(
name=func.__name__,
description="",
source_type="python",
tags=[],
source_code=parse_source_code(func),
json_schema=generate_schema(func, None),
)
def comprehensive_agent_checks(agent: AgentState, request: Union[CreateAgent, UpdateAgent], actor: PydanticUser):
# Assert scalar fields
assert agent.system == request.system, f"System prompt mismatch: {agent.system} != {request.system}"
assert agent.description == request.description, f"Description mismatch: {agent.description} != {request.description}"
assert agent.metadata == request.metadata, f"Metadata mismatch: {agent.metadata} != {request.metadata}"
# Assert agent env vars
if hasattr(request, "tool_exec_environment_variables") and request.tool_exec_environment_variables:
for agent_env_var in agent.tool_exec_environment_variables:
assert agent_env_var.key in request.tool_exec_environment_variables
assert request.tool_exec_environment_variables[agent_env_var.key] == agent_env_var.value
assert agent_env_var.organization_id == actor.organization_id
if hasattr(request, "secrets") and request.secrets:
for agent_env_var in agent.secrets:
assert agent_env_var.key in request.secrets
assert request.secrets[agent_env_var.key] == agent_env_var.value
assert agent_env_var.organization_id == actor.organization_id
# Assert agent type
if hasattr(request, "agent_type"):
assert agent.agent_type == request.agent_type, f"Agent type mismatch: {agent.agent_type} != {request.agent_type}"
# Assert LLM configuration
assert agent.llm_config == request.llm_config, f"LLM config mismatch: {agent.llm_config} != {request.llm_config}"
# Assert embedding configuration
assert agent.embedding_config == request.embedding_config, (
f"Embedding config mismatch: {agent.embedding_config} != {request.embedding_config}"
)
# Assert memory blocks
if hasattr(request, "memory_blocks"):
assert len(agent.memory.blocks) == len(request.memory_blocks) + len(request.block_ids), (
f"Memory blocks count mismatch: {len(agent.memory.blocks)} != {len(request.memory_blocks) + len(request.block_ids)}"
)
memory_block_values = {block.value for block in agent.memory.blocks}
expected_block_values = {block.value for block in request.memory_blocks}
assert expected_block_values.issubset(memory_block_values), (
f"Memory blocks mismatch: {expected_block_values} not in {memory_block_values}"
)
# Assert tools
assert len(agent.tools) == len(request.tool_ids), f"Tools count mismatch: {len(agent.tools)} != {len(request.tool_ids)}"
assert {tool.id for tool in agent.tools} == set(request.tool_ids), (
f"Tools mismatch: {set(tool.id for tool in agent.tools)} != {set(request.tool_ids)}"
)
# Assert sources
assert len(agent.sources) == len(request.source_ids), f"Sources count mismatch: {len(agent.sources)} != {len(request.source_ids)}"
assert {source.id for source in agent.sources} == set(request.source_ids), (
f"Sources mismatch: {set(source.id for source in agent.sources)} != {set(request.source_ids)}"
)
# Assert tags
assert set(agent.tags) == set(request.tags), f"Tags mismatch: {set(agent.tags)} != {set(request.tags)}"
# Assert tool rules
print("TOOLRULES", request.tool_rules)
print("AGENTTOOLRULES", agent.tool_rules)
if request.tool_rules:
assert len(agent.tool_rules) == len(request.tool_rules), (
f"Tool rules count mismatch: {len(agent.tool_rules)} != {len(request.tool_rules)}"
)
assert all(any(rule.tool_name == req_rule.tool_name for rule in agent.tool_rules) for req_rule in request.tool_rules), (
f"Tool rules mismatch: {agent.tool_rules} != {request.tool_rules}"
)
# Assert message_buffer_autoclear
if request.message_buffer_autoclear is not None:
assert agent.message_buffer_autoclear == request.message_buffer_autoclear
def validate_context_window_overview(
agent_state: AgentState, overview: ContextWindowOverview, attached_file: Optional[FileAgent] = None
) -> None:
"""Validate common sense assertions for ContextWindowOverview"""
# 1. Current context size should not exceed maximum
assert overview.context_window_size_current <= overview.context_window_size_max, (
f"Current context size ({overview.context_window_size_current}) exceeds maximum ({overview.context_window_size_max})"
)
# 2. All token counts should be non-negative
assert overview.num_tokens_system >= 0, "System token count cannot be negative"
assert overview.num_tokens_core_memory >= 0, "Core memory token count cannot be negative"
assert overview.num_tokens_external_memory_summary >= 0, "External memory summary token count cannot be negative"
assert overview.num_tokens_summary_memory >= 0, "Summary memory token count cannot be negative"
assert overview.num_tokens_messages >= 0, "Messages token count cannot be negative"
assert overview.num_tokens_functions_definitions >= 0, "Functions definitions token count cannot be negative"
# 3. Token components should sum to total
expected_total = (
overview.num_tokens_system
+ overview.num_tokens_core_memory
+ overview.num_tokens_external_memory_summary
+ overview.num_tokens_summary_memory
+ overview.num_tokens_messages
+ overview.num_tokens_functions_definitions
)
assert overview.context_window_size_current == expected_total, (
f"Token sum ({expected_total}) doesn't match current size ({overview.context_window_size_current})"
)
# 4. Message count should match messages list length
assert len(overview.messages) == overview.num_messages, (
f"Messages list length ({len(overview.messages)}) doesn't match num_messages ({overview.num_messages})"
)
# 5. If summary_memory is None, its token count should be 0
if overview.summary_memory is None:
assert overview.num_tokens_summary_memory == 0, "Summary memory is None but has non-zero token count"
# 7. External memory summary consistency
assert overview.num_tokens_external_memory_summary > 0, "External memory summary exists but has zero token count"
# 8. System prompt consistency
assert overview.num_tokens_system > 0, "System prompt exists but has zero token count"
# 9. Core memory consistency
assert overview.num_tokens_core_memory > 0, "Core memory exists but has zero token count"
# 10. Functions definitions consistency
assert overview.num_tokens_functions_definitions > 0, "Functions definitions exist but have zero token count"
assert len(overview.functions_definitions) > 0, "Functions definitions list should not be empty"
# 11. Memory counts should be non-negative
assert overview.num_archival_memory >= 0, "Archival memory count cannot be negative"
assert overview.num_recall_memory >= 0, "Recall memory count cannot be negative"
# 12. Context window max should be positive
assert overview.context_window_size_max > 0, "Maximum context window size must be positive"
# 13. If there are messages, check basic structure
# At least one message should be system message (typical pattern)
has_system_message = any(msg.role == MessageRole.system for msg in overview.messages)
# This is a soft assertion - log warning instead of failing
if not has_system_message:
print("Warning: No system message found in messages list")
# Average tokens per message should be reasonable (typically > 0)
avg_tokens_per_message = overview.num_tokens_messages / overview.num_messages
assert avg_tokens_per_message >= 0, "Average tokens per message should be non-negative"
# 16. Check attached file is visible
if attached_file:
assert attached_file.visible_content in overview.core_memory, "File must be attached in core memory"
assert '<file status="open"' in overview.core_memory
assert "</file>" in overview.core_memory
assert "max_files_open" in overview.core_memory, "Max files should be set in core memory"
assert "current_files_open" in overview.core_memory, "Current files should be set in core memory"
# Check for tools
assert overview.num_tokens_functions_definitions > 0
assert len(overview.functions_definitions) > 0
# Changed this from server_url to client since client may be authenticated or not
def upload_test_agentfile_from_disk(client: Letta, filename: str) -> ImportedAgentsResponse:
"""
Upload a given .af file to live FastAPI server.
"""
path_to_current_file = os.path.dirname(__file__)
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
file_path = os.path.join(path_to_test_agent_files, filename)
with open(file_path, "rb") as f:
return client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
async def upload_test_agentfile_from_disk_async(client: AsyncLetta, filename: str) -> ImportedAgentsResponse:
"""
Upload a given .af file to live FastAPI server.
"""
path_to_current_file = os.path.dirname(__file__)
path_to_test_agent_files = path_to_current_file.removesuffix("/helpers") + "/test_agent_files"
file_path = os.path.join(path_to_test_agent_files, filename)
with open(file_path, "rb") as f:
uploaded = await client.agents.import_file(file=f, append_copy_suffix=True, override_existing_tools=False)
return uploaded
def upload_file_and_wait(
client: Letta,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: Optional[str] = None,
):
"""Helper function to upload a file and wait for processing to complete"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# wait for the file to be processed
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
file_metadata = client.sources.get_file_metadata(source_id=source_id, file_id=file_metadata.id)
print("Waiting for file processing to complete...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
return file_metadata
def upload_file_and_wait_list_files(
client: Letta,
source_id: str,
file_path: str,
name: Optional[str] = None,
max_wait: int = 60,
duplicate_handling: Optional[str] = None,
):
"""Helper function to upload a file and wait for processing using list_files instead of get_file_metadata"""
with open(file_path, "rb") as f:
if duplicate_handling:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, duplicate_handling=duplicate_handling, name=name)
else:
file_metadata = client.sources.files.upload(source_id=source_id, file=f, name=name)
# wait for the file to be processed using list_files
start_time = time.time()
while file_metadata.processing_status != "completed" and file_metadata.processing_status != "error":
if time.time() - start_time > max_wait:
raise TimeoutError(f"File processing timed out after {max_wait} seconds")
time.sleep(1)
# use list_files to get all files and find our specific file
files = client.sources.files.list(source_id=source_id, limit=100)
# find the file with matching id
for file in files:
if file.id == file_metadata.id:
file_metadata = file
break
else:
raise RuntimeError(f"File {file_metadata.id} not found in source files list")
print("Waiting for file processing to complete (via list_files)...", file_metadata.processing_status)
if file_metadata.processing_status == "error":
raise RuntimeError(f"File processing failed: {file_metadata.error_message}")
return file_metadata