diff --git a/tests/managers/test_agent_manager.py b/tests/managers/test_agent_manager.py index 94c7cb92..a864a4a8 100644 --- a/tests/managers/test_agent_manager.py +++ b/tests/managers/test_agent_manager.py @@ -1020,3 +1020,353 @@ async def test_agent_environment_variables_update_encryption(server: SyncServer, assert decrypted == "new-value-xyz" else: pytest.fail(f"Unexpected key: {env_var.key}") + + +@pytest.mark.asyncio +async def test_agent_state_schema_unchanged(server: SyncServer): + """ + Test that the AgentState pydantic schema structure has not changed. + This test validates all fields including nested pydantic objects to ensure + the schema remains stable across changes. + """ + from letta.schemas.agent import AgentState, AgentType + from letta.schemas.block import Block + from letta.schemas.embedding_config import EmbeddingConfig + from letta.schemas.environment_variables import AgentEnvironmentVariable + from letta.schemas.group import Group + from letta.schemas.llm_config import LLMConfig + from letta.schemas.memory import Memory + from letta.schemas.response_format import ResponseFormatUnion + from letta.schemas.source import Source + from letta.schemas.tool import Tool + from letta.schemas.tool_rule import ToolRule + + # Define the expected schema structure + expected_schema = { + # Core identification + "id": str, + "name": str, + # Tool rules + "tool_rules": (list, type(None)), + # In-context memory + "message_ids": (list, type(None)), + # System prompt + "system": str, + # Agent configuration + "agent_type": AgentType, + # LLM information + "llm_config": LLMConfig, + "embedding_config": EmbeddingConfig, + "response_format": (ResponseFormatUnion, type(None)), + # State fields + "description": (str, type(None)), + "metadata": (dict, type(None)), + # Memory and tools + "memory": Memory, + "tools": list, + "sources": list, + "tags": list, + "tool_exec_environment_variables": list, + "secrets": list, + # Project and template fields + "project_id": (str, type(None)), + "template_id": (str, type(None)), + "base_template_id": (str, type(None)), + "deployment_id": (str, type(None)), + "entity_id": (str, type(None)), + "identity_ids": list, + # Advanced configuration + "message_buffer_autoclear": bool, + "enable_sleeptime": (bool, type(None)), + # Multi-agent + "multi_agent_group": (Group, type(None)), + # Run metrics + "last_run_completion": (datetime, type(None)), + "last_run_duration_ms": (int, type(None)), + # Timezone + "timezone": (str, type(None)), + # File controls + "max_files_open": (int, type(None)), + "per_file_view_window_char_limit": (int, type(None)), + # Indexing controls + "hidden": (bool, type(None)), + # Metadata fields (from OrmMetadataBase) + "created_by_id": (str, type(None)), + "last_updated_by_id": (str, type(None)), + "created_at": (datetime, type(None)), + "updated_at": (datetime, type(None)), + } + + # Get the actual schema fields from AgentState + agent_state_fields = AgentState.model_fields + actual_field_names = set(agent_state_fields.keys()) + expected_field_names = set(expected_schema.keys()) + + # Check for added fields + added_fields = actual_field_names - expected_field_names + if added_fields: + pytest.fail( + f"New fields detected in AgentState schema: {sorted(added_fields)}. " + "This test must be updated to include these fields, and the schema change must be intentional." + ) + + # Check for removed fields + removed_fields = expected_field_names - actual_field_names + if removed_fields: + pytest.fail( + f"Fields removed from AgentState schema: {sorted(removed_fields)}. " + "This test must be updated to remove these fields, and the schema change must be intentional." + ) + + # Validate field types + import typing + + for field_name, expected_type in expected_schema.items(): + field = agent_state_fields[field_name] + annotation = field.annotation + + # Helper function to check if annotation matches expected type + def check_type_match(annotation, expected): + origin = typing.get_origin(annotation) + args = typing.get_args(annotation) + + # Direct match + if annotation == expected: + return True + + # Handle list type (List[X] should match list) + if expected is list and origin is list: + return True + + # Handle dict type (Dict[X, Y] should match dict) + if expected is dict and origin is dict: + return True + + # Handle Optional types + if origin is typing.Union: + # Check if expected type is in the union + if expected in args: + return True + # Handle list case within Union (e.g., Union[List[X], None]) + if expected is list: + for arg in args: + if typing.get_origin(arg) is list: + return True + # Handle dict case within Union + if expected is dict: + for arg in args: + if typing.get_origin(arg) is dict: + return True + + return False + + # Handle tuple of expected types (Optional) + if isinstance(expected_type, tuple): + valid = any(check_type_match(annotation, exp_t) for exp_t in expected_type) + if not valid: + pytest.fail( + f"Field '{field_name}' type changed. Expected one of {expected_type}, " + f"but got {annotation}. Schema changes must be intentional." + ) + else: + # Single expected type + valid = check_type_match(annotation, expected_type) + if not valid: + pytest.fail( + f"Field '{field_name}' type changed. Expected {expected_type}, " + f"but got {annotation}. Schema changes must be intentional." + ) + + # Validate nested object schemas + # Memory schema + memory_fields = Memory.model_fields + expected_memory_fields = {"agent_type", "blocks", "file_blocks", "prompt_template"} + actual_memory_fields = set(memory_fields.keys()) + if actual_memory_fields != expected_memory_fields: + pytest.fail( + f"Memory schema changed. Expected fields: {expected_memory_fields}, " + f"Got: {actual_memory_fields}. Schema changes must be intentional." + ) + + # Block schema + block_fields = Block.model_fields + expected_block_fields = { + "id", + "value", + "limit", + "project_id", + "template_name", + "is_template", + "template_id", + "base_template_id", + "deployment_id", + "entity_id", + "preserve_on_migration", + "label", + "read_only", + "description", + "metadata", + "hidden", + "created_by_id", + "last_updated_by_id", + } + actual_block_fields = set(block_fields.keys()) + if actual_block_fields != expected_block_fields: + pytest.fail( + f"Block schema changed. Expected fields: {expected_block_fields}, " + f"Got: {actual_block_fields}. Schema changes must be intentional." + ) + + # Tool schema + tool_fields = Tool.model_fields + expected_tool_fields = { + "id", + "tool_type", + "description", + "source_type", + "name", + "tags", + "source_code", + "json_schema", + "args_json_schema", + "return_char_limit", + "pip_requirements", + "npm_requirements", + "default_requires_approval", + "enable_parallel_execution", + "created_by_id", + "last_updated_by_id", + "metadata_", + } + actual_tool_fields = set(tool_fields.keys()) + if actual_tool_fields != expected_tool_fields: + pytest.fail( + f"Tool schema changed. Expected fields: {expected_tool_fields}, Got: {actual_tool_fields}. Schema changes must be intentional." + ) + + # Source schema + source_fields = Source.model_fields + expected_source_fields = { + "id", + "name", + "description", + "instructions", + "metadata", + "embedding_config", + "organization_id", + "vector_db_provider", + "created_by_id", + "last_updated_by_id", + "created_at", + "updated_at", + } + actual_source_fields = set(source_fields.keys()) + if actual_source_fields != expected_source_fields: + pytest.fail( + f"Source schema changed. Expected fields: {expected_source_fields}, " + f"Got: {actual_source_fields}. Schema changes must be intentional." + ) + + # LLMConfig schema + llm_config_fields = LLMConfig.model_fields + expected_llm_config_fields = { + "model", + "display_name", + "model_endpoint_type", + "model_endpoint", + "provider_name", + "provider_category", + "model_wrapper", + "context_window", + "put_inner_thoughts_in_kwargs", + "handle", + "temperature", + "max_tokens", + "enable_reasoner", + "reasoning_effort", + "max_reasoning_tokens", + "frequency_penalty", + "compatibility_type", + "verbosity", + "tier", + "parallel_tool_calls", + } + actual_llm_config_fields = set(llm_config_fields.keys()) + if actual_llm_config_fields != expected_llm_config_fields: + pytest.fail( + f"LLMConfig schema changed. Expected fields: {expected_llm_config_fields}, " + f"Got: {actual_llm_config_fields}. Schema changes must be intentional." + ) + + # EmbeddingConfig schema + embedding_config_fields = EmbeddingConfig.model_fields + expected_embedding_config_fields = { + "embedding_endpoint_type", + "embedding_endpoint", + "embedding_model", + "embedding_dim", + "embedding_chunk_size", + "handle", + "batch_size", + "azure_endpoint", + "azure_version", + "azure_deployment", + } + actual_embedding_config_fields = set(embedding_config_fields.keys()) + if actual_embedding_config_fields != expected_embedding_config_fields: + pytest.fail( + f"EmbeddingConfig schema changed. Expected fields: {expected_embedding_config_fields}, " + f"Got: {actual_embedding_config_fields}. Schema changes must be intentional." + ) + + # AgentEnvironmentVariable schema + agent_env_var_fields = AgentEnvironmentVariable.model_fields + expected_agent_env_var_fields = { + "id", + "key", + "value", + "description", + "organization_id", + "value_enc", + "agent_id", + # From OrmMetadataBase + "created_by_id", + "last_updated_by_id", + "created_at", + "updated_at", + } + actual_agent_env_var_fields = set(agent_env_var_fields.keys()) + if actual_agent_env_var_fields != expected_agent_env_var_fields: + pytest.fail( + f"AgentEnvironmentVariable schema changed. Expected fields: {expected_agent_env_var_fields}, " + f"Got: {actual_agent_env_var_fields}. Schema changes must be intentional." + ) + + # Group schema + group_fields = Group.model_fields + expected_group_fields = { + "id", + "manager_type", + "agent_ids", + "description", + "project_id", + "template_id", + "base_template_id", + "deployment_id", + "shared_block_ids", + "manager_agent_id", + "termination_token", + "max_turns", + "sleeptime_agent_frequency", + "turns_counter", + "last_processed_message_id", + "max_message_buffer_length", + "min_message_buffer_length", + "hidden", + } + actual_group_fields = set(group_fields.keys()) + if actual_group_fields != expected_group_fields: + pytest.fail( + f"Group schema changed. Expected fields: {expected_group_fields}, " + f"Got: {actual_group_fields}. Schema changes must be intentional." + )