feat: have core ask cloud for any relavent api credentials to allow a… [LET-6179] (#6172)
feat: have core ask cloud for any relavent api credentials to allow an agent to perform letta tasks Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
committed by
Caren Thomas
parent
34f5b5e33c
commit
acbbccd28a
@@ -1298,6 +1298,7 @@ class SyncServer(object):
|
|||||||
actor=actor,
|
actor=actor,
|
||||||
sandbox_env_vars=tool_env_vars,
|
sandbox_env_vars=tool_env_vars,
|
||||||
)
|
)
|
||||||
|
|
||||||
# TODO: Integrate sandbox result
|
# TODO: Integrate sandbox result
|
||||||
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
tool_execution_result = await tool_execution_manager.execute_tool_async(
|
||||||
function_name=tool_name,
|
function_name=tool_name,
|
||||||
|
|||||||
80
letta/services/sandbox_credentials_service.py
Normal file
80
letta/services/sandbox_credentials_service.py
Normal file
@@ -0,0 +1,80 @@
|
|||||||
|
import logging
|
||||||
|
import os
|
||||||
|
from typing import Any, Dict, Optional
|
||||||
|
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
from letta.schemas.user import User
|
||||||
|
|
||||||
|
logger = logging.getLogger(__name__)
|
||||||
|
|
||||||
|
|
||||||
|
class SandboxCredentialsService:
|
||||||
|
"""Service for fetching sandbox credentials from a webhook."""
|
||||||
|
|
||||||
|
def __init__(self):
|
||||||
|
self.credentials_webhook_url = os.getenv("STEP_ORCHESTRATOR_ENDPOINT")
|
||||||
|
self.credentials_webhook_key = os.getenv("STEP_COMPLETE_KEY")
|
||||||
|
|
||||||
|
async def fetch_credentials(
|
||||||
|
self,
|
||||||
|
actor: User,
|
||||||
|
tool_name: Optional[str] = None,
|
||||||
|
agent_id: Optional[str] = None,
|
||||||
|
) -> Dict[str, Any]:
|
||||||
|
"""
|
||||||
|
Fetch sandbox credentials from the configured webhook.
|
||||||
|
|
||||||
|
Args:
|
||||||
|
actor: The user executing the tool
|
||||||
|
tool_name: Optional name of the tool being executed
|
||||||
|
agent_id: Optional ID of the agent executing the tool
|
||||||
|
|
||||||
|
Returns:
|
||||||
|
Dict[str, Any]: Dictionary of environment variables to add to sandbox
|
||||||
|
"""
|
||||||
|
if not self.credentials_webhook_url:
|
||||||
|
logger.debug("SANDBOX_CREDENTIALS_WEBHOOK not configured, skipping credentials fetch")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
try:
|
||||||
|
headers = {}
|
||||||
|
if self.credentials_webhook_key:
|
||||||
|
headers["Authorization"] = f"Bearer {self.credentials_webhook_key}"
|
||||||
|
|
||||||
|
payload = {
|
||||||
|
"user_id": actor.id,
|
||||||
|
"organization_id": actor.organization_id,
|
||||||
|
}
|
||||||
|
|
||||||
|
if tool_name:
|
||||||
|
payload["tool_name"] = tool_name
|
||||||
|
if agent_id:
|
||||||
|
payload["agent_id"] = agent_id
|
||||||
|
|
||||||
|
async with httpx.AsyncClient(timeout=10.0) as client:
|
||||||
|
response = await client.post(
|
||||||
|
self.credentials_webhook_url + "/webhook/sandbox-credentials",
|
||||||
|
json=payload,
|
||||||
|
headers=headers,
|
||||||
|
)
|
||||||
|
response.raise_for_status()
|
||||||
|
|
||||||
|
response_data = response.json()
|
||||||
|
|
||||||
|
if not isinstance(response_data, dict):
|
||||||
|
logger.warning(f"Invalid response format from credentials webhook: expected dict, got {type(response_data)}")
|
||||||
|
return {}
|
||||||
|
|
||||||
|
logger.info(f"Successfully fetched sandbox credentials for user {actor.id}")
|
||||||
|
return response_data
|
||||||
|
|
||||||
|
except httpx.TimeoutException:
|
||||||
|
logger.warning(f"Timeout fetching sandbox credentials for user {actor.id}")
|
||||||
|
return {}
|
||||||
|
except httpx.HTTPStatusError as e:
|
||||||
|
logger.warning(f"HTTP error fetching sandbox credentials for user {actor.id}: {e.response.status_code}")
|
||||||
|
return {}
|
||||||
|
except Exception as e:
|
||||||
|
logger.error(f"Unexpected error fetching sandbox credentials for user {actor.id}: {e}")
|
||||||
|
return {}
|
||||||
149
letta/services/sandbox_credentials_service_test.py
Normal file
149
letta/services/sandbox_credentials_service_test.py
Normal file
@@ -0,0 +1,149 @@
|
|||||||
|
"""
|
||||||
|
Test for sandbox credentials service functionality.
|
||||||
|
|
||||||
|
To run this test:
|
||||||
|
python -m pytest letta/services/sandbox_credentials_service_test.py -v
|
||||||
|
|
||||||
|
To test with actual webhook:
|
||||||
|
export SANDBOX_CREDENTIALS_WEBHOOK=https://your-webhook-url.com/endpoint
|
||||||
|
export SANDBOX_CREDENTIALS_KEY=your-secret-key
|
||||||
|
python -m pytest letta/services/sandbox_credentials_service_test.py -v
|
||||||
|
"""
|
||||||
|
|
||||||
|
import os
|
||||||
|
from unittest.mock import AsyncMock, patch
|
||||||
|
|
||||||
|
import pytest
|
||||||
|
|
||||||
|
from letta.schemas.user import User
|
||||||
|
from letta.services.sandbox_credentials_service import SandboxCredentialsService
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_not_configured():
|
||||||
|
"""Test that credentials fetch returns empty dict when URL is not configured."""
|
||||||
|
with patch.dict(os.environ, {}, clear=True):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
result = await service.fetch_credentials(mock_user)
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_fetch_success():
|
||||||
|
"""Test successful credentials fetch."""
|
||||||
|
with patch.dict(
|
||||||
|
os.environ,
|
||||||
|
{"SANDBOX_CREDENTIALS_WEBHOOK": "https://example.com/credentials", "SANDBOX_CREDENTIALS_KEY": "test-key"},
|
||||||
|
):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = AsyncMock()
|
||||||
|
mock_response.json = AsyncMock(return_value={"API_KEY": "secret_key_123", "OTHER_VAR": "value"})
|
||||||
|
|
||||||
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
|
||||||
|
result = await service.fetch_credentials(mock_user, tool_name="my_tool", agent_id="agent_789")
|
||||||
|
|
||||||
|
assert result == {"API_KEY": "secret_key_123", "OTHER_VAR": "value"}
|
||||||
|
mock_post.assert_called_once()
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
assert call_args.kwargs["json"] == {
|
||||||
|
"user_id": "user_123",
|
||||||
|
"organization_id": "org_456",
|
||||||
|
"tool_name": "my_tool",
|
||||||
|
"agent_id": "agent_789",
|
||||||
|
}
|
||||||
|
assert call_args.kwargs["headers"]["Authorization"] == "Bearer test-key"
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_fetch_without_auth():
|
||||||
|
"""Test credentials fetch without authentication key."""
|
||||||
|
with patch.dict(os.environ, {"SANDBOX_CREDENTIALS_WEBHOOK": "https://example.com/credentials"}, clear=True):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = AsyncMock()
|
||||||
|
mock_response.json = AsyncMock(return_value={"API_KEY": "secret_key_123"})
|
||||||
|
|
||||||
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
|
||||||
|
result = await service.fetch_credentials(mock_user)
|
||||||
|
|
||||||
|
assert result == {"API_KEY": "secret_key_123"}
|
||||||
|
call_args = mock_post.call_args
|
||||||
|
# Should not have Authorization header
|
||||||
|
assert "Authorization" not in call_args.kwargs["headers"]
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_fetch_timeout():
|
||||||
|
"""Test credentials fetch timeout handling."""
|
||||||
|
with patch.dict(os.environ, {"SANDBOX_CREDENTIALS_WEBHOOK": "https://example.com/credentials"}):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_post = AsyncMock(side_effect=httpx.TimeoutException("Request timed out"))
|
||||||
|
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
|
||||||
|
result = await service.fetch_credentials(mock_user)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_fetch_http_error():
|
||||||
|
"""Test credentials fetch HTTP error handling."""
|
||||||
|
with patch.dict(os.environ, {"SANDBOX_CREDENTIALS_WEBHOOK": "https://example.com/credentials"}):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
import httpx
|
||||||
|
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 500
|
||||||
|
mock_response.raise_for_status = AsyncMock(
|
||||||
|
side_effect=httpx.HTTPStatusError("Server error", request=None, response=mock_response)
|
||||||
|
)
|
||||||
|
|
||||||
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
|
||||||
|
result = await service.fetch_credentials(mock_user)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
|
|
||||||
|
|
||||||
|
@pytest.mark.asyncio
|
||||||
|
async def test_credentials_fetch_invalid_response():
|
||||||
|
"""Test credentials fetch with invalid response format."""
|
||||||
|
with patch.dict(os.environ, {"SANDBOX_CREDENTIALS_WEBHOOK": "https://example.com/credentials"}):
|
||||||
|
service = SandboxCredentialsService()
|
||||||
|
mock_user = User(id="user_123", organization_id="org_456")
|
||||||
|
|
||||||
|
with patch("httpx.AsyncClient") as mock_client:
|
||||||
|
mock_response = AsyncMock()
|
||||||
|
mock_response.status_code = 200
|
||||||
|
mock_response.raise_for_status = AsyncMock()
|
||||||
|
mock_response.json = AsyncMock(return_value="not a dict")
|
||||||
|
|
||||||
|
mock_post = AsyncMock(return_value=mock_response)
|
||||||
|
mock_client.return_value.__aenter__.return_value.post = mock_post
|
||||||
|
|
||||||
|
result = await service.fetch_credentials(mock_user)
|
||||||
|
|
||||||
|
assert result == {}
|
||||||
@@ -11,6 +11,7 @@ from letta.schemas.tool import Tool
|
|||||||
from letta.schemas.tool_execution_result import ToolExecutionResult
|
from letta.schemas.tool_execution_result import ToolExecutionResult
|
||||||
from letta.schemas.user import User
|
from letta.schemas.user import User
|
||||||
from letta.services.agent_manager import AgentManager
|
from letta.services.agent_manager import AgentManager
|
||||||
|
from letta.services.sandbox_credentials_service import SandboxCredentialsService
|
||||||
from letta.services.tool_executor.tool_executor_base import ToolExecutor
|
from letta.services.tool_executor.tool_executor_base import ToolExecutor
|
||||||
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
|
from letta.services.tool_sandbox.local_sandbox import AsyncToolSandboxLocal
|
||||||
from letta.settings import tool_settings
|
from letta.settings import tool_settings
|
||||||
@@ -40,6 +41,20 @@ class SandboxToolExecutor(ToolExecutor):
|
|||||||
else:
|
else:
|
||||||
orig_memory_str = None
|
orig_memory_str = None
|
||||||
|
|
||||||
|
# Fetch credentials from webhook
|
||||||
|
credentials_service = SandboxCredentialsService()
|
||||||
|
|
||||||
|
fetched_credentials = await credentials_service.fetch_credentials(
|
||||||
|
actor=actor,
|
||||||
|
tool_name=tool.name,
|
||||||
|
agent_id=agent_state.id if agent_state else None,
|
||||||
|
)
|
||||||
|
|
||||||
|
# Merge fetched credentials with provided sandbox_env_vars
|
||||||
|
if sandbox_env_vars is None:
|
||||||
|
sandbox_env_vars = {}
|
||||||
|
sandbox_env_vars = {**fetched_credentials, **sandbox_env_vars}
|
||||||
|
|
||||||
try:
|
try:
|
||||||
# Prepare function arguments
|
# Prepare function arguments
|
||||||
function_args = self._prepare_function_args(function_args, tool, function_name)
|
function_args = self._prepare_function_args(function_args, tool, function_name)
|
||||||
|
|||||||
Reference in New Issue
Block a user