feat: Add per-agent locking to send message (#2109)

This commit is contained in:
Matthew Zhou
2024-11-26 13:30:58 -08:00
committed by GitHub
parent 4d9b4eef9d
commit 056cbb0eec
6 changed files with 92 additions and 26 deletions

View File

@@ -25,6 +25,7 @@ from letta.schemas.tool_rule import (
ToolRule,
)
from letta.schemas.user import User
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.settings import settings
from letta.utils import enforce_types, get_utc_time, printd
@@ -383,7 +384,11 @@ class MetadataStore:
session.commit()
@enforce_types
def delete_agent(self, agent_id: str):
def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager):
# TODO: Remove this once Agent is on the ORM
# TODO: To prevent unbounded growth
per_agent_lock_manager.clear_lock(agent_id)
with self.session_maker() as session:
# delete agents

View File

@@ -475,19 +475,21 @@ async def send_message(
"""
actor = server.get_user_or_default(user_id=user_id)
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
messages=request.messages,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_function_name=request.assistant_message_function_name,
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
)
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
async with agent_lock:
result = await send_message_to_agent(
server=server,
agent_id=agent_id,
user_id=actor.id,
messages=request.messages,
stream_steps=request.stream_steps,
stream_tokens=request.stream_tokens,
return_message_object=request.return_message_object,
# Support for AssistantMessage
use_assistant_message=request.use_assistant_message,
assistant_message_function_name=request.assistant_message_function_name,
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
)
return result

View File

@@ -3,6 +3,7 @@ import os
import traceback
import warnings
from abc import abstractmethod
from asyncio import Lock
from datetime import datetime
from typing import Callable, Dict, List, Optional, Tuple, Union
@@ -79,6 +80,7 @@ from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.organization_manager import OrganizationManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
from letta.services.tool_manager import ToolManager
@@ -231,6 +233,9 @@ class SyncServer(Server):
self.credentials = LettaCredentials.load()
# Locks
self.send_message_lock = Lock()
# Initialize the metadata store
config = LettaConfig.load()
if settings.letta_pg_uri_no_default:
@@ -252,6 +257,9 @@ class SyncServer(Server):
self.blocks_agents_manager = BlocksAgentsManager()
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
# Managers that interface with parallelism
self.per_agent_lock_manager = PerAgentLockManager()
# Make default user and org
if init_with_default_org_and_user:
self.default_org = self.organization_manager.create_default_organization()
@@ -925,7 +933,7 @@ class SyncServer(Server):
logger.exception(e)
try:
if agent:
self.ms.delete_agent(agent_id=agent.agent_state.id)
self.ms.delete_agent(agent_id=agent.agent_state.id, per_agent_lock_manager=self.per_agent_lock_manager)
except Exception as delete_e:
logger.exception(f"Failed to delete_agent:\n{delete_e}")
raise e
@@ -1522,7 +1530,7 @@ class SyncServer(Server):
# Next, attempt to delete it from the actual database
try:
self.ms.delete_agent(agent_id=agent_id)
self.ms.delete_agent(agent_id=agent_id, per_agent_lock_manager=self.per_agent_lock_manager)
except Exception as e:
logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}")
raise ValueError(f"Failed to delete agent {agent_id} in database")

View File

@@ -0,0 +1,18 @@
import asyncio
from collections import defaultdict
class PerAgentLockManager:
"""Manages per-agent locks."""
def __init__(self):
self.locks = defaultdict(asyncio.Lock)
def get_lock(self, agent_id: str) -> asyncio.Lock:
"""Retrieve the lock for a specific agent_id."""
return self.locks[agent_id]
def clear_lock(self, agent_id: str):
"""Optionally remove a lock if no longer needed (to prevent unbounded growth)."""
if agent_id in self.locks:
del self.locks[agent_id]

View File

@@ -1,3 +1,4 @@
import asyncio
import os
import threading
import time
@@ -295,3 +296,45 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent
finally:
client.delete_agent(agent.id)
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
@pytest.mark.asyncio
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
"""
Test that sending two messages in parallel does not error.
"""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
async def send_message_task(message: str):
response = await asyncio.to_thread(client.send_message, agent.id, message, role="user")
assert response, f"Sending message '{message}' failed"
return response
# Prepare two tasks with different messages
messages = ["Test message 1", "Test message 2"]
tasks = [send_message_task(message) for message in messages]
# Run the tasks concurrently
responses = await asyncio.gather(*tasks, return_exceptions=True)
# Check for exceptions and validate responses
for i, response in enumerate(responses):
if isinstance(response, Exception):
pytest.fail(f"Task {i} failed with exception: {response}")
else:
assert response, f"Task {i} returned an invalid response: {response}"
# Ensure both tasks completed
assert len(responses) == len(messages), "Not all messages were processed"

View File

@@ -223,16 +223,6 @@ def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
assert send_message_response, "Sending message failed"
messages_response = client.get_messages(agent_id=agent.id, limit=1)
assert len(messages_response) > 0, "Retrieving messages failed"
def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState):
if isinstance(client, LocalClient):
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")