feat: Fix race condition with creating archival memories in parallel [LET-4205] (#4464)
* Test archive manager and add race condition handling * Fix client tests * Remove bad test
This commit is contained in:
@@ -4130,6 +4130,176 @@ async def test_search_agent_archival_memory_async(disable_turbopuffer, server: S
|
||||
await server.passage_manager.delete_agent_passage_by_id_async(passage_id=passage.id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Archive Manager Tests
|
||||
# ======================================================================================================================
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_manager_delete_archive_async(server: SyncServer, default_user):
|
||||
"""Test the delete_archive_async function."""
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="test_archive_to_delete", description="This archive will be deleted", actor=default_user
|
||||
)
|
||||
|
||||
retrieved_archive = await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user)
|
||||
assert retrieved_archive.id == archive.id
|
||||
|
||||
await server.archive_manager.delete_archive_async(archive_id=archive.id, actor=default_user)
|
||||
|
||||
with pytest.raises(Exception):
|
||||
await server.archive_manager.get_archive_by_id_async(archive_id=archive.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_manager_get_agents_for_archive_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test getting all agents that have access to an archive."""
|
||||
archive = await server.archive_manager.create_archive_async(
|
||||
name="shared_archive", description="Archive shared by multiple agents", actor=default_user
|
||||
)
|
||||
|
||||
agent2 = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="test_agent_2",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
await server.archive_manager.attach_agent_to_archive_async(
|
||||
agent_id=sarah_agent.id, archive_id=archive.id, is_owner=True, actor=default_user
|
||||
)
|
||||
|
||||
await server.archive_manager.attach_agent_to_archive_async(
|
||||
agent_id=agent2.id, archive_id=archive.id, is_owner=False, actor=default_user
|
||||
)
|
||||
|
||||
agent_ids = await server.archive_manager.get_agents_for_archive_async(archive_id=archive.id, actor=default_user)
|
||||
|
||||
assert len(agent_ids) == 2
|
||||
assert sarah_agent.id in agent_ids
|
||||
assert agent2.id in agent_ids
|
||||
|
||||
# Cleanup
|
||||
await server.agent_manager.delete_agent_async(agent2.id, actor=default_user)
|
||||
await server.archive_manager.delete_archive_async(archive.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_manager_race_condition_handling(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test that the race condition fix in get_or_create_default_archive_for_agent_async works."""
|
||||
from unittest.mock import patch
|
||||
|
||||
from sqlalchemy.exc import IntegrityError
|
||||
|
||||
agent = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="test_agent_race_condition",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
created_archives = []
|
||||
original_create = server.archive_manager.create_archive_async
|
||||
|
||||
async def track_create(*args, **kwargs):
|
||||
result = await original_create(*args, **kwargs)
|
||||
created_archives.append(result)
|
||||
return result
|
||||
|
||||
# First, create an archive that will be attached by a "concurrent" request
|
||||
concurrent_archive = await server.archive_manager.create_archive_async(
|
||||
name=f"{agent.name}'s Archive", description="Default archive created automatically", actor=default_user
|
||||
)
|
||||
|
||||
call_count = 0
|
||||
original_attach = server.archive_manager.attach_agent_to_archive_async
|
||||
|
||||
async def failing_attach(*args, **kwargs):
|
||||
nonlocal call_count
|
||||
call_count += 1
|
||||
if call_count == 1:
|
||||
# Simulate another request already attached the agent to an archive
|
||||
await original_attach(agent_id=agent.id, archive_id=concurrent_archive.id, is_owner=True, actor=default_user)
|
||||
# Now raise the IntegrityError as if our attempt failed
|
||||
raise IntegrityError("duplicate key value violates unique constraint", None, None)
|
||||
# This shouldn't be called since we already have an archive
|
||||
raise Exception("Should not reach here")
|
||||
|
||||
with patch.object(server.archive_manager, "create_archive_async", side_effect=track_create):
|
||||
with patch.object(server.archive_manager, "attach_agent_to_archive_async", side_effect=failing_attach):
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=agent.id, agent_name=agent.name, actor=default_user
|
||||
)
|
||||
|
||||
assert archive is not None
|
||||
assert archive.id == concurrent_archive.id # Should return the existing archive
|
||||
assert archive.name == f"{agent.name}'s Archive"
|
||||
|
||||
# One archive was created in our attempt (but then deleted)
|
||||
assert len(created_archives) == 1
|
||||
|
||||
# Verify only one archive is attached to the agent
|
||||
archive_ids = await server.agent_manager.get_agent_archive_ids_async(agent_id=agent.id, actor=default_user)
|
||||
assert len(archive_ids) == 1
|
||||
assert archive_ids[0] == concurrent_archive.id
|
||||
|
||||
# Cleanup
|
||||
await server.agent_manager.delete_agent_async(agent.id, actor=default_user)
|
||||
await server.archive_manager.delete_archive_async(concurrent_archive.id, actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_archive_manager_get_agent_from_passage_async(server: SyncServer, default_user, sarah_agent):
|
||||
"""Test getting the agent ID that owns a passage through its archive."""
|
||||
archive = await server.archive_manager.get_or_create_default_archive_for_agent_async(
|
||||
agent_id=sarah_agent.id, agent_name=sarah_agent.name, actor=default_user
|
||||
)
|
||||
|
||||
passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
text="Test passage for agent ownership",
|
||||
archive_id=archive.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent_id = await server.archive_manager.get_agent_from_passage_async(passage_id=passage.id, actor=default_user)
|
||||
|
||||
assert agent_id == sarah_agent.id
|
||||
|
||||
orphan_archive = await server.archive_manager.create_archive_async(
|
||||
name="orphan_archive", description="Archive with no agents", actor=default_user
|
||||
)
|
||||
|
||||
orphan_passage = await server.passage_manager.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
text="Orphan passage",
|
||||
archive_id=orphan_archive.id,
|
||||
organization_id=default_user.organization_id,
|
||||
embedding=[0.1],
|
||||
embedding_config=DEFAULT_EMBEDDING_CONFIG,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent_id = await server.archive_manager.get_agent_from_passage_async(passage_id=orphan_passage.id, actor=default_user)
|
||||
assert agent_id is None
|
||||
|
||||
# Cleanup
|
||||
await server.passage_manager.delete_passage_by_id_async(passage.id, actor=default_user)
|
||||
await server.passage_manager.delete_passage_by_id_async(orphan_passage.id, actor=default_user)
|
||||
await server.archive_manager.delete_archive_async(orphan_archive.id, actor=default_user)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user