feat: add x-agent-id header for mcp tool execution
Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
@@ -14,9 +14,15 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncBaseMCPClient:
|
||||
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
# HTTP headers
|
||||
AGENT_ID_HEADER = "X-Agent-Id"
|
||||
|
||||
def __init__(
|
||||
self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
|
||||
):
|
||||
self.server_config = server_config
|
||||
self.oauth_provider = oauth_provider
|
||||
self.agent_id = agent_id
|
||||
self.exit_stack = AsyncExitStack()
|
||||
self.session: Optional[ClientSession] = None
|
||||
self.initialized = False
|
||||
|
||||
@@ -16,8 +16,10 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
def __init__(
|
||||
self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
|
||||
):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
|
||||
headers = {}
|
||||
@@ -27,6 +29,9 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)
|
||||
|
||||
@@ -1,3 +1,5 @@
|
||||
from typing import Optional
|
||||
|
||||
from mcp import ClientSession, StdioServerParameters
|
||||
from mcp.client.stdio import stdio_client
|
||||
|
||||
@@ -10,6 +12,9 @@ logger = get_logger(__name__)
|
||||
|
||||
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
|
||||
class AsyncStdioMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: StdioServerConfig, oauth_provider=None, agent_id: Optional[str] = None):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: StdioServerConfig) -> None:
|
||||
args = [arg.split() for arg in server_config.args]
|
||||
# flatten
|
||||
|
||||
@@ -12,8 +12,13 @@ logger = get_logger(__name__)
|
||||
|
||||
|
||||
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
|
||||
super().__init__(server_config, oauth_provider)
|
||||
def __init__(
|
||||
self,
|
||||
server_config: StreamableHTTPServerConfig,
|
||||
oauth_provider: Optional[OAuthClientProvider] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
):
|
||||
super().__init__(server_config, oauth_provider, agent_id)
|
||||
|
||||
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
|
||||
if not isinstance(server_config, StreamableHTTPServerConfig):
|
||||
@@ -28,6 +33,10 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
|
||||
if server_config.auth_header and server_config.auth_token:
|
||||
headers[server_config.auth_header] = server_config.auth_token
|
||||
|
||||
# Add agent ID header if provided
|
||||
if self.agent_id:
|
||||
headers[self.AGENT_ID_HEADER] = self.agent_id
|
||||
|
||||
# Use OAuth provider if available, otherwise use regular headers
|
||||
if self.oauth_provider:
|
||||
streamable_http_cm = streamablehttp_client(
|
||||
|
||||
@@ -56,14 +56,14 @@ class MCPManager:
|
||||
self.cached_mcp_servers = {} # maps id -> async connection
|
||||
|
||||
@enforce_types
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
|
||||
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
|
||||
"""Get a list of all tools for a specific MCP server."""
|
||||
mcp_client = None
|
||||
try:
|
||||
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
|
||||
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
|
||||
server_config = mcp_config.to_config()
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# list tools
|
||||
@@ -93,6 +93,7 @@ class MCPManager:
|
||||
tool_args: Optional[Dict[str, Any]],
|
||||
environment_variables: Dict[str, str],
|
||||
actor: PydanticUser,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Tuple[str, bool]:
|
||||
"""Call a specific tool from a specific MCP server."""
|
||||
try:
|
||||
@@ -108,7 +109,7 @@ class MCPManager:
|
||||
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
|
||||
server_config = mcp_config[mcp_server_name]
|
||||
|
||||
mcp_client = await self.get_mcp_client(server_config, actor)
|
||||
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
|
||||
await mcp_client.connect_to_server()
|
||||
|
||||
# call tool
|
||||
@@ -449,6 +450,7 @@ class MCPManager:
|
||||
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
|
||||
actor: PydanticUser,
|
||||
oauth_provider: Optional[Any] = None,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
|
||||
"""
|
||||
Helper function to create the appropriate MCP client based on server configuration.
|
||||
@@ -481,13 +483,13 @@ class MCPManager:
|
||||
|
||||
if server_config.type == MCPServerType.SSE:
|
||||
server_config = SSEServerConfig(**server_config.model_dump())
|
||||
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STDIO:
|
||||
server_config = StdioServerConfig(**server_config.model_dump())
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
|
||||
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
|
||||
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
|
||||
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
|
||||
else:
|
||||
raise ValueError(f"Unsupported server config type: {type(server_config)}")
|
||||
|
||||
|
||||
@@ -37,8 +37,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
# TODO: may need to have better client connection management
|
||||
|
||||
environment_variables = {}
|
||||
agent_id = None
|
||||
if agent_state:
|
||||
environment_variables = agent_state.get_agent_env_vars_as_dict()
|
||||
agent_id = agent_state.id
|
||||
|
||||
function_response, success = await mcp_manager.execute_mcp_server_tool(
|
||||
mcp_server_name=mcp_server_name,
|
||||
@@ -46,6 +48,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
|
||||
tool_args=function_args,
|
||||
environment_variables=environment_variables,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
)
|
||||
|
||||
return ToolExecutionResult(
|
||||
|
||||
Reference in New Issue
Block a user