diff --git a/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py b/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py new file mode 100644 index 00000000..df3b4278 --- /dev/null +++ b/alembic/versions/dfafcf8210ca_add_model_endpoint_to_steps_table.py @@ -0,0 +1,31 @@ +"""add model endpoint to steps table + +Revision ID: dfafcf8210ca +Revises: f922ca16e42c +Create Date: 2025-02-04 16:45:34.132083 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "dfafcf8210ca" +down_revision: Union[str, None] = "f922ca16e42c" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("steps", sa.Column("model_endpoint", sa.String(), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("steps", "model_endpoint") + # ### end Alembic commands ### diff --git a/letta/agent.py b/letta/agent.py index 4284e86f..bb082fbf 100644 --- a/letta/agent.py +++ b/letta/agent.py @@ -505,8 +505,9 @@ class Agent(BaseAgent): function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool) if sandbox_run_result and sandbox_run_result.status == "error": - error_msg = f"Error calling function {function_name} with args {function_args}: {sandbox_run_result.stderr}" - messages = self._handle_function_error_response(error_msg, tool_call_id, function_name, function_response, messages) + messages = self._handle_function_error_response( + function_response, tool_call_id, function_name, function_response, messages + ) return messages, False, True # force a heartbeat to allow agent to handle error # handle trunction @@ -790,6 +791,7 @@ class Agent(BaseAgent): actor=self.user, provider_name=self.agent_state.llm_config.model_endpoint_type, model=self.agent_state.llm_config.model, + model_endpoint=self.agent_state.llm_config.model_endpoint, context_window_limit=self.agent_state.llm_config.context_window, usage=response.usage, # TODO(@caren): Add full provider support - this line is a workaround for v0 BYOK feature diff --git a/letta/client/client.py b/letta/client/client.py index 413e6b64..485cc6f9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -463,7 +463,7 @@ class RESTClient(AbstractClient): if token: self.headers = {"accept": "application/json", "Authorization": f"Bearer {token}"} elif password: - self.headers = {"accept": "application/json", "X-BARE-PASSWORD": f"password {password}"} + self.headers = {"accept": "application/json", "Authorization": f"Bearer {password}"} else: self.headers = {"accept": "application/json"} if headers: diff --git a/letta/constants.py b/letta/constants.py index ea42306a..1269dc84 100644 --- a/letta/constants.py +++ b/letta/constants.py @@ -53,6 +53,7 @@ BASE_MEMORY_TOOLS = ["core_memory_append", "core_memory_replace"] MULTI_AGENT_TOOLS = ["send_message_to_agent_and_wait_for_reply", "send_message_to_agents_matching_all_tags", "send_message_to_agent_async"] MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES = 3 MULTI_AGENT_SEND_MESSAGE_TIMEOUT = 20 * 60 +MULTI_AGENT_CONCURRENT_SENDS = 15 # The name of the tool used to send message to the user # May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...) diff --git a/letta/functions/function_sets/multi_agent.py b/letta/functions/function_sets/multi_agent.py index ef607713..bd8f7a94 100644 --- a/letta/functions/function_sets/multi_agent.py +++ b/letta/functions/function_sets/multi_agent.py @@ -1,11 +1,13 @@ import asyncio from typing import TYPE_CHECKING, List -from letta.constants import MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT -from letta.functions.helpers import async_send_message_with_retries, execute_send_message_to_agent, fire_and_forget_send_to_agent +from letta.functions.helpers import ( + _send_message_to_agents_matching_all_tags_async, + execute_send_message_to_agent, + fire_and_forget_send_to_agent, +) from letta.schemas.enums import MessageRole from letta.schemas.message import MessageCreate -from letta.server.rest_api.utils import get_letta_server if TYPE_CHECKING: from letta.agent import Agent @@ -22,12 +24,13 @@ def send_message_to_agent_and_wait_for_reply(self: "Agent", message: str, other_ Returns: str: The response from the target agent. """ - message = ( + augmented_message = ( f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " f"{message}" ) - messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] + messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=self.agent_state.name)] + return execute_send_message_to_agent( sender_agent=self, messages=messages, @@ -81,33 +84,4 @@ def send_message_to_agents_matching_all_tags(self: "Agent", message: str, tags: have an entry in the returned list. """ - server = get_letta_server() - - message = ( - f"[Incoming message from agent with ID '{self.agent_state.id}' - to reply to this message, " - f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " - f"{message}" - ) - - # Retrieve agents that match ALL specified tags - matching_agents = server.agent_manager.list_agents(actor=self.user, tags=tags, match_all_tags=True, limit=100) - messages = [MessageCreate(role=MessageRole.system, content=message, name=self.agent_state.name)] - - async def send_messages_to_all_agents(): - tasks = [ - async_send_message_with_retries( - server=server, - sender_agent=self, - target_agent_id=agent_state.id, - messages=messages, - max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, - timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, - logging_prefix="[send_message_to_agents_matching_all_tags]", - ) - for agent_state in matching_agents - ] - # Run all tasks in parallel - return await asyncio.gather(*tasks) - - # Run the async function and return results - return asyncio.run(send_messages_to_all_agents()) + return asyncio.run(_send_message_to_agents_matching_all_tags_async(self, message, tags)) diff --git a/letta/functions/helpers.py b/letta/functions/helpers.py index 8c232cd5..fe179e4a 100644 --- a/letta/functions/helpers.py +++ b/letta/functions/helpers.py @@ -1,5 +1,4 @@ import asyncio -import json import threading from random import uniform from typing import Any, List, Optional, Union @@ -12,13 +11,17 @@ from letta.constants import ( COMPOSIO_ENTITY_ENV_VAR_KEY, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, + MULTI_AGENT_CONCURRENT_SENDS, MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, MULTI_AGENT_SEND_MESSAGE_TIMEOUT, ) +from letta.functions.interface import MultiAgentMessagingInterface from letta.orm.errors import NoResultFound -from letta.schemas.letta_message import AssistantMessage, ReasoningMessage, ToolCallMessage +from letta.schemas.enums import MessageRole +from letta.schemas.letta_message import AssistantMessage from letta.schemas.letta_response import LettaResponse -from letta.schemas.message import MessageCreate +from letta.schemas.message import Message, MessageCreate +from letta.schemas.user import User from letta.server.rest_api.utils import get_letta_server @@ -249,29 +252,48 @@ def generate_import_code(module_attr_map: Optional[dict]): def parse_letta_response_for_assistant_message( target_agent_id: str, letta_response: LettaResponse, - assistant_message_tool_name: str = DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg: str = DEFAULT_MESSAGE_TOOL_KWARG, ) -> Optional[str]: messages = [] - # This is not ideal, but we would like to return something rather than nothing - fallback_reasoning = [] for m in letta_response.messages: if isinstance(m, AssistantMessage): messages.append(m.content) - elif isinstance(m, ToolCallMessage) and m.tool_call.name == assistant_message_tool_name: - try: - messages.append(json.loads(m.tool_call.arguments)[assistant_message_tool_kwarg]) - except Exception: # TODO: Make this more specific - continue - elif isinstance(m, ReasoningMessage): - fallback_reasoning.append(m.reasoning) if messages: messages_str = "\n".join(messages) - return f"Agent {target_agent_id} said: '{messages_str}'" + return f"{target_agent_id} said: '{messages_str}'" else: - messages_str = "\n".join(fallback_reasoning) - return f"Agent {target_agent_id}'s inner thoughts: '{messages_str}'" + return f"No response from {target_agent_id}" + + +async def async_execute_send_message_to_agent( + sender_agent: "Agent", + messages: List[MessageCreate], + other_agent_id: str, + log_prefix: str, +) -> Optional[str]: + """ + Async helper to: + 1) validate the target agent exists & is in the same org, + 2) send a message via async_send_message_with_retries. + """ + server = get_letta_server() + + # 1. Validate target agent + try: + server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) + except NoResultFound: + raise ValueError(f"Target agent {other_agent_id} either does not exist or is not in org " f"({sender_agent.user.organization_id}).") + + # 2. Use your async retry logic + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=other_agent_id, + messages=messages, + max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, + timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, + logging_prefix=log_prefix, + ) def execute_send_message_to_agent( @@ -281,53 +303,43 @@ def execute_send_message_to_agent( log_prefix: str, ) -> Optional[str]: """ - Helper function to send a message to a specific Letta agent. - - Args: - sender_agent ("Agent"): The sender agent object. - message (str): The message to send. - other_agent_id (str): The identifier of the target Letta agent. - log_prefix (str): Logging prefix for retries. - - Returns: - Optional[str]: The response from the Letta agent if required by the caller. + Synchronous wrapper that calls `async_execute_send_message_to_agent` using asyncio.run. + This function must be called from a synchronous context (i.e., no running event loop). """ - server = get_letta_server() + return asyncio.run(async_execute_send_message_to_agent(sender_agent, messages, other_agent_id, log_prefix)) - # Ensure the target agent is in the same org - try: - server.agent_manager.get_agent_by_id(agent_id=other_agent_id, actor=sender_agent.user) - except NoResultFound: - raise ValueError( - f"The passed-in agent_id {other_agent_id} either does not exist, " - f"or does not belong to the same org ({sender_agent.user.organization_id})." - ) - # Async logic to send a message with retries and timeout - async def async_send(): - return await async_send_message_with_retries( - server=server, - sender_agent=sender_agent, - target_agent_id=other_agent_id, - messages=messages, - max_retries=MULTI_AGENT_SEND_MESSAGE_MAX_RETRIES, - timeout=MULTI_AGENT_SEND_MESSAGE_TIMEOUT, - logging_prefix=log_prefix, - ) +async def send_message_to_agent_no_stream( + server: "SyncServer", + agent_id: str, + actor: User, + messages: Union[List[Message], List[MessageCreate]], + metadata: Optional[dict] = None, +) -> LettaResponse: + """ + A simpler helper to send messages to a single agent WITHOUT streaming. + Returns a LettaResponse containing the final messages. + """ + interface = MultiAgentMessagingInterface() + if metadata: + interface.metadata = metadata - # Run in the current event loop or create one if needed - try: - return asyncio.run(async_send()) - except RuntimeError: - loop = asyncio.get_event_loop() - if loop.is_running(): - return loop.run_until_complete(async_send()) - else: - raise + # Offload the synchronous `send_messages` call + usage_stats = await asyncio.to_thread( + server.send_messages, + actor=actor, + agent_id=agent_id, + messages=messages, + interface=interface, + metadata=metadata, + ) + + final_messages = interface.get_captured_send_messages() + return LettaResponse(messages=final_messages, usage=usage_stats) async def async_send_message_with_retries( - server, + server: "SyncServer", sender_agent: "Agent", target_agent_id: str, messages: List[MessageCreate], @@ -335,57 +347,34 @@ async def async_send_message_with_retries( timeout: int, logging_prefix: Optional[str] = None, ) -> str: - """ - Shared helper coroutine to send a message to an agent with retries and a timeout. - Args: - server: The Letta server instance (from get_letta_server()). - sender_agent (Agent): The agent initiating the send action. - target_agent_id (str): The ID of the agent to send the message to. - message_text (str): The text to send as the user message. - max_retries (int): Maximum number of retries for the request. - timeout (int): Maximum time to wait for a response (in seconds). - logging_prefix (str): A prefix to append to logging - Returns: - str: The response or an error message. - """ logging_prefix = logging_prefix or "[async_send_message_with_retries]" for attempt in range(1, max_retries + 1): try: - # Wrap in a timeout response = await asyncio.wait_for( - server.send_message_to_agent( + send_message_to_agent_no_stream( + server=server, agent_id=target_agent_id, actor=sender_agent.user, messages=messages, - stream_steps=False, - stream_tokens=False, - use_assistant_message=True, - assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, ), timeout=timeout, ) - # Extract assistant message - assistant_message = parse_letta_response_for_assistant_message( - target_agent_id, - response, - assistant_message_tool_name=DEFAULT_MESSAGE_TOOL, - assistant_message_tool_kwarg=DEFAULT_MESSAGE_TOOL_KWARG, - ) + # Then parse out the assistant message + assistant_message = parse_letta_response_for_assistant_message(target_agent_id, response) if assistant_message: sender_agent.logger.info(f"{logging_prefix} - {assistant_message}") return assistant_message else: msg = f"(No response from agent {target_agent_id})" sender_agent.logger.info(f"{logging_prefix} - {msg}") - sender_agent.logger.info(f"{logging_prefix} - raw response: {response.model_dump_json(indent=4)}") - sender_agent.logger.info(f"{logging_prefix} - parsed assistant message: {assistant_message}") return msg + except asyncio.TimeoutError: error_msg = f"(Timeout on attempt {attempt}/{max_retries} for agent {target_agent_id})" sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") + except Exception as e: error_msg = f"(Error on attempt {attempt}/{max_retries} for agent {target_agent_id}: {e})" sender_agent.logger.warning(f"{logging_prefix} - {error_msg}") @@ -393,10 +382,10 @@ async def async_send_message_with_retries( # Exponential backoff before retrying if attempt < max_retries: backoff = uniform(0.5, 2) * (2**attempt) - sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent to agent send_message...sleeping for {backoff}") + sender_agent.logger.warning(f"{logging_prefix} - Retrying the agent-to-agent send_message...sleeping for {backoff}") await asyncio.sleep(backoff) else: - sender_agent.logger.error(f"{logging_prefix} - Fatal error during agent to agent send_message: {error_msg}") + sender_agent.logger.error(f"{logging_prefix} - Fatal error: {error_msg}") raise Exception(error_msg) @@ -482,3 +471,43 @@ def fire_and_forget_send_to_agent( except RuntimeError: # Means no event loop is running in this thread run_in_background_thread(background_task()) + + +async def _send_message_to_agents_matching_all_tags_async(sender_agent: "Agent", message: str, tags: List[str]) -> List[str]: + server = get_letta_server() + + augmented_message = ( + f"[Incoming message from agent with ID '{sender_agent.agent_state.id}' - to reply to this message, " + f"make sure to use the 'send_message' at the end, and the system will notify the sender of your response] " + f"{message}" + ) + + # Retrieve up to 100 matching agents + matching_agents = server.agent_manager.list_agents(actor=sender_agent.user, tags=tags, match_all_tags=True, limit=100) + + # Create a system message + messages = [MessageCreate(role=MessageRole.system, content=augmented_message, name=sender_agent.agent_state.name)] + + # Possibly limit concurrency to avoid meltdown: + sem = asyncio.Semaphore(MULTI_AGENT_CONCURRENT_SENDS) + + async def _send_single(agent_state): + async with sem: + return await async_send_message_with_retries( + server=server, + sender_agent=sender_agent, + target_agent_id=agent_state.id, + messages=messages, + max_retries=3, + timeout=30, + ) + + tasks = [asyncio.create_task(_send_single(agent_state)) for agent_state in matching_agents] + results = await asyncio.gather(*tasks, return_exceptions=True) + final = [] + for r in results: + if isinstance(r, Exception): + final.append(str(r)) + else: + final.append(r) + return final diff --git a/letta/functions/interface.py b/letta/functions/interface.py new file mode 100644 index 00000000..82bf229e --- /dev/null +++ b/letta/functions/interface.py @@ -0,0 +1,75 @@ +import json +from typing import List, Optional + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.interface import AgentInterface +from letta.schemas.letta_message import AssistantMessage, LettaMessage +from letta.schemas.message import Message + + +class MultiAgentMessagingInterface(AgentInterface): + """ + A minimal interface that captures *only* calls to the 'send_message' function + by inspecting msg_obj.tool_calls. We parse out the 'message' field from the + JSON function arguments and store it as an AssistantMessage. + """ + + def __init__(self): + self._captured_messages: List[AssistantMessage] = [] + self.metadata = {} + + def internal_monologue(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore internal monologue.""" + + def assistant_message(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore normal assistant messages (only capturing send_message calls).""" + + def function_message(self, msg: str, msg_obj: Optional[Message] = None): + """ + Called whenever the agent logs a function call. We'll inspect msg_obj.tool_calls: + - If tool_calls include a function named 'send_message', parse its arguments + - Extract the 'message' field + - Save it as an AssistantMessage in self._captured_messages + """ + if not msg_obj or not msg_obj.tool_calls: + return + + for tool_call in msg_obj.tool_calls: + if not tool_call.function: + continue + if tool_call.function.name != DEFAULT_MESSAGE_TOOL: + # Skip any other function calls + continue + + # Now parse the JSON in tool_call.function.arguments + func_args_str = tool_call.function.arguments or "" + try: + data = json.loads(func_args_str) + # Extract the 'message' key if present + content = data.get(DEFAULT_MESSAGE_TOOL_KWARG, str(data)) + except json.JSONDecodeError: + # If we can't parse, store the raw string + content = func_args_str + + # Store as an AssistantMessage + new_msg = AssistantMessage( + id=msg_obj.id, + date=msg_obj.created_at, + content=content, + ) + self._captured_messages.append(new_msg) + + def user_message(self, msg: str, msg_obj: Optional[Message] = None): + """Ignore user messages.""" + + def step_complete(self): + """No streaming => no step boundaries.""" + + def step_yield(self): + """No streaming => no final yield needed.""" + + def get_captured_send_messages(self) -> List[LettaMessage]: + """ + Returns only the messages extracted from 'send_message' calls. + """ + return self._captured_messages diff --git a/letta/orm/step.py b/letta/orm/step.py index 8ea5f313..e5c33347 100644 --- a/letta/orm/step.py +++ b/letta/orm/step.py @@ -35,6 +35,7 @@ class Step(SqlalchemyBase): ) provider_name: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the provider used for this step.") model: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The name of the model used for this step.") + model_endpoint: Mapped[Optional[str]] = mapped_column(None, nullable=True, doc="The model endpoint url used for this step.") context_window_limit: Mapped[Optional[int]] = mapped_column( None, nullable=True, doc="The context window limit configured for this step." ) diff --git a/letta/schemas/step.py b/letta/schemas/step.py index c3482878..98bc51c7 100644 --- a/letta/schemas/step.py +++ b/letta/schemas/step.py @@ -20,6 +20,7 @@ class Step(StepBase): ) provider_name: Optional[str] = Field(None, description="The name of the provider used for this step.") model: Optional[str] = Field(None, description="The name of the model used for this step.") + model_endpoint: Optional[str] = Field(None, description="The model endpoint url used for this step.") context_window_limit: Optional[int] = Field(None, description="The context window limit configured for this step.") completion_tokens: Optional[int] = Field(None, description="The number of tokens generated by the agent during this step.") prompt_tokens: Optional[int] = Field(None, description="The number of tokens in the prompt during this step.") diff --git a/letta/server/rest_api/interface.py b/letta/server/rest_api/interface.py index ded9d749..a9e617f7 100644 --- a/letta/server/rest_api/interface.py +++ b/letta/server/rest_api/interface.py @@ -315,7 +315,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): # extra prints self.debug = False - self.timeout = 30 + self.timeout = 10 * 60 # 10 minute timeout def _reset_inner_thoughts_json_reader(self): # A buffer for accumulating function arguments (we want to buffer keys and run checks on each one) @@ -330,7 +330,7 @@ class StreamingServerInterface(AgentChunkStreamingInterface): while self._active: try: # Wait until there is an item in the deque or the stream is deactivated - await asyncio.wait_for(self._event.wait(), timeout=self.timeout) # 30 second timeout + await asyncio.wait_for(self._event.wait(), timeout=self.timeout) except asyncio.TimeoutError: break # Exit the loop if we timeout diff --git a/letta/services/step_manager.py b/letta/services/step_manager.py index 49dbf316..a316eda6 100644 --- a/letta/services/step_manager.py +++ b/letta/services/step_manager.py @@ -55,6 +55,7 @@ class StepManager: actor: PydanticUser, provider_name: str, model: str, + model_endpoint: Optional[str], context_window_limit: int, usage: UsageStatistics, provider_id: Optional[str] = None, @@ -66,6 +67,7 @@ class StepManager: "provider_id": provider_id, "provider_name": provider_name, "model": model, + "model_endpoint": model_endpoint, "context_window_limit": context_window_limit, "completion_tokens": usage.completion_tokens, "prompt_tokens": usage.prompt_tokens, diff --git a/tests/manual_test_multi_agent_broadcast_large.py b/tests/manual_test_multi_agent_broadcast_large.py new file mode 100644 index 00000000..4adcfa07 --- /dev/null +++ b/tests/manual_test_multi_agent_broadcast_large.py @@ -0,0 +1,91 @@ +import json +import os + +import pytest +from tqdm import tqdm + +from letta import create_client +from letta.functions.functions import derive_openai_json_schema, parse_source_code +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.tool import Tool +from tests.integration_test_summarizer import LLM_CONFIG_DIR + + +@pytest.fixture(scope="function") +def client(): + filename = os.path.join(LLM_CONFIG_DIR, "claude-3-5-haiku.json") + config_data = json.load(open(filename, "r")) + llm_config = LLMConfig(**config_data) + client = create_client() + client.set_default_llm_config(llm_config) + client.set_default_embedding_config(EmbeddingConfig.default_config(provider="openai")) + + yield client + + +@pytest.fixture +def roll_dice_tool(client): + def roll_dice(): + """ + Rolls a 6 sided die. + + Returns: + str: The roll result. + """ + return "Rolled a 5!" + + # Set up tool details + source_code = parse_source_code(roll_dice) + source_type = "python" + description = "test_description" + tags = ["test"] + + tool = Tool(description=description, tags=tags, source_code=source_code, source_type=source_type) + derived_json_schema = derive_openai_json_schema(source_code=tool.source_code, name=tool.name) + + derived_name = derived_json_schema["name"] + tool.json_schema = derived_json_schema + tool.name = derived_name + + tool = client.server.tool_manager.create_or_update_tool(tool, actor=client.user) + + # Yield the created tool + yield tool + + +def test_multi_agent_large(client, roll_dice_tool): + manager_tags = ["manager"] + worker_tags = ["helpers"] + + # Clean up first from possibly failed tests + prev_worker_agents = client.server.agent_manager.list_agents(client.user, tags=worker_tags + manager_tags, match_all_tags=True) + for agent in prev_worker_agents: + client.delete_agent(agent.id) + + # Create "manager" agent + send_message_to_agents_matching_all_tags_tool_id = client.get_tool_id(name="send_message_to_agents_matching_all_tags") + manager_agent_state = client.create_agent( + name="manager", tool_ids=[send_message_to_agents_matching_all_tags_tool_id], tags=manager_tags + ) + manager_agent = client.server.load_agent(agent_id=manager_agent_state.id, actor=client.user) + + # Create 3 worker agents + worker_agents = [] + num_workers = 50 + for idx in tqdm(range(num_workers)): + worker_agent_state = client.create_agent( + name=f"worker-{idx}", include_multi_agent_tools=False, tags=worker_tags, tool_ids=[roll_dice_tool.id] + ) + worker_agent = client.server.load_agent(agent_id=worker_agent_state.id, actor=client.user) + worker_agents.append(worker_agent) + + # Encourage the manager to send a message to the other agent_obj with the secret string + broadcast_message = f"Send a message to all agents with tags {worker_tags} asking them to roll a dice for you!" + client.send_message( + agent_id=manager_agent.agent_state.id, + role="user", + message=broadcast_message, + ) + + # Please manually inspect the agent results diff --git a/tests/test_base_functions.py b/tests/test_base_functions.py index 30ba8ab6..b5ce5104 100644 --- a/tests/test_base_functions.py +++ b/tests/test_base_functions.py @@ -173,7 +173,7 @@ def test_send_message_to_agent(client, agent_obj, other_agent_obj): # Search the sender agent for the response from another agent in_context_messages = agent_obj.agent_manager.get_in_context_messages(agent_id=agent_obj.agent_state.id, actor=agent_obj.user) found = False - target_snippet = f"Agent {other_agent_obj.agent_state.id} said:" + target_snippet = f"{other_agent_obj.agent_state.id} said:" for m in in_context_messages: if target_snippet in m.text: diff --git a/tests/test_client.py b/tests/test_client.py index c9cfae4a..b727f77a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -458,16 +458,16 @@ def test_function_return_limit(client: Union[LocalClient, RESTClient]): def test_function_always_error(client: Union[LocalClient, RESTClient]): """Test to see if function that errors works correctly""" - def always_error(): + def testing_method(): """ Always throw an error. """ return 5 / 0 - tool = client.create_or_update_tool(func=always_error) + tool = client.create_or_update_tool(func=testing_method) agent = client.create_agent(tool_ids=[tool.id]) # get function response - response = client.send_message(agent_id=agent.id, message="call the always_error function", role="user") + response = client.send_message(agent_id=agent.id, message="call the testing_method function and tell me the result", role="user") print(response.messages) response_message = None @@ -480,14 +480,11 @@ def test_function_always_error(client: Union[LocalClient, RESTClient]): assert response_message.status == "error" if isinstance(client, RESTClient): - assert ( - response_message.tool_return.startswith("Error calling function always_error") - and "ZeroDivisionError" in response_message.tool_return - ) + assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" else: response_json = json.loads(response_message.tool_return) assert response_json["status"] == "Failed" - assert "Error calling function always_error" in response_json["message"] and "ZeroDivisionError" in response_json["message"] + assert response_json["message"] == "Error executing function testing_method: ZeroDivisionError: division by zero" client.delete_agent(agent_id=agent.id) diff --git a/tests/test_managers.py b/tests/test_managers.py index a4d8adce..43ffbaa7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -3128,6 +3128,7 @@ def test_job_usage_stats_add_and_get(server: SyncServer, default_job, default_us step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3169,6 +3170,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3183,6 +3185,7 @@ def test_job_usage_stats_add_multiple(server: SyncServer, default_job, default_u step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id=default_job.id, usage=UsageStatistics( @@ -3219,6 +3222,7 @@ def test_job_usage_stats_add_nonexistent_job(server: SyncServer, default_user): step_manager.log_step( provider_name="openai", model="gpt-4", + model_endpoint="https://api.openai.com/v1", context_window_limit=8192, job_id="nonexistent_job", usage=UsageStatistics( diff --git a/tests/test_sdk_client.py b/tests/test_sdk_client.py index f01f431e..7c2325bd 100644 --- a/tests/test_sdk_client.py +++ b/tests/test_sdk_client.py @@ -410,13 +410,13 @@ def test_function_return_limit(client: LettaSDKClient, agent: AgentState): def test_function_always_error(client: LettaSDKClient, agent: AgentState): """Test to see if function that errors works correctly""" - def always_error(): + def testing_method(): """ - Always throw an error. + A method that has test functionalit. """ return 5 / 0 - tool = client.tools.upsert_from_function(func=always_error, return_char_limit=1000) + tool = client.tools.upsert_from_function(func=testing_method, return_char_limit=1000) client.agents.tools.attach(agent_id=agent.id, tool_id=tool.id) @@ -426,10 +426,9 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState): messages=[ MessageCreate( role="user", - content="call the always_error function", + content="call the testing_method function and tell me the result", ), ], - use_assistant_message=False, ) response_message = None @@ -441,10 +440,7 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState): assert response_message, "ToolReturnMessage message not found in response" assert response_message.status == "error" - # TODO try and get this format back, need to fix e2b return parsing - # assert response_message.tool_return == "Error executing function always_error: ZeroDivisionError: division by zero" - - assert response_message.tool_return.startswith("Error calling function always_error") + assert response_message.tool_return == "Error executing function testing_method: ZeroDivisionError: division by zero" assert "ZeroDivisionError" in response_message.tool_return