fix: update base provider to only handle _enc fields (#6591)

* base

* update

* another pass

* fix

* generate

* fix test

* don't set on create

* last fixes

---------

Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
jnjpng
2025-12-10 15:04:47 -08:00
committed by Caren Thomas
parent 99126c6283
commit 3221ed8a14
19 changed files with 166 additions and 169 deletions

View File

@@ -73,8 +73,8 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user,
assert created_provider.name == "test-openai-provider"
assert created_provider.provider_type == ProviderType.openai
# Verify plaintext api_key is still accessible (dual-write during migration)
assert created_provider.api_key == "sk-test-plaintext-api-key-12345"
# Verify encrypted api_key can be decrypted
assert created_provider.api_key_enc.get_plaintext() == "sk-test-plaintext-api-key-12345"
# Read directly from database to verify encryption
async with db_registry.async_session() as session:
@@ -84,14 +84,10 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user,
actor=default_user,
)
# Verify plaintext column has the value (dual-write)
assert provider_orm.api_key == "sk-test-plaintext-api-key-12345"
# Verify encrypted column is populated and different from plaintext
# Verify encrypted column is populated and decrypts correctly
assert provider_orm.api_key_enc is not None
assert provider_orm.api_key_enc != "sk-test-plaintext-api-key-12345"
# Encrypted value should be base64-encoded and longer
assert len(provider_orm.api_key_enc) > len("sk-test-plaintext-api-key-12345")
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
assert decrypted == "sk-test-plaintext-api-key-12345"
@pytest.mark.asyncio
@@ -110,13 +106,8 @@ async def test_provider_read_decrypts_api_key(provider_manager, default_user, en
# Read the provider back
retrieved_provider = await provider_manager.get_provider_async(provider_id, actor=default_user)
# Verify the api_key is decrypted correctly
assert retrieved_provider.api_key == "sk-ant-test-key-67890"
# Verify we can get the decrypted key through the secret getter
api_key_secret = retrieved_provider.get_api_key_secret()
assert isinstance(api_key_secret, Secret)
decrypted_key = api_key_secret.get_plaintext()
# Verify the api_key is decrypted correctly via api_key_enc
decrypted_key = retrieved_provider.api_key_enc.get_plaintext()
assert decrypted_key == "sk-ant-test-key-67890"
@@ -140,8 +131,8 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us
updated_provider = await provider_manager.update_provider_async(provider_id, provider_update, actor=default_user)
# Verify the updated key is accessible
assert updated_provider.api_key == "gsk-updated-key-456"
# Verify the updated key is accessible via the encrypted field
assert updated_provider.api_key_enc.get_plaintext() == "gsk-updated-key-456"
# Read from DB to verify new encrypted value
async with db_registry.async_session() as session:
@@ -151,11 +142,7 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us
actor=default_user,
)
# Verify both columns are updated
assert provider_orm.api_key == "gsk-updated-key-456"
assert provider_orm.api_key_enc is not None
# Decrypt and verify
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
assert decrypted == "gsk-updated-key-456"
@@ -174,9 +161,9 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Verify both keys are accessible
assert created_provider.api_key == "secret-access-key-xyz"
assert created_provider.access_key == "access-key-id-abc"
# Verify both keys are accessible via encrypted fields
assert created_provider.api_key_enc.get_plaintext() == "secret-access-key-xyz"
assert created_provider.access_key_enc.get_plaintext() == "access-key-id-abc"
# Read from DB to verify both are encrypted
async with db_registry.async_session() as session:
@@ -191,8 +178,8 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en
assert provider_orm.access_key_enc is not None
# Verify encrypted values are different from plaintext
assert provider_orm.api_key_enc != "secret-access-key-xyz"
assert provider_orm.access_key_enc != "access-key-id-abc"
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "secret-access-key-xyz"
assert Secret.from_encrypted(provider_orm.access_key_enc).get_plaintext() == "access-key-id-abc"
# Test the manager method for getting Bedrock credentials
access_key, secret_key, region = await provider_manager.get_bedrock_credentials_async("test-bedrock-provider", actor=default_user)
@@ -215,7 +202,7 @@ async def test_provider_secret_not_exposed_in_logs(provider_manager, default_use
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
# Get the Secret object
api_key_secret = created_provider.get_api_key_secret()
api_key_secret = created_provider.api_key_enc
# Verify string representation doesn't expose the key
secret_str = str(api_key_secret)
@@ -240,19 +227,19 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_
# Step 1: Create provider (Pydantic → ORM)
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
original_api_key = created_provider.api_key
original_api_key = created_provider.api_key_enc.get_plaintext()
# Step 2: Read provider back (ORM → Pydantic)
retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user)
# Verify data integrity
assert retrieved_provider.api_key == original_api_key
assert retrieved_provider.api_key_enc.get_plaintext() == original_api_key
assert retrieved_provider.name == "test-roundtrip-provider"
assert retrieved_provider.provider_type == ProviderType.openai
assert retrieved_provider.base_url == "https://api.openai.com/v1"
# Verify Secret object works correctly
api_key_secret = retrieved_provider.get_api_key_secret()
api_key_secret = retrieved_provider.api_key_enc
assert api_key_secret.get_plaintext() == original_api_key
# Step 3: Convert to ORM again (should preserve encrypted field)
@@ -261,7 +248,7 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_
# Verify encrypted field is in the ORM data
assert "api_key_enc" in orm_data
assert orm_data["api_key_enc"] is not None
assert orm_data["api_key"] == original_api_key
assert Secret.from_encrypted(orm_data["api_key_enc"]).get_plaintext() == original_api_key
@pytest.mark.asyncio
@@ -290,8 +277,8 @@ async def test_provider_with_none_api_key(provider_manager, default_user, encryp
)
# api_key_enc should handle empty string appropriately
# (encrypt empty string or store as None)
assert provider_orm.api_key_enc is not None or provider_orm.api_key == ""
assert provider_orm.api_key_enc is not None
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == ""
@pytest.mark.asyncio
@@ -316,9 +303,7 @@ async def test_list_providers_decrypts_all(provider_manager, default_user, encry
# Verify all are decrypted correctly
assert len(test_providers) == 3
for i, provider in enumerate(sorted(test_providers, key=lambda p: p.name)):
assert provider.api_key == f"sk-key-{i}"
# Verify Secret getter works
secret = provider.get_api_key_secret()
secret = provider.api_key_enc
assert secret.get_plaintext() == f"sk-key-{i}"

View File

@@ -16,13 +16,14 @@ from letta.schemas.providers import (
TogetherProvider,
VLLMProvider,
)
from letta.schemas.secret import Secret
from letta.settings import model_settings
def test_openai():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)
models = provider.list_llm_models()
@@ -38,7 +39,7 @@ def test_openai():
async def test_openai_async():
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)
models = await provider.list_llm_models_async()
@@ -54,7 +55,7 @@ async def test_openai_async():
async def test_anthropic():
provider = AnthropicProvider(
name="anthropic",
api_key=model_settings.anthropic_api_key,
api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key),
)
models = await provider.list_llm_models_async()
assert len(models) > 0
@@ -67,7 +68,7 @@ async def test_googleai():
assert api_key is not None
provider = GoogleAIProvider(
name="google_ai",
api_key=api_key,
api_key_enc=Secret.from_plaintext(api_key),
)
models = await provider.list_llm_models_async()
assert len(models) > 0
@@ -97,7 +98,7 @@ async def test_google_vertex():
@pytest.mark.skipif(model_settings.deepseek_api_key is None, reason="Only run if DEEPSEEK_API_KEY is set.")
@pytest.mark.asyncio
async def test_deepseek():
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
provider = DeepSeekProvider(name="deepseek", api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key))
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
@@ -108,7 +109,7 @@ async def test_deepseek():
async def test_groq():
provider = GroqProvider(
name="groq",
api_key=model_settings.groq_api_key,
api_key_enc=Secret.from_plaintext(model_settings.groq_api_key),
)
models = await provider.list_llm_models_async()
assert len(models) > 0
@@ -120,7 +121,7 @@ async def test_groq():
async def test_azure():
provider = AzureProvider(
name="azure",
api_key=model_settings.azure_api_key,
api_key_enc=Secret.from_plaintext(model_settings.azure_api_key),
base_url=model_settings.azure_base_url,
api_version=model_settings.azure_api_version,
)
@@ -138,7 +139,7 @@ async def test_azure():
async def test_together():
provider = TogetherProvider(
name="together",
api_key=model_settings.together_api_key,
api_key_enc=Secret.from_plaintext(model_settings.together_api_key),
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = await provider.list_llm_models_async()
@@ -161,7 +162,6 @@ async def test_ollama():
provider = OllamaProvider(
name="ollama",
base_url=model_settings.ollama_base_url,
api_key=None,
default_prompt_formatter=model_settings.default_prompt_formatter,
)
models = await provider.list_llm_models_async()
@@ -203,7 +203,7 @@ async def test_vllm():
async def test_custom_anthropic():
provider = AnthropicProvider(
name="custom_anthropic",
api_key=model_settings.anthropic_api_key,
api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key),
)
models = await provider.list_llm_models_async()
assert len(models) > 0
@@ -214,7 +214,7 @@ def test_provider_context_window():
"""Test that providers implement context window methods correctly."""
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)
@@ -230,7 +230,7 @@ async def test_provider_context_window_async():
"""Test that providers implement async context window methods correctly."""
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)
@@ -244,7 +244,7 @@ def test_provider_handle_generation():
"""Test that providers generate handles correctly."""
provider = OpenAIProvider(
name="test_openai",
api_key="test_key",
api_key_enc=Secret.from_plaintext("test_key"),
base_url="https://api.openai.com/v1",
)
@@ -266,14 +266,14 @@ def test_provider_casting():
name="test_provider",
provider_type=ProviderType.openai,
provider_category=ProviderCategory.base,
api_key="test_key",
api_key_enc=Secret.from_plaintext("test_key"),
base_url="https://api.openai.com/v1",
)
cast_provider = base_provider.cast_to_subtype()
assert isinstance(cast_provider, OpenAIProvider)
assert cast_provider.name == "test_provider"
assert cast_provider.api_key == "test_key"
assert cast_provider.api_key_enc.get_plaintext() == "test_key"
@pytest.mark.asyncio
@@ -281,7 +281,7 @@ async def test_provider_embedding_models_consistency():
"""Test that providers return consistent embedding model formats."""
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)
@@ -301,7 +301,7 @@ async def test_provider_llm_models_consistency():
"""Test that providers return consistent LLM model formats."""
provider = OpenAIProvider(
name="openai",
api_key=model_settings.openai_api_key,
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
base_url=model_settings.openai_api_base,
)