fix: update base provider to only handle _enc fields (#6591)
* base * update * another pass * fix * generate * fix test * don't set on create * last fixes --------- Co-authored-by: Letta Bot <noreply@letta.com>
This commit is contained in:
@@ -73,8 +73,8 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user,
|
||||
assert created_provider.name == "test-openai-provider"
|
||||
assert created_provider.provider_type == ProviderType.openai
|
||||
|
||||
# Verify plaintext api_key is still accessible (dual-write during migration)
|
||||
assert created_provider.api_key == "sk-test-plaintext-api-key-12345"
|
||||
# Verify encrypted api_key can be decrypted
|
||||
assert created_provider.api_key_enc.get_plaintext() == "sk-test-plaintext-api-key-12345"
|
||||
|
||||
# Read directly from database to verify encryption
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -84,14 +84,10 @@ async def test_provider_create_encrypts_api_key(provider_manager, default_user,
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify plaintext column has the value (dual-write)
|
||||
assert provider_orm.api_key == "sk-test-plaintext-api-key-12345"
|
||||
|
||||
# Verify encrypted column is populated and different from plaintext
|
||||
# Verify encrypted column is populated and decrypts correctly
|
||||
assert provider_orm.api_key_enc is not None
|
||||
assert provider_orm.api_key_enc != "sk-test-plaintext-api-key-12345"
|
||||
# Encrypted value should be base64-encoded and longer
|
||||
assert len(provider_orm.api_key_enc) > len("sk-test-plaintext-api-key-12345")
|
||||
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
|
||||
assert decrypted == "sk-test-plaintext-api-key-12345"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -110,13 +106,8 @@ async def test_provider_read_decrypts_api_key(provider_manager, default_user, en
|
||||
# Read the provider back
|
||||
retrieved_provider = await provider_manager.get_provider_async(provider_id, actor=default_user)
|
||||
|
||||
# Verify the api_key is decrypted correctly
|
||||
assert retrieved_provider.api_key == "sk-ant-test-key-67890"
|
||||
|
||||
# Verify we can get the decrypted key through the secret getter
|
||||
api_key_secret = retrieved_provider.get_api_key_secret()
|
||||
assert isinstance(api_key_secret, Secret)
|
||||
decrypted_key = api_key_secret.get_plaintext()
|
||||
# Verify the api_key is decrypted correctly via api_key_enc
|
||||
decrypted_key = retrieved_provider.api_key_enc.get_plaintext()
|
||||
assert decrypted_key == "sk-ant-test-key-67890"
|
||||
|
||||
|
||||
@@ -140,8 +131,8 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us
|
||||
|
||||
updated_provider = await provider_manager.update_provider_async(provider_id, provider_update, actor=default_user)
|
||||
|
||||
# Verify the updated key is accessible
|
||||
assert updated_provider.api_key == "gsk-updated-key-456"
|
||||
# Verify the updated key is accessible via the encrypted field
|
||||
assert updated_provider.api_key_enc.get_plaintext() == "gsk-updated-key-456"
|
||||
|
||||
# Read from DB to verify new encrypted value
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -151,11 +142,7 @@ async def test_provider_update_encrypts_new_api_key(provider_manager, default_us
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Verify both columns are updated
|
||||
assert provider_orm.api_key == "gsk-updated-key-456"
|
||||
assert provider_orm.api_key_enc is not None
|
||||
|
||||
# Decrypt and verify
|
||||
decrypted = Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext()
|
||||
assert decrypted == "gsk-updated-key-456"
|
||||
|
||||
@@ -174,9 +161,9 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en
|
||||
|
||||
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
|
||||
|
||||
# Verify both keys are accessible
|
||||
assert created_provider.api_key == "secret-access-key-xyz"
|
||||
assert created_provider.access_key == "access-key-id-abc"
|
||||
# Verify both keys are accessible via encrypted fields
|
||||
assert created_provider.api_key_enc.get_plaintext() == "secret-access-key-xyz"
|
||||
assert created_provider.access_key_enc.get_plaintext() == "access-key-id-abc"
|
||||
|
||||
# Read from DB to verify both are encrypted
|
||||
async with db_registry.async_session() as session:
|
||||
@@ -191,8 +178,8 @@ async def test_bedrock_credentials_encryption(provider_manager, default_user, en
|
||||
assert provider_orm.access_key_enc is not None
|
||||
|
||||
# Verify encrypted values are different from plaintext
|
||||
assert provider_orm.api_key_enc != "secret-access-key-xyz"
|
||||
assert provider_orm.access_key_enc != "access-key-id-abc"
|
||||
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == "secret-access-key-xyz"
|
||||
assert Secret.from_encrypted(provider_orm.access_key_enc).get_plaintext() == "access-key-id-abc"
|
||||
|
||||
# Test the manager method for getting Bedrock credentials
|
||||
access_key, secret_key, region = await provider_manager.get_bedrock_credentials_async("test-bedrock-provider", actor=default_user)
|
||||
@@ -215,7 +202,7 @@ async def test_provider_secret_not_exposed_in_logs(provider_manager, default_use
|
||||
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
|
||||
|
||||
# Get the Secret object
|
||||
api_key_secret = created_provider.get_api_key_secret()
|
||||
api_key_secret = created_provider.api_key_enc
|
||||
|
||||
# Verify string representation doesn't expose the key
|
||||
secret_str = str(api_key_secret)
|
||||
@@ -240,19 +227,19 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_
|
||||
|
||||
# Step 1: Create provider (Pydantic → ORM)
|
||||
created_provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
|
||||
original_api_key = created_provider.api_key
|
||||
original_api_key = created_provider.api_key_enc.get_plaintext()
|
||||
|
||||
# Step 2: Read provider back (ORM → Pydantic)
|
||||
retrieved_provider = await provider_manager.get_provider_async(created_provider.id, actor=default_user)
|
||||
|
||||
# Verify data integrity
|
||||
assert retrieved_provider.api_key == original_api_key
|
||||
assert retrieved_provider.api_key_enc.get_plaintext() == original_api_key
|
||||
assert retrieved_provider.name == "test-roundtrip-provider"
|
||||
assert retrieved_provider.provider_type == ProviderType.openai
|
||||
assert retrieved_provider.base_url == "https://api.openai.com/v1"
|
||||
|
||||
# Verify Secret object works correctly
|
||||
api_key_secret = retrieved_provider.get_api_key_secret()
|
||||
api_key_secret = retrieved_provider.api_key_enc
|
||||
assert api_key_secret.get_plaintext() == original_api_key
|
||||
|
||||
# Step 3: Convert to ORM again (should preserve encrypted field)
|
||||
@@ -261,7 +248,7 @@ async def test_provider_pydantic_to_orm_serialization(provider_manager, default_
|
||||
# Verify encrypted field is in the ORM data
|
||||
assert "api_key_enc" in orm_data
|
||||
assert orm_data["api_key_enc"] is not None
|
||||
assert orm_data["api_key"] == original_api_key
|
||||
assert Secret.from_encrypted(orm_data["api_key_enc"]).get_plaintext() == original_api_key
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -290,8 +277,8 @@ async def test_provider_with_none_api_key(provider_manager, default_user, encryp
|
||||
)
|
||||
|
||||
# api_key_enc should handle empty string appropriately
|
||||
# (encrypt empty string or store as None)
|
||||
assert provider_orm.api_key_enc is not None or provider_orm.api_key == ""
|
||||
assert provider_orm.api_key_enc is not None
|
||||
assert Secret.from_encrypted(provider_orm.api_key_enc).get_plaintext() == ""
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -316,9 +303,7 @@ async def test_list_providers_decrypts_all(provider_manager, default_user, encry
|
||||
# Verify all are decrypted correctly
|
||||
assert len(test_providers) == 3
|
||||
for i, provider in enumerate(sorted(test_providers, key=lambda p: p.name)):
|
||||
assert provider.api_key == f"sk-key-{i}"
|
||||
# Verify Secret getter works
|
||||
secret = provider.get_api_key_secret()
|
||||
secret = provider.api_key_enc
|
||||
assert secret.get_plaintext() == f"sk-key-{i}"
|
||||
|
||||
|
||||
|
||||
@@ -16,13 +16,14 @@ from letta.schemas.providers import (
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
)
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
def test_openai():
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = provider.list_llm_models()
|
||||
@@ -38,7 +39,7 @@ def test_openai():
|
||||
async def test_openai_async():
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
@@ -54,7 +55,7 @@ async def test_openai_async():
|
||||
async def test_anthropic():
|
||||
provider = AnthropicProvider(
|
||||
name="anthropic",
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key),
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
@@ -67,7 +68,7 @@ async def test_googleai():
|
||||
assert api_key is not None
|
||||
provider = GoogleAIProvider(
|
||||
name="google_ai",
|
||||
api_key=api_key,
|
||||
api_key_enc=Secret.from_plaintext(api_key),
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
@@ -97,7 +98,7 @@ async def test_google_vertex():
|
||||
@pytest.mark.skipif(model_settings.deepseek_api_key is None, reason="Only run if DEEPSEEK_API_KEY is set.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_deepseek():
|
||||
provider = DeepSeekProvider(name="deepseek", api_key=model_settings.deepseek_api_key)
|
||||
provider = DeepSeekProvider(name="deepseek", api_key_enc=Secret.from_plaintext(model_settings.deepseek_api_key))
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
@@ -108,7 +109,7 @@ async def test_deepseek():
|
||||
async def test_groq():
|
||||
provider = GroqProvider(
|
||||
name="groq",
|
||||
api_key=model_settings.groq_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.groq_api_key),
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
@@ -120,7 +121,7 @@ async def test_groq():
|
||||
async def test_azure():
|
||||
provider = AzureProvider(
|
||||
name="azure",
|
||||
api_key=model_settings.azure_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.azure_api_key),
|
||||
base_url=model_settings.azure_base_url,
|
||||
api_version=model_settings.azure_api_version,
|
||||
)
|
||||
@@ -138,7 +139,7 @@ async def test_azure():
|
||||
async def test_together():
|
||||
provider = TogetherProvider(
|
||||
name="together",
|
||||
api_key=model_settings.together_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.together_api_key),
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
@@ -161,7 +162,6 @@ async def test_ollama():
|
||||
provider = OllamaProvider(
|
||||
name="ollama",
|
||||
base_url=model_settings.ollama_base_url,
|
||||
api_key=None,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
@@ -203,7 +203,7 @@ async def test_vllm():
|
||||
async def test_custom_anthropic():
|
||||
provider = AnthropicProvider(
|
||||
name="custom_anthropic",
|
||||
api_key=model_settings.anthropic_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.anthropic_api_key),
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
@@ -214,7 +214,7 @@ def test_provider_context_window():
|
||||
"""Test that providers implement context window methods correctly."""
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
|
||||
@@ -230,7 +230,7 @@ async def test_provider_context_window_async():
|
||||
"""Test that providers implement async context window methods correctly."""
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
|
||||
@@ -244,7 +244,7 @@ def test_provider_handle_generation():
|
||||
"""Test that providers generate handles correctly."""
|
||||
provider = OpenAIProvider(
|
||||
name="test_openai",
|
||||
api_key="test_key",
|
||||
api_key_enc=Secret.from_plaintext("test_key"),
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
|
||||
@@ -266,14 +266,14 @@ def test_provider_casting():
|
||||
name="test_provider",
|
||||
provider_type=ProviderType.openai,
|
||||
provider_category=ProviderCategory.base,
|
||||
api_key="test_key",
|
||||
api_key_enc=Secret.from_plaintext("test_key"),
|
||||
base_url="https://api.openai.com/v1",
|
||||
)
|
||||
|
||||
cast_provider = base_provider.cast_to_subtype()
|
||||
assert isinstance(cast_provider, OpenAIProvider)
|
||||
assert cast_provider.name == "test_provider"
|
||||
assert cast_provider.api_key == "test_key"
|
||||
assert cast_provider.api_key_enc.get_plaintext() == "test_key"
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@@ -281,7 +281,7 @@ async def test_provider_embedding_models_consistency():
|
||||
"""Test that providers return consistent embedding model formats."""
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
|
||||
@@ -301,7 +301,7 @@ async def test_provider_llm_models_consistency():
|
||||
"""Test that providers return consistent LLM model formats."""
|
||||
provider = OpenAIProvider(
|
||||
name="openai",
|
||||
api_key=model_settings.openai_api_key,
|
||||
api_key_enc=Secret.from_plaintext(model_settings.openai_api_key),
|
||||
base_url=model_settings.openai_api_base,
|
||||
)
|
||||
|
||||
|
||||
Reference in New Issue
Block a user