diff --git a/letta/schemas/agent_file.py b/letta/schemas/agent_file.py index bb4de9a3..ae669b92 100644 --- a/letta/schemas/agent_file.py +++ b/letta/schemas/agent_file.py @@ -217,3 +217,6 @@ class AgentFileSchema(BaseModel): sources: List[SourceSchema] = Field(..., description="List of sources in this agent file") tools: List[ToolSchema] = Field(..., description="List of tools in this agent file") # mcp_servers: List[MCPServerSchema] = Field(..., description="List of MCP servers in this agent file") + metadata: Dict[str, str] = Field( + default_factory=dict, description="Metadata for this agent file, including revision_id and other export information." + ) diff --git a/letta/services/agent_file_manager.py b/letta/services/agent_file_manager.py index dcbdd29d..d58c3d41 100644 --- a/letta/services/agent_file_manager.py +++ b/letta/services/agent_file_manager.py @@ -27,6 +27,7 @@ from letta.services.mcp_manager import MCPManager from letta.services.message_manager import MessageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager +from letta.utils import get_latest_alembic_revision logger = get_logger(__name__) @@ -233,6 +234,7 @@ class AgentFileManager: sources=[], # TODO: Extract and convert sources tools=tool_schemas, # mcp_servers=[], # TODO: Extract and convert MCP servers + metadata={"revision_id": await get_latest_alembic_revision()}, ) except Exception as e: diff --git a/letta/utils.py b/letta/utils.py index 4ddf0a61..91218cd4 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -23,6 +23,7 @@ from urllib.parse import urljoin, urlparse import demjson3 as demjson import tiktoken from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename +from sqlalchemy import text import letta from letta.constants import ( @@ -35,8 +36,12 @@ from letta.constants import ( TOOL_CALL_ID_MAX_LEN, ) from letta.helpers.json_helpers import json_dumps, json_loads +from letta.log import get_logger from letta.schemas.openai.chat_completion_response import ChatCompletionResponse +logger = get_logger(__name__) + + DEBUG = False if "LOG_LEVEL" in os.environ: if os.environ["LOG_LEVEL"] == "DEBUG": @@ -1182,3 +1187,22 @@ class NullCancellationSignal(CancellationSignal): async def check_and_raise_if_cancelled(self): pass + + +async def get_latest_alembic_revision() -> str: + """Get the current alembic revision ID from the alembic_version table.""" + from letta.server.db import db_registry + + try: + async with db_registry.async_session() as session: + result = await session.execute(text("SELECT version_num FROM alembic_version")) + row = result.fetchone() + + if row: + return row[0] + else: + return "unknown" + + except Exception as e: + logger.error(f"Error getting latest alembic revision: {e}") + return "unknown" diff --git a/tests/test_agent_serialization_v2.py b/tests/test_agent_serialization_v2.py index b213996d..d36098fd 100644 --- a/tests/test_agent_serialization_v2.py +++ b/tests/test_agent_serialization_v2.py @@ -525,6 +525,11 @@ class TestAgentFileExport: assert len(agent_file.tools) > 0 # Should include base tools + weather tool assert len(agent_file.blocks) > 0 # Should include memory blocks + test block + # Validate revision_id is automatically set in metadata + assert agent_file.metadata.get("revision_id") is not None + assert agent_file.metadata.get("revision_id") != "unknown" + assert len(agent_file.metadata.get("revision_id")) > 0 + # Validate ID formats assert validate_id_format(agent_file) @@ -620,6 +625,26 @@ class TestAgentFileExport: with pytest.raises(AgentFileExportError): # Should raise AgentFileExportError for non-existent agent await agent_file_manager.export(["non-existent-id"], default_user) + async def test_revision_id_automatic_setting(self, agent_file_manager, test_agent, default_user): + """Test that revision_id is automatically set to the latest alembic revision.""" + # Export the agent + agent_file = await agent_file_manager.export([test_agent.id], default_user) + + # Get the expected revision ID from the function + from letta.utils import get_latest_alembic_revision + + expected_revision = await get_latest_alembic_revision() + + # Validate that the revision_id matches the latest alembic revision + assert agent_file.metadata.get("revision_id") == expected_revision + + # Validate that it's not the fallback "unknown" value + assert agent_file.metadata.get("revision_id") != "unknown" + + # Validate that it looks like a valid revision ID (12 hex characters) + assert len(agent_file.metadata.get("revision_id")) == 12 + assert all(c in "0123456789abcdef" for c in agent_file.metadata.get("revision_id")) + class TestAgentFileImport: """Tests for agent file import functionality.""" @@ -715,8 +740,14 @@ class TestAgentFileImport: async def test_import_validation_errors(self, agent_file_manager, other_user): """Test import validation catches errors.""" + # Get current revision for test + from letta.utils import get_latest_alembic_revision + + current_revision = await get_latest_alembic_revision() + # Create invalid agent file with duplicate IDs invalid_agent_file = AgentFileSchema( + metadata={"revision_id": current_revision}, agents=[ AgentSchema(id="agent-0", name="agent1"), AgentSchema(id="agent-0", name="agent2"), # Duplicate ID @@ -855,8 +886,12 @@ class TestAgentFileValidation: def test_agent_file_schema_validation(self, test_agent): """Test AgentFileSchema validation.""" + # Use a dummy revision for this test since we can't await in sync test + current_revision = "495f3f474131" # Use a known valid revision format + # Valid schema valid_schema = AgentFileSchema( + metadata={"revision_id": current_revision}, agents=[AgentSchema(id="agent-0", name="test")], groups=[], blocks=[], @@ -868,6 +903,7 @@ class TestAgentFileValidation: # Should not raise assert valid_schema.agents[0].id == "agent-0" + assert valid_schema.metadata.get("revision_id") == current_revision def test_message_schema_conversion(self, test_agent, server, default_user): """Test MessageSchema.from_message conversion.""" @@ -887,8 +923,12 @@ class TestAgentFileValidation: def test_id_format_validation(self): """Test ID format validation helper.""" + # Use a dummy revision for this test since we can't await in sync test + current_revision = "495f3f474131" # Use a known valid revision format + # Valid schema valid_schema = AgentFileSchema( + metadata={"revision_id": current_revision}, agents=[AgentSchema(id="agent-0", name="test")], groups=[], blocks=[BlockSchema(id="block-0", label="test", value="test")], @@ -909,6 +949,7 @@ class TestAgentFileValidation: # Invalid schema invalid_schema = AgentFileSchema( + metadata={"revision_id": current_revision}, agents=[AgentSchema(id="invalid-id-format", name="test")], groups=[], blocks=[], diff --git a/tests/test_utils.py b/tests/test_utils.py index 5da46f61..f0ef20a9 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,3 +1,5 @@ +import asyncio + import pytest from letta.constants import MAX_FILENAME_LENGTH @@ -508,3 +510,49 @@ def test_line_chunker_only_start_parameter(): # Test invalid start only with pytest.raises(ValueError, match="File test.py has only 3 lines, but requested offset 4 is out of range"): chunker.chunk_text(file, start=3, validate_range=True) + + +# ---------------------- Alembic Revision TESTS ---------------------- # + + +@pytest.fixture(scope="module") +def event_loop(): + """ + Create an event loop for the entire test session. + Ensures all async tasks use the same loop, avoiding cross-loop errors. + """ + loop = asyncio.new_event_loop() + yield loop + loop.close() + + +@pytest.mark.asyncio +async def test_get_latest_alembic_revision(event_loop): + """Test that get_latest_alembic_revision returns a valid revision ID from the database.""" + from letta.utils import get_latest_alembic_revision + + # Get the revision ID + revision_id = await get_latest_alembic_revision() + + # Validate that it's not the fallback "unknown" value + assert revision_id != "unknown" + + # Validate that it looks like a valid revision ID (12 hex characters) + assert len(revision_id) == 12 + assert all(c in "0123456789abcdef" for c in revision_id) + + # Validate that it's a string + assert isinstance(revision_id, str) + + +@pytest.mark.asyncio +async def test_get_latest_alembic_revision_consistency(event_loop): + """Test that get_latest_alembic_revision returns the same value on multiple calls.""" + from letta.utils import get_latest_alembic_revision + + # Get the revision ID twice + revision_id1 = await get_latest_alembic_revision() + revision_id2 = await get_latest_alembic_revision() + + # They should be identical + assert revision_id1 == revision_id2