feat: add MCP servers into a table and MCP tool execution to new agent loop (#2323)
Co-authored-by: Matt Zhou <mattzh1314@gmail.com> Co-authored-by: Kian Jones <11655409+kianjones9@users.noreply.github.com>
This commit is contained in:
@@ -0,0 +1,51 @@
|
||||
"""add table to store mcp servers
|
||||
|
||||
Revision ID: 9ecbdbaa409f
|
||||
Revises: 6c53224a7a58
|
||||
Create Date: 2025-05-21 15:25:12.483026
|
||||
|
||||
"""
|
||||
|
||||
from typing import Sequence, Union
|
||||
|
||||
import sqlalchemy as sa
|
||||
|
||||
import letta
|
||||
from alembic import op
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision: str = "9ecbdbaa409f"
|
||||
down_revision: Union[str, None] = "6c53224a7a58"
|
||||
branch_labels: Union[str, Sequence[str], None] = None
|
||||
depends_on: Union[str, Sequence[str], None] = None
|
||||
|
||||
|
||||
def upgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table(
|
||||
"mcp_server",
|
||||
sa.Column("id", sa.String(), nullable=False),
|
||||
sa.Column("server_name", sa.String(), nullable=False),
|
||||
sa.Column("server_type", sa.String(), nullable=False),
|
||||
sa.Column("server_url", sa.String(), nullable=True),
|
||||
sa.Column("stdio_config", letta.orm.custom_columns.MCPStdioServerConfigColumn(), nullable=True),
|
||||
sa.Column("organization_id", sa.String(), nullable=False),
|
||||
sa.Column("is_deleted", sa.Boolean(), server_default=sa.text("FALSE"), nullable=False),
|
||||
sa.Column("created_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("updated_at", sa.DateTime(timezone=True), server_default=sa.text("now()"), nullable=True),
|
||||
sa.Column("_created_by_id", sa.String(), nullable=True),
|
||||
sa.Column("_last_updated_by_id", sa.String(), nullable=True),
|
||||
sa.Column("metadata_", sa.JSON(), nullable=True),
|
||||
sa.PrimaryKeyConstraint("id"),
|
||||
sa.ForeignKeyConstraint(
|
||||
["organization_id"],
|
||||
["organizations.id"],
|
||||
),
|
||||
sa.UniqueConstraint("server_name", "organization_id", name="uix_name_organization_mcp_server"),
|
||||
)
|
||||
|
||||
|
||||
def downgrade() -> None:
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_table("mcp_server")
|
||||
# ### end Alembic commands ###
|
||||
@@ -27,7 +27,6 @@ from letta.errors import ContextWindowExceededError
|
||||
from letta.functions.ast_parsers import coerce_dict_args_by_annotations, get_function_annotations_from_source
|
||||
from letta.functions.composio_helpers import execute_composio_action, generate_composio_action_from_func_name
|
||||
from letta.functions.functions import get_function_from_module
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.helpers import ToolRulesSolver
|
||||
from letta.helpers.composio_helpers import get_composio_api_key
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@@ -61,6 +60,7 @@ from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import check_supports_structured_output, compile_memory_metadata_block
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
@@ -103,7 +103,7 @@ class Agent(BaseAgent):
|
||||
# extras
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
# MCP sessions, state held in-memory in the server
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
mcp_clients: Optional[Dict[str, AsyncBaseMCPClient]] = None,
|
||||
save_last_response: bool = False,
|
||||
):
|
||||
assert isinstance(agent_state.memory, Memory), f"Memory object is not of type Memory: {type(agent_state.memory)}"
|
||||
@@ -168,7 +168,11 @@ class Agent(BaseAgent):
|
||||
self.logger = get_logger(agent_state.id)
|
||||
|
||||
# MCPClient, state/sessions managed by the server
|
||||
self.mcp_clients = mcp_clients
|
||||
# TODO: This is temporary, as a bridge
|
||||
self.mcp_clients = None
|
||||
# TODO: no longer supported
|
||||
# if mcp_clients:
|
||||
# self.mcp_clients = {client_id: client.to_sync_client() for client_id, client in mcp_clients.items()}
|
||||
|
||||
def load_last_function_response(self):
|
||||
"""Load the last function response from message history"""
|
||||
@@ -1601,8 +1605,6 @@ class Agent(BaseAgent):
|
||||
if server_name not in self.mcp_clients:
|
||||
raise ValueError(f"Unknown MCP server name: {server_name}")
|
||||
mcp_client = self.mcp_clients[server_name]
|
||||
if not isinstance(mcp_client, BaseMCPClient):
|
||||
raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
||||
|
||||
# Check that tool exists
|
||||
available_tools = mcp_client.list_tools()
|
||||
|
||||
@@ -304,6 +304,7 @@ class LettaAgent(BaseAgent):
|
||||
current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async(
|
||||
input_messages, agent_state, self.message_manager, self.actor
|
||||
)
|
||||
|
||||
tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
llm_client = LLMClient.create(
|
||||
provider_type=agent_state.llm_config.model_endpoint_type,
|
||||
@@ -469,6 +470,8 @@ class LettaAgent(BaseAgent):
|
||||
ToolType.LETTA_SLEEPTIME_CORE,
|
||||
ToolType.LETTA_VOICE_SLEEPTIME_CORE,
|
||||
ToolType.LETTA_BUILTIN,
|
||||
ToolType.EXTERNAL_COMPOSIO,
|
||||
ToolType.EXTERNAL_MCP,
|
||||
}
|
||||
or (t.tool_type == ToolType.EXTERNAL_COMPOSIO)
|
||||
]
|
||||
@@ -533,12 +536,12 @@ class LettaAgent(BaseAgent):
|
||||
|
||||
tool_call_id = tool_call.id or f"call_{uuid.uuid4().hex[:8]}"
|
||||
|
||||
tool_result, success_flag = await self._execute_tool(
|
||||
tool_execution_result = await self._execute_tool(
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
function_response = package_function_response(tool_result, success_flag)
|
||||
function_response = package_function_response(tool_execution_result.func_return, tool_execution_result.success_flag)
|
||||
|
||||
# 4. Register tool call with tool rule solver
|
||||
# Resolve whether or not to continue stepping
|
||||
@@ -575,9 +578,10 @@ class LettaAgent(BaseAgent):
|
||||
model=agent_state.llm_config.model,
|
||||
function_name=tool_call_name,
|
||||
function_arguments=tool_args,
|
||||
tool_execution_result=tool_execution_result,
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=success_flag,
|
||||
function_response=tool_result,
|
||||
function_call_success=tool_execution_result.success_flag,
|
||||
function_response=tool_execution_result.func_return,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=continue_stepping,
|
||||
reasoning_content=reasoning_content,
|
||||
@@ -592,34 +596,37 @@ class LettaAgent(BaseAgent):
|
||||
return persisted_messages, continue_stepping
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
# Special memory case
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
return f"Tool not found: {tool_name}", False
|
||||
# TODO: fix this error message
|
||||
return ToolExecutionResult(
|
||||
func_return=f"Tool {tool_name} not found",
|
||||
status="error",
|
||||
)
|
||||
|
||||
# TODO: This temp. Move this logic and code to executors
|
||||
try:
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result.func_return, True
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
tool_execution_manager = ToolExecutionManager(
|
||||
agent_state=agent_state,
|
||||
message_manager=self.message_manager,
|
||||
agent_manager=self.agent_manager,
|
||||
block_manager=self.block_manager,
|
||||
passage_manager=self.passage_manager,
|
||||
actor=self.actor,
|
||||
)
|
||||
# TODO: Integrate sandbox result
|
||||
log_event(name=f"start_{tool_name}_execution", attributes=tool_args)
|
||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||
function_name=tool_name, function_args=tool_args, tool=target_tool
|
||||
)
|
||||
log_event(name=f"finish_{tool_name}_execution", attributes=tool_args)
|
||||
return tool_execution_result
|
||||
|
||||
@trace_method
|
||||
async def _send_message_to_agents_matching_tags(
|
||||
|
||||
@@ -27,6 +27,7 @@ from letta.schemas.llm_batch_job import LLMBatchItem
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall
|
||||
from letta.schemas.sandbox_config import SandboxConfig, SandboxType
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.utils import create_heartbeat_system_message, create_letta_messages_from_llm_response
|
||||
from letta.services.agent_manager import AgentManager
|
||||
@@ -66,15 +67,17 @@ class _ResumeContext:
|
||||
request_status_updates: List[RequestStatusUpdateInfo]
|
||||
|
||||
|
||||
async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[str, bool]]:
|
||||
async def execute_tool_wrapper(params: ToolExecutionParams) -> tuple[str, ToolExecutionResult]:
|
||||
"""
|
||||
Executes the tool in an out‑of‑process worker and returns:
|
||||
(agent_id, (tool_result:str, success_flag:bool))
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
# locate the tool on the agent
|
||||
target_tool = next((t for t in params.agent_state.tools if t.name == params.tool_call_name), None)
|
||||
if not target_tool:
|
||||
return params.agent_id, (f"Tool not found: {params.tool_call_name}", False)
|
||||
return params.agent_id, ToolExecutionResult(func_return=f"Tool not found: {params.tool_call_name}", status="error")
|
||||
|
||||
try:
|
||||
mgr = ToolExecutionManager(
|
||||
@@ -88,9 +91,9 @@ async def execute_tool_wrapper(params: ToolExecutionParams) -> Tuple[str, Tuple[
|
||||
function_args=params.tool_args,
|
||||
tool=target_tool,
|
||||
)
|
||||
return params.agent_id, (tool_execution_result.func_return, True)
|
||||
return params.agent_id, tool_execution_result
|
||||
except Exception as e:
|
||||
return params.agent_id, (f"Failed to call tool. Error: {e}", False)
|
||||
return params.agent_id, ToolExecutionResult(func_return=f"Failed to call tool. Error: {e}", status="error")
|
||||
|
||||
|
||||
# TODO: Limitations ->
|
||||
@@ -393,7 +396,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
return cfg, env
|
||||
|
||||
@trace_method
|
||||
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[tuple[str, ToolExecutionResult]]:
|
||||
sbx_cfg, sbx_env = self._build_sandbox()
|
||||
rethink_memory_tool_name = "rethink_memory"
|
||||
tool_params = []
|
||||
@@ -424,7 +427,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
return await pool.map(execute_tool_wrapper, tool_params)
|
||||
|
||||
@trace_method
|
||||
async def _bulk_rethink_memory_async(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
async def _bulk_rethink_memory_async(self, params: List[ToolExecutionParams]) -> Sequence[tuple[str, ToolExecutionResult]]:
|
||||
updates = {}
|
||||
result = []
|
||||
for param in params:
|
||||
@@ -443,7 +446,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
updates[block_id] = new_value
|
||||
|
||||
# TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools
|
||||
result.append((param.agent_id, ("", True)))
|
||||
result.append((param.agent_id, ToolExecutionResult(status="success")))
|
||||
|
||||
await self.block_manager.bulk_update_block_values_async(updates=updates, actor=self.actor)
|
||||
|
||||
@@ -451,7 +454,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
async def _persist_tool_messages(
|
||||
self,
|
||||
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
|
||||
exec_results: Sequence[Tuple[str, "ToolExecutionResult"]],
|
||||
ctx: _ResumeContext,
|
||||
) -> Dict[str, List[Message]]:
|
||||
# TODO: This is redundant, we should have this ready on the ctx
|
||||
@@ -459,14 +462,15 @@ class LettaAgentBatch(BaseAgent):
|
||||
agent_item_map: Dict[str, LLMBatchItem] = {item.agent_id: item for item in ctx.batch_items}
|
||||
|
||||
msg_map: Dict[str, List[Message]] = {}
|
||||
for aid, (tool_res, success) in exec_results:
|
||||
for aid, tool_exec_result in exec_results:
|
||||
msgs = self._create_tool_call_messages(
|
||||
llm_batch_item_id=agent_item_map[aid].id,
|
||||
agent_state=ctx.agent_state_map[aid],
|
||||
tool_call_name=ctx.tool_call_name_map[aid],
|
||||
tool_call_args=ctx.tool_call_args_map[aid],
|
||||
tool_exec_result=tool_res,
|
||||
success_flag=success,
|
||||
tool_exec_result=tool_exec_result.func_return,
|
||||
success_flag=tool_exec_result.success_flag,
|
||||
tool_exec_result_obj=tool_exec_result,
|
||||
reasoning_content=None,
|
||||
)
|
||||
msg_map[aid] = msgs
|
||||
@@ -482,14 +486,14 @@ class LettaAgentBatch(BaseAgent):
|
||||
|
||||
def _prepare_next_iteration(
|
||||
self,
|
||||
exec_results: Sequence[Tuple[str, Tuple[str, bool]]],
|
||||
exec_results: Sequence[Tuple[str, "ToolExecutionResult"]],
|
||||
ctx: _ResumeContext,
|
||||
msg_map: Dict[str, List[Message]],
|
||||
) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]:
|
||||
# who continues?
|
||||
continues = [aid for aid, cont in ctx.should_continue_map.items() if cont]
|
||||
|
||||
success_flag_map = {aid: flag for aid, (_res, flag) in exec_results}
|
||||
success_flag_map = {aid: result.success_flag for aid, result in exec_results}
|
||||
|
||||
batch_reqs: List[LettaBatchRequest] = []
|
||||
for aid in continues:
|
||||
@@ -528,6 +532,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
tool_call_name: str,
|
||||
tool_call_args: Dict[str, Any],
|
||||
tool_exec_result: str,
|
||||
tool_exec_result_obj: "ToolExecutionResult",
|
||||
success_flag: bool,
|
||||
reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None,
|
||||
) -> List[Message]:
|
||||
@@ -541,6 +546,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=success_flag,
|
||||
function_response=tool_exec_result,
|
||||
tool_execution_result=tool_exec_result_obj,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=False,
|
||||
reasoning_content=reasoning_content,
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import json
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional, Tuple
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
|
||||
import openai
|
||||
|
||||
@@ -118,6 +118,7 @@ class VoiceAgent(BaseAgent):
|
||||
Main streaming loop that yields partial tokens.
|
||||
Whenever we detect a tool call, we yield from _handle_ai_response as well.
|
||||
"""
|
||||
print("CALL STREAM")
|
||||
if len(input_messages) != 1 or input_messages[0].role != MessageRole.user:
|
||||
raise ValueError(f"Voice Agent was invoked with multiple input messages or message did not have role `user`: {input_messages}")
|
||||
|
||||
@@ -238,14 +239,17 @@ class VoiceAgent(BaseAgent):
|
||||
)
|
||||
in_memory_message_history.append(assistant_tool_call_msg.model_dump())
|
||||
|
||||
tool_result, success_flag = await self._execute_tool(
|
||||
tool_execution_result = await self._execute_tool(
|
||||
user_query=user_query,
|
||||
tool_name=tool_call_name,
|
||||
tool_args=tool_args,
|
||||
agent_state=agent_state,
|
||||
)
|
||||
tool_result = tool_execution_result.func_return
|
||||
success_flag = tool_execution_result.success_flag
|
||||
|
||||
# 3. Provide function_call response back into the conversation
|
||||
# TODO: fix this tool format
|
||||
tool_message = ToolMessage(
|
||||
content=json.dumps({"result": tool_result}),
|
||||
tool_call_id=tool_call_id,
|
||||
@@ -267,6 +271,7 @@ class VoiceAgent(BaseAgent):
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=success_flag,
|
||||
function_response=tool_result,
|
||||
tool_execution_result=tool_execution_result,
|
||||
actor=self.actor,
|
||||
add_heartbeat_request_system_message=True,
|
||||
)
|
||||
@@ -388,10 +393,14 @@ class VoiceAgent(BaseAgent):
|
||||
for t in tools
|
||||
]
|
||||
|
||||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
async def _execute_tool(self, user_query: str, tool_name: str, tool_args: dict, agent_state: AgentState) -> "ToolExecutionResult":
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
print("EXECUTING TOOL")
|
||||
|
||||
# Special memory case
|
||||
if tool_name == "search_memory":
|
||||
tool_result = await self._search_memory(
|
||||
@@ -401,11 +410,17 @@ class VoiceAgent(BaseAgent):
|
||||
end_minutes_ago=tool_args["end_minutes_ago"],
|
||||
agent_state=agent_state,
|
||||
)
|
||||
return tool_result, True
|
||||
return ToolExecutionResult(
|
||||
func_return=tool_result,
|
||||
status="success",
|
||||
)
|
||||
else:
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
return f"Tool not found: {tool_name}", False
|
||||
return ToolExecutionResult(
|
||||
func_return=f"Tool not found: {tool_name}",
|
||||
status="error",
|
||||
)
|
||||
|
||||
try:
|
||||
tool_result, _ = execute_external_tool(
|
||||
@@ -416,9 +431,9 @@ class VoiceAgent(BaseAgent):
|
||||
actor=self.actor,
|
||||
allow_agent_state_modifications=False,
|
||||
)
|
||||
return tool_result, True
|
||||
return ToolExecutionResult(func_return=tool_result, status="success")
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
return ToolExecutionResult(func_return=f"Failed to call tool. Error: {e}", status="error")
|
||||
|
||||
async def _search_memory(
|
||||
self,
|
||||
|
||||
@@ -89,20 +89,23 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
)
|
||||
|
||||
@trace_method
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
async def _execute_tool(self, tool_name: str, tool_args: dict, agent_state: AgentState):
|
||||
"""
|
||||
Executes a tool and returns (result, success_flag).
|
||||
"""
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
|
||||
# Special memory case
|
||||
target_tool = next((x for x in agent_state.tools if x.name == tool_name), None)
|
||||
if not target_tool:
|
||||
return f"Tool not found: {tool_name}", False
|
||||
return ToolExecutionResult(func_return=f"Tool not found: {tool_name}", success_flag=False)
|
||||
|
||||
try:
|
||||
if target_tool.name == "rethink_user_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
|
||||
return self.rethink_user_memory(agent_state=agent_state, **tool_args)
|
||||
func_return, success_flag = self.rethink_user_memory(agent_state=agent_state, **tool_args)
|
||||
return ToolExecutionResult(func_return=func_return, status="success" if success_flag else "error")
|
||||
elif target_tool.name == "finish_rethinking_memory" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
|
||||
return "", True
|
||||
return ToolExecutionResult(func_return="", status="success")
|
||||
elif target_tool.name == "store_memories" and target_tool.tool_type == ToolType.LETTA_VOICE_SLEEPTIME_CORE:
|
||||
chunks = tool_args.get("chunks", [])
|
||||
results = [self.store_memory(agent_state=self.convo_agent_state, **chunk_args) for chunk_args in chunks]
|
||||
@@ -110,12 +113,14 @@ class VoiceSleeptimeAgent(LettaAgent):
|
||||
aggregated_result = next((res for res, _ in results if res is not None), None)
|
||||
aggregated_success = all(success for _, success in results)
|
||||
|
||||
return aggregated_result, aggregated_success # Note that here we store to the convo agent's archival memory
|
||||
return ToolExecutionResult(
|
||||
func_return=aggregated_result, status="success" if aggregated_success else "error"
|
||||
) # Note that here we store to the convo agent's archival memory
|
||||
else:
|
||||
result = f"Voice sleeptime agent tried invoking invalid tool with type {target_tool.tool_type}: {target_tool}"
|
||||
return result, False
|
||||
return ToolExecutionResult(func_return=result, status="error")
|
||||
except Exception as e:
|
||||
return f"Failed to call tool. Error: {e}", False
|
||||
return ToolExecutionResult(func_return=f"Failed to call tool. Error: {e}", status="error")
|
||||
|
||||
def rethink_user_memory(self, new_memory: str, agent_state: AgentState) -> Tuple[str, bool]:
|
||||
if agent_state.memory.get_block(self.target_block_label) is None:
|
||||
|
||||
@@ -99,6 +99,13 @@ LETTA_TOOL_SET = set(
|
||||
+ BUILTIN_TOOLS
|
||||
)
|
||||
|
||||
|
||||
def FUNCTION_RETURN_VALUE_TRUNCATED(return_str, return_char: int, return_char_limit: int):
|
||||
return (
|
||||
f"{return_str}... [NOTE: function output was truncated since it exceeded the character limit: {return_char} > {return_char_limit}]"
|
||||
)
|
||||
|
||||
|
||||
# The name of the tool used to send message to the user
|
||||
# May not be relevant in cases where the agent has multiple ways to message to user (send_imessage, send_discord_mesasge, ...)
|
||||
# or in cases where the agent has no concept of messaging a user (e.g. a workflow agent)
|
||||
|
||||
@@ -1,102 +1,156 @@
|
||||
import asyncio
|
||||
from typing import List, Optional, Tuple
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.types import TextContent
|
||||
|
||||
from letta.functions.mcp_client.exceptions import MCPTimeoutError
|
||||
from letta.functions.mcp_client.types import BaseServerConfig, MCPTool
|
||||
from letta.log import get_logger
|
||||
from letta.settings import tool_settings
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class BaseMCPClient:
|
||||
def __init__(self, server_config: BaseServerConfig):
|
||||
self.server_config = server_config
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.stdio = None
|
||||
self.write = None
|
||||
self.initialized = False
|
||||
self.loop = asyncio.new_event_loop()
|
||||
self.cleanup_funcs = []
|
||||
|
||||
def connect_to_server(self):
|
||||
asyncio.set_event_loop(self.loop)
|
||||
success = self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
|
||||
if success:
|
||||
try:
|
||||
self.loop.run_until_complete(
|
||||
asyncio.wait_for(self.session.initialize(), timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
)
|
||||
self.initialized = True
|
||||
except asyncio.TimeoutError:
|
||||
raise MCPTimeoutError("initializing session", self.server_config.server_name, tool_settings.mcp_connect_to_server_timeout)
|
||||
else:
|
||||
raise RuntimeError(
|
||||
f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
)
|
||||
|
||||
def _initialize_connection(self, server_config: BaseServerConfig, timeout: float) -> bool:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
def list_tools(self) -> List[MCPTool]:
|
||||
self._check_initialized()
|
||||
try:
|
||||
response = self.loop.run_until_complete(
|
||||
asyncio.wait_for(self.session.list_tools(), timeout=tool_settings.mcp_list_tools_timeout)
|
||||
)
|
||||
return response.tools
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"Timed out while listing tools for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_list_tools_timeout}s)."
|
||||
)
|
||||
raise MCPTimeoutError("listing tools", self.server_config.server_name, tool_settings.mcp_list_tools_timeout)
|
||||
|
||||
def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
self._check_initialized()
|
||||
try:
|
||||
result = self.loop.run_until_complete(
|
||||
asyncio.wait_for(self.session.call_tool(tool_name, tool_args), timeout=tool_settings.mcp_execute_tool_timeout)
|
||||
)
|
||||
|
||||
parsed_content = []
|
||||
for content_piece in result.content:
|
||||
if isinstance(content_piece, TextContent):
|
||||
parsed_content.append(content_piece.text)
|
||||
print("parsed_content (text)", parsed_content)
|
||||
else:
|
||||
parsed_content.append(str(content_piece))
|
||||
print("parsed_content (other)", parsed_content)
|
||||
|
||||
if len(parsed_content) > 0:
|
||||
final_content = " ".join(parsed_content)
|
||||
else:
|
||||
# TODO move hardcoding to constants
|
||||
final_content = "Empty response from tool"
|
||||
|
||||
return final_content, result.isError
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(
|
||||
f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
|
||||
)
|
||||
raise MCPTimeoutError(f"executing tool '{tool_name}'", self.server_config.server_name, tool_settings.mcp_execute_tool_timeout)
|
||||
|
||||
def _check_initialized(self):
|
||||
if not self.initialized:
|
||||
logger.error("MCPClient has not been initialized")
|
||||
raise RuntimeError("MCPClient has not been initialized")
|
||||
|
||||
def cleanup(self):
|
||||
try:
|
||||
for cleanup_func in self.cleanup_funcs:
|
||||
cleanup_func()
|
||||
self.initialized = False
|
||||
if not self.loop.is_closed():
|
||||
self.loop.close()
|
||||
except Exception as e:
|
||||
logger.warning(e)
|
||||
finally:
|
||||
logger.info("Cleaned up MCP clients on shutdown.")
|
||||
# class BaseMCPClient:
|
||||
# def __init__(self, server_config: BaseServerConfig):
|
||||
# self.server_config = server_config
|
||||
# self.session: Optional[ClientSession] = None
|
||||
# self.stdio = None
|
||||
# self.write = None
|
||||
# self.initialized = False
|
||||
# self.loop = asyncio.new_event_loop()
|
||||
# self.cleanup_funcs = []
|
||||
#
|
||||
# def connect_to_server(self):
|
||||
# asyncio.set_event_loop(self.loop)
|
||||
# success = self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
#
|
||||
# if success:
|
||||
# try:
|
||||
# self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.initialize(), timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
# )
|
||||
# self.initialized = True
|
||||
# except asyncio.TimeoutError:
|
||||
# raise MCPTimeoutError("initializing session", self.server_config.server_name, tool_settings.mcp_connect_to_server_timeout)
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
# )
|
||||
#
|
||||
# def _initialize_connection(self, server_config: BaseServerConfig, timeout: float) -> bool:
|
||||
# raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
#
|
||||
# def list_tools(self) -> List[MCPTool]:
|
||||
# self._check_initialized()
|
||||
# try:
|
||||
# response = self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.list_tools(), timeout=tool_settings.mcp_list_tools_timeout)
|
||||
# )
|
||||
# return response.tools
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(
|
||||
# f"Timed out while listing tools for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_list_tools_timeout}s)."
|
||||
# )
|
||||
# raise MCPTimeoutError("listing tools", self.server_config.server_name, tool_settings.mcp_list_tools_timeout)
|
||||
#
|
||||
# def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
# self._check_initialized()
|
||||
# try:
|
||||
# result = self.loop.run_until_complete(
|
||||
# asyncio.wait_for(self.session.call_tool(tool_name, tool_args), timeout=tool_settings.mcp_execute_tool_timeout)
|
||||
# )
|
||||
#
|
||||
# parsed_content = []
|
||||
# for content_piece in result.content:
|
||||
# if isinstance(content_piece, TextContent):
|
||||
# parsed_content.append(content_piece.text)
|
||||
# print("parsed_content (text)", parsed_content)
|
||||
# else:
|
||||
# parsed_content.append(str(content_piece))
|
||||
# print("parsed_content (other)", parsed_content)
|
||||
#
|
||||
# if len(parsed_content) > 0:
|
||||
# final_content = " ".join(parsed_content)
|
||||
# else:
|
||||
# # TODO move hardcoding to constants
|
||||
# final_content = "Empty response from tool"
|
||||
#
|
||||
# return final_content, result.isError
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(
|
||||
# f"Timed out while executing tool '{tool_name}' for MCP server {self.server_config.server_name} (timeout={tool_settings.mcp_execute_tool_timeout}s)."
|
||||
# )
|
||||
# raise MCPTimeoutError(f"executing tool '{tool_name}'", self.server_config.server_name, tool_settings.mcp_execute_tool_timeout)
|
||||
#
|
||||
# def _check_initialized(self):
|
||||
# if not self.initialized:
|
||||
# logger.error("MCPClient has not been initialized")
|
||||
# raise RuntimeError("MCPClient has not been initialized")
|
||||
#
|
||||
# def cleanup(self):
|
||||
# try:
|
||||
# for cleanup_func in self.cleanup_funcs:
|
||||
# cleanup_func()
|
||||
# self.initialized = False
|
||||
# if not self.loop.is_closed():
|
||||
# self.loop.close()
|
||||
# except Exception as e:
|
||||
# logger.warning(e)
|
||||
# finally:
|
||||
# logger.info("Cleaned up MCP clients on shutdown.")
|
||||
#
|
||||
#
|
||||
# class BaseAsyncMCPClient:
|
||||
# def __init__(self, server_config: BaseServerConfig):
|
||||
# self.server_config = server_config
|
||||
# self.session: Optional[ClientSession] = None
|
||||
# self.stdio = None
|
||||
# self.write = None
|
||||
# self.initialized = False
|
||||
# self.cleanup_funcs = []
|
||||
#
|
||||
# async def connect_to_server(self):
|
||||
#
|
||||
# success = await self._initialize_connection(self.server_config, timeout=tool_settings.mcp_connect_to_server_timeout)
|
||||
#
|
||||
# if success:
|
||||
# self.initialized = True
|
||||
# else:
|
||||
# raise RuntimeError(
|
||||
# f"Connecting to MCP server failed. Please review your server config: {self.server_config.model_dump_json(indent=4)}"
|
||||
# )
|
||||
#
|
||||
# async def list_tools(self) -> List[MCPTool]:
|
||||
# self._check_initialized()
|
||||
# response = await self.session.list_tools()
|
||||
# return response.tools
|
||||
#
|
||||
# async def execute_tool(self, tool_name: str, tool_args: dict) -> Tuple[str, bool]:
|
||||
# self._check_initialized()
|
||||
# result = await self.session.call_tool(tool_name, tool_args)
|
||||
#
|
||||
# parsed_content = []
|
||||
# for content_piece in result.content:
|
||||
# if isinstance(content_piece, TextContent):
|
||||
# parsed_content.append(content_piece.text)
|
||||
# else:
|
||||
# parsed_content.append(str(content_piece))
|
||||
#
|
||||
# if len(parsed_content) > 0:
|
||||
# final_content = " ".join(parsed_content)
|
||||
# else:
|
||||
# # TODO move hardcoding to constants
|
||||
# final_content = "Empty response from tool"
|
||||
#
|
||||
# return final_content, result.isError
|
||||
#
|
||||
# def _check_initialized(self):
|
||||
# if not self.initialized:
|
||||
# logger.error("MCPClient has not been initialized")
|
||||
# raise RuntimeError("MCPClient has not been initialized")
|
||||
#
|
||||
# async def cleanup(self):
|
||||
# try:
|
||||
# for cleanup_func in self.cleanup_funcs:
|
||||
# cleanup_func()
|
||||
# self.initialized = False
|
||||
# if not self.loop.is_closed():
|
||||
# self.loop.close()
|
||||
# except Exception as e:
|
||||
# logger.warning(e)
|
||||
# finally:
|
||||
# logger.info("Cleaned up MCP clients on shutdown.")
|
||||
#
|
||||
|
||||
@@ -1,33 +1,51 @@
|
||||
import asyncio
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.types import SSEServerConfig
|
||||
from letta.log import get_logger
|
||||
|
||||
# see: https://modelcontextprotocol.io/quickstart/user
|
||||
MCP_CONFIG_TOPLEVEL_KEY = "mcpServers"
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# import asyncio
|
||||
#
|
||||
# from mcp import ClientSession
|
||||
# from mcp.client.sse import sse_client
|
||||
#
|
||||
# from letta.functions.mcp_client.base_client import BaseAsyncMCPClient, BaseMCPClient
|
||||
# from letta.functions.mcp_client.types import SSEServerConfig
|
||||
# from letta.log import get_logger
|
||||
#
|
||||
## see: https://modelcontextprotocol.io/quickstart/user
|
||||
#
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
class SSEMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
|
||||
try:
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout))
|
||||
self.stdio, self.write = sse_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timed out while establishing SSE connection (timeout={timeout}s).")
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Exception occurred while initializing SSE client session.")
|
||||
return False
|
||||
# class SSEMCPClient(BaseMCPClient):
|
||||
# def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# sse_cm = sse_client(url=server_config.server_url)
|
||||
# sse_transport = self.loop.run_until_complete(asyncio.wait_for(sse_cm.__aenter__(), timeout=timeout))
|
||||
# self.stdio, self.write = sse_transport
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(sse_cm.__aexit__(None, None, None)))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
# return True
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(f"Timed out while establishing SSE connection (timeout={timeout}s).")
|
||||
# return False
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing SSE client session.")
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# class AsyncSSEMCPClient(BaseAsyncMCPClient):
|
||||
#
|
||||
# async def _initialize_connection(self, server_config: SSEServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# sse_cm = sse_client(url=server_config.server_url)
|
||||
# sse_transport = await sse_cm.__aenter__()
|
||||
# self.stdio, self.write = sse_transport
|
||||
# self.cleanup_funcs.append(lambda: sse_cm.__aexit__(None, None, None))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = await session_cm.__aenter__()
|
||||
# self.cleanup_funcs.append(lambda: session_cm.__aexit__(None, None, None))
|
||||
# return True
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing SSE client session.")
|
||||
# return False
|
||||
#
|
||||
|
||||
@@ -1,108 +1,109 @@
|
||||
import asyncio
|
||||
import sys
|
||||
from contextlib import asynccontextmanager
|
||||
|
||||
import anyio
|
||||
import anyio.lowlevel
|
||||
import mcp.types as types
|
||||
from anyio.streams.text import TextReceiveStream
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import get_default_environment
|
||||
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
# import asyncio
|
||||
# import sys
|
||||
# from contextlib import asynccontextmanager
|
||||
#
|
||||
# import anyio
|
||||
# import anyio.lowlevel
|
||||
# import mcp.types as types
|
||||
# from anyio.streams.text import TextReceiveStream
|
||||
# from mcp import ClientSession, StdioServerParameters
|
||||
# from mcp.client.stdio import get_default_environment
|
||||
#
|
||||
# from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
# from letta.functions.mcp_client.types import StdioServerConfig
|
||||
# from letta.log import get_logger
|
||||
#
|
||||
# logger = get_logger(__name__)
|
||||
|
||||
|
||||
class StdioMCPClient(BaseMCPClient):
|
||||
def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
|
||||
try:
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
|
||||
stdio_cm = forked_stdio_client(server_params)
|
||||
stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
||||
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
return True
|
||||
except asyncio.TimeoutError:
|
||||
logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
|
||||
return False
|
||||
except Exception:
|
||||
logger.exception("Exception occurred while initializing stdio client session.")
|
||||
return False
|
||||
|
||||
|
||||
@asynccontextmanager
|
||||
async def forked_stdio_client(server: StdioServerParameters):
|
||||
"""
|
||||
Client transport for stdio: this will connect to a server by spawning a
|
||||
process and communicating with it over stdin/stdout.
|
||||
"""
|
||||
read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
|
||||
try:
|
||||
process = await anyio.open_process(
|
||||
[server.command, *server.args],
|
||||
env=server.env or get_default_environment(),
|
||||
stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
|
||||
)
|
||||
except OSError as exc:
|
||||
raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
|
||||
|
||||
async def stdout_reader():
|
||||
assert process.stdout, "Opened process is missing stdout"
|
||||
buffer = ""
|
||||
try:
|
||||
async with read_stream_writer:
|
||||
async for chunk in TextReceiveStream(
|
||||
process.stdout,
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
):
|
||||
lines = (buffer + chunk).split("\n")
|
||||
buffer = lines.pop()
|
||||
for line in lines:
|
||||
try:
|
||||
message = types.JSONRPCMessage.model_validate_json(line)
|
||||
except Exception as exc:
|
||||
await read_stream_writer.send(exc)
|
||||
continue
|
||||
await read_stream_writer.send(message)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def stdin_writer():
|
||||
assert process.stdin, "Opened process is missing stdin"
|
||||
try:
|
||||
async with write_stream_reader:
|
||||
async for message in write_stream_reader:
|
||||
json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
await process.stdin.send(
|
||||
(json + "\n").encode(
|
||||
encoding=server.encoding,
|
||||
errors=server.encoding_error_handler,
|
||||
)
|
||||
)
|
||||
except anyio.ClosedResourceError:
|
||||
await anyio.lowlevel.checkpoint()
|
||||
|
||||
async def watch_process_exit():
|
||||
returncode = await process.wait()
|
||||
if returncode != 0:
|
||||
raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
|
||||
|
||||
async with anyio.create_task_group() as tg, process:
|
||||
tg.start_soon(stdout_reader)
|
||||
tg.start_soon(stdin_writer)
|
||||
tg.start_soon(watch_process_exit)
|
||||
|
||||
with anyio.move_on_after(0.2):
|
||||
await anyio.sleep_forever()
|
||||
|
||||
yield read_stream, write_stream
|
||||
# class StdioMCPClient(BaseMCPClient):
|
||||
# def _initialize_connection(self, server_config: StdioServerConfig, timeout: float) -> bool:
|
||||
# try:
|
||||
# server_params = StdioServerParameters(command=server_config.command, args=server_config.args, env=server_config.env)
|
||||
# stdio_cm = forked_stdio_client(server_params)
|
||||
# stdio_transport = self.loop.run_until_complete(asyncio.wait_for(stdio_cm.__aenter__(), timeout=timeout))
|
||||
# self.stdio, self.write = stdio_transport
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(stdio_cm.__aexit__(None, None, None)))
|
||||
#
|
||||
# session_cm = ClientSession(self.stdio, self.write)
|
||||
# self.session = self.loop.run_until_complete(asyncio.wait_for(session_cm.__aenter__(), timeout=timeout))
|
||||
# self.cleanup_funcs.append(lambda: self.loop.run_until_complete(session_cm.__aexit__(None, None, None)))
|
||||
# return True
|
||||
# except asyncio.TimeoutError:
|
||||
# logger.error(f"Timed out while establishing stdio connection (timeout={timeout}s).")
|
||||
# return False
|
||||
# except Exception:
|
||||
# logger.exception("Exception occurred while initializing stdio client session.")
|
||||
# return False
|
||||
#
|
||||
#
|
||||
# @asynccontextmanager
|
||||
# async def forked_stdio_client(server: StdioServerParameters):
|
||||
# """
|
||||
# Client transport for stdio: this will connect to a server by spawning a
|
||||
# process and communicating with it over stdin/stdout.
|
||||
# """
|
||||
# read_stream_writer, read_stream = anyio.create_memory_object_stream(0)
|
||||
# write_stream, write_stream_reader = anyio.create_memory_object_stream(0)
|
||||
#
|
||||
# try:
|
||||
# process = await anyio.open_process(
|
||||
# [server.command, *server.args],
|
||||
# env=server.env or get_default_environment(),
|
||||
# stderr=sys.stderr, # Consider logging stderr somewhere instead of silencing it
|
||||
# )
|
||||
# except OSError as exc:
|
||||
# raise RuntimeError(f"Failed to spawn process: {server.command} {server.args}") from exc
|
||||
#
|
||||
# async def stdout_reader():
|
||||
# assert process.stdout, "Opened process is missing stdout"
|
||||
# buffer = ""
|
||||
# try:
|
||||
# async with read_stream_writer:
|
||||
# async for chunk in TextReceiveStream(
|
||||
# process.stdout,
|
||||
# encoding=server.encoding,
|
||||
# errors=server.encoding_error_handler,
|
||||
# ):
|
||||
# lines = (buffer + chunk).split("\n")
|
||||
# buffer = lines.pop()
|
||||
# for line in lines:
|
||||
# try:
|
||||
# message = types.JSONRPCMessage.model_validate_json(line)
|
||||
# except Exception as exc:
|
||||
# await read_stream_writer.send(exc)
|
||||
# continue
|
||||
# await read_stream_writer.send(message)
|
||||
# except anyio.ClosedResourceError:
|
||||
# await anyio.lowlevel.checkpoint()
|
||||
#
|
||||
# async def stdin_writer():
|
||||
# assert process.stdin, "Opened process is missing stdin"
|
||||
# try:
|
||||
# async with write_stream_reader:
|
||||
# async for message in write_stream_reader:
|
||||
# json = message.model_dump_json(by_alias=True, exclude_none=True)
|
||||
# await process.stdin.send(
|
||||
# (json + "\n").encode(
|
||||
# encoding=server.encoding,
|
||||
# errors=server.encoding_error_handler,
|
||||
# )
|
||||
# )
|
||||
# except anyio.ClosedResourceError:
|
||||
# await anyio.lowlevel.checkpoint()
|
||||
#
|
||||
# async def watch_process_exit():
|
||||
# returncode = await process.wait()
|
||||
# if returncode != 0:
|
||||
# raise RuntimeError(f"Subprocess exited with code {returncode}. Command: {server.command} {server.args}")
|
||||
#
|
||||
# async with anyio.create_task_group() as tg, process:
|
||||
# tg.start_soon(stdout_reader)
|
||||
# tg.start_soon(stdin_writer)
|
||||
# tg.start_soon(watch_process_exit)
|
||||
#
|
||||
# with anyio.move_on_after(0.2):
|
||||
# await anyio.sleep_forever()
|
||||
#
|
||||
# yield read_stream, write_stream
|
||||
#
|
||||
|
||||
@@ -2,13 +2,13 @@ import json
|
||||
from typing import Dict, Optional, Union
|
||||
|
||||
from letta.agent import Agent
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm.group import Group
|
||||
from letta.orm.user import User
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.group import ManagerType
|
||||
from letta.schemas.message import Message
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
|
||||
|
||||
def load_multi_agent(
|
||||
@@ -16,7 +16,7 @@ def load_multi_agent(
|
||||
agent_state: Optional[AgentState],
|
||||
actor: User,
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
mcp_clients: Optional[Dict[str, AsyncBaseMCPClient]] = None,
|
||||
) -> Agent:
|
||||
if len(group.agent_ids) == 0:
|
||||
raise ValueError("Empty group: group must have at least one agent")
|
||||
@@ -76,7 +76,6 @@ def load_multi_agent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
user=actor,
|
||||
mcp_clients=mcp_clients,
|
||||
group_id=group.id,
|
||||
agent_ids=group.agent_ids,
|
||||
description=group.description,
|
||||
|
||||
@@ -1,10 +1,9 @@
|
||||
import asyncio
|
||||
import threading
|
||||
from datetime import datetime, timezone
|
||||
from typing import Dict, List, Optional
|
||||
from typing import List, Optional
|
||||
|
||||
from letta.agent import Agent, AgentState
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.groups.helpers import stringify_message
|
||||
from letta.interface import AgentInterface
|
||||
from letta.orm import User
|
||||
@@ -27,7 +26,7 @@ class SleeptimeMultiAgent(Agent):
|
||||
interface: AgentInterface,
|
||||
agent_state: AgentState,
|
||||
user: User,
|
||||
mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
# mcp_clients: Optional[Dict[str, BaseMCPClient]] = None,
|
||||
# custom
|
||||
group_id: str = "",
|
||||
agent_ids: List[str] = [],
|
||||
@@ -42,7 +41,8 @@ class SleeptimeMultiAgent(Agent):
|
||||
self.group_manager = GroupManager()
|
||||
self.message_manager = MessageManager()
|
||||
self.job_manager = JobManager()
|
||||
self.mcp_clients = mcp_clients
|
||||
# TODO: add back MCP support with new agent loop
|
||||
self.mcp_clients = {}
|
||||
|
||||
def _run_async_in_new_thread(self, coro):
|
||||
"""Run an async coroutine in a new thread with its own event loop"""
|
||||
|
||||
@@ -7,6 +7,7 @@ from openai.types.chat.chat_completion_message_tool_call import ChatCompletionMe
|
||||
from openai.types.chat.chat_completion_message_tool_call import Function as OpenAIFunction
|
||||
from sqlalchemy import Dialect
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.schemas.agent import AgentStepState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderType, ToolRuleType
|
||||
@@ -400,3 +401,22 @@ def deserialize_response_format(data: Optional[Dict]) -> Optional[ResponseFormat
|
||||
return JsonSchemaResponseFormat(**data)
|
||||
if data["type"] == ResponseFormatType.json_object:
|
||||
return JsonObjectResponseFormat(**data)
|
||||
|
||||
|
||||
# --------------------------
|
||||
# MCP Stdio Server Config Serialization
|
||||
# --------------------------
|
||||
|
||||
|
||||
def serialize_mcp_stdio_config(config: Union[Optional[StdioServerConfig], Dict]) -> Optional[Dict]:
|
||||
"""Convert an StdioServerConfig object into a JSON-serializable dictionary."""
|
||||
if config and isinstance(config, StdioServerConfig):
|
||||
return config.to_dict()
|
||||
return config
|
||||
|
||||
|
||||
def deserialize_mcp_stdio_config(data: Optional[Dict]) -> Optional[StdioServerConfig]:
|
||||
"""Convert a dictionary back into an StdioServerConfig object."""
|
||||
if not data:
|
||||
return None
|
||||
return StdioServerConfig(**data)
|
||||
|
||||
@@ -16,6 +16,7 @@ class OpenAIChatCompletionsStreamingInterface:
|
||||
"""
|
||||
|
||||
def __init__(self, stream_pre_execution_message: bool = True):
|
||||
print("CHAT COMPLETITION INTERFACE")
|
||||
self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser()
|
||||
self.stream_pre_execution_message: bool = stream_pre_execution_message
|
||||
|
||||
|
||||
@@ -15,6 +15,7 @@ from letta.orm.job import Job
|
||||
from letta.orm.job_messages import JobMessage
|
||||
from letta.orm.llm_batch_items import LLMBatchItem
|
||||
from letta.orm.llm_batch_job import LLMBatchJob
|
||||
from letta.orm.mcp_server import MCPServer
|
||||
from letta.orm.message import Message
|
||||
from letta.orm.organization import Organization
|
||||
from letta.orm.passage import AgentPassage, BasePassage, SourcePassage
|
||||
|
||||
@@ -7,6 +7,7 @@ from letta.helpers.converters import (
|
||||
deserialize_create_batch_response,
|
||||
deserialize_embedding_config,
|
||||
deserialize_llm_config,
|
||||
deserialize_mcp_stdio_config,
|
||||
deserialize_message_content,
|
||||
deserialize_poll_batch_response,
|
||||
deserialize_response_format,
|
||||
@@ -19,6 +20,7 @@ from letta.helpers.converters import (
|
||||
serialize_create_batch_response,
|
||||
serialize_embedding_config,
|
||||
serialize_llm_config,
|
||||
serialize_mcp_stdio_config,
|
||||
serialize_message_content,
|
||||
serialize_poll_batch_response,
|
||||
serialize_response_format,
|
||||
@@ -183,3 +185,14 @@ class ResponseFormatColumn(TypeDecorator):
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_response_format(value)
|
||||
|
||||
|
||||
class MCPStdioServerConfigColumn(TypeDecorator):
|
||||
impl = JSON
|
||||
cache_ok = True
|
||||
|
||||
def process_bind_param(self, value, dialect):
|
||||
return serialize_mcp_stdio_config(value)
|
||||
|
||||
def process_result_value(self, value, dialect):
|
||||
return deserialize_mcp_stdio_config(value)
|
||||
|
||||
@@ -32,3 +32,8 @@ class ActorType(str, Enum):
|
||||
LETTA_USER = "letta_user"
|
||||
LETTA_AGENT = "letta_agent"
|
||||
LETTA_SYSTEM = "letta_system"
|
||||
|
||||
|
||||
class MCPServerType(str, Enum):
|
||||
SSE = "sse"
|
||||
STDIO = "stdio"
|
||||
|
||||
48
letta/orm/mcp_server.py
Normal file
48
letta/orm/mcp_server.py
Normal file
@@ -0,0 +1,48 @@
|
||||
from typing import TYPE_CHECKING, Optional
|
||||
|
||||
from sqlalchemy import JSON, String, UniqueConstraint
|
||||
from sqlalchemy.orm import Mapped, mapped_column
|
||||
|
||||
from letta.functions.mcp_client.types import StdioServerConfig
|
||||
from letta.orm.custom_columns import MCPStdioServerConfigColumn
|
||||
|
||||
# TODO everything in functions should live in this model
|
||||
from letta.orm.enums import MCPServerType
|
||||
from letta.orm.mixins import OrganizationMixin
|
||||
from letta.orm.sqlalchemy_base import SqlalchemyBase
|
||||
from letta.schemas.mcp import MCPServer
|
||||
|
||||
if TYPE_CHECKING:
|
||||
pass
|
||||
|
||||
|
||||
class MCPServer(SqlalchemyBase, OrganizationMixin):
|
||||
"""Represents a registered MCP server"""
|
||||
|
||||
__tablename__ = "mcp_server"
|
||||
__pydantic_model__ = MCPServer
|
||||
|
||||
# Add unique constraint on (name, _organization_id)
|
||||
# An organization should not have multiple tools with the same name
|
||||
__table_args__ = (UniqueConstraint("server_name", "organization_id", name="uix_name_organization_mcp_server"),)
|
||||
|
||||
server_name: Mapped[str] = mapped_column(doc="The display name of the MCP server")
|
||||
server_type: Mapped[MCPServerType] = mapped_column(
|
||||
String, default=MCPServerType.SSE, doc="The type of the MCP server. Only SSE is supported for remote servers."
|
||||
)
|
||||
|
||||
# sse server
|
||||
server_url: Mapped[Optional[str]] = mapped_column(
|
||||
String, nullable=True, doc="The URL of the server (MCP SSE client will connect to this URL)"
|
||||
)
|
||||
|
||||
# stdio server
|
||||
stdio_config: Mapped[Optional[StdioServerConfig]] = mapped_column(
|
||||
MCPStdioServerConfigColumn, nullable=True, doc="The configuration for the stdio server"
|
||||
)
|
||||
|
||||
metadata_: Mapped[Optional[dict]] = mapped_column(
|
||||
JSON, default=lambda: {}, doc="A dictionary of additional metadata for the MCP server."
|
||||
)
|
||||
# relationships
|
||||
# organization: Mapped["Organization"] = relationship("Organization", back_populates="mcp_server", lazy="selectin")
|
||||
@@ -28,6 +28,7 @@ class Organization(SqlalchemyBase):
|
||||
# relationships
|
||||
users: Mapped[List["User"]] = relationship("User", back_populates="organization", cascade="all, delete-orphan")
|
||||
tools: Mapped[List["Tool"]] = relationship("Tool", back_populates="organization", cascade="all, delete-orphan")
|
||||
# mcp_servers: Mapped[List["MCPServer"]] = relationship("MCPServer", back_populates="organization", cascade="all, delete-orphan")
|
||||
blocks: Mapped[List["Block"]] = relationship("Block", back_populates="organization", cascade="all, delete-orphan")
|
||||
sources: Mapped[List["Source"]] = relationship("Source", back_populates="organization", cascade="all, delete-orphan")
|
||||
files: Mapped[List["FileMetadata"]] = relationship("FileMetadata", back_populates="organization", cascade="all, delete-orphan")
|
||||
|
||||
74
letta/schemas/mcp.py
Normal file
74
letta/schemas/mcp.py
Normal file
@@ -0,0 +1,74 @@
|
||||
from typing import Any, Dict, Optional, Union
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.functions.mcp_client.types import MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.schemas.letta_base import LettaBase
|
||||
|
||||
|
||||
class BaseMCPServer(LettaBase):
|
||||
__id_prefix__ = "mcp_server"
|
||||
|
||||
|
||||
class MCPServer(BaseMCPServer):
|
||||
id: str = BaseMCPServer.generate_id_field()
|
||||
server_type: MCPServerType = MCPServerType.SSE
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
|
||||
# sse config
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
# stdio config
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
None, description="The configuration for the server (MCP 'local' client will run this command)"
|
||||
)
|
||||
|
||||
organization_id: Optional[str] = Field(None, description="The unique identifier of the organization associated with the tool.")
|
||||
|
||||
# metadata fields
|
||||
created_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
last_updated_by_id: Optional[str] = Field(None, description="The id of the user that made this Tool.")
|
||||
metadata_: Optional[Dict[str, Any]] = Field(default_factory=dict, description="A dictionary of additional metadata for the tool.")
|
||||
|
||||
# TODO: add tokens?
|
||||
|
||||
def to_config(self) -> Union[SSEServerConfig, StdioServerConfig]:
|
||||
if self.server_type == MCPServerType.SSE:
|
||||
return SSEServerConfig(
|
||||
server_name=self.server_name,
|
||||
server_url=self.server_url,
|
||||
)
|
||||
elif self.server_type == MCPServerType.STDIO:
|
||||
return self.stdio_config
|
||||
|
||||
|
||||
class RegisterSSEMCPServer(LettaBase):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
server_type: MCPServerType = MCPServerType.SSE
|
||||
server_url: str = Field(..., description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
|
||||
class RegisterStdioMCPServer(LettaBase):
|
||||
server_name: str = Field(..., description="The name of the server")
|
||||
server_type: MCPServerType = MCPServerType.STDIO
|
||||
stdio_config: StdioServerConfig = Field(..., description="The configuration for the server (MCP 'local' client will run this command)")
|
||||
|
||||
|
||||
class UpdateSSEMCPServer(LettaBase):
|
||||
"""Update an SSE MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the server")
|
||||
server_url: Optional[str] = Field(None, description="The URL of the server (MCP SSE client will connect to this URL)")
|
||||
|
||||
|
||||
class UpdateStdioMCPServer(LettaBase):
|
||||
"""Update a Stdio MCP server"""
|
||||
|
||||
server_name: Optional[str] = Field(None, description="The name of the server")
|
||||
stdio_config: Optional[StdioServerConfig] = Field(
|
||||
None, description="The configuration for the server (MCP 'local' client will run this command)"
|
||||
)
|
||||
|
||||
|
||||
UpdateMCPServer = Union[UpdateSSEMCPServer, UpdateStdioMCPServer]
|
||||
RegisterMCPServer = Union[RegisterSSEMCPServer, RegisterStdioMCPServer]
|
||||
@@ -1101,3 +1101,4 @@ class ToolReturn(BaseModel):
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool call")
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (e.g. prints, logs) from the tool invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the tool invocation")
|
||||
# func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
|
||||
@@ -14,7 +14,6 @@ from letta.constants import (
|
||||
from letta.functions.ast_parsers import get_function_name_and_description
|
||||
from letta.functions.composio_helpers import generate_composio_tool_wrapper
|
||||
from letta.functions.functions import derive_openai_json_schema, get_json_schema_from_module
|
||||
from letta.functions.helpers import generate_langchain_tool_wrapper, generate_mcp_tool_wrapper, generate_model_from_args_json_schema
|
||||
from letta.functions.mcp_client.types import MCPTool
|
||||
from letta.functions.schema_generator import (
|
||||
generate_schema_from_args_schema_v2,
|
||||
@@ -71,6 +70,8 @@ class Tool(BaseTool):
|
||||
"""
|
||||
Refresh name, description, source_code, and json_schema.
|
||||
"""
|
||||
from letta.functions.helpers import generate_model_from_args_json_schema
|
||||
|
||||
if self.tool_type == ToolType.CUSTOM:
|
||||
# If it's a custom tool, we need to ensure source_code is present
|
||||
if not self.source_code:
|
||||
@@ -146,6 +147,8 @@ class ToolCreate(LettaBase):
|
||||
|
||||
@classmethod
|
||||
def from_mcp(cls, mcp_server_name: str, mcp_tool: MCPTool) -> "ToolCreate":
|
||||
from letta.functions.helpers import generate_mcp_tool_wrapper
|
||||
|
||||
# Pass the MCP tool to the schema generator
|
||||
json_schema = generate_tool_schema_for_mcp(mcp_tool=mcp_tool)
|
||||
|
||||
@@ -218,6 +221,8 @@ class ToolCreate(LettaBase):
|
||||
Returns:
|
||||
Tool: A Letta Tool initialized with attributes derived from the provided LangChain BaseTool object.
|
||||
"""
|
||||
from letta.functions.helpers import generate_langchain_tool_wrapper
|
||||
|
||||
description = langchain_tool.description
|
||||
source_type = "python"
|
||||
tags = ["langchain"]
|
||||
|
||||
@@ -6,9 +6,14 @@ from letta.schemas.agent import AgentState
|
||||
|
||||
|
||||
class ToolExecutionResult(BaseModel):
|
||||
|
||||
status: Literal["success", "error"] = Field(..., description="The status of the tool execution and return object")
|
||||
func_return: Optional[Any] = Field(None, description="The function return object")
|
||||
agent_state: Optional[AgentState] = Field(None, description="The agent state")
|
||||
stdout: Optional[List[str]] = Field(None, description="Captured stdout (prints, logs) from function invocation")
|
||||
stderr: Optional[List[str]] = Field(None, description="Captured stderr from the function invocation")
|
||||
sandbox_config_fingerprint: Optional[str] = Field(None, description="The fingerprint of the config for the sandbox")
|
||||
|
||||
@property
|
||||
def success_flag(self) -> bool:
|
||||
return self.status == "success"
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
import asyncio
|
||||
import concurrent.futures
|
||||
import json
|
||||
import logging
|
||||
import os
|
||||
@@ -17,7 +15,6 @@ from letta.__init__ import __version__
|
||||
from letta.agents.exceptions import IncompatibleAgentType
|
||||
from letta.constants import ADMIN_PREFIX, API_PREFIX, OPENAI_API_PREFIX
|
||||
from letta.errors import BedrockPermissionError, LettaAgentNotFoundError, LettaUserNotFoundError
|
||||
from letta.jobs.scheduler import shutdown_scheduler_and_release_lock, start_scheduler_with_leader_election
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import DatabaseTimeoutError, ForeignKeyConstraintViolationError, NoResultFound, UniqueConstraintViolationError
|
||||
from letta.schemas.letta_message import create_letta_message_union_schema
|
||||
@@ -100,7 +97,7 @@ class CheckPasswordMiddleware(BaseHTTPMiddleware):
|
||||
async def dispatch(self, request, call_next):
|
||||
|
||||
# Exclude health check endpoint from password protection
|
||||
if request.url.path == "/v1/health/" or request.url.path == "/latest/health/":
|
||||
if request.url.path in {"/v1/health", "/v1/health/", "/latest/health/"}:
|
||||
return await call_next(request)
|
||||
|
||||
if (
|
||||
@@ -142,34 +139,6 @@ def create_application() -> "FastAPI":
|
||||
debug=debug_mode, # if True, the stack trace will be printed in the response
|
||||
)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def configure_executor():
|
||||
print(f"INFO: Configured event loop executor with {settings.event_loop_threadpool_max_workers} workers.")
|
||||
loop = asyncio.get_running_loop()
|
||||
executor = concurrent.futures.ThreadPoolExecutor(max_workers=settings.event_loop_threadpool_max_workers)
|
||||
loop.set_default_executor(executor)
|
||||
|
||||
@app.on_event("startup")
|
||||
async def on_startup():
|
||||
global server
|
||||
|
||||
await start_scheduler_with_leader_election(server)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
def shutdown_mcp_clients():
|
||||
global server
|
||||
import threading
|
||||
|
||||
def cleanup_clients():
|
||||
if hasattr(server, "mcp_clients"):
|
||||
for client in server.mcp_clients.values():
|
||||
client.cleanup()
|
||||
server.mcp_clients.clear()
|
||||
|
||||
t = threading.Thread(target=cleanup_clients)
|
||||
t.start()
|
||||
t.join()
|
||||
|
||||
@app.exception_handler(IncompatibleAgentType)
|
||||
async def handle_incompatible_agent_type(request: Request, exc: IncompatibleAgentType):
|
||||
return JSONResponse(
|
||||
@@ -320,12 +289,6 @@ def create_application() -> "FastAPI":
|
||||
# Generate OpenAPI schema after all routes are mounted
|
||||
generate_openapi_schema(app)
|
||||
|
||||
@app.on_event("shutdown")
|
||||
async def on_shutdown():
|
||||
global server
|
||||
# server = None
|
||||
await shutdown_scheduler_and_release_lock()
|
||||
|
||||
return app
|
||||
|
||||
|
||||
|
||||
@@ -21,6 +21,7 @@ from letta.schemas.letta_message import ToolReturnMessage
|
||||
from letta.schemas.tool import Tool, ToolCreate, ToolRunFromSource, ToolUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import tool_settings
|
||||
|
||||
router = APIRouter(prefix="/tools", tags=["tools"])
|
||||
|
||||
@@ -354,18 +355,21 @@ def add_composio_tool(
|
||||
|
||||
# Specific routes for MCP
|
||||
@router.get("/mcp/servers", response_model=dict[str, Union[SSEServerConfig, StdioServerConfig]], operation_id="list_mcp_servers")
|
||||
def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
||||
async def list_mcp_servers(server: SyncServer = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id")):
|
||||
"""
|
||||
Get a list of all configured MCP servers
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
return server.get_mcp_servers()
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return server.get_mcp_servers()
|
||||
else:
|
||||
mcp_servers = await server.mcp_manager.list_mcp_servers(actor=server.user_manager.get_user_or_default(user_id=user_id))
|
||||
return {server.server_name: server.to_config() for server in mcp_servers}
|
||||
|
||||
|
||||
# NOTE: async because the MCP client/session calls are async
|
||||
# TODO: should we make the return type MCPTool, not Tool (since we don't have ID)?
|
||||
@router.get("/mcp/servers/{mcp_server_name}/tools", response_model=List[MCPTool], operation_id="list_mcp_tools_by_server")
|
||||
def list_mcp_tools_by_server(
|
||||
async def list_mcp_tools_by_server(
|
||||
mcp_server_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -373,32 +377,36 @@ def list_mcp_tools_by_server(
|
||||
"""
|
||||
Get a list of all tools for a specific MCP server
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
try:
|
||||
return server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except ValueError as e:
|
||||
# ValueError means that the MCP server name doesn't exist
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408, # Timeout
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
if tool_settings.mcp_read_from_config:
|
||||
try:
|
||||
return await server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except ValueError as e:
|
||||
# ValueError means that the MCP server name doesn't exist
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408, # Timeout
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
else:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
mcp_tools = await server.mcp_manager.list_mcp_server_tools(mcp_server_name=mcp_server_name, actor=actor)
|
||||
return mcp_tools
|
||||
|
||||
|
||||
@router.post("/mcp/servers/{mcp_server_name}/{mcp_tool_name}", response_model=Tool, operation_id="add_mcp_tool")
|
||||
def add_mcp_tool(
|
||||
async def add_mcp_tool(
|
||||
mcp_server_name: str,
|
||||
mcp_tool_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
@@ -409,50 +417,55 @@ def add_mcp_tool(
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
available_tools = server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except ValueError as e:
|
||||
# ValueError means that the MCP server name doesn't exist
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408, # Timeout
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
if tool_settings.mcp_read_from_config:
|
||||
|
||||
# See if the tool is in the available list
|
||||
mcp_tool = None
|
||||
for tool in available_tools:
|
||||
if tool.name == mcp_tool_name:
|
||||
mcp_tool = tool
|
||||
break
|
||||
if not mcp_tool:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPToolNotFoundError",
|
||||
"message": f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
|
||||
"mcp_tool_name": mcp_tool_name,
|
||||
},
|
||||
)
|
||||
try:
|
||||
available_tools = await server.get_tools_from_mcp_server(mcp_server_name=mcp_server_name)
|
||||
except ValueError as e:
|
||||
# ValueError means that the MCP server name doesn't exist
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
except MCPTimeoutError as e:
|
||||
raise HTTPException(
|
||||
status_code=408, # Timeout
|
||||
detail={
|
||||
"code": "MCPTimeoutError",
|
||||
"message": str(e),
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return server.tool_manager.create_or_update_mcp_tool(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
# See if the tool is in the available list
|
||||
mcp_tool = None
|
||||
for tool in available_tools:
|
||||
if tool.name == mcp_tool_name:
|
||||
mcp_tool = tool
|
||||
break
|
||||
if not mcp_tool:
|
||||
raise HTTPException(
|
||||
status_code=400, # Bad Request
|
||||
detail={
|
||||
"code": "MCPToolNotFoundError",
|
||||
"message": f"Tool {mcp_tool_name} not found in MCP server {mcp_server_name} - available tools: {', '.join([tool.name for tool in available_tools])}",
|
||||
"mcp_tool_name": mcp_tool_name,
|
||||
},
|
||||
)
|
||||
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await server.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
|
||||
else:
|
||||
return await server.mcp_manager.add_tool_from_mcp_server(mcp_server_name=mcp_server_name, mcp_tool_name=mcp_tool_name, actor=actor)
|
||||
|
||||
|
||||
@router.put("/mcp/servers", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="add_mcp_server")
|
||||
def add_mcp_server_to_config(
|
||||
async def add_mcp_server_to_config(
|
||||
request: Union[StdioServerConfig, SSEServerConfig] = Body(...),
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -460,14 +473,31 @@ def add_mcp_server_to_config(
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
# write to config file
|
||||
return await server.add_mcp_server_to_config(server_config=request, allow_upsert=True)
|
||||
else:
|
||||
# log to DB
|
||||
from letta.schemas.mcp import MCPServer
|
||||
|
||||
if isinstance(request, StdioServerConfig):
|
||||
mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, stdio_config=request)
|
||||
elif isinstance(request, SSEServerConfig):
|
||||
mapped_request = MCPServer(server_name=request.server_name, server_type=request.type, server_url=request.server_url)
|
||||
mcp_server = await server.mcp_manager.create_or_update_mcp_server(mapped_request, actor=actor)
|
||||
|
||||
# TODO: don't do this in the future (just return MCPServer)
|
||||
all_servers = await server.mcp_manager.list_mcp_servers(actor=actor)
|
||||
return [server.to_config() for server in all_servers]
|
||||
|
||||
|
||||
@router.delete(
|
||||
"/mcp/servers/{mcp_server_name}", response_model=List[Union[StdioServerConfig, SSEServerConfig]], operation_id="delete_mcp_server"
|
||||
)
|
||||
def delete_mcp_server_from_config(
|
||||
async def delete_mcp_server_from_config(
|
||||
mcp_server_name: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"),
|
||||
@@ -475,5 +505,11 @@ def delete_mcp_server_from_config(
|
||||
"""
|
||||
Add a new MCP server to the Letta MCP server config
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
return server.delete_mcp_server_from_config(server_name=mcp_server_name)
|
||||
if tool_settings.mcp_read_from_config:
|
||||
# write to config file
|
||||
return server.delete_mcp_server_from_config(server_name=mcp_server_name)
|
||||
else:
|
||||
# log to DB
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
mcp_server_id = await server.mcp_manager.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
return server.mcp_manager.delete_mcp_server_by_id(mcp_server_id, actor=actor)
|
||||
|
||||
@@ -21,7 +21,8 @@ from letta.log import get_logger
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message, MessageCreate
|
||||
from letta.schemas.message import Message, MessageCreate, ToolReturn
|
||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.server.rest_api.interface import StreamingServerInterface
|
||||
@@ -181,6 +182,7 @@ def create_letta_messages_from_llm_response(
|
||||
model: str,
|
||||
function_name: str,
|
||||
function_arguments: Dict,
|
||||
tool_execution_result: ToolExecutionResult,
|
||||
tool_call_id: str,
|
||||
function_call_success: bool,
|
||||
function_response: Optional[str],
|
||||
@@ -234,6 +236,14 @@ def create_letta_messages_from_llm_response(
|
||||
created_at=get_utc_time(),
|
||||
name=function_name,
|
||||
batch_item_id=llm_batch_item_id,
|
||||
tool_returns=[
|
||||
ToolReturn(
|
||||
status=tool_execution_result.status,
|
||||
stderr=tool_execution_result.stderr,
|
||||
stdout=tool_execution_result.stdout,
|
||||
# func_return=tool_execution_result.func_return,
|
||||
)
|
||||
],
|
||||
)
|
||||
if pre_computed_tool_message_id:
|
||||
tool_message.id = pre_computed_tool_message_id
|
||||
@@ -286,6 +296,7 @@ def create_assistant_messages_from_openai_response(
|
||||
model=model,
|
||||
function_name=DEFAULT_MESSAGE_TOOL,
|
||||
function_arguments={DEFAULT_MESSAGE_TOOL_KWARG: response_text}, # Avoid raw string manipulation
|
||||
tool_execution_result=ToolExecutionResult(status="success"),
|
||||
tool_call_id=tool_call_id,
|
||||
function_call_success=True,
|
||||
function_response=None,
|
||||
|
||||
@@ -23,9 +23,6 @@ from letta.config import LettaConfig
|
||||
from letta.constants import LETTA_TOOL_EXECUTION_DIR
|
||||
from letta.data_sources.connectors import DataConnector, load_data
|
||||
from letta.errors import HandleNotFoundError
|
||||
from letta.functions.mcp_client.base_client import BaseMCPClient
|
||||
from letta.functions.mcp_client.sse_client import MCP_CONFIG_TOPLEVEL_KEY, SSEMCPClient
|
||||
from letta.functions.mcp_client.stdio_client import StdioMCPClient
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.groups.helpers import load_multi_agent
|
||||
from letta.helpers.datetime_helpers import get_utc_time
|
||||
@@ -87,6 +84,10 @@ from letta.services.helpers.tool_execution_helper import prepare_local_sandbox
|
||||
from letta.services.identity_manager import IdentityManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.llm_batch_manager import LLMBatchManager
|
||||
from letta.services.mcp.base_client import AsyncBaseMCPClient
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -203,6 +204,7 @@ class SyncServer(Server):
|
||||
self.passage_manager = PassageManager()
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.mcp_manager = MCPManager()
|
||||
self.block_manager = BlockManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.sandbox_config_manager = SandboxConfigManager()
|
||||
@@ -380,30 +382,9 @@ class SyncServer(Server):
|
||||
self._enabled_providers.append(XAIProvider(name="xai", api_key=model_settings.xai_api_key))
|
||||
|
||||
# For MCP
|
||||
# TODO: remove this
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
mcp_server_configs = self.get_mcp_servers()
|
||||
self.mcp_clients: Dict[str, BaseMCPClient] = {}
|
||||
|
||||
for server_name, server_config in mcp_server_configs.items():
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
self.mcp_clients[server_name] = SSEMCPClient(server_config)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
self.mcp_clients[server_name] = StdioMCPClient(server_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
|
||||
try:
|
||||
self.mcp_clients[server_name].connect_to_server()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
self.mcp_clients.pop(server_name)
|
||||
|
||||
# Print out the tools that are connected
|
||||
for server_name, client in self.mcp_clients.items():
|
||||
logger.info(f"Attempting to fetch tools from MCP server: {server_name}")
|
||||
mcp_tools = client.list_tools()
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
||||
self.mcp_clients: Dict[str, AsyncBaseMCPClient] = {}
|
||||
|
||||
# TODO: Remove these in memory caches
|
||||
self._llm_config_cache = {}
|
||||
@@ -412,6 +393,31 @@ class SyncServer(Server):
|
||||
# TODO: Replace this with the Anthropic client we have in house
|
||||
self.anthropic_async_client = AsyncAnthropic()
|
||||
|
||||
async def init_mcp_clients(self):
|
||||
# TODO: remove this
|
||||
mcp_server_configs = self.get_mcp_servers()
|
||||
|
||||
for server_name, server_config in mcp_server_configs.items():
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
self.mcp_clients[server_name] = AsyncSSEMCPClient(server_config)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
self.mcp_clients[server_name] = AsyncStdioMCPClient(server_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
|
||||
try:
|
||||
await self.mcp_clients[server_name].connect_to_server()
|
||||
except Exception as e:
|
||||
logger.error(e)
|
||||
self.mcp_clients.pop(server_name)
|
||||
|
||||
# Print out the tools that are connected
|
||||
for server_name, client in self.mcp_clients.items():
|
||||
logger.info(f"Attempting to fetch tools from MCP server: {server_name}")
|
||||
mcp_tools = await client.list_tools()
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in mcp_tools])}")
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -1918,7 +1924,8 @@ class SyncServer(Server):
|
||||
|
||||
# TODO implement non-flatfile mechanism
|
||||
if not tool_settings.mcp_read_from_config:
|
||||
raise RuntimeError("MCP config file disabled. Enable it in settings.")
|
||||
return {}
|
||||
# raise RuntimeError("MCP config file disabled. Enable it in settings.")
|
||||
|
||||
mcp_server_list = {}
|
||||
|
||||
@@ -1972,14 +1979,14 @@ class SyncServer(Server):
|
||||
# If the file doesn't exist, return empty dictionary
|
||||
return mcp_server_list
|
||||
|
||||
def get_tools_from_mcp_server(self, mcp_server_name: str) -> List[MCPTool]:
|
||||
async def get_tools_from_mcp_server(self, mcp_server_name: str) -> List[MCPTool]:
|
||||
"""List the tools in an MCP server. Requires a client to be created."""
|
||||
if mcp_server_name not in self.mcp_clients:
|
||||
raise ValueError(f"No client was created for MCP server: {mcp_server_name}")
|
||||
|
||||
return self.mcp_clients[mcp_server_name].list_tools()
|
||||
return await self.mcp_clients[mcp_server_name].list_tools()
|
||||
|
||||
def add_mcp_server_to_config(
|
||||
async def add_mcp_server_to_config(
|
||||
self, server_config: Union[SSEServerConfig, StdioServerConfig], allow_upsert: bool = True
|
||||
) -> List[Union[SSEServerConfig, StdioServerConfig]]:
|
||||
"""Add a new server config to the MCP config file"""
|
||||
@@ -2008,19 +2015,19 @@ class SyncServer(Server):
|
||||
|
||||
# Attempt to initialize the connection to the server
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
new_mcp_client = SSEMCPClient(server_config)
|
||||
new_mcp_client = AsyncSSEMCPClient(server_config)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
new_mcp_client = StdioMCPClient(server_config)
|
||||
new_mcp_client = AsyncStdioMCPClient(server_config)
|
||||
else:
|
||||
raise ValueError(f"Invalid MCP server config: {server_config}")
|
||||
try:
|
||||
new_mcp_client.connect_to_server()
|
||||
await new_mcp_client.connect_to_server()
|
||||
except:
|
||||
logger.exception(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
raise RuntimeError(f"Failed to connect to MCP server: {server_config.server_name}")
|
||||
# Print out the tools that are connected
|
||||
logger.info(f"Attempting to fetch tools from MCP server: {server_config.server_name}")
|
||||
new_mcp_tools = new_mcp_client.list_tools()
|
||||
new_mcp_tools = await new_mcp_client.list_tools()
|
||||
logger.info(f"MCP tools connected: {', '.join([t.name for t in new_mcp_tools])}")
|
||||
logger.debug(f"MCP tools: {', '.join([str(t) for t in new_mcp_tools])}")
|
||||
|
||||
|
||||
@@ -30,7 +30,7 @@ class AsyncBaseMCPClient:
|
||||
)
|
||||
raise e
|
||||
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: BaseServerConfig) -> None:
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
raise NotImplementedError("Subclasses must implement _initialize_connection")
|
||||
|
||||
async def list_tools(self) -> list[MCPTool]:
|
||||
@@ -65,3 +65,6 @@ class AsyncBaseMCPClient:
|
||||
async def cleanup(self):
|
||||
"""Clean up resources"""
|
||||
await self.exit_stack.aclose()
|
||||
|
||||
def to_sync_client(self):
|
||||
raise NotImplementedError("Subclasses must implement to_sync_client")
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession
|
||||
from mcp.client.sse import sse_client
|
||||
|
||||
@@ -15,11 +13,11 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: SSEServerConfig) -> None:
|
||||
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
|
||||
sse_cm = sse_client(url=server_config.server_url)
|
||||
sse_transport = await exit_stack.enter_async_context(sse_cm)
|
||||
sse_transport = await self.exit_stack.enter_async_context(sse_cm)
|
||||
self.stdio, self.write = sse_transport
|
||||
|
||||
# Create and enter the ClientSession context manager
|
||||
session_cm = ClientSession(self.stdio, self.write)
|
||||
self.session = await exit_stack.enter_async_context(session_cm)
|
||||
self.session = await self.exit_stack.enter_async_context(session_cm)
|
||||
|
||||
@@ -1,5 +1,3 @@
|
||||
from contextlib import AsyncExitStack
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
@@ -12,8 +10,8 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncStdioMCPClient(AsyncBaseMCPClient):
|
||||
async def _initialize_connection(self, exit_stack: AsyncExitStack[bool | None], server_config: StdioServerConfig) -> None:
|
||||
async def _initialize_connection(self, server_config: StdioServerConfig) -> None:
|
||||
server_params = StdioServerParameters(command=server_config.command, args=server_config.args)
|
||||
stdio_transport = await exit_stack.enter_async_context(stdio_client(server_params))
|
||||
stdio_transport = await self.exit_stack.enter_async_context(stdio_client(server_params))
|
||||
self.stdio, self.write = stdio_transport
|
||||
self.session = await exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
self.session = await self.exit_stack.enter_async_context(ClientSession(self.stdio, self.write))
|
||||
|
||||
280
letta/services/mcp_manager.py
Normal file
280
letta/services/mcp_manager.py
Normal file
@@ -0,0 +1,280 @@
|
||||
import json
|
||||
import os
|
||||
from typing import Any, Dict, List, Optional, Union
|
||||
|
||||
import letta.constants as constants
|
||||
from letta.functions.mcp_client.types import MCPServerType, MCPTool, SSEServerConfig, StdioServerConfig
|
||||
from letta.log import get_logger
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.mcp_server import MCPServer as MCPServerModel
|
||||
from letta.schemas.mcp import MCPServer, UpdateMCPServer, UpdateSSEMCPServer, UpdateStdioMCPServer
|
||||
from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.sse_client import MCP_CONFIG_TOPLEVEL_KEY, AsyncSSEMCPClient
|
||||
from letta.services.mcp.stdio_client import AsyncStdioMCPClient
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
class MCPManager:
|
||||
"""Manager class to handle business logic related to MCP."""
|
||||
|
||||
def __init__(self):
|
||||
# TODO: timeouts?
|
||||
self.tool_manager = ToolManager()
|
||||
self.cached_mcp_servers = {} # maps id -> async connection
|
||||
|
||||
@enforce_types
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
|
||||
"""Get a list of all tools for a specific MCP server."""
|
||||
print("mcp_server_name", mcp_server_name)
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
|
||||
if mcp_config.server_type == MCPServerType.SSE:
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
elif mcp_config.server_type == MCPServerType.STDIO:
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# list tools
|
||||
tools = await mcp_client.list_tools()
|
||||
# TODO: change to pydantic tools
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return tools
|
||||
|
||||
@enforce_types
|
||||
async def execute_mcp_server_tool(
|
||||
self, mcp_server_name: str, tool_name: str, tool_args: Optional[Dict[str, Any]], actor: PydanticUser
|
||||
) -> PydanticTool:
|
||||
"""Call a specific tool from a specific MCP server."""
|
||||
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if not tool_settings.mcp_read_from_config:
|
||||
# read from DB
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
else:
|
||||
# read from config file
|
||||
mcp_config = self.read_mcp_config()
|
||||
if mcp_server_name not in mcp_config:
|
||||
print("MCP server not found in config.", mcp_config)
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
|
||||
if isinstance(server_config, SSEServerConfig):
|
||||
mcp_client = AsyncSSEMCPClient(server_config=server_config)
|
||||
elif isinstance(server_config, StdioServerConfig):
|
||||
mcp_client = AsyncStdioMCPClient(server_config=server_config)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
result = await mcp_client.execute_tool(tool_name, tool_args)
|
||||
# TODO: change to pydantic tool
|
||||
|
||||
await mcp_client.cleanup()
|
||||
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
async def add_tool_from_mcp_server(self, mcp_server_name: str, mcp_tool_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Add a tool from an MCP server to the Letta tool registry."""
|
||||
mcp_tools = await self.list_mcp_server_tools(mcp_server_name, actor=actor)
|
||||
|
||||
for mcp_tool in mcp_tools:
|
||||
if mcp_tool.name == mcp_tool_name:
|
||||
tool_create = ToolCreate.from_mcp(mcp_server_name=mcp_server_name, mcp_tool=mcp_tool)
|
||||
return await self.tool_manager.create_mcp_tool_async(tool_create=tool_create, mcp_server_name=mcp_server_name, actor=actor)
|
||||
|
||||
# failed to add - handle error?
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def list_mcp_servers(self, actor: PydanticUser) -> List[MCPServer]:
|
||||
"""List all MCP servers available"""
|
||||
async with db_registry.async_session() as session:
|
||||
mcp_servers = await MCPServerModel.list_async(
|
||||
db_session=session,
|
||||
organization_id=actor.organization_id,
|
||||
)
|
||||
|
||||
return [mcp_server.to_pydantic() for mcp_server in mcp_servers]
|
||||
|
||||
@enforce_types
|
||||
async def create_or_update_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> MCPServer:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name=pydantic_mcp_server.server_name, actor=actor)
|
||||
print("FOUND SERVER", mcp_server_id, pydantic_mcp_server.server_name)
|
||||
if mcp_server_id:
|
||||
# Put to dict and remove fields that should not be reset
|
||||
update_data = pydantic_mcp_server.model_dump(exclude_unset=True, exclude_none=True)
|
||||
|
||||
# If there's anything to update (can only update the configs, not the name)
|
||||
if update_data:
|
||||
if pydantic_mcp_server.server_type == MCPServerType.SSE:
|
||||
update_request = UpdateSSEMCPServer(server_url=pydantic_mcp_server.server_url)
|
||||
elif pydantic_mcp_server.server_type == MCPServerType.STDIO:
|
||||
update_request = UpdateStdioMCPServer(stdio_config=pydantic_mcp_server.stdio_config)
|
||||
mcp_server = await self.update_mcp_server_by_id(mcp_server_id, update_request, actor)
|
||||
print("RETURN", mcp_server)
|
||||
else:
|
||||
printd(
|
||||
f"`create_or_update_mcp_server` was called with user_id={actor.id}, organization_id={actor.organization_id}, name={pydantic_mcp_server.server_name}, but found existing mcp server with nothing to update."
|
||||
)
|
||||
mcp_server = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
else:
|
||||
mcp_server = await self.create_mcp_server(pydantic_mcp_server, actor=actor)
|
||||
|
||||
return mcp_server
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server(self, pydantic_mcp_server: MCPServer, actor: PydanticUser) -> PydanticTool:
|
||||
"""Create a new tool based on the ToolCreate schema."""
|
||||
with db_registry.session() as session:
|
||||
# Set the organization id at the ORM layer
|
||||
pydantic_mcp_server.organization_id = actor.organization_id
|
||||
mcp_server_data = pydantic_mcp_server.model_dump(to_orm=True)
|
||||
|
||||
mcp_server = MCPServerModel(**mcp_server_data)
|
||||
mcp_server.create(session, actor=actor) # Re-raise other database-related errors
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_mcp_server_by_id(self, mcp_server_id: str, mcp_server_update: UpdateMCPServer, actor: PydanticUser) -> PydanticTool:
|
||||
"""Update a tool by its ID with the given ToolUpdate object."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the tool by ID
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
|
||||
# Update tool attributes with only the fields that were explicitly set
|
||||
update_data = mcp_server_update.model_dump(to_orm=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(mcp_server, key, value)
|
||||
|
||||
mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
|
||||
|
||||
# Save the updated tool to the database mcp_server = await mcp_server.update_async(db_session=session, actor=actor)
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_mcp_server_id_by_name(self, mcp_server_name: str, actor: PydanticUser) -> Optional[str]:
|
||||
"""Retrieve a MCP server by its name and a user"""
|
||||
try:
|
||||
async with db_registry.async_session() as session:
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, server_name=mcp_server_name, actor=actor)
|
||||
return mcp_server.id
|
||||
except NoResultFound:
|
||||
return None
|
||||
|
||||
@enforce_types
|
||||
async def get_mcp_server_by_id_async(self, mcp_server_id: str, actor: PydanticUser) -> MCPServer:
|
||||
"""Fetch a tool by its ID."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve tool by id using the Tool model's read method
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
# Convert the SQLAlchemy Tool object to PydanticTool
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
"""Get a tool by name."""
|
||||
async with db_registry.async_session() as session:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
if not mcp_server:
|
||||
raise HTTPException(
|
||||
status_code=404, # Not Found
|
||||
detail={
|
||||
"code": "MCPServerNotFoundError",
|
||||
"message": f"MCP server {mcp_server_name} not found",
|
||||
"mcp_server_name": mcp_server_name,
|
||||
},
|
||||
)
|
||||
return mcp_server.to_pydantic()
|
||||
|
||||
# @enforce_types
|
||||
# async def delete_mcp_server(self, mcp_server_name: str, actor: PydanticUser) -> None:
|
||||
# """Delete an existing tool."""
|
||||
# with db_registry.session() as session:
|
||||
# mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor)
|
||||
# mcp_server = await MCPServerModel.read_async(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
# if not mcp_server:
|
||||
# raise HTTPException(
|
||||
# status_code=404, # Not Found
|
||||
# detail={
|
||||
# "code": "MCPServerNotFoundError",
|
||||
# "message": f"MCP server {mcp_server_name} not found",
|
||||
# "mcp_server_name": mcp_server_name,
|
||||
# },
|
||||
# )
|
||||
# mcp_server.delete(session, actor=actor) # Re-raise other database-related errors
|
||||
|
||||
@enforce_types
|
||||
def delete_mcp_server_by_id(self, mcp_server_id: str, actor: PydanticUser) -> None:
|
||||
"""Delete a tool by its ID."""
|
||||
with db_registry.session() as session:
|
||||
try:
|
||||
mcp_server = MCPServerModel.read(db_session=session, identifier=mcp_server_id, actor=actor)
|
||||
mcp_server.hard_delete(db_session=session, actor=actor)
|
||||
except NoResultFound:
|
||||
raise ValueError(f"MCP server with id {mcp_server_id} not found.")
|
||||
|
||||
def read_mcp_config(self) -> dict[str, Union[SSEServerConfig, StdioServerConfig]]:
|
||||
mcp_server_list = {}
|
||||
|
||||
# Attempt to read from ~/.letta/mcp_config.json
|
||||
mcp_config_path = os.path.join(constants.LETTA_DIR, constants.MCP_CONFIG_NAME)
|
||||
if os.path.exists(mcp_config_path):
|
||||
with open(mcp_config_path, "r") as f:
|
||||
|
||||
try:
|
||||
mcp_config = json.load(f)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse MCP config file ({mcp_config_path}) as json: {e}")
|
||||
return mcp_server_list
|
||||
|
||||
# Proper formatting is "mcpServers" key at the top level,
|
||||
# then a dict with the MCP server name as the key,
|
||||
# with the value being the schema from StdioServerParameters
|
||||
if MCP_CONFIG_TOPLEVEL_KEY in mcp_config:
|
||||
for server_name, server_params_raw in mcp_config[MCP_CONFIG_TOPLEVEL_KEY].items():
|
||||
|
||||
# No support for duplicate server names
|
||||
if server_name in mcp_server_list:
|
||||
logger.error(f"Duplicate MCP server name found (skipping): {server_name}")
|
||||
continue
|
||||
|
||||
if "url" in server_params_raw:
|
||||
# Attempt to parse the server params as an SSE server
|
||||
try:
|
||||
server_params = SSEServerConfig(
|
||||
server_name=server_name,
|
||||
server_url=server_params_raw["url"],
|
||||
)
|
||||
mcp_server_list[server_name] = server_params
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
||||
continue
|
||||
else:
|
||||
# Attempt to parse the server params as a StdioServerParameters
|
||||
try:
|
||||
server_params = StdioServerConfig(
|
||||
server_name=server_name,
|
||||
command=server_params_raw["command"],
|
||||
args=server_params_raw.get("args", []),
|
||||
env=server_params_raw.get("env", {}),
|
||||
)
|
||||
mcp_server_list[server_name] = server_params
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to parse server params for MCP server {server_name} (skipping): {e}")
|
||||
continue
|
||||
return mcp_server_list
|
||||
@@ -1,6 +1,7 @@
|
||||
import traceback
|
||||
from typing import Any, Dict, Optional, Type
|
||||
|
||||
from letta.constants import FUNCTION_RETURN_VALUE_TRUNCATED
|
||||
from letta.log import get_logger
|
||||
from letta.orm.enums import ToolType
|
||||
from letta.schemas.agent import AgentState
|
||||
@@ -143,11 +144,26 @@ class ToolExecutionManager:
|
||||
)
|
||||
# TODO: Extend this async model to composio
|
||||
if isinstance(
|
||||
executor, (SandboxToolExecutor, ExternalComposioToolExecutor, LettaBuiltinToolExecutor, LettaMultiAgentToolExecutor)
|
||||
executor,
|
||||
(
|
||||
SandboxToolExecutor,
|
||||
ExternalComposioToolExecutor,
|
||||
ExternalMCPToolExecutor,
|
||||
LettaBuiltinToolExecutor,
|
||||
LettaMultiAgentToolExecutor,
|
||||
),
|
||||
):
|
||||
result = await executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
else:
|
||||
result = executor.execute(function_name, function_args, self.agent_state, tool, self.actor)
|
||||
|
||||
print("TOOL RESULT", result)
|
||||
|
||||
# trim result
|
||||
return_str = str(result.func_return)
|
||||
if len(return_str) > tool.return_char_limit:
|
||||
# TODO: okay that this become a string?
|
||||
result.func_return = FUNCTION_RETURN_VALUE_TRUNCATED(return_str, len(return_str), tool.return_char_limit)
|
||||
return result
|
||||
|
||||
except Exception as e:
|
||||
|
||||
@@ -9,6 +9,7 @@ from typing import Any, Dict, List, Literal, Optional
|
||||
from letta.constants import (
|
||||
COMPOSIO_ENTITY_ENV_VAR_KEY,
|
||||
CORE_MEMORY_LINE_NUMBER_WARNING,
|
||||
MCP_TOOL_TAG_NAME_PREFIX,
|
||||
READ_ONLY_BLOCK_EDIT_ERROR,
|
||||
RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE,
|
||||
WEB_SEARCH_CLIP_CONTENT,
|
||||
@@ -31,6 +32,7 @@ from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.mcp_manager import MCPManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.tool_sandbox.e2b_sandbox import AsyncToolSandboxE2B
|
||||
@@ -668,53 +670,35 @@ class ExternalComposioToolExecutor(ToolExecutor):
|
||||
class ExternalMCPToolExecutor(ToolExecutor):
|
||||
"""Executor for external MCP tools."""
|
||||
|
||||
# TODO: Implement
|
||||
#
|
||||
# def execute(self, function_name: str, function_args: dict, agent_state: AgentState, tool: Tool, actor: User) -> ToolExecutionResult:
|
||||
# # Get the server name from the tool tag
|
||||
# server_name = self._extract_server_name(tool)
|
||||
#
|
||||
# # Get the MCPClient
|
||||
# mcp_client = self._get_mcp_client(agent, server_name)
|
||||
#
|
||||
# # Validate tool exists
|
||||
# self._validate_tool_exists(mcp_client, function_name, server_name)
|
||||
#
|
||||
# # Execute the tool
|
||||
# function_response, is_error = mcp_client.execute_tool(tool_name=function_name, tool_args=function_args)
|
||||
#
|
||||
# return ToolExecutionResult(
|
||||
# status="error" if is_error else "success",
|
||||
# func_return=function_response,
|
||||
# )
|
||||
#
|
||||
# def _extract_server_name(self, tool: Tool) -> str:
|
||||
# """Extract server name from tool tags."""
|
||||
# return tool.tags[0].split(":")[1]
|
||||
#
|
||||
# def _get_mcp_client(self, agent: "Agent", server_name: str):
|
||||
# """Get the MCP client for the given server name."""
|
||||
# if not agent.mcp_clients:
|
||||
# raise ValueError("No MCP client available to use")
|
||||
#
|
||||
# if server_name not in agent.mcp_clients:
|
||||
# raise ValueError(f"Unknown MCP server name: {server_name}")
|
||||
#
|
||||
# mcp_client = agent.mcp_clients[server_name]
|
||||
# if not isinstance(mcp_client, BaseMCPClient):
|
||||
# raise RuntimeError(f"Expected an MCPClient, but got: {type(mcp_client)}")
|
||||
#
|
||||
# return mcp_client
|
||||
#
|
||||
# def _validate_tool_exists(self, mcp_client, function_name: str, server_name: str):
|
||||
# """Validate that the tool exists in the MCP server."""
|
||||
# available_tools = mcp_client.list_tools()
|
||||
# available_tool_names = [t.name for t in available_tools]
|
||||
#
|
||||
# if function_name not in available_tool_names:
|
||||
# raise ValueError(
|
||||
# f"{function_name} is not available in MCP server {server_name}. " f"Please check your `~/.letta/mcp_config.json` file."
|
||||
# )
|
||||
@trace_method
|
||||
async def execute(
|
||||
self,
|
||||
function_name: str,
|
||||
function_args: dict,
|
||||
agent_state: AgentState,
|
||||
tool: Tool,
|
||||
actor: User,
|
||||
sandbox_config: Optional[SandboxConfig] = None,
|
||||
sandbox_env_vars: Optional[Dict[str, Any]] = None,
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
pass
|
||||
|
||||
mcp_server_tag = [tag for tag in tool.tags if tag.startswith(f"{MCP_TOOL_TAG_NAME_PREFIX}:")]
|
||||
if not mcp_server_tag:
|
||||
raise ValueError(f"Tool {tool.name} does not have a valid MCP server tag")
|
||||
mcp_server_name = mcp_server_tag[0].split(":")[1]
|
||||
|
||||
mcp_manager = MCPManager()
|
||||
# TODO: may need to have better client connection management
|
||||
function_response = await mcp_manager.execute_mcp_server_tool(
|
||||
mcp_server_name=mcp_server_name, tool_name=function_name, tool_args=function_args, actor=actor
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
status="success",
|
||||
func_return=function_response,
|
||||
)
|
||||
|
||||
|
||||
class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
@@ -1,7 +1,7 @@
|
||||
import asyncio
|
||||
import importlib
|
||||
import warnings
|
||||
from typing import List, Optional
|
||||
from typing import List, Optional, Union
|
||||
|
||||
from letta.constants import (
|
||||
BASE_FUNCTION_RETURN_CHAR_LIMIT,
|
||||
@@ -26,6 +26,7 @@ from letta.schemas.tool import Tool as PydanticTool
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.server.db import db_registry
|
||||
from letta.services.mcp.types import SSEServerConfig, StdioServerConfig
|
||||
from letta.tracing import trace_method
|
||||
from letta.utils import enforce_types, printd
|
||||
|
||||
@@ -90,6 +91,12 @@ class ToolManager:
|
||||
|
||||
return tool
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_server(
|
||||
self, server_config: Union[StdioServerConfig, SSEServerConfig], actor: PydanticUser
|
||||
) -> List[Union[StdioServerConfig, SSEServerConfig]]:
|
||||
pass
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_mcp_tool(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
@@ -101,6 +108,16 @@ class ToolManager:
|
||||
actor,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
async def create_mcp_tool_async(self, tool_create: ToolCreate, mcp_server_name: str, actor: PydanticUser) -> PydanticTool:
|
||||
metadata = {MCP_TOOL_TAG_NAME_PREFIX: {"server_name": mcp_server_name}}
|
||||
return await self.create_or_update_tool_async(
|
||||
PydanticTool(
|
||||
tool_type=ToolType.EXTERNAL_MCP, name=tool_create.json_schema["name"], metadata_=metadata, **tool_create.model_dump()
|
||||
),
|
||||
actor,
|
||||
)
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def create_or_update_composio_tool(self, tool_create: ToolCreate, actor: PydanticUser) -> PydanticTool:
|
||||
|
||||
@@ -1 +1 @@
|
||||
{}
|
||||
{}
|
||||
@@ -105,7 +105,8 @@ def agent_state(client):
|
||||
client.agents.delete(agent_state.id)
|
||||
|
||||
|
||||
def test_sse_mcp_server(client, agent_state):
|
||||
@pytest.mark.asyncio
|
||||
async def test_sse_mcp_server(client, agent_state):
|
||||
mcp_server_name = "github_composio"
|
||||
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
@@ -148,11 +149,16 @@ def test_sse_mcp_server(client, agent_state):
|
||||
# status field
|
||||
assert tr.status == "success", f"Bad status: {tr.status}"
|
||||
# parse JSON payload
|
||||
payload = json.loads(tr.tool_return)
|
||||
full_payload = json.loads(tr.tool_return)
|
||||
payload = json.loads(full_payload["message"][0])
|
||||
from pprint import pprint
|
||||
|
||||
pprint(payload)
|
||||
assert payload.get("successful", False), f"Tool returned failure payload: {payload}"
|
||||
assert payload["data"]["details"] == "Action executed successfully", f"Unexpected details: {payload}"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
def test_stdio_mcp_server(client, agent_state):
|
||||
req_file = Path(__file__).parent / "weather" / "requirements.txt"
|
||||
create_virtualenv_and_install_requirements(req_file, name="venv")
|
||||
|
||||
@@ -432,7 +432,8 @@ def test_function_always_error(client: Letta):
|
||||
|
||||
assert response_message, "ToolReturnMessage message not found in response"
|
||||
assert response_message.status == "error"
|
||||
assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
|
||||
# TODO: add this back
|
||||
# assert "Error executing function testing_method" in response_message.tool_return, response_message.tool_return
|
||||
assert "ZeroDivisionError: division by zero" in response_message.stderr[0]
|
||||
|
||||
client.agents.delete(agent_id=agent.id)
|
||||
|
||||
@@ -505,6 +505,8 @@ async def test_partial_error_from_anthropic_batch(
|
||||
assert len(new_batch_responses) == 1
|
||||
post_resume_response = new_batch_responses[0]
|
||||
|
||||
print("POST", post_resume_response)
|
||||
print("PRE", pre_resume_response)
|
||||
assert (
|
||||
post_resume_response.letta_batch_id == pre_resume_response.letta_batch_id
|
||||
), "resume_step_after_request is expected to have the same letta_batch_id"
|
||||
|
||||
@@ -5598,3 +5598,61 @@ async def test_count_batch_items(
|
||||
|
||||
# Assert that the count matches the expected number.
|
||||
assert count == num_items, f"Expected {num_items} items, got {count}"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# MCPManager Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_create_mcp_server(server, default_user, event_loop):
|
||||
from letta.schemas.mcp import MCPServer, MCPServerType, SSEServerConfig, StdioServerConfig
|
||||
from letta.settings import tool_settings
|
||||
|
||||
if tool_settings.mcp_read_from_config:
|
||||
return
|
||||
|
||||
# Test with a valid StdioServerConfig
|
||||
server_config = StdioServerConfig(
|
||||
server_name="test_server", type=MCPServerType.STDIO, command="echo 'test'", args=["arg1", "arg2"], env={"ENV1": "value1"}
|
||||
)
|
||||
mcp_server = MCPServer(server_name="test_server", server_type=MCPServerType.STDIO, stdio_config=server_config)
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_server, actor=default_user)
|
||||
print(created_server)
|
||||
assert created_server.server_name == server_config.server_name
|
||||
assert created_server.server_type == server_config.type
|
||||
|
||||
# Test with a valid SSEServerConfig
|
||||
mcp_server_name = "github_composio"
|
||||
server_url = "https://mcp.composio.dev/composio/server/3c44733b-75ae-4ba8-9a68-7153265fadd8"
|
||||
sse_mcp_config = SSEServerConfig(server_name=mcp_server_name, server_url=server_url)
|
||||
mcp_sse_server = MCPServer(server_name=mcp_server_name, server_type=MCPServerType.SSE, server_url=server_url)
|
||||
created_server = await server.mcp_manager.create_or_update_mcp_server(mcp_sse_server, actor=default_user)
|
||||
print(created_server)
|
||||
assert created_server.server_name == mcp_server_name
|
||||
assert created_server.server_type == MCPServerType.SSE
|
||||
|
||||
# list mcp servers
|
||||
servers = await server.mcp_manager.list_mcp_servers(actor=default_user)
|
||||
print(servers)
|
||||
assert len(servers) > 0, "No MCP servers found"
|
||||
|
||||
# list tools from sse server
|
||||
tools = await server.mcp_manager.list_mcp_server_tools(created_server.server_name, actor=default_user)
|
||||
print(tools)
|
||||
|
||||
# call a tool from the sse server
|
||||
tool_name = "GITHUB_STAR_A_REPOSITORY_FOR_THE_AUTHENTICATED_USER"
|
||||
tool_args = {"owner": "letta-ai", "repo": "letta"}
|
||||
result = await server.mcp_manager.execute_mcp_server_tool(
|
||||
created_server.server_name, tool_name=tool_name, tool_args=tool_args, actor=default_user
|
||||
)
|
||||
print(result)
|
||||
|
||||
# add a tool
|
||||
tool = await server.mcp_manager.add_tool_from_mcp_server(created_server.server_name, tool_name, actor=default_user)
|
||||
print(tool)
|
||||
assert tool.name == tool_name
|
||||
assert f"mcp:{created_server.server_name}" in tool.tags, f"Expected tag {f'mcp:{created_server.server_name}'}, got {tool.tags}"
|
||||
print("TAGS", tool.tags)
|
||||
|
||||
Reference in New Issue
Block a user