From 05e1623389fbdacf94dfa6ba322eb6aba22b3c67 Mon Sep 17 00:00:00 2001 From: jnjpng Date: Wed, 7 May 2025 16:26:55 -0700 Subject: [PATCH] feat: add endpoint to test connection to llm provider (#2032) Co-authored-by: Jin Peng --- letta/llm_api/anthropic.py | 16 +++++++++++- letta/llm_api/google_ai_client.py | 20 ++++++++++++++ letta/llm_api/google_constants.py | 2 ++ letta/llm_api/openai.py | 16 ++++++++++++ letta/schemas/enums.py | 1 + letta/schemas/providers.py | 26 ++++++++++++++++++- letta/server/rest_api/routers/v1/providers.py | 19 +++++++++++++- letta/services/provider_manager.py | 17 +++++++++++- 8 files changed, 113 insertions(+), 4 deletions(-) diff --git a/letta/llm_api/anthropic.py b/letta/llm_api/anthropic.py index aada2259..88cf0e79 100644 --- a/letta/llm_api/anthropic.py +++ b/letta/llm_api/anthropic.py @@ -19,7 +19,7 @@ from anthropic.types.beta import ( BetaToolUseBlock, ) -from letta.errors import BedrockError, BedrockPermissionError +from letta.errors import BedrockError, BedrockPermissionError, ErrorCode, LLMAuthenticationError, LLMError from letta.helpers.datetime_helpers import get_utc_time_int, timestamp_to_datetime from letta.llm_api.aws_bedrock import get_bedrock_client from letta.llm_api.helpers import add_inner_thoughts_to_functions @@ -119,6 +119,20 @@ DUMMY_FIRST_USER_MESSAGE = "User initializing bootup sequence." VALID_EVENT_TYPES = {"content_block_stop", "message_stop"} +def anthropic_check_valid_api_key(api_key: Union[str, None]) -> None: + if api_key: + anthropic_client = anthropic.Anthropic(api_key=api_key) + try: + # just use a cheap model to count some tokens - as of 5/7/2025 this is faster than fetching the list of models + anthropic_client.messages.count_tokens(model=MODEL_LIST[-1]["name"], messages=[{"role": "user", "content": "a"}]) + except anthropic.AuthenticationError as e: + raise LLMAuthenticationError(message=f"Failed to authenticate with Anthropic: {e}", code=ErrorCode.UNAUTHENTICATED) + except Exception as e: + raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR) + else: + raise ValueError("No API key provided") + + def antropic_get_model_context_window(url: str, api_key: Union[str, None], model: str) -> int: for model_dict in anthropic_get_model_list(url=url, api_key=api_key): if model_dict["name"] == model: diff --git a/letta/llm_api/google_ai_client.py b/letta/llm_api/google_ai_client.py index ad650c5f..f056a64b 100644 --- a/letta/llm_api/google_ai_client.py +++ b/letta/llm_api/google_ai_client.py @@ -3,11 +3,14 @@ import uuid from typing import List, Optional, Tuple import requests +from google import genai from google.genai.types import FunctionCallingConfig, FunctionCallingConfigMode, ToolConfig from letta.constants import NON_USER_MSG_PREFIX +from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.helpers.datetime_helpers import get_utc_time_int from letta.helpers.json_helpers import json_dumps +from letta.llm_api.google_constants import GOOGLE_MODEL_FOR_API_KEY_CHECK from letta.llm_api.helpers import make_post_request from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.json_parser import clean_json_string_extra_backslash @@ -443,6 +446,23 @@ def get_gemini_endpoint_and_headers( return url, headers +def google_ai_check_valid_api_key(api_key: str): + client = genai.Client(api_key=api_key) + # use the count token endpoint for a cheap model - as of 5/7/2025 this is slightly faster than fetching the list of models + try: + client.models.count_tokens( + model=GOOGLE_MODEL_FOR_API_KEY_CHECK, + contents="", + ) + except genai.errors.ClientError as e: + # google api returns 400 invalid argument for invalid api key + if e.code == 400: + raise LLMAuthenticationError(message=f"Failed to authenticate with Google AI: {e}", code=ErrorCode.UNAUTHENTICATED) + raise e + except Exception as e: + raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR) + + def google_ai_get_model_list(base_url: str, api_key: str, key_in_header: bool = True) -> List[dict]: from letta.utils import printd diff --git a/letta/llm_api/google_constants.py b/letta/llm_api/google_constants.py index c720a33a..1c30d615 100644 --- a/letta/llm_api/google_constants.py +++ b/letta/llm_api/google_constants.py @@ -14,3 +14,5 @@ GOOGLE_MODEL_TO_CONTEXT_LENGTH = { GOOGLE_MODEL_TO_OUTPUT_LENGTH = {"gemini-2.0-flash-001": 8192, "gemini-2.5-pro-exp-03-25": 65536} GOOGLE_EMBEDING_MODEL_TO_DIM = {"text-embedding-005": 768, "text-multilingual-embedding-002": 768} + +GOOGLE_MODEL_FOR_API_KEY_CHECK = "gemini-2.0-flash-lite" diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index e35429bc..2fe8ade3 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -5,6 +5,7 @@ import requests from openai import OpenAI from letta.constants import LETTA_MODEL_ENDPOINT +from letta.errors import ErrorCode, LLMAuthenticationError, LLMError from letta.helpers.datetime_helpers import timestamp_to_datetime from letta.llm_api.helpers import add_inner_thoughts_to_functions, convert_to_structured_output, make_post_request from letta.llm_api.openai_client import accepts_developer_role, supports_parallel_tool_calling, supports_temperature_param @@ -34,6 +35,21 @@ from letta.utils import get_tool_call_id, smart_urljoin logger = get_logger(__name__) +def openai_check_valid_api_key(base_url: str, api_key: Union[str, None]) -> None: + if api_key: + try: + # just get model list to check if the api key is valid until we find a cheaper / quicker endpoint + openai_get_model_list(url=base_url, api_key=api_key) + except requests.HTTPError as e: + if e.response.status_code == 401: + raise LLMAuthenticationError(message=f"Failed to authenticate with OpenAI: {e}", code=ErrorCode.UNAUTHENTICATED) + raise e + except Exception as e: + raise LLMError(message=f"{e}", code=ErrorCode.INTERNAL_SERVER_ERROR) + else: + raise ValueError("No API key provided") + + def openai_get_model_list( url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None ) -> dict: diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index 2a3de409..555ffadd 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -3,6 +3,7 @@ from enum import Enum class ProviderType(str, Enum): anthropic = "anthropic" + anthropic_bedrock = "bedrock" google_ai = "google_ai" google_vertex = "google_vertex" openai = "openai" diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 291271e3..f1e9edd6 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -2,7 +2,7 @@ import warnings from datetime import datetime from typing import List, Literal, Optional -from pydantic import Field, model_validator +from pydantic import BaseModel, Field, model_validator from letta.constants import LETTA_MODEL_ENDPOINT, LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW from letta.llm_api.azure_openai import get_azure_chat_completions_endpoint, get_azure_embeddings_endpoint @@ -40,6 +40,10 @@ class Provider(ProviderBase): if not self.id: self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) + def check_api_key(self): + """Check if the API key is valid for the provider""" + raise NotImplementedError + def list_llm_models(self) -> List[LLMConfig]: return [] @@ -112,6 +116,11 @@ class ProviderUpdate(ProviderBase): api_key: str = Field(..., description="API key used for requests to the provider.") +class ProviderCheck(BaseModel): + provider_type: ProviderType = Field(..., description="The type of the provider.") + api_key: str = Field(..., description="API key used for requests to the provider.") + + class LettaProvider(Provider): provider_type: Literal[ProviderType.letta] = Field(ProviderType.letta, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") @@ -148,6 +157,11 @@ class OpenAIProvider(Provider): api_key: str = Field(..., description="API key for the OpenAI API.") base_url: str = Field(..., description="Base URL for the OpenAI API.") + def check_api_key(self): + from letta.llm_api.openai import openai_check_valid_api_key + + openai_check_valid_api_key(self.base_url, self.api_key) + def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.openai import openai_get_model_list @@ -549,6 +563,11 @@ class AnthropicProvider(Provider): api_key: str = Field(..., description="API key for the Anthropic API.") base_url: str = "https://api.anthropic.com/v1" + def check_api_key(self): + from letta.llm_api.anthropic import anthropic_check_valid_api_key + + anthropic_check_valid_api_key(self.api_key) + def list_llm_models(self) -> List[LLMConfig]: from letta.llm_api.anthropic import MODEL_LIST, anthropic_get_model_list @@ -951,6 +970,11 @@ class GoogleAIProvider(Provider): api_key: str = Field(..., description="API key for the Google AI API.") base_url: str = "https://generativelanguage.googleapis.com" + def check_api_key(self): + from letta.llm_api.google_ai_client import google_ai_check_valid_api_key + + google_ai_check_valid_api_key(self.api_key) + def list_llm_models(self): from letta.llm_api.google_ai_client import google_ai_get_model_list diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 0cf29114..8111ffba 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -3,9 +3,10 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status from fastapi.responses import JSONResponse +from letta.errors import LLMAuthenticationError from letta.orm.errors import NoResultFound from letta.schemas.enums import ProviderType -from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate +from letta.schemas.providers import Provider, ProviderCheck, ProviderCreate, ProviderUpdate from letta.server.rest_api.utils import get_letta_server if TYPE_CHECKING: @@ -67,6 +68,22 @@ def modify_provider( return server.provider_manager.update_provider(provider_id=provider_id, request=request, actor=actor) +@router.get("/check", response_model=None, operation_id="check_provider") +def check_provider( + provider_type: ProviderType = Query(...), + api_key: str = Header(..., alias="x-api-key"), + server: "SyncServer" = Depends(get_letta_server), +): + try: + provider_check = ProviderCheck(provider_type=provider_type, api_key=api_key) + server.provider_manager.check_provider_api_key(provider_check=provider_check) + return JSONResponse(status_code=status.HTTP_200_OK, content={"message": f"Valid api key for provider_type={provider_type.value}"}) + except LLMAuthenticationError as e: + raise HTTPException(status_code=status.HTTP_401_UNAUTHORIZED, detail=f"{e.message}") + except Exception as e: + raise HTTPException(status_code=status.HTTP_500_INTERNAL_SERVER_ERROR, detail=f"{e}") + + @router.delete("/{provider_id}", response_model=None, operation_id="delete_provider") def delete_provider( provider_id: str, diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 49ec99f4..e77a3f2f 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -3,7 +3,7 @@ from typing import List, Optional, Union from letta.orm.provider import Provider as ProviderModel from letta.schemas.enums import ProviderCategory, ProviderType from letta.schemas.providers import Provider as PydanticProvider -from letta.schemas.providers import ProviderCreate, ProviderUpdate +from letta.schemas.providers import ProviderCheck, ProviderCreate, ProviderUpdate from letta.schemas.user import User as PydanticUser from letta.utils import enforce_types @@ -99,3 +99,18 @@ class ProviderManager: def get_override_key(self, provider_name: Union[str, None], actor: PydanticUser) -> Optional[str]: providers = self.list_providers(name=provider_name, actor=actor) return providers[0].api_key if providers else None + + @enforce_types + def check_provider_api_key(self, provider_check: ProviderCheck) -> None: + provider = PydanticProvider( + name=provider_check.provider_type.value, + provider_type=provider_check.provider_type, + api_key=provider_check.api_key, + provider_category=ProviderCategory.byok, + ).cast_to_subtype() + + # TODO: add more string sanity checks here before we hit actual endpoints + if not provider.api_key: + raise ValueError("API key is required") + + provider.check_api_key()