From f5be308f54e066a4d46d233057bb32075a551bd5 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 8 Sep 2025 12:01:35 -0700 Subject: [PATCH] feat: Fix race condition with creating archival memories in parallel [LET-4205] (#4464) * Test archive manager and add race condition handling * Fix client tests * Remove bad test --- letta/services/archive_manager.py | 67 ++++++++-- tests/integration_test_turbopuffer.py | 17 --- tests/test_managers.py | 170 ++++++++++++++++++++++++++ 3 files changed, 229 insertions(+), 25 deletions(-) diff --git a/letta/services/archive_manager.py b/letta/services/archive_manager.py index 9f98721d..d18266c5 100644 --- a/letta/services/archive_manager.py +++ b/letta/services/archive_manager.py @@ -5,6 +5,7 @@ from sqlalchemy import select from letta.helpers.tpuf_client import should_use_tpuf from letta.log import get_logger from letta.orm import ArchivalPassage, Archive as ArchiveModel, ArchivesAgents +from letta.otel.tracing import trace_method from letta.schemas.archive import Archive as PydanticArchive from letta.schemas.enums import VectorDBProvider from letta.schemas.user import User as PydanticUser @@ -19,6 +20,7 @@ class ArchiveManager: """Manager class to handle business logic related to Archives.""" @enforce_types + @trace_method def create_archive( self, name: str, @@ -44,6 +46,7 @@ class ArchiveManager: raise @enforce_types + @trace_method async def create_archive_async( self, name: str, @@ -69,6 +72,7 @@ class ArchiveManager: raise @enforce_types + @trace_method async def get_archive_by_id_async( self, archive_id: str, @@ -84,6 +88,7 @@ class ArchiveManager: return archive.to_pydantic() @enforce_types + @trace_method def attach_agent_to_archive( self, agent_id: str, @@ -113,6 +118,7 @@ class ArchiveManager: session.commit() @enforce_types + @trace_method async def attach_agent_to_archive_async( self, agent_id: str, @@ -148,6 +154,7 @@ class ArchiveManager: await session.commit() @enforce_types + @trace_method async def get_default_archive_for_agent_async( self, agent_id: str, @@ -179,6 +186,24 @@ class ArchiveManager: return None @enforce_types + @trace_method + async def delete_archive_async( + self, + archive_id: str, + actor: PydanticUser = None, + ) -> None: + """Delete an archive permanently.""" + async with db_registry.async_session() as session: + archive_model = await ArchiveModel.read_async( + db_session=session, + identifier=archive_id, + actor=actor, + ) + await archive_model.hard_delete_async(session, actor=actor) + logger.info(f"Deleted archive {archive_id}") + + @enforce_types + @trace_method async def get_or_create_default_archive_for_agent_async( self, agent_id: str, @@ -187,6 +212,8 @@ class ArchiveManager: ) -> PydanticArchive: """Get the agent's default archive, creating one if it doesn't exist.""" # First check if agent has any archives + from sqlalchemy.exc import IntegrityError + from letta.services.agent_manager import AgentManager agent_manager = AgentManager() @@ -215,17 +242,38 @@ class ArchiveManager: actor=actor, ) - # Attach the agent to the archive as owner - await self.attach_agent_to_archive_async( - agent_id=agent_id, - archive_id=archive.id, - is_owner=True, - actor=actor, - ) + try: + # Attach the agent to the archive as owner + await self.attach_agent_to_archive_async( + agent_id=agent_id, + archive_id=archive.id, + is_owner=True, + actor=actor, + ) + return archive + except IntegrityError: + # race condition: another concurrent request already created and attached an archive + # clean up the orphaned archive we just created + logger.info(f"Race condition detected for agent {agent_id}, cleaning up orphaned archive {archive.id}") + await self.delete_archive_async(archive_id=archive.id, actor=actor) - return archive + # fetch the existing archive that was created by the concurrent request + archive_ids = await agent_manager.get_agent_archive_ids_async( + agent_id=agent_id, + actor=actor, + ) + if archive_ids: + archive = await self.get_archive_by_id_async( + archive_id=archive_ids[0], + actor=actor, + ) + return archive + else: + # this shouldn't happen, but if it does, re-raise + raise @enforce_types + @trace_method def get_or_create_default_archive_for_agent( self, agent_id: str, @@ -269,6 +317,7 @@ class ArchiveManager: return archive_model.to_pydantic() @enforce_types + @trace_method async def get_agents_for_archive_async( self, archive_id: str, @@ -280,6 +329,7 @@ class ArchiveManager: return [row[0] for row in result.fetchall()] @enforce_types + @trace_method async def get_agent_from_passage_async( self, passage_id: str, @@ -309,6 +359,7 @@ class ArchiveManager: return agent_ids[0] @enforce_types + @trace_method async def get_or_set_vector_db_namespace_async( self, archive_id: str, diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index b820f701..e1c4bf92 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -1933,23 +1933,6 @@ class TestNamespaceTracking: namespace2 = await server.archive_manager.get_or_set_vector_db_namespace_async(archive.id) assert namespace == namespace2 - @pytest.mark.asyncio - async def test_agent_namespace_tracking(self, server, default_user, sarah_agent, enable_message_embedding): - """Test that agent message namespaces are properly tracked in database""" - # Get namespace - should be generated and stored - namespace = await server.agent_manager.get_or_set_vector_db_namespace_async(default_user.organization_id) - - # Should have messages_org_ prefix and environment suffix - expected_prefix = "messages_" - assert namespace.startswith(expected_prefix) - assert default_user.organization_id in namespace - if settings.environment: - assert settings.environment.lower() in namespace - - # Call again - should return same namespace from database - namespace2 = await server.agent_manager.get_or_set_vector_db_namespace_async(default_user.organization_id) - assert namespace == namespace2 - @pytest.mark.asyncio async def test_namespace_consistency_with_tpuf_client(self, server, default_user, enable_turbopuffer): """Test that the namespace from managers matches what tpuf_client would generate""" diff --git a/tests/test_managers.py b/tests/test_managers.py index eb05514f..8a668189 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4130,6 +4130,176 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user) +# ====================================================================================================================== +# Archive Manager Tests +# ====================================================================================================================== +@pytest.mark.asyncio +async def test_archive_manager_delete_archive_async(server: SyncServer, default_user): + """Test the delete_archive_async function.""" + archive = await server.archive_manager.create_archive_async( + name="test_archive_to_delete", description="This archive will be deleted", actor=default_user + ) + + retrieved_archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user) + assert retrieved_archive.id == archive.id + + await server.archive_manager.delete_archive_async(archive_id=archive.id, actor=default_user) + + with pytest.raises(Exception): + await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_archive_manager_get_agents_for_archive_async(server: SyncServer, default_user, sarah_agent): + """Test getting all agents that have access to an archive.""" + archive = await server.archive_manager.create_archive_async( + name="shared_archive", description="Archive shared by multiple agents", actor=default_user + ) + + agent2 = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="test_agent_2", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, + ), + actor=default_user, + ) + + await server.archive_manager.attach_agent_to_archive_async( + agent_id=sarah_agent.id, archive_id=archive.id, is_owner=True, actor=default_user + ) + + await server.archive_manager.attach_agent_to_archive_async( + agent_id=agent2.id, archive_id=archive.id, is_owner=False, actor=default_user + ) + + agent_ids = await server.archive_manager.get_agents_for_archive_async(archive_id=archive.id, actor=default_user) + + assert len(agent_ids) == 2 + assert sarah_agent.id in agent_ids + assert agent2.id in agent_ids + + # Cleanup + await server.agent_manager.delete_agent_async(agent2.id, actor=default_user) + await server.archive_manager.delete_archive_async(archive.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_archive_manager_race_condition_handling(server: SyncServer, default_user, sarah_agent): + """Test that the race condition fix in get_or_create_default_archive_for_agent_async works.""" + from unittest.mock import patch + + from sqlalchemy.exc import IntegrityError + + agent = await server.agent_manager.create_agent_async( + agent_create=CreateAgent( + name="test_agent_race_condition", + memory_blocks=[], + llm_config=LLMConfig.default_config("gpt-4o-mini"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + include_base_tools=False, + ), + actor=default_user, + ) + + created_archives = [] + original_create = server.archive_manager.create_archive_async + + async def track_create(*args, **kwargs): + result = await original_create(*args, **kwargs) + created_archives.append(result) + return result + + # First, create an archive that will be attached by a "concurrent" request + concurrent_archive = await server.archive_manager.create_archive_async( + name=f"{agent.name}'s Archive", description="Default archive created automatically", actor=default_user + ) + + call_count = 0 + original_attach = server.archive_manager.attach_agent_to_archive_async + + async def failing_attach(*args, **kwargs): + nonlocal call_count + call_count += 1 + if call_count == 1: + # Simulate another request already attached the agent to an archive + await original_attach(agent_id=agent.id, archive_id=concurrent_archive.id, is_owner=True, actor=default_user) + # Now raise the IntegrityError as if our attempt failed + raise IntegrityError("duplicate key value violates unique constraint", None, None) + # This shouldn't be called since we already have an archive + raise Exception("Should not reach here") + + with patch.object(server.archive_manager, "create_archive_async", side_effect=track_create): + with patch.object(server.archive_manager, "attach_agent_to_archive_async", side_effect=failing_attach): + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=agent.id, agent_name=agent.name, actor=default_user + ) + + assert archive is not None + assert archive.id == concurrent_archive.id # Should return the existing archive + assert archive.name == f"{agent.name}'s Archive" + + # One archive was created in our attempt (but then deleted) + assert len(created_archives) == 1 + + # Verify only one archive is attached to the agent + archive_ids = await server.agent_manager.get_agent_archive_ids_async(agent_id=agent.id, actor=default_user) + assert len(archive_ids) == 1 + assert archive_ids[0] == concurrent_archive.id + + # Cleanup + await server.agent_manager.delete_agent_async(agent.id, actor=default_user) + await server.archive_manager.delete_archive_async(concurrent_archive.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, default_user, sarah_agent): + """Test getting the agent ID that owns a passage through its archive.""" + archive = await server.archive_manager.get_or_create_default_archive_for_agent_async( + agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user + ) + + passage = await server.passage_manager.create_agent_passage_async( + PydanticPassage( + text="Test passage for agent ownership", + archive_id=archive.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + agent_id = await server.archive_manager.get_agent_from_passage_async(passage_id=passage.id, actor=default_user) + + assert agent_id == sarah_agent.id + + orphan_archive = await server.archive_manager.create_archive_async( + name="orphan_archive", description="Archive with no agents", actor=default_user + ) + + orphan_passage = await server.passage_manager.create_agent_passage_async( + PydanticPassage( + text="Orphan passage", + archive_id=orphan_archive.id, + organization_id=default_user.organization_id, + embedding=[0.1], + embedding_config=DEFAULT_EMBEDDING_CONFIG, + ), + actor=default_user, + ) + + agent_id = await server.archive_manager.get_agent_from_passage_async(passage_id=orphan_passage.id, actor=default_user) + assert agent_id is None + + # Cleanup + await server.passage_manager.delete_passage_by_id_async(passage.id, actor=default_user) + await server.passage_manager.delete_passage_by_id_async(orphan_passage.id, actor=default_user) + await server.archive_manager.delete_archive_async(orphan_archive.id, actor=default_user) + + # ====================================================================================================================== # User Manager Tests # ======================================================================================================================