diff --git a/fern/openapi.json b/fern/openapi.json index 6c3affe2..8160efb0 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -33073,7 +33073,8 @@ } ], "title": "Api Key", - "description": "API key or secret key used for requests to the provider." + "description": "API key or secret key used for requests to the provider.", + "deprecated": true }, "base_url": { "anyOf": [ @@ -33097,7 +33098,8 @@ } ], "title": "Access Key", - "description": "Access key used for requests to the provider." + "description": "Access key used for requests to the provider.", + "deprecated": true }, "region": { "anyOf": [ diff --git a/letta/schemas/providers/anthropic.py b/letta/schemas/providers/anthropic.py index 6e1e4af7..d137e234 100644 --- a/letta/schemas/providers/anthropic.py +++ b/letta/schemas/providers/anthropic.py @@ -104,11 +104,11 @@ MODEL_LIST = [ class AnthropicProvider(Provider): provider_type: Literal[ProviderType.anthropic] = Field(ProviderType.anthropic, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Anthropic API.") + api_key: str | None = Field(None, description="API key for the Anthropic API.", deprecated=True) base_url: str = "https://api.anthropic.com/v1" async def check_api_key(self): - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None if api_key: anthropic_client = anthropic.Anthropic(api_key=api_key) try: @@ -127,7 +127,7 @@ class AnthropicProvider(Provider): NOTE: currently there is no GET /models, so we need to hardcode """ - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None if api_key: anthropic_client = anthropic.AsyncAnthropic(api_key=api_key) elif model_settings.anthropic_api_key: diff --git a/letta/schemas/providers/azure.py b/letta/schemas/providers/azure.py index 2e2225a7..19f85a90 100644 --- a/letta/schemas/providers/azure.py +++ b/letta/schemas/providers/azure.py @@ -36,7 +36,7 @@ class AzureProvider(Provider): base_url: str = Field( ..., description="Base URL for the Azure API endpoint. This should be specific to your org, e.g. `https://letta.openai.azure.com`." ) - api_key: str = Field(..., description="API key for the Azure API.") + api_key: str | None = Field(None, description="API key for the Azure API.", deprecated=True) api_version: str = Field(default=LATEST_API_VERSION, description="API version for the Azure API") @field_validator("api_version", mode="before") @@ -60,7 +60,7 @@ class AzureProvider(Provider): async def azure_openai_get_deployed_model_list(self) -> list: """https://learn.microsoft.com/en-us/rest/api/azureopenai/models/list?view=rest-azureopenai-2023-05-15&tabs=HTTP""" - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None client = AsyncAzureOpenAI(api_key=api_key, api_version=self.api_version, azure_endpoint=self.base_url) try: @@ -169,7 +169,7 @@ class AzureProvider(Provider): return AZURE_MODEL_TO_CONTEXT_LENGTH.get(model_name, llm_default) async def check_api_key(self): - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None if not api_key: raise ValueError("No API key provided") diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index f9457c7d..b89fc1bb 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -4,7 +4,7 @@ from letta.log import get_logger logger = get_logger(__name__) -from pydantic import BaseModel, Field, model_validator +from pydantic import BaseModel, Field, field_validator, model_validator from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.embedding_config_overrides import EMBEDDING_HANDLE_OVERRIDES @@ -25,9 +25,9 @@ class Provider(ProviderBase): name: str = Field(..., description="The name of the provider") provider_type: ProviderType = Field(..., description="The type of the provider") provider_category: ProviderCategory = Field(..., description="The category of the provider (base or byok)") - api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.") + api_key: str | None = Field(None, description="API key or secret key used for requests to the provider.", deprecated=True) base_url: str | None = Field(None, description="Base URL for the provider.") - access_key: str | None = Field(None, description="Access key used for requests to the provider.") + access_key: str | None = Field(None, description="Access key used for requests to the provider.", deprecated=True) region: str | None = Field(None, description="Region used for requests to the provider.") api_version: str | None = Field(None, description="API version used for requests to the provider.") organization_id: str | None = Field(None, description="The organization id of the user") @@ -38,53 +38,50 @@ class Provider(ProviderBase): api_key_enc: Secret | None = Field(None, description="Encrypted API key as Secret object") access_key_enc: Secret | None = Field(None, description="Encrypted access key as Secret object") + # TODO: remove these checks once fully migrated to encrypted fields + def __setattr__(self, name: str, value) -> None: + if name in ("api_key", "access_key"): + logger.warning( + f"DEPRECATION: Setting '{name}' directly is deprecated. Use the encrypted fields (`api_key_enc`/`access_key_enc`) instead." + ) + return super().__setattr__(name, value) + + def __getattribute__(self, name: str): + if name in ("api_key", "access_key"): + logger.warning( + f"DEPRECATION: Accessing '{name}' directly is deprecated. " + "Use the encrypted fields (`api_key_enc`/`access_key_enc`) instead." + ) + return super().__getattribute__(name) + + @field_validator("api_key") + def deprecate_api_key(cls, v: str): + if v: + logger.warning( + "DEPRECATION: Creating provider with 'api_key' directly is deprecated. Use the encrypted fields (`api_key_enc`) instead." + ) + return v + + @field_validator("access_key") + def deprecate_access_key(cls, v: str): + if v: + logger.warning( + "DEPRECATION: Creating provider with 'access_key' directly is deprecated. Use the encrypted fields (`access_key_enc`) instead." + ) + return v + @model_validator(mode="after") def default_base_url(self): # Set default base URL if self.provider_type == ProviderType.openai and self.base_url is None: self.base_url = model_settings.openai_api_base + return self def resolve_identifier(self): if not self.id: self.id = ProviderBase.generate_id(prefix=ProviderBase.__id_prefix__) - def get_api_key_secret(self) -> Secret: - """Get the API key as a Secret object. Prefers encrypted, falls back to plaintext with error logging.""" - # If api_key_enc is already a Secret, return it - if self.api_key_enc is not None: - return self.api_key_enc - # Fallback to plaintext with error logging via Secret.from_db() - return Secret.from_db(encrypted_value=None, plaintext_value=self.api_key) - - def get_access_key_secret(self) -> Secret: - """Get the access key as a Secret object. Prefers encrypted, falls back to plaintext with error logging.""" - # If access_key_enc is already a Secret, return it - if self.access_key_enc is not None: - return self.access_key_enc - # Fallback to plaintext with error logging via Secret.from_db() - return Secret.from_db(encrypted_value=None, plaintext_value=self.access_key) - - def set_api_key_secret(self, secret: Secret) -> None: - """Set API key from a Secret object, directly storing the Secret.""" - self.api_key_enc = secret - # Also update plaintext field for dual-write during migration - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.api_key = secret_dict["plaintext"] - else: - self.api_key = None - - def set_access_key_secret(self, secret: Secret) -> None: - """Set access key from a Secret object, directly storing the Secret.""" - self.access_key_enc = secret - # Also update plaintext field for dual-write during migration - secret_dict = secret.to_dict() - if not secret.was_encrypted: - self.access_key = secret_dict["plaintext"] - else: - self.access_key = None - async def check_api_key(self): """Check if the API key is valid for the provider""" raise NotImplementedError diff --git a/letta/schemas/providers/bedrock.py b/letta/schemas/providers/bedrock.py index 461b77fa..ef809089 100644 --- a/letta/schemas/providers/bedrock.py +++ b/letta/schemas/providers/bedrock.py @@ -26,8 +26,8 @@ class BedrockProvider(Provider): try: # Decrypt credentials before using - access_key = self.get_access_key_secret().get_plaintext() - secret_key = self.get_api_key_secret().get_plaintext() + access_key = self.access_key_enc.get_plaintext() if self.access_key_enc else None + secret_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None session = Session() async with session.client( diff --git a/letta/schemas/providers/cerebras.py b/letta/schemas/providers/cerebras.py index 75adf5c9..19470e5e 100644 --- a/letta/schemas/providers/cerebras.py +++ b/letta/schemas/providers/cerebras.py @@ -26,7 +26,7 @@ class CerebrasProvider(OpenAIProvider): provider_type: Literal[ProviderType.cerebras] = Field(ProviderType.cerebras, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field("https://api.cerebras.ai/v1", description="Base URL for the Cerebras API.") - api_key: str = Field(..., description="API key for the Cerebras API.") + api_key: str | None = Field(None, description="API key for the Cerebras API.", deprecated=True) def get_model_context_window_size(self, model_name: str) -> int | None: """Cerebras has limited context window sizes. @@ -41,7 +41,7 @@ class CerebrasProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await openai_get_model_list_async(self.base_url, api_key=api_key) if "data" in response: diff --git a/letta/schemas/providers/deepseek.py b/letta/schemas/providers/deepseek.py index be2ef0b1..a1ff7bb1 100644 --- a/letta/schemas/providers/deepseek.py +++ b/letta/schemas/providers/deepseek.py @@ -18,7 +18,7 @@ class DeepSeekProvider(OpenAIProvider): provider_type: Literal[ProviderType.deepseek] = Field(ProviderType.deepseek, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = Field("https://api.deepseek.com/v1", description="Base URL for the DeepSeek API.") - api_key: str = Field(..., description="API key for the DeepSeek API.") + api_key: str | None = Field(None, description="API key for the DeepSeek API.", deprecated=True) # TODO (cliandy): this may need to be updated to reflect current models def get_model_context_window_size(self, model_name: str) -> int | None: @@ -34,7 +34,7 @@ class DeepSeekProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await openai_get_model_list_async(self.base_url, api_key=api_key) data = response.get("data", response) diff --git a/letta/schemas/providers/google_gemini.py b/letta/schemas/providers/google_gemini.py index 1261e668..c1c135f7 100644 --- a/letta/schemas/providers/google_gemini.py +++ b/letta/schemas/providers/google_gemini.py @@ -17,20 +17,20 @@ from letta.schemas.providers.base import Provider class GoogleAIProvider(Provider): provider_type: Literal[ProviderType.google_ai] = Field(ProviderType.google_ai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Google AI API.") + api_key: str | None = Field(None, description="API key for the Google AI API.", deprecated=True) base_url: str = "https://generativelanguage.googleapis.com" async def check_api_key(self): from letta.llm_api.google_ai_client import google_ai_check_valid_api_key_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None await google_ai_check_valid_api_key_async(api_key) async def list_llm_models_async(self): from letta.llm_api.google_ai_client import google_ai_get_model_list_async # Get and filter the model list - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key) model_options = [mo for mo in model_options if "generateContent" in mo["supportedGenerationMethods"]] model_options = [str(m["name"]) for m in model_options] @@ -64,7 +64,7 @@ class GoogleAIProvider(Provider): from letta.llm_api.google_ai_client import google_ai_get_model_list_async # TODO: use base_url instead - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None model_options = await google_ai_get_model_list_async(base_url=self.base_url, api_key=api_key) return self._list_embedding_models(model_options) @@ -98,7 +98,7 @@ class GoogleAIProvider(Provider): if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None return google_ai_get_model_context_window(self.base_url, api_key, model_name) async def get_model_context_window_async(self, model_name: str) -> int | None: @@ -107,5 +107,5 @@ class GoogleAIProvider(Provider): if model_name in LLM_MAX_TOKENS: return LLM_MAX_TOKENS[model_name] else: - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None return await google_ai_get_model_context_window_async(self.base_url, api_key, model_name) diff --git a/letta/schemas/providers/groq.py b/letta/schemas/providers/groq.py index 9945e4ff..23488c4b 100644 --- a/letta/schemas/providers/groq.py +++ b/letta/schemas/providers/groq.py @@ -11,12 +11,12 @@ class GroqProvider(OpenAIProvider): provider_type: Literal[ProviderType.groq] = Field(ProviderType.groq, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.groq.com/openai/v1" - api_key: str = Field(..., description="API key for the Groq API.") + api_key: str | None = Field(None, description="API key for the Groq API.", deprecated=True) async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await openai_get_model_list_async(self.base_url, api_key=api_key) configs = [] for model in response["data"]: diff --git a/letta/schemas/providers/mistral.py b/letta/schemas/providers/mistral.py index c4777eba..f174e381 100644 --- a/letta/schemas/providers/mistral.py +++ b/letta/schemas/providers/mistral.py @@ -10,7 +10,7 @@ from letta.schemas.providers.base import Provider class MistralProvider(Provider): provider_type: Literal[ProviderType.mistral] = Field(ProviderType.mistral, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the Mistral API.") + api_key: str | None = Field(None, description="API key for the Mistral API.", deprecated=True) base_url: str = "https://api.mistral.ai/v1" async def list_llm_models_async(self) -> list[LLMConfig]: @@ -18,7 +18,7 @@ class MistralProvider(Provider): # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... # See: https://openrouter.ai/docs/requests - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await mistral_get_model_list_async(self.base_url, api_key=api_key) assert "data" in response, f"Mistral model query response missing 'data' field: {response}" diff --git a/letta/schemas/providers/openai.py b/letta/schemas/providers/openai.py index 52ff323d..4feaefc1 100644 --- a/letta/schemas/providers/openai.py +++ b/letta/schemas/providers/openai.py @@ -19,14 +19,14 @@ DEFAULT_EMBEDDING_BATCH_SIZE = 1024 class OpenAIProvider(Provider): provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the OpenAI API.") + api_key: str | None = Field(None, description="API key for the OpenAI API.", deprecated=True) base_url: str = Field("https://api.openai.com/v1", description="Base URL for the OpenAI API.") async def check_api_key(self): from letta.llm_api.openai import openai_check_valid_api_key # TODO: DO NOT USE THIS - old code path # Decrypt API key before using - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None openai_check_valid_api_key(self.base_url, api_key) async def _get_models_async(self) -> list[dict]: @@ -40,7 +40,7 @@ class OpenAIProvider(Provider): extra_params = {"verbose": True} if "nebius.com" in self.base_url else None # Decrypt API key before using - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await openai_get_model_list_async( self.base_url, diff --git a/letta/schemas/providers/openrouter.py b/letta/schemas/providers/openrouter.py index 9e2a3052..5027b9f6 100644 --- a/letta/schemas/providers/openrouter.py +++ b/letta/schemas/providers/openrouter.py @@ -19,7 +19,7 @@ logger = get_logger(__name__) class OpenRouterProvider(OpenAIProvider): provider_type: Literal[ProviderType.openai] = Field(ProviderType.openai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the OpenRouter API.") + api_key: str | None = Field(None, description="API key for the OpenRouter API.", deprecated=True) base_url: str = Field("https://openrouter.ai/api/v1", description="Base URL for the OpenRouter API.") def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]: diff --git a/letta/schemas/providers/together.py b/letta/schemas/providers/together.py index 1229d2bd..2bf099b5 100644 --- a/letta/schemas/providers/together.py +++ b/letta/schemas/providers/together.py @@ -22,7 +22,7 @@ class TogetherProvider(OpenAIProvider): provider_type: Literal[ProviderType.together] = Field(ProviderType.together, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") base_url: str = "https://api.together.xyz/v1" - api_key: str = Field(..., description="API key for the Together API.") + api_key: str | None = Field(None, description="API key for the Together API.", deprecated=True) default_prompt_formatter: Optional[str] = Field( None, description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API." ) @@ -30,7 +30,7 @@ class TogetherProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None models = await openai_get_model_list_async(self.base_url, api_key=api_key) return self._list_llm_models(models) @@ -93,7 +93,7 @@ class TogetherProvider(OpenAIProvider): return configs async def check_api_key(self): - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None if not api_key: raise ValueError("No API key provided") diff --git a/letta/schemas/providers/xai.py b/letta/schemas/providers/xai.py index 4e9bfd57..61b92f17 100644 --- a/letta/schemas/providers/xai.py +++ b/letta/schemas/providers/xai.py @@ -27,7 +27,7 @@ class XAIProvider(OpenAIProvider): provider_type: Literal[ProviderType.xai] = Field(ProviderType.xai, description="The type of the provider.") provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") - api_key: str = Field(..., description="API key for the xAI/Grok API.") + api_key: str | None = Field(None, description="API key for the xAI/Grok API.", deprecated=True) base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") def get_model_context_window_size(self, model_name: str) -> int | None: @@ -38,7 +38,7 @@ class XAIProvider(OpenAIProvider): async def list_llm_models_async(self) -> list[LLMConfig]: from letta.llm_api.openai import openai_get_model_list_async - api_key = self.get_api_key_secret().get_plaintext() + api_key = self.api_key_enc.get_plaintext() if self.api_key_enc else None response = await openai_get_model_list_async(self.base_url, api_key=api_key) data = response.get("data", response) diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index f1539797..cd0d493c 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -122,9 +122,12 @@ async def check_existing_provider( provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor) # Create a ProviderCheck from the existing provider + api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None + access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None provider_check = ProviderCheck( provider_type=provider.provider_type, - api_key=provider.api_key, + api_key=api_key, + access_key=access_key, base_url=provider.base_url, ) diff --git a/letta/server/server.py b/letta/server/server.py index 20cd7f85..85accb80 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -213,7 +213,6 @@ class SyncServer(object): self._enabled_providers.append( OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) @@ -222,7 +221,6 @@ class SyncServer(object): self._enabled_providers.append( AnthropicProvider( name="anthropic", - api_key=model_settings.anthropic_api_key, api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key), ) ) @@ -231,7 +229,6 @@ class SyncServer(object): OllamaProvider( name="ollama", base_url=model_settings.ollama_base_url, - api_key=None, default_prompt_formatter=model_settings.default_prompt_formatter, ) ) @@ -239,7 +236,6 @@ class SyncServer(object): self._enabled_providers.append( GoogleAIProvider( name="google_ai", - api_key=model_settings.gemini_api_key, api_key_enc=Secret.from_plaintext(model_settings.gemini_api_key), ) ) @@ -256,7 +252,6 @@ class SyncServer(object): self._enabled_providers.append( AzureProvider( name="azure", - api_key=model_settings.azure_api_key, api_key_enc=Secret.from_plaintext(model_settings.azure_api_key), base_url=model_settings.azure_base_url, api_version=model_settings.azure_api_version, @@ -266,7 +261,6 @@ class SyncServer(object): self._enabled_providers.append( GroqProvider( name="groq", - api_key=model_settings.groq_api_key, api_key_enc=Secret.from_plaintext(model_settings.groq_api_key), ) ) @@ -274,7 +268,6 @@ class SyncServer(object): self._enabled_providers.append( TogetherProvider( name="together", - api_key=model_settings.together_api_key, api_key_enc=Secret.from_plaintext(model_settings.together_api_key), default_prompt_formatter=model_settings.default_prompt_formatter, ) @@ -313,7 +306,6 @@ class SyncServer(object): self._enabled_providers.append( DeepSeekProvider( name="deepseek", - api_key=model_settings.deepseek_api_key, api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key), ) ) @@ -321,7 +313,6 @@ class SyncServer(object): self._enabled_providers.append( XAIProvider( name="xai", - api_key=model_settings.xai_api_key, api_key_enc=Secret.from_plaintext(model_settings.xai_api_key), ) ) @@ -329,7 +320,6 @@ class SyncServer(object): self._enabled_providers.append( OpenRouterProvider( name=model_settings.openrouter_handle_base if model_settings.openrouter_handle_base else "openrouter", - api_key=model_settings.openrouter_api_key, api_key_enc=Secret.from_plaintext(model_settings.openrouter_api_key), ) ) diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index f99f7944..bbd66364 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -1,5 +1,6 @@ from typing import List, Optional, Tuple, Union +from letta.log import get_logger from letta.orm.provider import Provider as ProviderModel from letta.orm.provider_model import ProviderModel as ProviderModelORM from letta.otel.tracing import trace_method @@ -14,6 +15,8 @@ from letta.server.db import db_registry from letta.utils import enforce_types from letta.validators import raise_on_invalid_id +logger = get_logger(__name__) + class ProviderManager: @enforce_types @@ -56,6 +59,11 @@ class ProviderManager: # Create provider with the appropriate category provider_data = request.model_dump() + + # Unset deprecated api_key and access_key as to not write plaintext values, api_key_enc and access_key_enc will be set below + provider_data.pop("api_key", None) + provider_data.pop("access_key", None) + provider_data["provider_category"] = ProviderCategory.byok if is_byok else ProviderCategory.base provider = PydanticProvider(**provider_data) @@ -71,10 +79,10 @@ class ProviderManager: provider.resolve_identifier() # Explicitly populate encrypted fields from plaintext - if provider.api_key is not None: - provider.api_key_enc = Secret.from_plaintext(provider.api_key) - if provider.access_key is not None: - provider.access_key_enc = Secret.from_plaintext(provider.access_key) + if request.api_key is not None: + provider.api_key_enc = Secret.from_plaintext(request.api_key) + if request.access_key is not None: + provider.access_key_enc = Secret.from_plaintext(request.access_key) new_provider = ProviderModel(**provider.model_dump(to_orm=True, exclude_unset=True)) await new_provider.create_async(session, actor=actor) @@ -108,14 +116,10 @@ class ProviderManager: if existing_provider.api_key_enc: existing_secret = Secret.from_encrypted(existing_provider.api_key_enc) existing_api_key = existing_secret.get_plaintext() - elif existing_provider.api_key: - existing_api_key = existing_provider.api_key # Only re-encrypt if different if existing_api_key != update_data["api_key"]: existing_provider.api_key_enc = Secret.from_plaintext(update_data["api_key"]).get_encrypted() - # Keep plaintext for dual-write during migration - existing_provider.api_key = update_data["api_key"] # Remove from update_data since we set directly on existing_provider update_data.pop("api_key", None) @@ -129,14 +133,10 @@ class ProviderManager: if existing_provider.access_key_enc: existing_secret = Secret.from_encrypted(existing_provider.access_key_enc) existing_access_key = existing_secret.get_plaintext() - elif existing_provider.access_key: - existing_access_key = existing_provider.access_key # Only re-encrypt if different if existing_access_key != update_data["access_key"]: existing_provider.access_key_enc = Secret.from_plaintext(update_data["access_key"]).get_encrypted() - # Keep plaintext for dual-write during migration - existing_provider.access_key = update_data["access_key"] # Remove from update_data since we set directly on existing_provider update_data.pop("access_key", None) @@ -160,7 +160,15 @@ class ProviderManager: existing_provider = await ProviderModel.read_async( db_session=session, identifier=provider_id, actor=actor, check_is_deleted=True ) + existing_provider.api_key_enc = None + existing_provider.access_key_enc = None + + # Only accessing these deprecated fields to clear, which may trigger a warning existing_provider.api_key = None + existing_provider.access_key = None + + logger.info("Soft deleting provider with id: %s", provider_id) + await existing_provider.update_async(session, actor=actor) # Soft delete in provider table @@ -217,6 +225,11 @@ class ProviderManager: # Combine both lists all_providers = org_providers + global_providers + # Remove deprecated api_key and access_key fields from the response + for provider in all_providers: + provider.api_key = None + provider.access_key = None + return [provider.to_pydantic() for provider in all_providers] @enforce_types @@ -291,6 +304,9 @@ class ProviderManager: result = await session.execute(stmt) provider_model = result.scalar_one_or_none() if provider_model: + # Remove deprecated api_key and access_key fields from the response + provider_model.api_key = None + provider_model.access_key = None return provider_model.to_pydantic() else: from letta.orm.errors import NoResultFound @@ -309,8 +325,8 @@ class ProviderManager: providers = self.list_providers(name=provider_name, actor=actor) if providers: # Decrypt the API key before returning - api_key_secret = providers[0].get_api_key_secret() - return api_key_secret.get_plaintext() + api_key_secret = providers[0].api_key_enc + return api_key_secret.get_plaintext() if api_key_secret else None return None @enforce_types @@ -319,8 +335,8 @@ class ProviderManager: providers = await self.list_providers_async(name=provider_name, actor=actor) if providers: # Decrypt the API key before returning - api_key_secret = providers[0].get_api_key_secret() - return api_key_secret.get_plaintext() + api_key_secret = providers[0].api_key_enc + return api_key_secret.get_plaintext() if api_key_secret else None return None @enforce_types @@ -331,10 +347,10 @@ class ProviderManager: providers = await self.list_providers_async(name=provider_name, actor=actor) if providers: # Decrypt the credentials before returning - access_key_secret = providers[0].get_access_key_secret() - api_key_secret = providers[0].get_api_key_secret() - access_key = access_key_secret.get_plaintext() - secret_key = api_key_secret.get_plaintext() + access_key_secret = providers[0].access_key_enc + api_key_secret = providers[0].api_key_enc + access_key = access_key_secret.get_plaintext() if access_key_secret else None + secret_key = api_key_secret.get_plaintext() if api_key_secret else None region = providers[0].region return access_key, secret_key, region return None, None, None @@ -347,8 +363,8 @@ class ProviderManager: providers = self.list_providers(name=provider_name, actor=actor) if providers: # Decrypt the API key before returning - api_key_secret = providers[0].get_api_key_secret() - api_key = api_key_secret.get_plaintext() + api_key_secret = providers[0].api_key_enc + api_key = api_key_secret.get_plaintext() if api_key_secret else None base_url = providers[0].base_url api_version = providers[0].api_version return api_key, base_url, api_version @@ -362,8 +378,8 @@ class ProviderManager: providers = await self.list_providers_async(name=provider_name, actor=actor) if providers: # Decrypt the API key before returning - api_key_secret = providers[0].get_api_key_secret() - api_key = api_key_secret.get_plaintext() + api_key_secret = providers[0].api_key_enc + api_key = api_key_secret.get_plaintext() if api_key_secret else None base_url = providers[0].base_url api_version = providers[0].api_version return api_key, base_url, api_version @@ -375,16 +391,16 @@ class ProviderManager: provider = PydanticProvider( name=provider_check.provider_type.value, provider_type=provider_check.provider_type, - api_key=provider_check.api_key, + api_key_enc=Secret.from_plaintext(provider_check.api_key), provider_category=ProviderCategory.byok, - access_key=provider_check.access_key, # This contains the access key ID for Bedrock + access_key_enc=Secret.from_plaintext(provider_check.access_key) if provider_check.access_key else None, region=provider_check.region, base_url=provider_check.base_url, api_version=provider_check.api_version, ).cast_to_subtype() # TODO: add more string sanity checks here before we hit actual endpoints - if not provider.api_key: + if not provider.api_key_enc or not provider.api_key_enc.get_plaintext(): raise ValueError("API key is required!") await provider.check_api_key() @@ -423,15 +439,17 @@ class ProviderManager: return # Create provider instance with necessary parameters + api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None + access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None kwargs = { "name": provider.name, - "api_key": provider.api_key, + "api_key": api_key, "provider_category": provider.provider_category, } if provider.base_url: kwargs["base_url"] = provider.base_url - if provider.access_key: - kwargs["access_key"] = provider.access_key + if access_key: + kwargs["access_key"] = access_key if provider.region: kwargs["region"] = provider.region if provider.api_version: @@ -498,11 +516,13 @@ class ProviderManager: continue # Convert Provider to ProviderCreate + api_key = provider.api_key_enc.get_plaintext() if provider.api_key_enc else None + access_key = provider.access_key_enc.get_plaintext() if provider.access_key_enc else None provider_create = ProviderCreate( name=provider.name, provider_type=provider.provider_type, - api_key=provider.api_key or "", # ProviderCreate requires api_key, use empty string if None - access_key=provider.access_key, + api_key=api_key or "", # ProviderCreate requires api_key, use empty string if None + access_key=access_key, region=provider.region, base_url=provider.base_url, api_version=provider.api_version, diff --git a/tests/managers/test_provider_manager.py b/tests/managers/test_provider_manager.py index 61e2597f..2f8e96b0 100644 --- a/tests/managers/test_provider_manager.py +++ b/tests/managers/test_provider_manager.py @@ -73,8 +73,8 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user, assert created_provider.name == "test-openai-provider" assert created_provider.provider_type == ProviderType.openai - # Verify plaintext api_key is still accessible (dual-write during migration) - assert created_provider.api_key == "sk-test-plaintext-api-key-12345" + # Verify encrypted api_key can be decrypted + assert created_provider.api_key_enc.get_plaintext() == "sk-test-plaintext-api-key-12345" # Read directly from database to verify encryption async with db_registry.async_session() as session: @@ -84,14 +84,10 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user, actor=default_user, ) - # Verify plaintext column has the value (dual-write) - assert provider_orm.api_key == "sk-test-plaintext-api-key-12345" - - # Verify encrypted column is populated and different from plaintext + # Verify encrypted column is populated and decrypts correctly assert provider_orm.api_key_enc is not None - assert provider_orm.api_key_enc != "sk-test-plaintext-api-key-12345" - # Encrypted value should be base64-encoded and longer - assert len(provider_orm.api_key_enc) > len("sk-test-plaintext-api-key-12345") + decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() + assert decrypted == "sk-test-plaintext-api-key-12345" @pytest.mark.asyncio @@ -110,13 +106,8 @@ async def test_provider_read_decrypts_api_key(provider_manager, default_user, en # Read the provider back retrieved_provider = await provider_manager.get_provider_async(provider_id, actor=default_user) - # Verify the api_key is decrypted correctly - assert retrieved_provider.api_key == "sk-ant-test-key-67890" - - # Verify we can get the decrypted key through the secret getter - api_key_secret = retrieved_provider.get_api_key_secret() - assert isinstance(api_key_secret, Secret) - decrypted_key = api_key_secret.get_plaintext() + # Verify the api_key is decrypted correctly via api_key_enc + decrypted_key = retrieved_provider.api_key_enc.get_plaintext() assert decrypted_key == "sk-ant-test-key-67890" @@ -140,8 +131,8 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us updated_provider = await provider_manager.update_provider_async(provider_id, provider_update, actor=default_user) - # Verify the updated key is accessible - assert updated_provider.api_key == "gsk-updated-key-456" + # Verify the updated key is accessible via the encrypted field + assert updated_provider.api_key_enc.get_plaintext() == "gsk-updated-key-456" # Read from DB to verify new encrypted value async with db_registry.async_session() as session: @@ -151,11 +142,7 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us actor=default_user, ) - # Verify both columns are updated - assert provider_orm.api_key == "gsk-updated-key-456" assert provider_orm.api_key_enc is not None - - # Decrypt and verify decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() assert decrypted == "gsk-updated-key-456" @@ -174,9 +161,9 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) - # Verify both keys are accessible - assert created_provider.api_key == "secret-access-key-xyz" - assert created_provider.access_key == "access-key-id-abc" + # Verify both keys are accessible via encrypted fields + assert created_provider.api_key_enc.get_plaintext() == "secret-access-key-xyz" + assert created_provider.access_key_enc.get_plaintext() == "access-key-id-abc" # Read from DB to verify both are encrypted async with db_registry.async_session() as session: @@ -191,8 +178,8 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en assert provider_orm.access_key_enc is not None # Verify encrypted values are different from plaintext - assert provider_orm.api_key_enc != "secret-access-key-xyz" - assert provider_orm.access_key_enc != "access-key-id-abc" + assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "secret-access-key-xyz" + assert Secret.from_encrypted(provider_orm.access_key_enc).get_plaintext() == "access-key-id-abc" # Test the manager method for getting Bedrock credentials access_key, secret_key, region = await provider_manager.get_bedrock_credentials_async("test-bedrock-provider", actor=default_user) @@ -215,7 +202,7 @@ async def test_provider_secret_not_exposed_in_logs(provider_manager, default_use created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) # Get the Secret object - api_key_secret = created_provider.get_api_key_secret() + api_key_secret = created_provider.api_key_enc # Verify string representation doesn't expose the key secret_str = str(api_key_secret) @@ -240,19 +227,19 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_ # Step 1: Create provider (Pydantic → ORM) created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user) - original_api_key = created_provider.api_key + original_api_key = created_provider.api_key_enc.get_plaintext() # Step 2: Read provider back (ORM → Pydantic) retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user) # Verify data integrity - assert retrieved_provider.api_key == original_api_key + assert retrieved_provider.api_key_enc.get_plaintext() == original_api_key assert retrieved_provider.name == "test-roundtrip-provider" assert retrieved_provider.provider_type == ProviderType.openai assert retrieved_provider.base_url == "https://api.openai.com/v1" # Verify Secret object works correctly - api_key_secret = retrieved_provider.get_api_key_secret() + api_key_secret = retrieved_provider.api_key_enc assert api_key_secret.get_plaintext() == original_api_key # Step 3: Convert to ORM again (should preserve encrypted field) @@ -261,7 +248,7 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_ # Verify encrypted field is in the ORM data assert "api_key_enc" in orm_data assert orm_data["api_key_enc"] is not None - assert orm_data["api_key"] == original_api_key + assert Secret.from_encrypted(orm_data["api_key_enc"]).get_plaintext() == original_api_key @pytest.mark.asyncio @@ -290,8 +277,8 @@ async def test_provider_with_none_api_key(provider_manager, default_user, encryp ) # api_key_enc should handle empty string appropriately - # (encrypt empty string or store as None) - assert provider_orm.api_key_enc is not None or provider_orm.api_key == "" + assert provider_orm.api_key_enc is not None + assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "" @pytest.mark.asyncio @@ -316,9 +303,7 @@ async def test_list_providers_decrypts_all(provider_manager, default_user, encry # Verify all are decrypted correctly assert len(test_providers) == 3 for i, provider in enumerate(sorted(test_providers, key=lambda p: p.name)): - assert provider.api_key == f"sk-key-{i}" - # Verify Secret getter works - secret = provider.get_api_key_secret() + secret = provider.api_key_enc assert secret.get_plaintext() == f"sk-key-{i}" diff --git a/tests/test_providers.py b/tests/test_providers.py index e1062d6b..9671ed22 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -16,13 +16,14 @@ from letta.schemas.providers import ( TogetherProvider, VLLMProvider, ) +from letta.schemas.secret import Secret from letta.settings import model_settings def test_openai(): provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) models = provider.list_llm_models() @@ -38,7 +39,7 @@ def test_openai(): async def test_openai_async(): provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) models = await provider.list_llm_models_async() @@ -54,7 +55,7 @@ async def test_openai_async(): async def test_anthropic(): provider = AnthropicProvider( name="anthropic", - api_key=model_settings.anthropic_api_key, + api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key), ) models = await provider.list_llm_models_async() assert len(models) > 0 @@ -67,7 +68,7 @@ async def test_googleai(): assert api_key is not None provider = GoogleAIProvider( name="google_ai", - api_key=api_key, + api_key_enc=Secret.from_plaintext(api_key), ) models = await provider.list_llm_models_async() assert len(models) > 0 @@ -97,7 +98,7 @@ async def test_google_vertex(): @pytest.mark.skipif(model_settings.deepseek_api_key is None, reason="Only run if DEEPSEEK_API_KEY is set.") @pytest.mark.asyncio async def test_deepseek(): - provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key) + provider = DeepSeekProvider(name="deepseek", api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key)) models = await provider.list_llm_models_async() assert len(models) > 0 assert models[0].handle == f"{provider.name}/{models[0].model}" @@ -108,7 +109,7 @@ async def test_deepseek(): async def test_groq(): provider = GroqProvider( name="groq", - api_key=model_settings.groq_api_key, + api_key_enc=Secret.from_plaintext(model_settings.groq_api_key), ) models = await provider.list_llm_models_async() assert len(models) > 0 @@ -120,7 +121,7 @@ async def test_groq(): async def test_azure(): provider = AzureProvider( name="azure", - api_key=model_settings.azure_api_key, + api_key_enc=Secret.from_plaintext(model_settings.azure_api_key), base_url=model_settings.azure_base_url, api_version=model_settings.azure_api_version, ) @@ -138,7 +139,7 @@ async def test_azure(): async def test_together(): provider = TogetherProvider( name="together", - api_key=model_settings.together_api_key, + api_key_enc=Secret.from_plaintext(model_settings.together_api_key), default_prompt_formatter=model_settings.default_prompt_formatter, ) models = await provider.list_llm_models_async() @@ -161,7 +162,6 @@ async def test_ollama(): provider = OllamaProvider( name="ollama", base_url=model_settings.ollama_base_url, - api_key=None, default_prompt_formatter=model_settings.default_prompt_formatter, ) models = await provider.list_llm_models_async() @@ -203,7 +203,7 @@ async def test_vllm(): async def test_custom_anthropic(): provider = AnthropicProvider( name="custom_anthropic", - api_key=model_settings.anthropic_api_key, + api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key), ) models = await provider.list_llm_models_async() assert len(models) > 0 @@ -214,7 +214,7 @@ def test_provider_context_window(): """Test that providers implement context window methods correctly.""" provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) @@ -230,7 +230,7 @@ async def test_provider_context_window_async(): """Test that providers implement async context window methods correctly.""" provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) @@ -244,7 +244,7 @@ def test_provider_handle_generation(): """Test that providers generate handles correctly.""" provider = OpenAIProvider( name="test_openai", - api_key="test_key", + api_key_enc=Secret.from_plaintext("test_key"), base_url="https://api.openai.com/v1", ) @@ -266,14 +266,14 @@ def test_provider_casting(): name="test_provider", provider_type=ProviderType.openai, provider_category=ProviderCategory.base, - api_key="test_key", + api_key_enc=Secret.from_plaintext("test_key"), base_url="https://api.openai.com/v1", ) cast_provider = base_provider.cast_to_subtype() assert isinstance(cast_provider, OpenAIProvider) assert cast_provider.name == "test_provider" - assert cast_provider.api_key == "test_key" + assert cast_provider.api_key_enc.get_plaintext() == "test_key" @pytest.mark.asyncio @@ -281,7 +281,7 @@ async def test_provider_embedding_models_consistency(): """Test that providers return consistent embedding model formats.""" provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, ) @@ -301,7 +301,7 @@ async def test_provider_llm_models_consistency(): """Test that providers return consistent LLM model formats.""" provider = OpenAIProvider( name="openai", - api_key=model_settings.openai_api_key, + api_key_enc=Secret.from_plaintext(model_settings.openai_api_key), base_url=model_settings.openai_api_base, )