feat: Refactor letta tool execution to not require agent class (#1384)
This commit is contained in:
122
letta/agent.py
122
letta/agent.py
@@ -3,21 +3,28 @@ import time
|
||||
import traceback
|
||||
import warnings
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Dict, List, Optional, Tuple, Union
|
||||
from typing import Any, Dict, List, Optional, Tuple, Union
|
||||
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
ERROR_MESSAGE_PREFIX,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
FUNC_FAILED_HEARTBEAT_MESSAGE,
|
||||
LETTA_CORE_TOOL_MODULE_NAME,
|
||||
LETTA_MULTI_AGENT_TOOL_MODULE_NAME,
|
||||
LLM_MAX_TOKENS,
|
||||
REQ_HEARTBEAT_MESSAGE,
|
||||
)
|
||||
from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
from letta.helpers.json_helpers import json_dumps, json_loads
|
||||
from letta.interface import AgentInterface
|
||||
@@ -28,6 +35,7 @@ from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_mes
|
||||
from letta.log import get_logger
|
||||
from letta.memory import summarize_messages
|
||||
from letta.orm import User
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse, UpdateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
@@ -38,6 +46,8 @@ from letta.schemas.message import Message, ToolReturn
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
from letta.schemas.openai.chat_completion_response import Message as ChatCompletionMessage
|
||||
from letta.schemas.openai.chat_completion_response import UsageStatistics
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.tool_rule import TerminalToolRule
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -48,7 +58,7 @@ from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
from letta.services.step_manager import StepManager
|
||||
from letta.services.tool_executor.tool_executor import ToolExecutionManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import summarizer_settings
|
||||
from letta.streaming_interface import StreamingRefreshCLIInterface
|
||||
@@ -502,12 +512,7 @@ class Agent(BaseAgent):
|
||||
},
|
||||
)
|
||||
|
||||
# TODO: Make this at the __init__ level
|
||||
# TODO: Add refresh agent_state logic to ToolExecutionManager, either by passing in or retreiving from db
|
||||
tool_execution_manager = ToolExecutionManager(self)
|
||||
function_response, sandbox_run_result = tool_execution_manager.execute_tool(
|
||||
function_name=function_name, function_args=function_args, tool=target_letta_tool
|
||||
)
|
||||
function_response, sandbox_run_result = self.execute_tool_and_persist_state(function_name, function_args, target_letta_tool)
|
||||
|
||||
log_event(
|
||||
"tool_call_ended",
|
||||
@@ -1199,6 +1204,107 @@ class Agent(BaseAgent):
|
||||
context_window_breakdown = self.get_context_window()
|
||||
return context_window_breakdown.context_window_size_current
|
||||
|
||||
# TODO: Refactor into separate class v.s. large if/elses here
|
||||
def execute_tool_and_persist_state(
|
||||
self, function_name: str, function_args: dict, target_letta_tool: Tool
|
||||
) -> tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute tool modifications and persist the state of the agent.
|
||||
Note: only some agent state modifications will be persisted, such as data in the AgentState ORM and block data
|
||||
"""
|
||||
# TODO: add agent manager here
|
||||
orig_memory_str = self.agent_state.memory.compile()
|
||||
|
||||
# TODO: need to have an AgentState object that actually has full access to the block data
|
||||
# this is because the sandbox tools need to be able to access block.value to edit this data
|
||||
try:
|
||||
if target_letta_tool.tool_type == ToolType.LETTA_CORE:
|
||||
# base tools are allowed to access the `Agent` object and run on the database
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
elif target_letta_tool.tool_type == ToolType.LETTA_MULTI_AGENT_CORE:
|
||||
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = self # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
elif target_letta_tool.tool_type == ToolType.LETTA_MEMORY_CORE:
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
agent_state_copy = self.agent_state.__deepcopy__()
|
||||
function_args["agent_state"] = agent_state_copy # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
self.update_memory_if_changed(agent_state_copy.memory)
|
||||
elif target_letta_tool.tool_type == ToolType.EXTERNAL_COMPOSIO:
|
||||
action_name = generate_composio_action_from_func_name(target_letta_tool.name)
|
||||
# Get entity ID from the agent_state
|
||||
entity_id = None
|
||||
for env_var in self.agent_state.tool_exec_environment_variables:
|
||||
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
|
||||
entity_id = env_var.value
|
||||
# Get composio_api_key
|
||||
composio_api_key = get_composio_api_key(actor=self.user, logger=self.logger)
|
||||
function_response = execute_composio_action(
|
||||
action_name=action_name, args=function_args, api_key=composio_api_key, entity_id=entity_id
|
||||
)
|
||||
elif target_letta_tool.tool_type == ToolType.EXTERNAL_MCP:
|
||||
# Get the server name from the tool tag
|
||||
# TODO make a property instead?
|
||||
server_name = target_letta_tool.tags[0].split(":")[1]
|
||||
|
||||
# Get the MCPClient from the server's handle
|
||||
# TODO these don't get raised properly
|
||||
if not self.mcp_clients:
|
||||
raise ValueError(f"No MCP client available to use")
|
||||
if server_name not in self.mcp_clients:
|
||||
raise ValueError(f"Unknown MCP server name: {server_name}")
|
||||
mcp_client = self.mcp_clients[server_name]
|
||||
if not isinstance(mcp_client, BaseMCPClient):
|
||||
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
||||
|
||||
# Check that tool exists
|
||||
available_tools = mcp_client.list_tools()
|
||||
available_tool_names = [t.name for t in available_tools]
|
||||
if function_name not in available_tool_names:
|
||||
raise ValueError(
|
||||
f"{function_name} is not available in MCP server {server_name}. Please check your `~/.letta/mcp_config.json` file."
|
||||
)
|
||||
|
||||
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
return function_response, sandbox_run_result
|
||||
else:
|
||||
try:
|
||||
# Parse the source code to extract function annotations
|
||||
annotations = get_function_annotations_from_source(target_letta_tool.source_code, function_name)
|
||||
# Coerce the function arguments to the correct types based on the annotations
|
||||
function_args = coerce_dict_args_by_annotations(function_args, annotations)
|
||||
except ValueError as e:
|
||||
self.logger.debug(f"Error coercing function arguments: {e}")
|
||||
|
||||
# execute tool in a sandbox
|
||||
# TODO: allow agent_state to specify which sandbox to execute tools in
|
||||
# TODO: This is only temporary, can remove after we publish a pip package with this object
|
||||
agent_state_copy = self.agent_state.__deepcopy__()
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, self.user, tool_object=target_letta_tool).run(
|
||||
agent_state=agent_state_copy
|
||||
)
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
assert orig_memory_str == self.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
if updated_agent_state is not None:
|
||||
self.update_memory_if_changed(updated_agent_state.memory)
|
||||
return function_response, sandbox_run_result
|
||||
except Exception as e:
|
||||
# Need to catch error here, or else trunction wont happen
|
||||
# TODO: modify to function execution error
|
||||
function_response = get_friendly_error_msg(
|
||||
function_name=function_name, exception_name=type(e).__name__, exception_message=str(e)
|
||||
)
|
||||
return function_response, SandboxRunResult(status="error")
|
||||
|
||||
return function_response, None
|
||||
|
||||
|
||||
def save_agent(agent: Agent):
|
||||
"""Save agent to metadata store"""
|
||||
|
||||
@@ -744,6 +744,8 @@ class AgentManager:
|
||||
Update internal memory object and system prompt if there have been modifications.
|
||||
|
||||
Args:
|
||||
actor:
|
||||
agent_id:
|
||||
new_memory (Memory): the new memory object to compare to the current memory object
|
||||
|
||||
Returns:
|
||||
|
||||
74
letta/services/tool_executor/tool_execution_manager.py
Normal file
74
letta/services/tool_executor/tool_execution_manager.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.tool_executor.tool_executor import (
|
||||
ExternalComposioToolExecutor,
|
||||
ExternalMCPToolExecutor,
|
||||
LettaCoreToolExecutor,
|
||||
LettaMemoryToolExecutor,
|
||||
LettaMultiAgentToolExecutor,
|
||||
SandboxToolExecutor,
|
||||
ToolExecutor,
|
||||
)
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
|
||||
class ToolExecutorFactory:
|
||||
"""Factory for creating appropriate tool executors based on tool type."""
|
||||
|
||||
_executor_map: Dict[ToolType, Type[ToolExecutor]] = {
|
||||
ToolType.LETTA_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
|
||||
ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor,
|
||||
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
|
||||
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
|
||||
"""Get the appropriate executor for the given tool type."""
|
||||
executor_class = cls._executor_map.get(tool_type)
|
||||
|
||||
if executor_class:
|
||||
return executor_class()
|
||||
|
||||
# Default to sandbox executor for unknown types
|
||||
return SandboxToolExecutor()
|
||||
|
||||
|
||||
class ToolExecutionManager:
|
||||
"""Manager class for tool execution operations."""
|
||||
|
||||
def __init__(self, agent_state: AgentState, actor: User):
|
||||
self.agent_state = agent_state
|
||||
self.logger = get_logger(__name__)
|
||||
self.actor = actor
|
||||
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute a tool and persist any state changes.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to execute
|
||||
function_args: Arguments to pass to the function
|
||||
tool: Tool object containing metadata about the tool
|
||||
|
||||
Returns:
|
||||
Tuple containing the function response and sandbox run result (if applicable)
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
|
||||
# Execute the tool
|
||||
return executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
@@ -1,15 +1,19 @@
|
||||
import math
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import Any, Dict, Optional, Tuple, Type
|
||||
from typing import Any, Optional, Tuple
|
||||
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, LETTA_CORE_TOOL_MODULE_NAME, LETTA_MULTI_AGENT_TOOL_MODULE_NAME
|
||||
from letta.constants import COMPOSIO_ENTITY_ENV_VAR_KEY, RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.helpers.json_helpers import json_dumps
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.sandbox_config import SandboxRunResult
|
||||
from letta.schemas.tool import Tool
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.utils import get_friendly_error_msg
|
||||
|
||||
@@ -18,53 +22,234 @@ class ToolExecutor(ABC):
|
||||
"""Abstract base class for tool executors."""
|
||||
|
||||
@abstractmethod
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""Execute the tool and return the result."""
|
||||
|
||||
|
||||
class LettaCoreToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA core tools."""
|
||||
"""Executor for LETTA core tools with direct implementation of functions."""
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"send_message": self.send_message,
|
||||
"conversation_search": self.conversation_search,
|
||||
"archival_memory_search": self.archival_memory_search,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
|
||||
# Execute the appropriate function
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = function_map[function_name](agent_state, actor, **function_args_copy)
|
||||
return function_response, None
|
||||
|
||||
def send_message(self, agent_state: AgentState, actor: User, message: str) -> Optional[str]:
|
||||
"""
|
||||
Sends a message to the human user.
|
||||
|
||||
Args:
|
||||
message (str): Message contents. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
return None
|
||||
|
||||
def conversation_search(self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (int): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
messages = MessageManager().list_user_messages_for_agent(
|
||||
agent_id=agent_state.id,
|
||||
actor=actor,
|
||||
query_text=query,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
total = len(messages)
|
||||
num_pages = math.ceil(total / count) - 1 # 0 index
|
||||
|
||||
if len(messages) == 0:
|
||||
results_str = f"No results found."
|
||||
else:
|
||||
results_pref = f"Showing {len(messages)} of {total} results (page {page}/{num_pages}):"
|
||||
results_formatted = [message.content[0].text for message in messages]
|
||||
results_str = f"{results_pref} {json_dumps(results_formatted)}"
|
||||
|
||||
return results_str
|
||||
|
||||
def archival_memory_search(
|
||||
self, agent_state: AgentState, actor: User, query: str, page: Optional[int] = 0, start: Optional[int] = 0
|
||||
) -> Optional[str]:
|
||||
"""
|
||||
Search archival memory using semantic (embedding-based) search.
|
||||
|
||||
Args:
|
||||
query (str): String to search for.
|
||||
page (Optional[int]): Allows you to page through results. Only use on a follow-up query. Defaults to 0 (first page).
|
||||
start (Optional[int]): Starting index for the search results. Defaults to 0.
|
||||
|
||||
Returns:
|
||||
str: Query result string
|
||||
"""
|
||||
if page is None or (isinstance(page, str) and page.lower().strip() == "none"):
|
||||
page = 0
|
||||
try:
|
||||
page = int(page)
|
||||
except:
|
||||
raise ValueError(f"'page' argument must be an integer")
|
||||
|
||||
count = RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE
|
||||
|
||||
try:
|
||||
# Get results using passage manager
|
||||
all_results = AgentManager().list_passages(
|
||||
actor=actor,
|
||||
agent_id=agent_state.id,
|
||||
query_text=query,
|
||||
limit=count + start, # Request enough results to handle offset
|
||||
embedding_config=agent_state.embedding_config,
|
||||
embed_query=True,
|
||||
)
|
||||
|
||||
# Apply pagination
|
||||
end = min(count + start, len(all_results))
|
||||
paged_results = all_results[start:end]
|
||||
|
||||
# Format results to match previous implementation
|
||||
formatted_results = [{"timestamp": str(result.created_at), "content": result.text} for result in paged_results]
|
||||
|
||||
return formatted_results, len(formatted_results)
|
||||
|
||||
except Exception as e:
|
||||
raise e
|
||||
|
||||
def archival_memory_insert(self, agent_state: AgentState, actor: User, content: str) -> Optional[str]:
|
||||
"""
|
||||
Add to archival memory. Make sure to phrase the memory contents such that it can be easily queried later.
|
||||
|
||||
Args:
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
PassageManager().insert_passage(
|
||||
agent_state=agent_state,
|
||||
agent_id=agent_state.id,
|
||||
text=content,
|
||||
actor=actor,
|
||||
)
|
||||
AgentManager().rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True)
|
||||
return None
|
||||
|
||||
|
||||
class LettaMultiAgentToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA multi-agent core tools."""
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
function_response = callable_func(**function_args)
|
||||
return function_response, None
|
||||
# TODO: Implement
|
||||
# def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# callable_func = get_function_from_module(LETTA_MULTI_AGENT_TOOL_MODULE_NAME, function_name)
|
||||
# function_args["self"] = agent # need to attach self to arg since it's dynamically linked
|
||||
# function_response = callable_func(**function_args)
|
||||
# return function_response, None
|
||||
|
||||
|
||||
class LettaMemoryToolExecutor(ToolExecutor):
|
||||
"""Executor for LETTA memory core tools."""
|
||||
"""Executor for LETTA memory core tools with direct implementation."""
|
||||
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Map function names to method calls
|
||||
function_map = {
|
||||
"core_memory_append": self.core_memory_append,
|
||||
"core_memory_replace": self.core_memory_replace,
|
||||
}
|
||||
|
||||
if function_name not in function_map:
|
||||
raise ValueError(f"Unknown function: {function_name}")
|
||||
|
||||
# Execute the appropriate function with the copied state
|
||||
function_args_copy = function_args.copy() # Make a copy to avoid modifying the original
|
||||
function_response = function_map[function_name](agent_state, **function_args_copy)
|
||||
|
||||
# Update memory if changed
|
||||
AgentManager().update_memory_if_changed(agent_id=agent_state.id, new_memory=agent_state.memory, actor=actor)
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
callable_func = get_function_from_module(LETTA_CORE_TOOL_MODULE_NAME, function_name)
|
||||
agent_state_copy = agent.agent_state.__deepcopy__()
|
||||
function_args["agent_state"] = agent_state_copy
|
||||
function_response = callable_func(**function_args)
|
||||
agent.update_memory_if_changed(agent_state_copy.memory)
|
||||
return function_response, None
|
||||
|
||||
def core_memory_append(self, agent_state: "AgentState", label: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
new_value = current_value + "\n" + str(content)
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, agent_state: "AgentState", label: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
label (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
current_value = str(agent_state.memory.get_block(label).value)
|
||||
if old_content not in current_value:
|
||||
raise ValueError(f"Old content '{old_content}' not found in memory block '{label}'")
|
||||
new_value = current_value.replace(str(old_content), str(new_content))
|
||||
agent_state.memory.update_block_value(label=label, value=new_value)
|
||||
return None
|
||||
|
||||
|
||||
class ExternalComposioToolExecutor(ToolExecutor):
|
||||
"""Executor for external Composio tools."""
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
action_name = generate_composio_action_from_func_name(tool.name)
|
||||
|
||||
# Get entity ID from the agent_state
|
||||
entity_id = self._get_entity_id(agent)
|
||||
entity_id = self._get_entity_id(agent_state)
|
||||
|
||||
# Get composio_api_key
|
||||
composio_api_key = get_composio_api_key(actor=agent.user, logger=agent.logger)
|
||||
composio_api_key = get_composio_api_key(actor=actor)
|
||||
|
||||
# TODO (matt): Roll in execute_composio_action into this class
|
||||
function_response = execute_composio_action(
|
||||
@@ -73,9 +258,9 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
|
||||
return function_response, None
|
||||
|
||||
def _get_entity_id(self, agent: "Agent") -> Optional[str]:
|
||||
def _get_entity_id(self, agent_state: AgentState) -> Optional[str]:
|
||||
"""Extract the entity ID from environment variables."""
|
||||
for env_var in agent.agent_state.tool_exec_environment_variables:
|
||||
for env_var in agent_state.tool_exec_environment_variables:
|
||||
if env_var.key == COMPOSIO_ENTITY_ENV_VAR_KEY:
|
||||
return env_var.value
|
||||
return None
|
||||
@@ -84,78 +269,83 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
class ExternalMCPToolExecutor(ToolExecutor):
|
||||
"""Executor for external MCP tools."""
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Get the server name from the tool tag
|
||||
server_name = self._extract_server_name(tool)
|
||||
|
||||
# Get the MCPClient
|
||||
mcp_client = self._get_mcp_client(agent, server_name)
|
||||
|
||||
# Validate tool exists
|
||||
self._validate_tool_exists(mcp_client, function_name, server_name)
|
||||
|
||||
# Execute the tool
|
||||
function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
|
||||
sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
return function_response, sandbox_run_result
|
||||
|
||||
def _extract_server_name(self, tool: Tool) -> str:
|
||||
"""Extract server name from tool tags."""
|
||||
return tool.tags[0].split(":")[1]
|
||||
|
||||
def _get_mcp_client(self, agent: "Agent", server_name: str):
|
||||
"""Get the MCP client for the given server name."""
|
||||
if not agent.mcp_clients:
|
||||
raise ValueError("No MCP client available to use")
|
||||
|
||||
if server_name not in agent.mcp_clients:
|
||||
raise ValueError(f"Unknown MCP server name: {server_name}")
|
||||
|
||||
mcp_client = agent.mcp_clients[server_name]
|
||||
if not isinstance(mcp_client, BaseMCPClient):
|
||||
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
||||
|
||||
return mcp_client
|
||||
|
||||
def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str):
|
||||
"""Validate that the tool exists in the MCP server."""
|
||||
available_tools = mcp_client.list_tools()
|
||||
available_tool_names = [t.name for t in available_tools]
|
||||
|
||||
if function_name not in available_tool_names:
|
||||
raise ValueError(
|
||||
f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file."
|
||||
)
|
||||
# TODO: Implement
|
||||
#
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> Tuple[
|
||||
# Any, Optional[SandboxRunResult]]:
|
||||
# # Get the server name from the tool tag
|
||||
# server_name = self._extract_server_name(tool)
|
||||
#
|
||||
# # Get the MCPClient
|
||||
# mcp_client = self._get_mcp_client(agent, server_name)
|
||||
#
|
||||
# # Validate tool exists
|
||||
# self._validate_tool_exists(mcp_client, function_name, server_name)
|
||||
#
|
||||
# # Execute the tool
|
||||
# function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
#
|
||||
# sandbox_run_result = SandboxRunResult(status="error" if is_error else "success")
|
||||
# return function_response, sandbox_run_result
|
||||
#
|
||||
# def _extract_server_name(self, tool: Tool) -> str:
|
||||
# """Extract server name from tool tags."""
|
||||
# return tool.tags[0].split(":")[1]
|
||||
#
|
||||
# def _get_mcp_client(self, agent: "Agent", server_name: str):
|
||||
# """Get the MCP client for the given server name."""
|
||||
# if not agent.mcp_clients:
|
||||
# raise ValueError("No MCP client available to use")
|
||||
#
|
||||
# if server_name not in agent.mcp_clients:
|
||||
# raise ValueError(f"Unknown MCP server name: {server_name}")
|
||||
#
|
||||
# mcp_client = agent.mcp_clients[server_name]
|
||||
# if not isinstance(mcp_client, BaseMCPClient):
|
||||
# raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
||||
#
|
||||
# return mcp_client
|
||||
#
|
||||
# def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str):
|
||||
# """Validate that the tool exists in the MCP server."""
|
||||
# available_tools = mcp_client.list_tools()
|
||||
# available_tool_names = [t.name for t in available_tools]
|
||||
#
|
||||
# if function_name not in available_tool_names:
|
||||
# raise ValueError(
|
||||
# f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file."
|
||||
# )
|
||||
|
||||
|
||||
class SandboxToolExecutor(ToolExecutor):
|
||||
"""Executor for sandboxed tools."""
|
||||
|
||||
def execute(self, function_name: str, function_args: dict, agent: "Agent", tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
def execute(
|
||||
self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User
|
||||
) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
# Store original memory state
|
||||
orig_memory_str = agent.agent_state.memory.compile()
|
||||
orig_memory_str = agent_state.memory.compile()
|
||||
|
||||
try:
|
||||
# Prepare function arguments
|
||||
function_args = self._prepare_function_args(function_args, tool, function_name)
|
||||
|
||||
# Create agent state copy for sandbox
|
||||
agent_state_copy = self._create_agent_state_copy(agent)
|
||||
agent_state_copy = self._create_agent_state_copy(agent_state)
|
||||
|
||||
# Execute in sandbox
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, agent.user, tool_object=tool).run(
|
||||
sandbox_run_result = ToolExecutionSandbox(function_name, function_args, actor, tool_object=tool).run(
|
||||
agent_state=agent_state_copy
|
||||
)
|
||||
|
||||
function_response, updated_agent_state = sandbox_run_result.func_return, sandbox_run_result.agent_state
|
||||
|
||||
# Verify memory integrity
|
||||
assert orig_memory_str == agent.agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
if updated_agent_state is not None:
|
||||
agent.update_memory_if_changed(updated_agent_state.memory)
|
||||
AgentManager().update_memory_if_changed(agent_state.id, updated_agent_state.memory, actor)
|
||||
|
||||
return function_response, sandbox_run_result
|
||||
|
||||
@@ -174,9 +364,9 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
# This is defensive programming - we try to coerce but fall back if it fails
|
||||
return function_args
|
||||
|
||||
def _create_agent_state_copy(self, agent: "Agent"):
|
||||
def _create_agent_state_copy(self, agent_state: AgentState):
|
||||
"""Create a copy of agent state for sandbox execution."""
|
||||
agent_state_copy = agent.agent_state.__deepcopy__()
|
||||
agent_state_copy = agent_state.__deepcopy__()
|
||||
# Remove tools from copy to prevent nested tool execution
|
||||
agent_state_copy.tools = []
|
||||
agent_state_copy.tool_rules = []
|
||||
@@ -188,58 +378,3 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
function_name=function_name, exception_name=type(exception).__name__, exception_message=str(exception)
|
||||
)
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
|
||||
|
||||
class ToolExecutorFactory:
|
||||
"""Factory for creating appropriate tool executors based on tool type."""
|
||||
|
||||
_executor_map: Dict[ToolType, Type[ToolExecutor]] = {
|
||||
ToolType.LETTA_CORE: LettaCoreToolExecutor,
|
||||
ToolType.LETTA_MULTI_AGENT_CORE: LettaMultiAgentToolExecutor,
|
||||
ToolType.LETTA_MEMORY_CORE: LettaMemoryToolExecutor,
|
||||
ToolType.EXTERNAL_COMPOSIO: ExternalComposioToolExecutor,
|
||||
ToolType.EXTERNAL_MCP: ExternalMCPToolExecutor,
|
||||
}
|
||||
|
||||
@classmethod
|
||||
def get_executor(cls, tool_type: ToolType) -> ToolExecutor:
|
||||
"""Get the appropriate executor for the given tool type."""
|
||||
executor_class = cls._executor_map.get(tool_type)
|
||||
|
||||
if executor_class:
|
||||
return executor_class()
|
||||
|
||||
# Default to sandbox executor for unknown types
|
||||
return SandboxToolExecutor()
|
||||
|
||||
|
||||
class ToolExecutionManager:
|
||||
"""Manager class for tool execution operations."""
|
||||
|
||||
def __init__(self, agent: "Agent"):
|
||||
self.agent = agent
|
||||
self.logger = agent.logger
|
||||
|
||||
def execute_tool(self, function_name: str, function_args: dict, tool: Tool) -> Tuple[Any, Optional[SandboxRunResult]]:
|
||||
"""
|
||||
Execute a tool and persist any state changes.
|
||||
|
||||
Args:
|
||||
function_name: Name of the function to execute
|
||||
function_args: Arguments to pass to the function
|
||||
tool: Tool object containing metadata about the tool
|
||||
|
||||
Returns:
|
||||
Tuple containing the function response and sandbox run result (if applicable)
|
||||
"""
|
||||
try:
|
||||
# Get the appropriate executor for this tool type
|
||||
executor = ToolExecutorFactory.get_executor(tool.tool_type)
|
||||
|
||||
# Execute the tool
|
||||
return executor.execute(function_name, function_args, self.agent, tool)
|
||||
|
||||
except Exception as e:
|
||||
self.logger.error(f"Error executing tool {function_name}: {str(e)}")
|
||||
error_message = get_friendly_error_msg(function_name=function_name, exception_name=type(e).__name__, exception_message=str(e))
|
||||
return error_message, SandboxRunResult(status="error")
|
||||
|
||||
@@ -10,7 +10,7 @@ from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.server.rest_api.app import app
|
||||
from letta.server.server import SyncServer
|
||||
from letta.services.tool_executor.tool_executor import ToolExecutionManager
|
||||
from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -67,9 +67,8 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
agent = server.load_agent(agent_state.id, actor=default_user)
|
||||
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool(
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool
|
||||
)
|
||||
|
||||
@@ -81,8 +80,7 @@ def test_composio_tool_execution_e2e(check_composio_key_set, composio_gmail_get_
|
||||
agent_update=UpdateAgent(tool_exec_environment_variables={COMPOSIO_ENTITY_ENV_VAR_KEY: "matt"}),
|
||||
actor=default_user,
|
||||
)
|
||||
agent = server.load_agent(agent_state.id, actor=default_user)
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent).execute_tool(
|
||||
function_response, sandbox_run_result = ToolExecutionManager(agent_state, actor=default_user).execute_tool(
|
||||
function_name=composio_gmail_get_profile_tool.name, function_args={}, tool=composio_gmail_get_profile_tool
|
||||
)
|
||||
assert function_response["response_data"]["emailAddress"] == "matt@letta.com"
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -577,36 +576,37 @@ def test_send_system_message(client: Letta, agent: AgentState):
|
||||
assert send_system_message_response, "Sending message failed"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: Letta, agent: AgentState, request):
|
||||
"""
|
||||
Test that sending two messages in parallel does not error.
|
||||
"""
|
||||
|
||||
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
async def send_message_task(message: str):
|
||||
response = await asyncio.to_thread(
|
||||
client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)]
|
||||
)
|
||||
assert response, f"Sending message '{message}' failed"
|
||||
return response
|
||||
|
||||
# Prepare two tasks with different messages
|
||||
messages = ["Test message 1", "Test message 2"]
|
||||
tasks = [send_message_task(message) for message in messages]
|
||||
|
||||
# Run the tasks concurrently
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for exceptions and validate responses
|
||||
for i, response in enumerate(responses):
|
||||
if isinstance(response, Exception):
|
||||
pytest.fail(f"Task {i} failed with exception: {response}")
|
||||
else:
|
||||
assert response, f"Task {i} returned an invalid response: {response}"
|
||||
|
||||
# Ensure both tasks completed
|
||||
assert len(responses) == len(messages), "Not all messages were processed"
|
||||
# TODO: Add back when new agent loop hits
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_send_message_parallel(client: Letta, agent: AgentState, request):
|
||||
# """
|
||||
# Test that sending two messages in parallel does not error.
|
||||
# """
|
||||
#
|
||||
# # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
# async def send_message_task(message: str):
|
||||
# response = await asyncio.to_thread(
|
||||
# client.agents.messages.create, agent_id=agent.id, messages=[MessageCreate(role="user", content=message)]
|
||||
# )
|
||||
# assert response, f"Sending message '{message}' failed"
|
||||
# return response
|
||||
#
|
||||
# # Prepare two tasks with different messages
|
||||
# messages = ["Test message 1", "Test message 2"]
|
||||
# tasks = [send_message_task(message) for message in messages]
|
||||
#
|
||||
# # Run the tasks concurrently
|
||||
# responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
#
|
||||
# # Check for exceptions and validate responses
|
||||
# for i, response in enumerate(responses):
|
||||
# if isinstance(response, Exception):
|
||||
# pytest.fail(f"Task {i} failed with exception: {response}")
|
||||
# else:
|
||||
# assert response, f"Task {i} returned an invalid response: {response}"
|
||||
#
|
||||
# # Ensure both tasks completed
|
||||
# assert len(responses) == len(messages), "Not all messages were processed"
|
||||
|
||||
|
||||
# ----------------------------------------------------------------------------------------------------
|
||||
|
||||
@@ -1,4 +1,3 @@
|
||||
import asyncio
|
||||
import os
|
||||
import threading
|
||||
import time
|
||||
@@ -444,43 +443,44 @@ def test_function_always_error(client: LettaSDKClient, agent: AgentState):
|
||||
assert "ZeroDivisionError" in response_message.tool_return
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_send_message_parallel(client: LettaSDKClient, agent: AgentState):
|
||||
"""
|
||||
Test that sending two messages in parallel does not error.
|
||||
"""
|
||||
|
||||
# Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
async def send_message_task(message: str):
|
||||
response = await asyncio.to_thread(
|
||||
client.agents.messages.create,
|
||||
agent_id=agent.id,
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=message,
|
||||
),
|
||||
],
|
||||
)
|
||||
assert response, f"Sending message '{message}' failed"
|
||||
return response
|
||||
|
||||
# Prepare two tasks with different messages
|
||||
messages = ["Test message 1", "Test message 2"]
|
||||
tasks = [send_message_task(message) for message in messages]
|
||||
|
||||
# Run the tasks concurrently
|
||||
responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
|
||||
# Check for exceptions and validate responses
|
||||
for i, response in enumerate(responses):
|
||||
if isinstance(response, Exception):
|
||||
pytest.fail(f"Task {i} failed with exception: {response}")
|
||||
else:
|
||||
assert response, f"Task {i} returned an invalid response: {response}"
|
||||
|
||||
# Ensure both tasks completed
|
||||
assert len(responses) == len(messages), "Not all messages were processed"
|
||||
# TODO: Add back when the new agent loop hits
|
||||
# @pytest.mark.asyncio
|
||||
# async def test_send_message_parallel(client: LettaSDKClient, agent: AgentState):
|
||||
# """
|
||||
# Test that sending two messages in parallel does not error.
|
||||
# """
|
||||
#
|
||||
# # Define a coroutine for sending a message using asyncio.to_thread for synchronous calls
|
||||
# async def send_message_task(message: str):
|
||||
# response = await asyncio.to_thread(
|
||||
# client.agents.messages.create,
|
||||
# agent_id=agent.id,
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=message,
|
||||
# ),
|
||||
# ],
|
||||
# )
|
||||
# assert response, f"Sending message '{message}' failed"
|
||||
# return response
|
||||
#
|
||||
# # Prepare two tasks with different messages
|
||||
# messages = ["Test message 1", "Test message 2"]
|
||||
# tasks = [send_message_task(message) for message in messages]
|
||||
#
|
||||
# # Run the tasks concurrently
|
||||
# responses = await asyncio.gather(*tasks, return_exceptions=True)
|
||||
#
|
||||
# # Check for exceptions and validate responses
|
||||
# for i, response in enumerate(responses):
|
||||
# if isinstance(response, Exception):
|
||||
# pytest.fail(f"Task {i} failed with exception: {response}")
|
||||
# else:
|
||||
# assert response, f"Task {i} returned an invalid response: {response}"
|
||||
#
|
||||
# # Ensure both tasks completed
|
||||
# assert len(responses) == len(messages), "Not all messages were processed"
|
||||
|
||||
|
||||
def test_send_message_async(client: LettaSDKClient, agent: AgentState):
|
||||
|
||||
Reference in New Issue
Block a user