diff --git a/letta/metadata.py b/letta/metadata.py index d492fdbc..1b8f6a22 100644 --- a/letta/metadata.py +++ b/letta/metadata.py @@ -25,6 +25,7 @@ from letta.schemas.tool_rule import ( ToolRule, ) from letta.schemas.user import User +from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.settings import settings from letta.utils import enforce_types, get_utc_time, printd @@ -383,7 +384,11 @@ class MetadataStore: session.commit() @enforce_types - def delete_agent(self, agent_id: str): + def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager): + # TODO: Remove this once Agent is on the ORM + # TODO: To prevent unbounded growth + per_agent_lock_manager.clear_lock(agent_id) + with self.session_maker() as session: # delete agents diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index bdc6a577..3553760e 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -475,19 +475,21 @@ async def send_message( """ actor = server.get_user_or_default(user_id=user_id) - result = await send_message_to_agent( - server=server, - agent_id=agent_id, - user_id=actor.id, - messages=request.messages, - stream_steps=request.stream_steps, - stream_tokens=request.stream_tokens, - return_message_object=request.return_message_object, - # Support for AssistantMessage - use_assistant_message=request.use_assistant_message, - assistant_message_function_name=request.assistant_message_function_name, - assistant_message_function_kwarg=request.assistant_message_function_kwarg, - ) + agent_lock = server.per_agent_lock_manager.get_lock(agent_id) + async with agent_lock: + result = await send_message_to_agent( + server=server, + agent_id=agent_id, + user_id=actor.id, + messages=request.messages, + stream_steps=request.stream_steps, + stream_tokens=request.stream_tokens, + return_message_object=request.return_message_object, + # Support for AssistantMessage + use_assistant_message=request.use_assistant_message, + assistant_message_function_name=request.assistant_message_function_name, + assistant_message_function_kwarg=request.assistant_message_function_kwarg, + ) return result diff --git a/letta/server/server.py b/letta/server/server.py index be4b827a..f06689aa 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -3,6 +3,7 @@ import os import traceback import warnings from abc import abstractmethod +from asyncio import Lock from datetime import datetime from typing import Callable, Dict, List, Optional, Tuple, Union @@ -79,6 +80,7 @@ from letta.services.agents_tags_manager import AgentsTagsManager from letta.services.block_manager import BlockManager from letta.services.blocks_agents_manager import BlocksAgentsManager from letta.services.organization_manager import OrganizationManager +from letta.services.per_agent_lock_manager import PerAgentLockManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager @@ -231,6 +233,9 @@ class SyncServer(Server): self.credentials = LettaCredentials.load() + # Locks + self.send_message_lock = Lock() + # Initialize the metadata store config = LettaConfig.load() if settings.letta_pg_uri_no_default: @@ -252,6 +257,9 @@ class SyncServer(Server): self.blocks_agents_manager = BlocksAgentsManager() self.sandbox_config_manager = SandboxConfigManager(tool_settings) + # Managers that interface with parallelism + self.per_agent_lock_manager = PerAgentLockManager() + # Make default user and org if init_with_default_org_and_user: self.default_org = self.organization_manager.create_default_organization() @@ -925,7 +933,7 @@ class SyncServer(Server): logger.exception(e) try: if agent: - self.ms.delete_agent(agent_id=agent.agent_state.id) + self.ms.delete_agent(agent_id=agent.agent_state.id, per_agent_lock_manager=self.per_agent_lock_manager) except Exception as delete_e: logger.exception(f"Failed to delete_agent:\n{delete_e}") raise e @@ -1522,7 +1530,7 @@ class SyncServer(Server): # Next, attempt to delete it from the actual database try: - self.ms.delete_agent(agent_id=agent_id) + self.ms.delete_agent(agent_id=agent_id, per_agent_lock_manager=self.per_agent_lock_manager) except Exception as e: logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}") raise ValueError(f"Failed to delete agent {agent_id} in database") diff --git a/letta/services/per_agent_lock_manager.py b/letta/services/per_agent_lock_manager.py new file mode 100644 index 00000000..53587fc7 --- /dev/null +++ b/letta/services/per_agent_lock_manager.py @@ -0,0 +1,18 @@ +import asyncio +from collections import defaultdict + + +class PerAgentLockManager: + """Manages per-agent locks.""" + + def __init__(self): + self.locks = defaultdict(asyncio.Lock) + + def get_lock(self, agent_id: str) -> asyncio.Lock: + """Retrieve the lock for a specific agent_id.""" + return self.locks[agent_id] + + def clear_lock(self, agent_id: str): + """Optionally remove a lock if no longer needed (to prevent unbounded growth).""" + if agent_id in self.locks: + del self.locks[agent_id] diff --git a/tests/test_client.py b/tests/test_client.py index b23fd85f..f9b368e2 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -1,3 +1,4 @@ +import asyncio import os import threading import time @@ -295,3 +296,45 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent finally: client.delete_agent(agent.id) + + +def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): + # _reset_config() + + send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") + assert send_message_response, "Sending message failed" + + messages_response = client.get_messages(agent_id=agent.id, limit=1) + assert len(messages_response) > 0, "Retrieving messages failed" + + +@pytest.mark.asyncio +async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request): + """ + Test that sending two messages in parallel does not error. + """ + if not isinstance(client, RESTClient): + pytest.skip("This test only runs when the server is enabled") + + # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls + async def send_message_task(message: str): + response = await asyncio.to_thread(client.send_message, agent.id, message, role="user") + assert response, f"Sending message '{message}' failed" + return response + + # Prepare two tasks with different messages + messages = ["Test message 1", "Test message 2"] + tasks = [send_message_task(message) for message in messages] + + # Run the tasks concurrently + responses = await asyncio.gather(*tasks, return_exceptions=True) + + # Check for exceptions and validate responses + for i, response in enumerate(responses): + if isinstance(response, Exception): + pytest.fail(f"Task {i} failed with exception: {response}") + else: + assert response, f"Task {i} returned an invalid response: {response}" + + # Ensure both tasks completed + assert len(responses) == len(messages), "Not all messages were processed" diff --git a/tests/test_client_legacy.py b/tests/test_client_legacy.py index 56bbf9a6..0b6f3821 100644 --- a/tests/test_client_legacy.py +++ b/tests/test_client_legacy.py @@ -223,16 +223,6 @@ def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState): assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}" -def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState): - # _reset_config() - - send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user") - assert send_message_response, "Sending message failed" - - messages_response = client.get_messages(agent_id=agent.id, limit=1) - assert len(messages_response) > 0, "Retrieving messages failed" - - def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState): if isinstance(client, LocalClient): pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")