feat: add x-agent-id header for mcp tool execution

Co-authored-by: Jin Peng <jinjpeng@Jins-MacBook-Pro.local>
This commit is contained in:
jnjpng
2025-08-22 14:09:03 -07:00
committed by GitHub
parent c38b1f5992
commit afc4809be0
6 changed files with 41 additions and 11 deletions

View File

@@ -14,9 +14,15 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncBaseMCPClient:
def __init__(self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
# HTTP headers
AGENT_ID_HEADER = "X-Agent-Id"
def __init__(
self, server_config: BaseServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
):
self.server_config = server_config
self.oauth_provider = oauth_provider
self.agent_id = agent_id
self.exit_stack = AsyncExitStack()
self.session: Optional[ClientSession] = None
self.initialized = False

View File

@@ -16,8 +16,10 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncSSEMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
def __init__(
self, server_config: SSEServerConfig, oauth_provider: Optional[OAuthClientProvider] = None, agent_id: Optional[str] = None
):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: SSEServerConfig) -> None:
headers = {}
@@ -27,6 +29,9 @@ class AsyncSSEMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
sse_cm = sse_client(url=server_config.server_url, headers=headers if headers else None, auth=self.oauth_provider)

View File

@@ -1,3 +1,5 @@
from typing import Optional
from mcp import ClientSession, StdioServerParameters
from mcp.client.stdio import stdio_client
@@ -10,6 +12,9 @@ logger = get_logger(__name__)
# TODO: Get rid of Async prefix on this class name once we deprecate old sync code
class AsyncStdioMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: StdioServerConfig, oauth_provider=None, agent_id: Optional[str] = None):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: StdioServerConfig) -> None:
args = [arg.split() for arg in server_config.args]
# flatten

View File

@@ -12,8 +12,13 @@ logger = get_logger(__name__)
class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
def __init__(self, server_config: StreamableHTTPServerConfig, oauth_provider: Optional[OAuthClientProvider] = None):
super().__init__(server_config, oauth_provider)
def __init__(
self,
server_config: StreamableHTTPServerConfig,
oauth_provider: Optional[OAuthClientProvider] = None,
agent_id: Optional[str] = None,
):
super().__init__(server_config, oauth_provider, agent_id)
async def _initialize_connection(self, server_config: BaseServerConfig) -> None:
if not isinstance(server_config, StreamableHTTPServerConfig):
@@ -28,6 +33,10 @@ class AsyncStreamableHTTPMCPClient(AsyncBaseMCPClient):
if server_config.auth_header and server_config.auth_token:
headers[server_config.auth_header] = server_config.auth_token
# Add agent ID header if provided
if self.agent_id:
headers[self.AGENT_ID_HEADER] = self.agent_id
# Use OAuth provider if available, otherwise use regular headers
if self.oauth_provider:
streamable_http_cm = streamablehttp_client(

View File

@@ -56,14 +56,14 @@ class MCPManager:
self.cached_mcp_servers = {} # maps id -> async connection
@enforce_types
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser) -> List[MCPTool]:
async def list_mcp_server_tools(self, mcp_server_name: str, actor: PydanticUser, agent_id: Optional[str] = None) -> List[MCPTool]:
"""Get a list of all tools for a specific MCP server."""
mcp_client = None
try:
mcp_server_id = await self.get_mcp_server_id_by_name(mcp_server_name, actor=actor)
mcp_config = await self.get_mcp_server_by_id_async(mcp_server_id, actor=actor)
server_config = mcp_config.to_config()
mcp_client = await self.get_mcp_client(server_config, actor)
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
# list tools
@@ -93,6 +93,7 @@ class MCPManager:
tool_args: Optional[Dict[str, Any]],
environment_variables: Dict[str, str],
actor: PydanticUser,
agent_id: Optional[str] = None,
) -> Tuple[str, bool]:
"""Call a specific tool from a specific MCP server."""
try:
@@ -108,7 +109,7 @@ class MCPManager:
raise ValueError(f"MCP server {mcp_server_name} not found in config.")
server_config = mcp_config[mcp_server_name]
mcp_client = await self.get_mcp_client(server_config, actor)
mcp_client = await self.get_mcp_client(server_config, actor, agent_id=agent_id)
await mcp_client.connect_to_server()
# call tool
@@ -449,6 +450,7 @@ class MCPManager:
server_config: Union[SSEServerConfig, StdioServerConfig, StreamableHTTPServerConfig],
actor: PydanticUser,
oauth_provider: Optional[Any] = None,
agent_id: Optional[str] = None,
) -> Union[AsyncSSEMCPClient, AsyncStdioMCPClient, AsyncStreamableHTTPMCPClient]:
"""
Helper function to create the appropriate MCP client based on server configuration.
@@ -481,13 +483,13 @@ class MCPManager:
if server_config.type == MCPServerType.SSE:
server_config = SSEServerConfig(**server_config.model_dump())
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncSSEMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
elif server_config.type == MCPServerType.STDIO:
server_config = StdioServerConfig(**server_config.model_dump())
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncStdioMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
elif server_config.type == MCPServerType.STREAMABLE_HTTP:
server_config = StreamableHTTPServerConfig(**server_config.model_dump())
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider)
return AsyncStreamableHTTPMCPClient(server_config=server_config, oauth_provider=oauth_provider, agent_id=agent_id)
else:
raise ValueError(f"Unsupported server config type: {type(server_config)}")

View File

@@ -37,8 +37,10 @@ class ExternalMCPToolExecutor(ToolExecutor):
# TODO: may need to have better client connection management
environment_variables = {}
agent_id = None
if agent_state:
environment_variables = agent_state.get_agent_env_vars_as_dict()
agent_id = agent_state.id
function_response, success = await mcp_manager.execute_mcp_server_tool(
mcp_server_name=mcp_server_name,
@@ -46,6 +48,7 @@ class ExternalMCPToolExecutor(ToolExecutor):
tool_args=function_args,
environment_variables=environment_variables,
actor=actor,
agent_id=agent_id,
)
return ToolExecutionResult(