feat: Make multi agent broadcast directly invoke step (#1355)
This commit is contained in:
@@ -522,7 +522,7 @@ class Agent(BaseAgent):
|
||||
openai_message_dict=response_message.model_dump(),
|
||||
)
|
||||
) # extend conversation with assistant's reply
|
||||
self.logger.info(f"Function call message: {messages[-1]}")
|
||||
self.logger.debug(f"Function call message: {messages[-1]}")
|
||||
|
||||
nonnull_content = False
|
||||
if response_message.content:
|
||||
@@ -786,6 +786,7 @@ class Agent(BaseAgent):
|
||||
total_usage = UsageStatistics()
|
||||
step_count = 0
|
||||
function_failed = False
|
||||
steps_messages = []
|
||||
while True:
|
||||
kwargs["first_message"] = False
|
||||
kwargs["step_count"] = step_count
|
||||
@@ -800,6 +801,7 @@ class Agent(BaseAgent):
|
||||
function_failed = step_response.function_failed
|
||||
token_warning = step_response.in_context_memory_warning
|
||||
usage = step_response.usage
|
||||
steps_messages.append(step_response.messages)
|
||||
|
||||
step_count += 1
|
||||
total_usage += usage
|
||||
@@ -859,9 +861,9 @@ class Agent(BaseAgent):
|
||||
break
|
||||
|
||||
if self.agent_state.message_buffer_autoclear:
|
||||
self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)
|
||||
self.agent_state = self.agent_manager.trim_all_in_context_messages_except_system(self.agent_state.id, actor=self.user)
|
||||
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count)
|
||||
return LettaUsageStatistics(**total_usage.model_dump(), step_count=step_count, steps_messages=steps_messages)
|
||||
|
||||
def inner_step(
|
||||
self,
|
||||
|
||||
@@ -274,7 +274,7 @@ class VoiceAgent(BaseAgent):
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
if len(diff) > 0:
|
||||
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
new_system_message = self.message_manager.update_message_by_id(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
|
||||
@@ -1,16 +1,19 @@
|
||||
import asyncio
|
||||
import json
|
||||
from concurrent.futures import ThreadPoolExecutor, as_completed
|
||||
from typing import TYPE_CHECKING, List
|
||||
|
||||
from letta.functions.helpers import (
|
||||
_send_message_to_agents_matching_tags_async,
|
||||
_send_message_to_all_agents_in_group_async,
|
||||
execute_send_message_to_agent,
|
||||
extract_send_message_from_steps_messages,
|
||||
fire_and_forget_send_to_agent,
|
||||
)
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.message import MessageCreate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.utils import log_telemetry
|
||||
from letta.settings import settings
|
||||
|
||||
if TYPE_CHECKING:
|
||||
from letta.agent import Agent
|
||||
@@ -87,51 +90,59 @@ def send_message_to_agents_matching_tags(self: "Agent", message: str, match_all:
|
||||
response corresponds to a single agent. Agents that do not respond will not have an entry
|
||||
in the returned list.
|
||||
"""
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_send_message_to_agents_matching_tags_async start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
server = get_letta_server()
|
||||
|
||||
augmented_message = (
|
||||
f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, "
|
||||
f"[Incoming message from external Letta agent - to reply to this message, "
|
||||
f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] "
|
||||
f"{message}"
|
||||
)
|
||||
|
||||
# Retrieve up to 100 matching agents
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents start",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
# Find matching agents
|
||||
matching_agents = server.agent_manager.list_agents_matching_tags(actor=self.user, match_all=match_all, match_some=match_some)
|
||||
if not matching_agents:
|
||||
return []
|
||||
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_send_message_to_agents_matching_tags_async listing agents finish",
|
||||
message=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
def process_agent(agent_id: str) -> str:
|
||||
"""Loads an agent, formats the message, and executes .step()"""
|
||||
actor = self.user # Ensure correct actor context
|
||||
agent = server.load_agent(agent_id=agent_id, interface=None, actor=actor)
|
||||
|
||||
# Create a system message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
# Prepare the message
|
||||
messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)]
|
||||
input_messages = [prepare_input_message_create(m, agent_id) for m in messages]
|
||||
|
||||
result = asyncio.run(_send_message_to_agents_matching_tags_async(self, server, messages, matching_agents))
|
||||
log_telemetry(
|
||||
self.logger,
|
||||
"_send_message_to_agents_matching_tags_async finish",
|
||||
messages=message,
|
||||
match_all=match_all,
|
||||
match_some=match_some,
|
||||
)
|
||||
return result
|
||||
# Run .step() and return the response
|
||||
usage_stats = agent.step(
|
||||
messages=input_messages,
|
||||
chaining=True,
|
||||
max_chaining_steps=None,
|
||||
stream=False,
|
||||
skip_verify=True,
|
||||
metadata=None,
|
||||
put_inner_thoughts_first=True,
|
||||
)
|
||||
|
||||
send_messages = extract_send_message_from_steps_messages(usage_stats.steps_messages, logger=agent.logger)
|
||||
response_data = {
|
||||
"agent_id": agent_id,
|
||||
"response_messages": send_messages if send_messages else ["<no response>"],
|
||||
}
|
||||
|
||||
return json.dumps(response_data, indent=2)
|
||||
|
||||
# Use ThreadPoolExecutor for parallel execution
|
||||
results = []
|
||||
with ThreadPoolExecutor(max_workers=settings.multi_agent_concurrent_sends) as executor:
|
||||
future_to_agent = {executor.submit(process_agent, agent_state.id): agent_state for agent_state in matching_agents}
|
||||
|
||||
for future in as_completed(future_to_agent):
|
||||
try:
|
||||
results.append(future.result()) # Collect results
|
||||
except Exception as e:
|
||||
# Log or handle failure for specific agents if needed
|
||||
self.logger.exception(f"Error processing agent {future_to_agent[future]}: {e}")
|
||||
|
||||
return results
|
||||
|
||||
|
||||
def send_message_to_all_agents_in_group(self: "Agent", message: str) -> List[str]:
|
||||
|
||||
@@ -1,4 +1,6 @@
|
||||
import asyncio
|
||||
import json
|
||||
import logging
|
||||
import threading
|
||||
from random import uniform
|
||||
from typing import Any, Dict, List, Optional, Type, Union
|
||||
@@ -17,7 +19,6 @@ from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.settings import settings
|
||||
from letta.utils import log_telemetry
|
||||
|
||||
|
||||
# TODO: This is kind of hacky, as this is used to search up the action later on composio's side
|
||||
@@ -386,15 +387,9 @@ async def async_send_message_with_retries(
|
||||
logging_prefix: Optional[str] = None,
|
||||
) -> str:
|
||||
logging_prefix = logging_prefix or "[async_send_message_with_retries]"
|
||||
log_telemetry(sender_agent.logger, f"async_send_message_with_retries start", target_agent_id=target_agent_id)
|
||||
|
||||
for attempt in range(1, max_retries + 1):
|
||||
try:
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream start",
|
||||
target_agent_id=target_agent_id,
|
||||
)
|
||||
response = await asyncio.wait_for(
|
||||
send_message_to_agent_no_stream(
|
||||
server=server,
|
||||
@@ -404,24 +399,15 @@ async def async_send_message_with_retries(
|
||||
),
|
||||
timeout=timeout,
|
||||
)
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
f"async_send_message_with_retries -> asyncio wait for send_message_to_agent_no_stream finish",
|
||||
target_agent_id=target_agent_id,
|
||||
)
|
||||
|
||||
# Then parse out the assistant message
|
||||
assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response)
|
||||
if assistant_message:
|
||||
sender_agent.logger.info(f"{logging_prefix} - {assistant_message}")
|
||||
log_telemetry(
|
||||
sender_agent.logger, f"async_send_message_with_retries finish with assistant message", target_agent_id=target_agent_id
|
||||
)
|
||||
return assistant_message
|
||||
else:
|
||||
msg = f"(No response from agent {target_agent_id})"
|
||||
sender_agent.logger.info(f"{logging_prefix} - {msg}")
|
||||
log_telemetry(sender_agent.logger, f"async_send_message_with_retries finish no response", target_agent_id=target_agent_id)
|
||||
return msg
|
||||
|
||||
except asyncio.TimeoutError:
|
||||
@@ -439,12 +425,6 @@ async def async_send_message_with_retries(
|
||||
await asyncio.sleep(backoff)
|
||||
else:
|
||||
sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}")
|
||||
log_telemetry(
|
||||
sender_agent.logger,
|
||||
f"async_send_message_with_retries finish fatal error",
|
||||
target_agent_id=target_agent_id,
|
||||
error_msg=error_msg,
|
||||
)
|
||||
raise Exception(error_msg)
|
||||
|
||||
|
||||
@@ -673,3 +653,27 @@ def _get_field_type(field_schema: Dict[str, Any], nested_models: Dict[str, Type[
|
||||
else:
|
||||
return Union[tuple(types)]
|
||||
raise ValueError(f"Unable to convert pydantic field schema to type: {field_schema}")
|
||||
|
||||
|
||||
def extract_send_message_from_steps_messages(
|
||||
steps_messages: List[List[Message]],
|
||||
agent_send_message_tool_name: str = DEFAULT_MESSAGE_TOOL,
|
||||
agent_send_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
logger: Optional[logging.Logger] = None,
|
||||
) -> List[str]:
|
||||
extracted_messages = []
|
||||
|
||||
for step in steps_messages:
|
||||
for message in step:
|
||||
if message.tool_calls:
|
||||
for tool_call in message.tool_calls:
|
||||
if tool_call.function.name == agent_send_message_tool_name:
|
||||
try:
|
||||
# Parse arguments to extract the "message" field
|
||||
arguments = json.loads(tool_call.function.arguments)
|
||||
if agent_send_message_tool_kwarg in arguments:
|
||||
extracted_messages.append(arguments[agent_send_message_tool_kwarg])
|
||||
except json.JSONDecodeError:
|
||||
logger.error(f"Failed to parse arguments for tool call: {tool_call.id}")
|
||||
|
||||
return extracted_messages
|
||||
|
||||
41
letta/helpers/message_helper.py
Normal file
41
letta/helpers/message_helper.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from letta import system
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
|
||||
|
||||
def prepare_input_message_create(
|
||||
message: MessageCreate,
|
||||
agent_id: str,
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
) -> Message:
|
||||
"""Converts a MessageCreate object into a Message object, applying wrapping if needed."""
|
||||
# TODO: This seems like extra boilerplate with little benefit
|
||||
assert isinstance(message, MessageCreate)
|
||||
|
||||
# Extract message content
|
||||
if isinstance(message.content, str):
|
||||
message_content = message.content
|
||||
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
|
||||
message_content = message.content[0].text
|
||||
else:
|
||||
raise ValueError("Message content is empty or invalid")
|
||||
|
||||
# Apply wrapping if needed
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
elif message.role not in {MessageRole.user, MessageRole.system}:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
return Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message_content)] if message_content else [],
|
||||
name=message.name,
|
||||
model=None, # assigned later?
|
||||
tool_calls=None, # irrelevant
|
||||
tool_call_id=None,
|
||||
)
|
||||
@@ -361,14 +361,12 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
if identifier_set != results_set:
|
||||
# Construct a detailed error message based on query conditions
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
logger.warning(
|
||||
f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}"
|
||||
)
|
||||
logger.debug(f"{cls.__name__} not found with {conditions_str}. Queried ids: {identifier_set}, Found ids: {results_set}")
|
||||
return results
|
||||
|
||||
# Construct a detailed error message based on query conditions
|
||||
conditions_str = ", ".join(query_conditions) if query_conditions else "no specific conditions"
|
||||
logger.warning(f"{cls.__name__} not found with {conditions_str}")
|
||||
logger.debug(f"{cls.__name__} not found with {conditions_str}")
|
||||
return []
|
||||
|
||||
@handle_db_timeout
|
||||
|
||||
@@ -1,7 +1,9 @@
|
||||
from typing import Literal
|
||||
from typing import List, Literal, Optional
|
||||
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from letta.schemas.message import Message
|
||||
|
||||
|
||||
class LettaUsageStatistics(BaseModel):
|
||||
"""
|
||||
@@ -19,3 +21,5 @@ class LettaUsageStatistics(BaseModel):
|
||||
prompt_tokens: int = Field(0, description="The number of tokens in the prompt.")
|
||||
total_tokens: int = Field(0, description="The total number of tokens processed by the agent.")
|
||||
step_count: int = Field(0, description="The number of steps taken by the agent.")
|
||||
# TODO: Optional for now. This field makes everyone's lives easier
|
||||
steps_messages: Optional[List[List[Message]]] = Field(None, description="The messages generated per step")
|
||||
|
||||
@@ -26,6 +26,7 @@ from letta.functions.mcp_client.stdio_client import StdioMCPClient
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.helpers.message_helper import prepare_input_message_create
|
||||
|
||||
# TODO use custom interface
|
||||
from letta.interface import AgentInterface # abstract
|
||||
@@ -48,7 +49,7 @@ from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.letta_response import LettaResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ArchivalMemorySummary, ContextWindowOverview, Memory, RecallMemorySummary
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage, PassageUpdate
|
||||
from letta.schemas.providers import (
|
||||
@@ -85,7 +86,6 @@ from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@@ -210,9 +210,6 @@ class SyncServer(Server):
|
||||
self.identity_manager = IdentityManager()
|
||||
self.group_manager = GroupManager()
|
||||
|
||||
# Managers that interface with parallelism
|
||||
self.per_agent_lock_manager = PerAgentLockManager()
|
||||
|
||||
# Make default user and org
|
||||
if init_with_default_org_and_user:
|
||||
self.default_org = self.organization_manager.create_default_organization()
|
||||
@@ -353,21 +350,19 @@ class SyncServer(Server):
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
|
||||
with agent_lock:
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.multi_agent_group:
|
||||
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
if agent_state.multi_agent_group:
|
||||
return self.load_multi_agent(agent_state.multi_agent_group, actor, interface, agent_state)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor, mcp_clients=self.mcp_clients)
|
||||
elif agent_state.agent_type == AgentType.offline_memory_agent:
|
||||
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor, mcp_clients=self.mcp_clients)
|
||||
elif agent_state.agent_type == AgentType.offline_memory_agent:
|
||||
agent = OfflineMemoryAgent(agent_state=agent_state, interface=interface, user=actor)
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
|
||||
|
||||
return agent
|
||||
return agent
|
||||
|
||||
def load_multi_agent(
|
||||
self, group: Group, actor: User, interface: Union[AgentInterface, None] = None, agent_state: Optional[AgentState] = None
|
||||
@@ -702,63 +697,22 @@ class SyncServer(Server):
|
||||
actor: User,
|
||||
agent_id: str,
|
||||
messages: Union[List[MessageCreate], List[Message]],
|
||||
# whether or not to wrap user and system message as MemGPT-style stringified JSON
|
||||
wrap_user_message: bool = True,
|
||||
wrap_system_message: bool = True,
|
||||
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed to getting responses
|
||||
interface: Union[AgentInterface, ChatCompletionsStreamingInterface, None] = None, # needed for responses
|
||||
metadata: Optional[dict] = None, # Pass through metadata to interface
|
||||
put_inner_thoughts_first: bool = True,
|
||||
) -> LettaUsageStatistics:
|
||||
"""Send a list of messages to the agent
|
||||
"""Send a list of messages to the agent.
|
||||
|
||||
If the messages are of type MessageCreate, we need to turn them into
|
||||
Message objects first before sending them through step.
|
||||
|
||||
Otherwise, we can pass them in directly.
|
||||
If messages are of type MessageCreate, convert them to Message objects before sending.
|
||||
"""
|
||||
message_objects: List[Message] = []
|
||||
|
||||
if all(isinstance(m, MessageCreate) for m in messages):
|
||||
for message in messages:
|
||||
assert isinstance(message, MessageCreate)
|
||||
|
||||
# If wrapping is enabled, wrap with metadata before placing content inside the Message object
|
||||
if isinstance(message.content, str):
|
||||
message_content = message.content
|
||||
elif message.content and len(message.content) > 0 and isinstance(message.content[0], TextContent):
|
||||
message_content = message.content[0].text
|
||||
else:
|
||||
assert message_content is not None, "Message content is empty"
|
||||
|
||||
if message.role == MessageRole.user and wrap_user_message:
|
||||
message_content = system.package_user_message(user_message=message_content)
|
||||
elif message.role == MessageRole.system and wrap_system_message:
|
||||
message_content = system.package_system_message(system_message=message_content)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
# Create the Message object
|
||||
message_objects.append(
|
||||
Message(
|
||||
agent_id=agent_id,
|
||||
role=message.role,
|
||||
content=[TextContent(text=message_content)] if message_content else [],
|
||||
name=message.name,
|
||||
# assigned later?
|
||||
model=None,
|
||||
# irrelevant
|
||||
tool_calls=None,
|
||||
tool_call_id=None,
|
||||
)
|
||||
)
|
||||
|
||||
message_objects = [prepare_input_message_create(m, agent_id, wrap_user_message, wrap_system_message) for m in messages]
|
||||
elif all(isinstance(m, Message) for m in messages):
|
||||
for message in messages:
|
||||
assert isinstance(message, Message)
|
||||
message_objects.append(message)
|
||||
|
||||
message_objects = messages
|
||||
else:
|
||||
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(message) for message in messages]}")
|
||||
raise ValueError(f"All messages must be of type Message or MessageCreate, got {[type(m) for m in messages]}")
|
||||
|
||||
# Store metadata in interface if provided
|
||||
if metadata and hasattr(interface, "metadata"):
|
||||
|
||||
@@ -639,7 +639,7 @@ class AgentManager:
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
if len(diff) > 0: # there was a diff
|
||||
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
|
||||
@@ -47,7 +47,7 @@ def retry_until_threshold(threshold=0.5, max_attempts=10, sleep_time_seconds=4):
|
||||
return decorator_retry
|
||||
|
||||
|
||||
def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
||||
def retry_until_success(max_attempts=10, sleep_time_seconds=4, flush_tables_in_between: bool = False):
|
||||
"""
|
||||
Decorator to retry a function until it succeeds or the maximum number of attempts is reached.
|
||||
|
||||
@@ -58,13 +58,25 @@ def retry_until_success(max_attempts=10, sleep_time_seconds=4):
|
||||
def decorator_retry(func):
|
||||
@functools.wraps(func)
|
||||
def wrapper(*args, **kwargs):
|
||||
from letta.orm.base import Base
|
||||
from letta.server.db import db_context
|
||||
|
||||
for attempt in range(1, max_attempts + 1):
|
||||
try:
|
||||
return func(*args, **kwargs)
|
||||
except Exception as e:
|
||||
print(f"\033[93mAttempt {attempt} failed with error:\n{e}\033[0m")
|
||||
|
||||
# Clear tables before retrying
|
||||
if flush_tables_in_between:
|
||||
with db_context() as session:
|
||||
for table in reversed(Base.metadata.sorted_tables): # Reverse to avoid FK issues
|
||||
session.execute(table.delete()) # Truncate table
|
||||
session.commit()
|
||||
|
||||
if attempt == max_attempts:
|
||||
raise
|
||||
|
||||
time.sleep(sleep_time_seconds)
|
||||
|
||||
return wrapper
|
||||
|
||||
@@ -10,12 +10,11 @@ from letta.schemas.letta_message import SystemMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import ChatMemory
|
||||
from letta.schemas.tool import Tool
|
||||
from tests.helpers.utils import retry_until_success
|
||||
from tests.utils import wait_for_incoming_message
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
def truncate_database():
|
||||
from letta.server.db import db_context
|
||||
|
||||
with db_context() as session:
|
||||
@@ -86,7 +85,6 @@ def roll_dice_tool(client):
|
||||
yield tool
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
secret_word = "banana"
|
||||
|
||||
@@ -125,7 +123,6 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj):
|
||||
print(response.messages)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_simple(client):
|
||||
worker_tags_123 = ["worker", "user-123"]
|
||||
worker_tags_456 = ["worker", "user-456"]
|
||||
@@ -141,20 +138,20 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
manager_agent_state = client.create_agent(tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent_state = client.create_agent(name="manager_agent", tool_ids=[send_message_to_agents_matching_tags_tool_id])
|
||||
manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user)
|
||||
|
||||
# Create 3 non-matching worker agents (These should NOT get the message)
|
||||
worker_agents_123 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_123)
|
||||
for idx in range(2):
|
||||
worker_agent_state = client.create_agent(name=f"not_worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_123)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents_123.append(worker_agent)
|
||||
|
||||
# Create 3 worker agents that should get the message
|
||||
worker_agents_456 = []
|
||||
for _ in range(3):
|
||||
worker_agent_state = client.create_agent(include_multi_agent_tools=False, tags=worker_tags_456)
|
||||
for idx in range(2):
|
||||
worker_agent_state = client.create_agent(name=f"worker_{idx}", include_multi_agent_tools=False, tags=worker_tags_456)
|
||||
worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user)
|
||||
worker_agents_456.append(worker_agent)
|
||||
|
||||
@@ -203,7 +200,6 @@ def test_send_message_to_agents_with_tags_simple(client):
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_tool):
|
||||
worker_tags = ["dice-rollers"]
|
||||
|
||||
@@ -252,7 +248,6 @@ def test_send_message_to_agents_with_tags_complex_tool_use(client, roll_dice_too
|
||||
client.delete_agent(agent.agent_state.id)
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
# Create "manager" agent
|
||||
send_message_to_agents_matching_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_tags")
|
||||
@@ -285,7 +280,6 @@ def test_send_message_to_sub_agents_auto_clear_message_buffer(client):
|
||||
assert "banana" in worker_agent_state.memory.compile().lower()
|
||||
|
||||
|
||||
@retry_until_success(max_attempts=3, sleep_time_seconds=2)
|
||||
def test_agents_async_simple(client):
|
||||
"""
|
||||
Test two agents with multi-agent tools sending messages back and forth to count to 5.
|
||||
|
||||
Reference in New Issue
Block a user