feat: Change block access queries (#3247)

This commit is contained in:
Matthew Zhou
2025-07-09 13:22:22 -07:00
committed by GitHub
parent 4c540cf717
commit fcb894a4e3
2 changed files with 267 additions and 8 deletions

View File

@@ -4,7 +4,7 @@ from datetime import datetime, timezone
from typing import Dict, List, Optional, Set, Tuple
import sqlalchemy as sa
from sqlalchemy import delete, func, insert, literal, or_, select
from sqlalchemy import delete, func, insert, literal, or_, select, tuple_
from sqlalchemy.dialects.postgresql import insert as pg_insert
from letta.constants import (
@@ -224,13 +224,44 @@ class AgentManager:
@staticmethod
async def _replace_pivot_rows_async(session, table, agent_id: str, rows: list[dict]):
"""
Replace all pivot rows for an agent with *exactly* the provided list.
Uses two bulk statements (DELETE + INSERT ... ON CONFLICT DO NOTHING).
Replace all pivot rows for an agent atomically using MERGE pattern.
"""
# delete all existing rows for this agent
await session.execute(delete(table).where(table.c.agent_id == agent_id))
if rows:
await AgentManager._bulk_insert_pivot_async(session, table, rows)
dialect = session.bind.dialect.name
if dialect == "postgresql":
if rows:
# separate upsert and delete operations
stmt = pg_insert(table).values(rows)
stmt = stmt.on_conflict_do_nothing()
await session.execute(stmt)
# delete rows not in new set
pk_names = [c.name for c in table.primary_key.columns]
new_keys = [tuple(r[c] for c in pk_names) for r in rows]
await session.execute(
delete(table).where(table.c.agent_id == agent_id, ~tuple_(*[table.c[c] for c in pk_names]).in_(new_keys))
)
else:
# if no rows to insert, just delete all
await session.execute(delete(table).where(table.c.agent_id == agent_id))
elif dialect == "sqlite":
if rows:
stmt = sa.insert(table).values(rows).prefix_with("OR REPLACE")
await session.execute(stmt)
if rows:
primary_key_cols = [table.c[c.name] for c in table.primary_key.columns]
new_keys = [tuple(r[c.name] for c in table.primary_key.columns) for r in rows]
await session.execute(delete(table).where(table.c.agent_id == agent_id, ~tuple_(*primary_key_cols).in_(new_keys)))
else:
await session.execute(delete(table).where(table.c.agent_id == agent_id))
else:
# fallback: use original DELETE + INSERT pattern
await session.execute(delete(table).where(table.c.agent_id == agent_id))
if rows:
await AgentManager._bulk_insert_pivot_async(session, table, rows)
# ======================================================================================================================
# Basic CRUD operations

View File

@@ -3902,7 +3902,7 @@ async def test_bulk_update_return_hydrated_true(server: SyncServer, default_user
mgr = BlockManager()
# create a block
b = mgr.create_or_update_block(
b = await mgr.create_or_update_block_async(
PydanticBlock(label="persona", value="foo", limit=20),
actor=default_user,
)
@@ -7813,3 +7813,231 @@ async def test_attach_files_bulk_oversized_bulk(server, default_user, sarah_agen
# All files should be attached (some open, some closed)
all_files_after = await server.file_agent_manager.list_files_for_agent(sarah_agent.id, actor=default_user)
assert len(all_files_after) == MAX_FILES_OPEN + 3
# ======================================================================================================================
# Race Condition Tests - Blocks
# ======================================================================================================================
@pytest.mark.asyncio
async def test_concurrent_block_updates_race_condition(
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
):
"""Test that concurrent block updates don't cause race conditions."""
agent, _ = comprehensive_test_agent_fixture
# Create multiple blocks to use in concurrent updates
blocks = []
for i in range(5):
block = await server.block_manager.create_or_update_block_async(
PydanticBlock(label=f"test_block_{i}", value=f"Test block content {i}", limit=1000), actor=default_user
)
blocks.append(block)
# Test concurrent updates with different block combinations
async def update_agent_blocks(block_subset):
"""Update agent with a specific subset of blocks."""
update_request = UpdateAgent(block_ids=[b.id for b in block_subset])
try:
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
except Exception as e:
# Capture any errors that occur during concurrent updates
return {"error": str(e)}
# Run concurrent updates with different block combinations
tasks = [
update_agent_blocks(blocks[:2]), # blocks 0, 1
update_agent_blocks(blocks[1:3]), # blocks 1, 2
update_agent_blocks(blocks[2:4]), # blocks 2, 3
update_agent_blocks(blocks[3:5]), # blocks 3, 4
update_agent_blocks(blocks[:1]), # block 0 only
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify no exceptions occurred
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
assert len(errors) == 0, f"Concurrent updates failed with errors: {errors}"
# Verify all results are valid agent states
valid_results = [r for r in results if not isinstance(r, Exception) and not (isinstance(r, dict) and "error" in r)]
assert len(valid_results) == 5, "All concurrent updates should succeed"
# Verify final state is consistent
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
assert final_agent is not None
assert len(final_agent.memory.blocks) > 0
# Clean up
for block in blocks:
await server.block_manager.delete_block_async(block.id, actor=default_user)
@pytest.mark.asyncio
async def test_concurrent_same_block_updates_race_condition(
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
):
"""Test that multiple concurrent updates to the same block configuration don't cause issues."""
agent, _ = comprehensive_test_agent_fixture
# Create a single block configuration to use in all updates
block = await server.block_manager.create_or_update_block_async(
PydanticBlock(label="shared_block", value="Shared block content", limit=1000), actor=default_user
)
# Test multiple concurrent updates with the same block configuration
async def update_agent_with_same_blocks():
"""Update agent with the same block configuration."""
update_request = UpdateAgent(block_ids=[block.id])
try:
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
except Exception as e:
return {"error": str(e)}
# Run 10 concurrent identical updates
tasks = [update_agent_with_same_blocks() for _ in range(10)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify no exceptions occurred
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
assert len(errors) == 0, f"Concurrent identical updates failed with errors: {errors}"
# Verify final state is consistent
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
assert len(final_agent.memory.blocks) == 1
assert final_agent.memory.blocks[0].id == block.id
# Clean up
await server.block_manager.delete_block_async(block.id, actor=default_user)
@pytest.mark.asyncio
async def test_concurrent_empty_block_updates_race_condition(
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
):
"""Test concurrent updates that remove all blocks."""
agent, _ = comprehensive_test_agent_fixture
# Test concurrent updates that clear all blocks
async def clear_agent_blocks():
"""Update agent to have no blocks."""
update_request = UpdateAgent(block_ids=[])
try:
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
except Exception as e:
return {"error": str(e)}
# Run concurrent clear operations
tasks = [clear_agent_blocks() for _ in range(5)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify no exceptions occurred
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
assert len(errors) == 0, f"Concurrent clear operations failed with errors: {errors}"
# Verify final state is consistent (no blocks)
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
assert len(final_agent.memory.blocks) == 0
@pytest.mark.asyncio
async def test_concurrent_mixed_block_operations_race_condition(
server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop
):
"""Test mixed concurrent operations: some adding blocks, some removing."""
agent, _ = comprehensive_test_agent_fixture
# Create test blocks
blocks = []
for i in range(3):
block = await server.block_manager.create_or_update_block_async(
PydanticBlock(label=f"mixed_block_{i}", value=f"Mixed block content {i}", limit=1000), actor=default_user
)
blocks.append(block)
# Mix of operations: add blocks, remove blocks, clear all
async def mixed_operation(operation_type):
"""Perform different types of block operations."""
if operation_type == "add_all":
update_request = UpdateAgent(block_ids=[b.id for b in blocks])
elif operation_type == "add_subset":
update_request = UpdateAgent(block_ids=[blocks[0].id])
elif operation_type == "clear":
update_request = UpdateAgent(block_ids=[])
else:
update_request = UpdateAgent(block_ids=[blocks[1].id, blocks[2].id])
try:
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
except Exception as e:
return {"error": str(e)}
# Run mixed concurrent operations
tasks = [
mixed_operation("add_all"),
mixed_operation("add_subset"),
mixed_operation("clear"),
mixed_operation("add_two"),
mixed_operation("add_all"),
]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify no exceptions occurred
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
assert len(errors) == 0, f"Mixed concurrent operations failed with errors: {errors}"
# Verify final state is consistent (any valid state is acceptable)
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
assert final_agent is not None
# Clean up
for block in blocks:
await server.block_manager.delete_block_async(block.id, actor=default_user)
@pytest.mark.asyncio
async def test_high_concurrency_stress_test(server: SyncServer, comprehensive_test_agent_fixture, default_user: PydanticUser, event_loop):
"""Stress test with high concurrency to catch race conditions."""
agent, _ = comprehensive_test_agent_fixture
# Create many blocks for stress testing
blocks = []
for i in range(10):
block = await server.block_manager.create_or_update_block_async(
PydanticBlock(label=f"stress_block_{i}", value=f"Stress test content {i}", limit=1000), actor=default_user
)
blocks.append(block)
# Create many concurrent update tasks
async def stress_update(task_id):
"""Perform a random block update operation."""
import random
# Random subset of blocks
num_blocks = random.randint(0, len(blocks))
selected_blocks = random.sample(blocks, num_blocks)
update_request = UpdateAgent(block_ids=[b.id for b in selected_blocks])
try:
return await server.agent_manager.update_agent_async(agent.id, update_request, actor=default_user)
except Exception as e:
return {"error": str(e), "task_id": task_id}
# Run 20 concurrent stress updates
tasks = [stress_update(i) for i in range(20)]
results = await asyncio.gather(*tasks, return_exceptions=True)
# Verify no exceptions occurred
errors = [r for r in results if isinstance(r, Exception) or (isinstance(r, dict) and "error" in r)]
assert len(errors) == 0, f"High concurrency stress test failed with errors: {errors}"
# Verify final state is consistent
final_agent = await server.agent_manager.get_agent_by_id_async(agent.id, actor=default_user)
assert final_agent is not None
# Clean up
for block in blocks:
await server.block_manager.delete_block_async(block.id, actor=default_user)