feat: allow suers to test their own key [LET-5312] (#5257)
* feat: allow suers to test their own key * chore: add unit test --------- Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
committed by
Caren Thomas
parent
6889a73dcf
commit
7eae22f77c
@@ -10137,6 +10137,45 @@
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/providers/{provider_id}/check": {
|
||||
"post": {
|
||||
"tags": ["providers"],
|
||||
"summary": "Check Existing Provider",
|
||||
"description": "Verify the API key and additional parameters for an existing provider.",
|
||||
"operationId": "check_existing_provider",
|
||||
"parameters": [
|
||||
{
|
||||
"name": "provider_id",
|
||||
"in": "path",
|
||||
"required": true,
|
||||
"schema": {
|
||||
"type": "string",
|
||||
"title": "Provider Id"
|
||||
}
|
||||
}
|
||||
],
|
||||
"responses": {
|
||||
"200": {
|
||||
"description": "Successful Response",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {}
|
||||
}
|
||||
}
|
||||
},
|
||||
"422": {
|
||||
"description": "Validation Error",
|
||||
"content": {
|
||||
"application/json": {
|
||||
"schema": {
|
||||
"$ref": "#/components/schemas/HTTPValidationError"
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
}
|
||||
},
|
||||
"/v1/runs/": {
|
||||
"get": {
|
||||
"tags": ["runs"],
|
||||
|
||||
@@ -120,6 +120,40 @@ async def check_provider(
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
|
||||
|
||||
|
||||
@router.post("/{provider_id}/check", response_model=None, operation_id="check_existing_provider")
|
||||
async def check_existing_provider(
|
||||
provider_id: str,
|
||||
headers: HeaderParams = Depends(get_headers),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Verify the API key and additional parameters for an existing provider.
|
||||
"""
|
||||
try:
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
|
||||
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
|
||||
|
||||
# Create a ProviderCheck from the existing provider
|
||||
provider_check = ProviderCheck(
|
||||
provider_type=provider.provider_type,
|
||||
api_key=provider.api_key,
|
||||
base_url=provider.base_url,
|
||||
)
|
||||
|
||||
await server.provider_manager.check_provider_api_key(provider_check=provider_check)
|
||||
return JSONResponse(
|
||||
status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider.provider_type.value}"}
|
||||
)
|
||||
except LLMAuthenticationError as e:
|
||||
raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}")
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Provider provider_id={provider_id} not found for user_id={actor.id}.")
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}")
|
||||
|
||||
|
||||
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
|
||||
async def delete_provider(
|
||||
provider_id: str,
|
||||
|
||||
96
tests/test_provider_api.py
Normal file
96
tests/test_provider_api.py
Normal file
@@ -0,0 +1,96 @@
|
||||
import os
|
||||
import threading
|
||||
|
||||
import pytest
|
||||
from dotenv import load_dotenv
|
||||
from letta_client import Letta
|
||||
from letta_client.core.api_error import ApiError
|
||||
|
||||
from tests.utils import wait_for_server
|
||||
|
||||
# Constants
|
||||
SERVER_PORT = 8283
|
||||
|
||||
|
||||
def run_server():
|
||||
load_dotenv()
|
||||
|
||||
from letta.server.rest_api.app import start_server
|
||||
|
||||
print("Starting server...")
|
||||
start_server(debug=True)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def client(request):
|
||||
# Get URL from environment or start server
|
||||
api_url = os.getenv("LETTA_API_URL")
|
||||
server_url = os.getenv("LETTA_SERVER_URL", f"http://localhost:{SERVER_PORT}")
|
||||
if not os.getenv("LETTA_SERVER_URL"):
|
||||
print("Starting server thread")
|
||||
thread = threading.Thread(target=run_server, daemon=True)
|
||||
thread.start()
|
||||
wait_for_server(server_url)
|
||||
print("Running client tests with server:", server_url)
|
||||
|
||||
# Overide the base_url if the LETTA_API_URL is set
|
||||
base_url = api_url if api_url else server_url
|
||||
# create the Letta client
|
||||
yield Letta(base_url=base_url, token=None)
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def test_provider(client: Letta):
|
||||
"""Create a test provider for testing."""
|
||||
# Create a provider with a test API key
|
||||
provider = client.providers.create(
|
||||
provider_type="openai",
|
||||
api_key="test-api-key-123",
|
||||
name="test-openai-provider",
|
||||
)
|
||||
|
||||
yield provider
|
||||
|
||||
# Clean up - delete the provider
|
||||
try:
|
||||
client.providers.delete(provider.id)
|
||||
except ApiError:
|
||||
# Provider might already be deleted
|
||||
pass
|
||||
|
||||
|
||||
def test_check_existing_provider_success(client: Letta, test_provider):
|
||||
"""Test checking an existing provider with valid credentials."""
|
||||
# This test assumes the test_provider has valid credentials
|
||||
# In a real scenario, you would need to use actual valid API keys
|
||||
# For this test, we'll check that the endpoint is callable
|
||||
try:
|
||||
response = client.providers.check(test_provider.id)
|
||||
# If we get here, the endpoint is working
|
||||
assert response is not None
|
||||
except ApiError as e:
|
||||
# Expected for invalid API key - just verify the endpoint exists
|
||||
# and returns 401 for invalid credentials
|
||||
assert e.status_code in [401, 500] # 401 for auth error, 500 for connection error
|
||||
|
||||
|
||||
def test_check_existing_provider_not_found(client: Letta):
|
||||
"""Test checking a provider that doesn't exist."""
|
||||
fake_provider_id = "00000000-0000-0000-0000-000000000000"
|
||||
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
client.providers.check(fake_provider_id)
|
||||
|
||||
# Should return 404 for provider not found
|
||||
assert exc_info.value.status_code == 404
|
||||
|
||||
|
||||
def test_check_existing_provider_unauthorized(client: Letta, test_provider):
|
||||
"""Test checking an existing provider with invalid API key."""
|
||||
# The test provider has a test API key which will fail authentication
|
||||
with pytest.raises(ApiError) as exc_info:
|
||||
client.providers.check(test_provider.id)
|
||||
|
||||
# Should return 401 for invalid API key
|
||||
# or 500 if the provider check fails for other reasons
|
||||
assert exc_info.value.status_code in [401, 500]
|
||||
Reference in New Issue
Block a user