diff --git a/letta/constants.py b/letta/constants.py index 6f793a25..ff77ff69 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -328,15 +328,15 @@ READ_ONLY_BLOCK_EDIT_ERROR = f"{ERROR_MESSAGE_PREFIX} This block is read-only an MESSAGE_SUMMARY_REQUEST_ACK = "Understood, I will respond with a summary of the message (and only the summary, nothing else) once I receive the conversation history. I'm ready." # Maximum length of an error message -MAX_ERROR_MESSAGE_CHAR_LIMIT = 500 +MAX_ERROR_MESSAGE_CHAR_LIMIT = 1000 # Default memory limits -CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 5000 -CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 5000 -CORE_MEMORY_BLOCK_CHAR_LIMIT: int = 5000 +CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 20000 +CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 20000 +CORE_MEMORY_BLOCK_CHAR_LIMIT: int = 20000 # Function return limits -FUNCTION_RETURN_CHAR_LIMIT = 6000 # ~300 words +FUNCTION_RETURN_CHAR_LIMIT = 50000 # ~300 words BASE_FUNCTION_RETURN_CHAR_LIMIT = 1000000 # very high (we rely on implementation) FILE_IS_TRUNCATED_WARNING = "# NOTE: This block is truncated, use functions to view the full content." @@ -394,5 +394,7 @@ PINECONE_THROTTLE_DELAY = 0.75 # seconds base delay between batches WEB_SEARCH_MODEL_ENV_VAR_NAME = "LETTA_BUILTIN_WEBSEARCH_OPENAI_MODEL_NAME" WEB_SEARCH_MODEL_ENV_VAR_DEFAULT_VALUE = "gpt-4.1-mini-2025-04-14" -# Excluded providers from base tool rules -EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES = {"anthropic", "openai", "google_ai", "google_vertex"} +# Excluded model keywords from base tool rules +EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES = ["claude-4-sonnet", "claude-3-5-sonnet", "gpt-5", "gemini-2.5-pro"] +# But include models with these keywords in base tool rules (overrides exclusion) +INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES = ["mini"] diff --git a/letta/schemas/providers/openai.py b/letta/schemas/providers/openai.py index 24c5d029..ed9d5988 100644 --- a/letta/schemas/providers/openai.py +++ b/letta/schemas/providers/openai.py @@ -12,7 +12,7 @@ from letta.schemas.providers.base import Provider logger = get_logger(__name__) ALLOWED_PREFIXES = {"gpt-4", "gpt-5", "o1", "o3", "o4"} -DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro"} +DISALLOWED_KEYWORDS = {"transcribe", "search", "realtime", "tts", "audio", "computer", "o1-mini", "o1-preview", "o1-pro", "chat"} DEFAULT_EMBEDDING_BATCH_SIZE = 1024 diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index e04f146b..d4c99dae 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -19,8 +19,9 @@ from letta.constants import ( DEFAULT_MAX_FILES_OPEN, DEFAULT_TIMEZONE, DEPRECATED_LETTA_TOOLS, - EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES, + EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES, FILES_TOOLS, + INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES, ) from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time @@ -117,6 +118,21 @@ class AgentManager: self.identity_manager = IdentityManager() self.file_agent_manager = FileAgentManager() + @staticmethod + def _should_exclude_model_from_base_tool_rules(model: str) -> bool: + """Check if a model should be excluded from base tool rules based on model keywords.""" + # First check if model contains any include keywords (overrides exclusion) + for include_keyword in INCLUDE_MODEL_KEYWORDS_BASE_TOOL_RULES: + if include_keyword in model: + return False + + # Then check if model contains any exclude keywords + for exclude_keyword in EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES: + if exclude_keyword in model: + return True + + return False + @staticmethod def _resolve_tools(session, names: Set[str], ids: Set[str], org_id: str) -> Tuple[Dict[str, str], Dict[str, str]]: """ @@ -334,16 +350,16 @@ class AgentManager: tool_rules = list(agent_create.tool_rules or []) - # Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True + # Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True if ( ( - agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES + self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model) and agent_create.include_base_tool_rules is None ) and agent_create.agent_type != AgentType.sleeptime_agent ) or agent_create.include_base_tool_rules is False: agent_create.include_base_tool_rules = False - logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}") + logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}") else: agent_create.include_base_tool_rules = True @@ -543,16 +559,16 @@ class AgentManager: tool_names = set(name_to_id.keys()) # now canonical tool_rules = list(agent_create.tool_rules or []) - # Override include_base_tool_rules to False if provider is not in excluded set and include_base_tool_rules is not explicitly set to True + # Override include_base_tool_rules to False if model matches exclusion keywords and include_base_tool_rules is not explicitly set to True if ( ( - agent_create.llm_config.model_endpoint_type in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES + self._should_exclude_model_from_base_tool_rules(agent_create.llm_config.model) and agent_create.include_base_tool_rules is None ) and agent_create.agent_type != AgentType.sleeptime_agent ) or agent_create.include_base_tool_rules is False: agent_create.include_base_tool_rules = False - logger.info(f"Overriding include_base_tool_rules to False for provider: {agent_create.llm_config.model_endpoint_type}") + logger.info(f"Overriding include_base_tool_rules to False for model: {agent_create.llm_config.model}") else: agent_create.include_base_tool_rules = True diff --git a/tests/sdk/blocks_test.py b/tests/sdk/blocks_test.py index 3d67b1be..301f2fd6 100644 --- a/tests/sdk/blocks_test.py +++ b/tests/sdk/blocks_test.py @@ -1,9 +1,11 @@ from conftest import create_test_module from letta_client.errors import UnprocessableEntityError +from letta.constants import CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT + BLOCKS_CREATE_PARAMS = [ - ("human_block", {"label": "human", "value": "test"}, {"limit": 5000}, None), - ("persona_block", {"label": "persona", "value": "test1"}, {"limit": 5000}, None), + ("human_block", {"label": "human", "value": "test"}, {"limit": CORE_MEMORY_HUMAN_CHAR_LIMIT}, None), + ("persona_block", {"label": "persona", "value": "test1"}, {"limit": CORE_MEMORY_PERSONA_CHAR_LIMIT}, None), ] BLOCKS_MODIFY_PARAMS = [ diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py deleted file mode 100644 index 5a88de87..00000000 --- a/tests/test_base_functions.py +++ /dev/null @@ -1,231 +0,0 @@ -import asyncio -import os -import threading - -import pytest -from dotenv import load_dotenv -from letta_client import Letta - -import letta.functions.function_sets.base as base_functions -from letta.config import LettaConfig -from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.llm_config import LLMConfig -from letta.schemas.message import MessageCreate -from letta.server.server import SyncServer -from tests.test_tool_schema_parsing_files.expected_base_tool_schemas import ( - get_finish_rethinking_memory_schema, - get_rethink_user_memory_schema, - get_search_memory_schema, - get_store_memories_schema, -) -from tests.utils import wait_for_server - - -def _run_server(): - """Starts the Letta server in a background thread.""" - load_dotenv() - from letta.server.rest_api.app import start_server - - start_server(debug=True) - - -@pytest.fixture(scope="module") -def server(): - """ - Creates a SyncServer instance for testing. - - Loads and saves config to ensure proper initialization. - """ - config = LettaConfig.load() - - config.save() - - server = SyncServer(init_with_default_org_and_user=True) - yield server - - -@pytest.fixture(scope="session") -def server_url(): - """Ensures a server is running and returns its base URL.""" - url = os.getenv("LETTA_SERVER_URL", "http://localhost:8283") - - if not os.getenv("LETTA_SERVER_URL"): - thread = threading.Thread(target=_run_server, daemon=True) - thread.start() - wait_for_server(url) - - return url - - -@pytest.fixture(scope="session") -def letta_client(server_url): - """Creates a REST client for testing.""" - client = Letta(base_url=server_url) - client.tools.upsert_base_tools() - return client - - -@pytest.fixture(scope="function") -def agent_obj(letta_client, server): - """Create a test agent that we can call functions on""" - send_message_to_agent_and_wait_for_reply_tool_id = letta_client.tools.list(name="send_message_to_agent_and_wait_for_reply")[0].id - agent_state = letta_client.agents.create( - tool_ids=[send_message_to_agent_and_wait_for_reply_tool_id], - include_base_tools=True, - memory_blocks=[ - { - "label": "human", - "value": "Name: Matt", - }, - { - "label": "persona", - "value": "Friendly agent", - }, - ], - llm_config=LLMConfig.default_config(model_name="gpt-4o-mini"), - embedding_config=EmbeddingConfig.default_config(provider="openai"), - ) - actor = server.user_manager.get_user_or_default() - agent_obj = server.load_agent(agent_id=agent_state.id, actor=actor) - yield agent_obj - - -def query_in_search_results(search_results, query): - for result in search_results: - if query.lower() in result["content"].lower(): - return True - return False - - -@pytest.mark.asyncio -async def test_archival(agent_obj): - """Test archival memory functions comprehensively.""" - # Test 1: Basic insertion and retrieval - await base_functions.archival_memory_insert(agent_obj, "The cat sleeps on the mat") - await asyncio.sleep(0.1) # Small delay to ensure session cleanup - await base_functions.archival_memory_insert(agent_obj, "The dog plays in the park") - await asyncio.sleep(0.1) - await base_functions.archival_memory_insert(agent_obj, "Python is a programming language") - await asyncio.sleep(0.1) - - # Test exact text search - results, _ = await base_functions.archival_memory_search(agent_obj, "cat") - assert query_in_search_results(results, "cat") - await asyncio.sleep(0.1) - - # Test semantic search (should return animal-related content) - results, _ = await base_functions.archival_memory_search(agent_obj, "animal pets") - assert query_in_search_results(results, "cat") or query_in_search_results(results, "dog") - await asyncio.sleep(0.1) - - # Test unrelated search (should not return animal content) - results, _ = await base_functions.archival_memory_search(agent_obj, "programming computers") - assert query_in_search_results(results, "python") - await asyncio.sleep(0.1) - - # Test 2: Test pagination - # Insert more items to test pagination - for i in range(10): - await base_functions.archival_memory_insert(agent_obj, f"Test passage number {i}") - await asyncio.sleep(0.05) # Shorter delay for bulk operations - - # Get first page - page0_results, next_page = await base_functions.archival_memory_search(agent_obj, "Test passage", page=0) - await asyncio.sleep(0.1) - # Get second page - page1_results, _ = await base_functions.archival_memory_search(agent_obj, "Test passage", page=1, start=next_page) - await asyncio.sleep(0.1) - - assert page0_results != page1_results - assert query_in_search_results(page0_results, "Test passage") - assert query_in_search_results(page1_results, "Test passage") - - # Test 3: Test complex text patterns - await base_functions.archival_memory_insert(agent_obj, "Important meeting on 2024-01-15 with John") - await base_functions.archival_memory_insert(agent_obj, "Follow-up meeting scheduled for next week") - await base_functions.archival_memory_insert(agent_obj, "Project deadline is approaching") - - # Search for meeting-related content - results, _ = await base_functions.archival_memory_search(agent_obj, "meeting schedule") - assert query_in_search_results(results, "meeting") - assert query_in_search_results(results, "2024-01-15") or query_in_search_results(results, "next week") - - # Test 4: Test error handling - # Test invalid page number - try: - await base_functions.archival_memory_search(agent_obj, "test", page="invalid") - assert False, "Should have raised ValueError" - except ValueError: - pass - - -def test_recall(server, agent_obj, default_user): - """Test that an agent can recall messages using a keyword via conversation search.""" - keyword = "banana" - "".join(reversed(keyword)) - - # Send messages - for msg in ["hello", keyword, "tell me a fun fact"]: - server.send_messages( - actor=default_user, - agent_id=agent_obj.agent_state.id, - input_messages=[MessageCreate(role="user", content=msg)], - ) - - # Search memory - result = base_functions.conversation_search(agent_obj, "banana") - assert keyword in result - - -def test_get_rethink_user_memory_parsing(letta_client): - tool = letta_client.tools.list(name="rethink_user_memory")[0] - json_schema = tool.json_schema - # Remove `request_heartbeat` from properties - json_schema["parameters"]["properties"].pop("request_heartbeat", None) - - # Remove it from the required list if present - required = json_schema["parameters"].get("required", []) - if "request_heartbeat" in required: - required.remove("request_heartbeat") - - assert json_schema == get_rethink_user_memory_schema() - - -def test_get_finish_rethinking_memory_parsing(letta_client): - tool = letta_client.tools.list(name="finish_rethinking_memory")[0] - json_schema = tool.json_schema - # Remove `request_heartbeat` from properties - json_schema["parameters"]["properties"].pop("request_heartbeat", None) - - # Remove it from the required list if present - required = json_schema["parameters"].get("required", []) - if "request_heartbeat" in required: - required.remove("request_heartbeat") - - assert json_schema == get_finish_rethinking_memory_schema() - - -def test_store_memories_parsing(letta_client): - tool = letta_client.tools.list(name="store_memories")[0] - json_schema = tool.json_schema - # Remove `request_heartbeat` from properties - json_schema["parameters"]["properties"].pop("request_heartbeat", None) - - # Remove it from the required list if present - required = json_schema["parameters"].get("required", []) - if "request_heartbeat" in required: - required.remove("request_heartbeat") - assert json_schema == get_store_memories_schema() - - -def test_search_memory_parsing(letta_client): - tool = letta_client.tools.list(name="search_memory")[0] - json_schema = tool.json_schema - # Remove `request_heartbeat` from properties - json_schema["parameters"]["properties"].pop("request_heartbeat", None) - - # Remove it from the required list if present - required = json_schema["parameters"].get("required", []) - if "request_heartbeat" in required: - required.remove("request_heartbeat") - assert json_schema == get_search_memory_schema() diff --git a/tests/test_managers.py b/tests/test_managers.py index 319d9ce8..866af301 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -848,7 +848,7 @@ async def test_create_agent_base_tool_rules_non_excluded_providers(server: SyncS memory_blocks=memory_blocks, llm_config=LLMConfig( model="llama-3.1-8b-instruct", - model_endpoint_type="together", # Not in EXCLUDED_PROVIDERS_FROM_BASE_TOOL_RULES + model_endpoint_type="together", # Model doesn't match EXCLUDE_MODEL_KEYWORDS_FROM_BASE_TOOL_RULES model_endpoint="https://api.together.xyz", context_window=8192, ), diff --git a/tests/test_memory.py b/tests/test_memory.py index 87e02a0e..c4675b94 100644 --- a/tests/test_memory.py +++ b/tests/test_memory.py @@ -19,10 +19,10 @@ def test_create_chat_memory(): def test_memory_limit_validation(sample_memory: Memory): """Test exceeding memory limit""" with pytest.raises(ValueError): - ChatMemory(persona="x " * 10000, human="y " * 10000) + ChatMemory(persona="x " * 50000, human="y " * 50000) with pytest.raises(ValueError): - sample_memory.get_block("persona").value = "x " * 10000 + sample_memory.get_block("persona").value = "x " * 50000 def test_memory_jinja2_set_template(sample_memory: Memory):