diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index ac41f37d..44a8eb6f 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -4,7 +4,7 @@ from datetime import datetime, timezone from typing import Dict, List, Optional, Set, Tuple import sqlalchemy as sa -from sqlalchemy import delete, func, insert, literal, or_, select +from sqlalchemy import delete, func, insert, literal, or_, select, tuple_ from sqlalchemy.dialects.postgresql import insert as pg_insert from letta.constants import ( @@ -224,13 +224,44 @@ class AgentManager: @staticmethod async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]): """ - Replace all pivot rows for an agent with *exactly* the provided list. - Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING). + Replace all pivot rows for an agent atomically using MERGE pattern. """ - # delete all existing rows for this agent - await session.execute(delete(table).where(table.c.agent_id == agent_id)) - if rows: - await AgentManager._bulk_insert_pivot_async(session, table, rows) + dialect = session.bind.dialect.name + + if dialect == "postgresql": + if rows: + # separate upsert and delete operations + stmt = pg_insert(table).values(rows) + stmt = stmt.on_conflict_do_nothing() + await session.execute(stmt) + + # delete rows not in new set + pk_names = [c.name for c in table.primary_key.columns] + new_keys = [tuple(r[c] for c in pk_names) for r in rows] + await session.execute( + delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys)) + ) + else: + # if no rows to insert, just delete all + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + + elif dialect == "sqlite": + if rows: + stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE") + await session.execute(stmt) + + if rows: + primary_key_cols = [table.c[c.name] for c in table.primary_key.columns] + new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows] + await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys))) + else: + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + + else: + # fallback: use original DELETE + INSERT pattern + await session.execute(delete(table).where(table.c.agent_id == agent_id)) + if rows: + await AgentManager._bulk_insert_pivot_async(session, table, rows) # ====================================================================================================================== # Basic CRUD operations diff --git a/tests/test_managers.py b/tests/test_managers.py index 8b88000a..be30c032 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3902,7 +3902,7 @@ async def test_bulk_update_return_hydrated_true(server: SyncServer, default_user mgr = BlockManager() # create a block - b = mgr.create_or_update_block( + b = await mgr.create_or_update_block_async( PydanticBlock(label="persona", value="foo", limit=20), actor=default_user, ) @@ -7813,3 +7813,231 @@ async def test_attach_files_bulk_oversized_bulk(server, default_user, sarah_agen # All files should be attached (some open, some closed) all_files_after = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user) assert len(all_files_after) == MAX_FILES_OPEN + 3 + + +# ====================================================================================================================== +# Race Condition Tests - Blocks +# ====================================================================================================================== + + +@pytest.mark.asyncio +async def test_concurrent_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test that concurrent block updates don't cause race conditions.""" + agent, _ = comprehensive_test_agent_fixture + + # Create multiple blocks to use in concurrent updates + blocks = [] + for i in range(5): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"test_block_{i}", value=f"Test block content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Test concurrent updates with different block combinations + async def update_agent_blocks(block_subset): + """Update agent with a specific subset of blocks.""" + update_request = UpdateAgent(block_ids=[b.id for b in block_subset]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + # Capture any errors that occur during concurrent updates + return {"error": str(e)} + + # Run concurrent updates with different block combinations + tasks = [ + update_agent_blocks(blocks[:2]), # blocks 0, 1 + update_agent_blocks(blocks[1:3]), # blocks 1, 2 + update_agent_blocks(blocks[2:4]), # blocks 2, 3 + update_agent_blocks(blocks[3:5]), # blocks 3, 4 + update_agent_blocks(blocks[:1]), # block 0 only + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent updates failed with errors: {errors}" + + # Verify all results are valid agent states + valid_results = [r for r in results if not isinstance(r, Exception) and not (isinstance(r, dict) and "error" in r)] + assert len(valid_results) == 5, "All concurrent updates should succeed" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + assert len(final_agent.memory.blocks) > 0 + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_concurrent_same_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test that multiple concurrent updates to the same block configuration don't cause issues.""" + agent, _ = comprehensive_test_agent_fixture + + # Create a single block configuration to use in all updates + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label="shared_block", value="Shared block content", limit=1000), actor=default_user + ) + + # Test multiple concurrent updates with the same block configuration + async def update_agent_with_same_blocks(): + """Update agent with the same block configuration.""" + update_request = UpdateAgent(block_ids=[block.id]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run 10 concurrent identical updates + tasks = [update_agent_with_same_blocks() for _ in range(10)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent identical updates failed with errors: {errors}" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert len(final_agent.memory.blocks) == 1 + assert final_agent.memory.blocks[0].id == block.id + + # Clean up + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_concurrent_empty_block_updates_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test concurrent updates that remove all blocks.""" + agent, _ = comprehensive_test_agent_fixture + + # Test concurrent updates that clear all blocks + async def clear_agent_blocks(): + """Update agent to have no blocks.""" + update_request = UpdateAgent(block_ids=[]) + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run concurrent clear operations + tasks = [clear_agent_blocks() for _ in range(5)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Concurrent clear operations failed with errors: {errors}" + + # Verify final state is consistent (no blocks) + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert len(final_agent.memory.blocks) == 0 + + +@pytest.mark.asyncio +async def test_concurrent_mixed_block_operations_race_condition( + server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop +): + """Test mixed concurrent operations: some adding blocks, some removing.""" + agent, _ = comprehensive_test_agent_fixture + + # Create test blocks + blocks = [] + for i in range(3): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"mixed_block_{i}", value=f"Mixed block content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Mix of operations: add blocks, remove blocks, clear all + async def mixed_operation(operation_type): + """Perform different types of block operations.""" + if operation_type == "add_all": + update_request = UpdateAgent(block_ids=[b.id for b in blocks]) + elif operation_type == "add_subset": + update_request = UpdateAgent(block_ids=[blocks[0].id]) + elif operation_type == "clear": + update_request = UpdateAgent(block_ids=[]) + else: + update_request = UpdateAgent(block_ids=[blocks[1].id, blocks[2].id]) + + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e)} + + # Run mixed concurrent operations + tasks = [ + mixed_operation("add_all"), + mixed_operation("add_subset"), + mixed_operation("clear"), + mixed_operation("add_two"), + mixed_operation("add_all"), + ] + + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"Mixed concurrent operations failed with errors: {errors}" + + # Verify final state is consistent (any valid state is acceptable) + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_high_concurrency_stress_test(server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop): + """Stress test with high concurrency to catch race conditions.""" + agent, _ = comprehensive_test_agent_fixture + + # Create many blocks for stress testing + blocks = [] + for i in range(10): + block = await server.block_manager.create_or_update_block_async( + PydanticBlock(label=f"stress_block_{i}", value=f"Stress test content {i}", limit=1000), actor=default_user + ) + blocks.append(block) + + # Create many concurrent update tasks + async def stress_update(task_id): + """Perform a random block update operation.""" + import random + + # Random subset of blocks + num_blocks = random.randint(0, len(blocks)) + selected_blocks = random.sample(blocks, num_blocks) + + update_request = UpdateAgent(block_ids=[b.id for b in selected_blocks]) + + try: + return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user) + except Exception as e: + return {"error": str(e), "task_id": task_id} + + # Run 20 concurrent stress updates + tasks = [stress_update(i) for i in range(20)] + results = await asyncio.gather(*tasks, return_exceptions=True) + + # Verify no exceptions occurred + errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)] + assert len(errors) == 0, f"High concurrency stress test failed with errors: {errors}" + + # Verify final state is consistent + final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user) + assert final_agent is not None + + # Clean up + for block in blocks: + await server.block_manager.delete_block_async(block.id, actor=default_user)