feat: Add per-agent locking to send message (#2109)
This commit is contained in:
@@ -25,6 +25,7 @@ from letta.schemas.tool_rule import (
|
||||
ToolRule,
|
||||
)
|
||||
from letta.schemas.user import User
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types, get_utc_time, printd
|
||||
|
||||
@@ -383,7 +384,11 @@ class MetadataStore:
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def delete_agent(self, agent_id: str):
|
||||
def delete_agent(self, agent_id: str, per_agent_lock_manager: PerAgentLockManager):
|
||||
# TODO: Remove this once Agent is on the ORM
|
||||
# TODO: To prevent unbounded growth
|
||||
per_agent_lock_manager.clear_lock(agent_id)
|
||||
|
||||
with self.session_maker() as session:
|
||||
|
||||
# delete agents
|
||||
|
||||
@@ -475,19 +475,21 @@ async def send_message(
|
||||
"""
|
||||
actor = server.get_user_or_default(user_id=user_id)
|
||||
|
||||
result = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
messages=request.messages,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
return_message_object=request.return_message_object,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_function_name=request.assistant_message_function_name,
|
||||
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
|
||||
)
|
||||
agent_lock = server.per_agent_lock_manager.get_lock(agent_id)
|
||||
async with agent_lock:
|
||||
result = await send_message_to_agent(
|
||||
server=server,
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
messages=request.messages,
|
||||
stream_steps=request.stream_steps,
|
||||
stream_tokens=request.stream_tokens,
|
||||
return_message_object=request.return_message_object,
|
||||
# Support for AssistantMessage
|
||||
use_assistant_message=request.use_assistant_message,
|
||||
assistant_message_function_name=request.assistant_message_function_name,
|
||||
assistant_message_function_kwarg=request.assistant_message_function_kwarg,
|
||||
)
|
||||
return result
|
||||
|
||||
|
||||
|
||||
@@ -3,6 +3,7 @@ import os
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import abstractmethod
|
||||
from asyncio import Lock
|
||||
from datetime import datetime
|
||||
from typing import Callable, Dict, List, Optional, Tuple, Union
|
||||
|
||||
@@ -79,6 +80,7 @@ from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
@@ -231,6 +233,9 @@ class SyncServer(Server):
|
||||
|
||||
self.credentials = LettaCredentials.load()
|
||||
|
||||
# Locks
|
||||
self.send_message_lock = Lock()
|
||||
|
||||
# Initialize the metadata store
|
||||
config = LettaConfig.load()
|
||||
if settings.letta_pg_uri_no_default:
|
||||
@@ -252,6 +257,9 @@ class SyncServer(Server):
|
||||
self.blocks_agents_manager = BlocksAgentsManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager(tool_settings)
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
|
||||
# Make default user and org
|
||||
if init_with_default_org_and_user:
|
||||
self.default_org = self.organization_manager.create_default_organization()
|
||||
@@ -925,7 +933,7 @@ class SyncServer(Server):
|
||||
logger.exception(e)
|
||||
try:
|
||||
if agent:
|
||||
self.ms.delete_agent(agent_id=agent.agent_state.id)
|
||||
self.ms.delete_agent(agent_id=agent.agent_state.id, per_agent_lock_manager=self.per_agent_lock_manager)
|
||||
except Exception as delete_e:
|
||||
logger.exception(f"Failed to delete_agent:\n{delete_e}")
|
||||
raise e
|
||||
@@ -1522,7 +1530,7 @@ class SyncServer(Server):
|
||||
|
||||
# Next, attempt to delete it from the actual database
|
||||
try:
|
||||
self.ms.delete_agent(agent_id=agent_id)
|
||||
self.ms.delete_agent(agent_id=agent_id, per_agent_lock_manager=self.per_agent_lock_manager)
|
||||
except Exception as e:
|
||||
logger.exception(f"Failed to delete agent {agent_id} via ID with:\n{str(e)}")
|
||||
raise ValueError(f"Failed to delete agent {agent_id} in database")
|
||||
|
||||
18
letta/services/per_agent_lock_manager.py
Normal file
18
letta/services/per_agent_lock_manager.py
Normal file
@@ -0,0 +1,18 @@
|
||||
import asyncio
|
||||
from collections import defaultdict
|
||||
|
||||
|
||||
class PerAgentLockManager:
|
||||
"""Manages per-agent locks."""
|
||||
|
||||
def __init__(self):
|
||||
self.locks = defaultdict(asyncio.Lock)
|
||||
|
||||
def get_lock(self, agent_id: str) -> asyncio.Lock:
|
||||
"""Retrieve the lock for a specific agent_id."""
|
||||
return self.locks[agent_id]
|
||||
|
||||
def clear_lock(self, agent_id: str):
|
||||
"""Optionally remove a lock if no longer needed (to prevent unbounded growth)."""
|
||||
if agent_id in self.locks:
|
||||
del self.locks[agent_id]
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -295,3 +296,45 @@ def test_update_agent_memory_limit(client: Union[LocalClient, RESTClient], agent
|
||||
|
||||
finally:
|
||||
client.delete_agent(agent.id)
|
||||
|
||||
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Union[LocalClient, RESTClient], agent: AgentState, request):
|
||||
"""
|
||||
Test that sending two messages in parallel does not error.
|
||||
"""
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("This test only runs when the server is enabled")
|
||||
|
||||
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
async def send_message_task(message: str):
|
||||
response = await asyncio.to_thread(client.send_message, agent.id, message, role="user")
|
||||
assert response, f"Sending message '{message}' failed"
|
||||
return response
|
||||
|
||||
# Prepare two tasks with different messages
|
||||
messages = ["Test message 1", "Test message 2"]
|
||||
tasks = [send_message_task(message) for message in messages]
|
||||
|
||||
# Run the tasks concurrently
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for exceptions and validate responses
|
||||
for i, response in enumerate(responses):
|
||||
if isinstance(response, Exception):
|
||||
pytest.fail(f"Task {i} failed with exception: {response}")
|
||||
else:
|
||||
assert response, f"Task {i} returned an invalid response: {response}"
|
||||
|
||||
# Ensure both tasks completed
|
||||
assert len(responses) == len(messages), "Not all messages were processed"
|
||||
|
||||
@@ -223,16 +223,6 @@ def test_core_memory(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
assert "Timber" in memory.get_block("human").value, f"Updating core memory failed: {memory.get_block('human').value}"
|
||||
|
||||
|
||||
def test_messages(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
send_message_response = client.send_message(agent_id=agent.id, message="Test message", role="user")
|
||||
assert send_message_response, "Sending message failed"
|
||||
|
||||
messages_response = client.get_messages(agent_id=agent.id, limit=1)
|
||||
assert len(messages_response) > 0, "Retrieving messages failed"
|
||||
|
||||
|
||||
def test_streaming_send_message(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
if isinstance(client, LocalClient):
|
||||
pytest.skip("Skipping test_streaming_send_message because LocalClient does not support streaming")
|
||||
|
||||
Reference in New Issue
Block a user