feat: create model table to persist available models (#5835)
--------- Co-authored-by: Ari Webb <arijwebb@gmail.com> Co-authored-by: Ari Webb <ari@letta.com>
This commit is contained in:
committed by
Caren Thomas
parent
f36845b485
commit
982501f6fa
@@ -320,3 +320,182 @@ async def test_list_providers_decrypts_all(provider_manager, default_user, encry
|
||||
# Verify Secret getter works
|
||||
secret = provider.get_api_key_secret()
|
||||
assert secret.get_plaintext() == f"sk-key-{i}"
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Handle to Config Conversion Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_handle_to_llm_config_conversion(provider_manager, default_user):
|
||||
"""Test that handle to LLMConfig conversion works correctly with database lookup."""
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# Create a test provider
|
||||
provider_create = ProviderCreate(
|
||||
name="test-handle-provider", provider_type=ProviderType.openai, api_key="sk-test-handle-key", base_url="https://api.openai.com/v1"
|
||||
)
|
||||
provider = await provider_manager.create_provider_async(provider_create, actor=default_user)
|
||||
|
||||
# Sync some test models
|
||||
llm_models = [
|
||||
LLMConfig(
|
||||
model="gpt-4",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8192,
|
||||
handle="test-handle-provider/gpt-4",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
),
|
||||
LLMConfig(
|
||||
model="gpt-3.5-turbo",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=4096,
|
||||
handle="test-handle-provider/gpt-3.5-turbo",
|
||||
provider_name=provider.name,
|
||||
provider_category=ProviderCategory.base,
|
||||
),
|
||||
]
|
||||
|
||||
embedding_models = [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle="test-handle-provider/text-embedding-ada-002",
|
||||
)
|
||||
]
|
||||
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=default_user.organization_id
|
||||
)
|
||||
|
||||
# Test LLM config from handle
|
||||
llm_config = await provider_manager.get_llm_config_from_handle(handle="test-handle-provider/gpt-4", actor=default_user)
|
||||
|
||||
# Verify the returned config
|
||||
assert llm_config.model == "gpt-4"
|
||||
assert llm_config.handle == "test-handle-provider/gpt-4"
|
||||
assert llm_config.context_window == 8192
|
||||
assert llm_config.model_endpoint == "https://api.openai.com/v1"
|
||||
assert llm_config.provider_name == "test-handle-provider"
|
||||
|
||||
# Test embedding config from handle
|
||||
embedding_config = await provider_manager.get_embedding_config_from_handle(
|
||||
handle="test-handle-provider/text-embedding-ada-002", actor=default_user
|
||||
)
|
||||
|
||||
# Verify the returned config
|
||||
assert embedding_config.embedding_model == "text-embedding-ada-002"
|
||||
assert embedding_config.handle == "test-handle-provider/text-embedding-ada-002"
|
||||
assert embedding_config.embedding_dim == 1536
|
||||
assert embedding_config.embedding_chunk_size == 300
|
||||
assert embedding_config.embedding_endpoint == "https://api.openai.com/v1"
|
||||
|
||||
# Test context window limit override would be done at server level
|
||||
# The provider_manager method doesn't support context_window_limit directly
|
||||
|
||||
# Test error handling for non-existent handle
|
||||
with pytest.raises(NoResultFound):
|
||||
await provider_manager.get_llm_config_from_handle(handle="nonexistent/model", actor=default_user)
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_auto_syncs_models(provider_manager, default_user, monkeypatch):
|
||||
"""Test that creating a BYOK provider attempts to sync its models."""
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# Mock the list_llm_models_async method
|
||||
async def mock_list_llm():
|
||||
return [
|
||||
LLMConfig(
|
||||
model="gpt-4o",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=128000,
|
||||
handle="openai/gpt-4o",
|
||||
provider_name="openai",
|
||||
provider_category=ProviderCategory.base,
|
||||
),
|
||||
LLMConfig(
|
||||
model="gpt-4",
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
context_window=8192,
|
||||
handle="openai/gpt-4",
|
||||
provider_name="openai",
|
||||
provider_category=ProviderCategory.base,
|
||||
),
|
||||
]
|
||||
|
||||
# Mock the list_embedding_models_async method
|
||||
async def mock_list_embedding():
|
||||
return [
|
||||
EmbeddingConfig(
|
||||
embedding_model="text-embedding-ada-002",
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
embedding_chunk_size=300,
|
||||
handle="openai/text-embedding-ada-002",
|
||||
)
|
||||
]
|
||||
|
||||
# Mock the _sync_default_models_for_provider method directly
|
||||
async def mock_sync(provider, actor):
|
||||
# Get mock models and update them for this provider
|
||||
llm_models = await mock_list_llm()
|
||||
embedding_models = await mock_list_embedding()
|
||||
|
||||
# Update models to match the BYOK provider
|
||||
for model in llm_models:
|
||||
model.provider_name = provider.name
|
||||
model.handle = f"{provider.name}/{model.model}"
|
||||
model.provider_category = provider.provider_category
|
||||
|
||||
for model in embedding_models:
|
||||
model.handle = f"{provider.name}/{model.embedding_model}"
|
||||
|
||||
# Call sync_provider_models_async with mock data
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
|
||||
)
|
||||
|
||||
monkeypatch.setattr(provider_manager, "_sync_default_models_for_provider", mock_sync)
|
||||
|
||||
# Create a BYOK OpenAI provider (simulates UI "Add API Key" flow)
|
||||
provider_create = ProviderCreate(name="my-openai-key", provider_type=ProviderType.openai, api_key="sk-my-personal-key-123")
|
||||
|
||||
# Create the BYOK provider (is_byok=True is the default)
|
||||
provider = await provider_manager.create_provider_async(provider_create, actor=default_user, is_byok=True)
|
||||
|
||||
# Verify provider was created
|
||||
assert provider.name == "my-openai-key"
|
||||
assert provider.provider_type == ProviderType.openai
|
||||
|
||||
# List models for this provider - they should have been auto-synced
|
||||
models = await provider_manager.list_models_async(actor=default_user, provider_id=provider.id)
|
||||
|
||||
# Should have both LLM and embedding models
|
||||
llm_models = [m for m in models if m.model_type == "llm"]
|
||||
embedding_models = [m for m in models if m.model_type == "embedding"]
|
||||
|
||||
assert len(llm_models) > 0, "No LLM models were synced"
|
||||
assert len(embedding_models) > 0, "No embedding models were synced"
|
||||
|
||||
# Verify handles are correctly formatted with BYOK provider name
|
||||
for model in models:
|
||||
assert model.handle.startswith(f"{provider.name}/")
|
||||
|
||||
# Test that we can get LLM config from handle
|
||||
llm_config = await provider_manager.get_llm_config_from_handle(handle="my-openai-key/gpt-4o", actor=default_user)
|
||||
assert llm_config.model == "gpt-4o"
|
||||
assert llm_config.provider_name == "my-openai-key"
|
||||
|
||||
@@ -1147,7 +1147,7 @@ class TestAgentFileImport:
|
||||
"""Test basic agent import functionality with embedding override."""
|
||||
agent_file = await agent_serialization_manager.export([test_agent.id], default_user)
|
||||
|
||||
embedding_config_override = await server.get_cached_embedding_config_async(actor=other_user, handle=embedding_handle_override)
|
||||
embedding_config_override = await server.get_embedding_config_from_handle_async(actor=other_user, handle=embedding_handle_override)
|
||||
result = await agent_serialization_manager.import_file(agent_file, other_user, override_embedding_config=embedding_config_override)
|
||||
|
||||
assert result.success
|
||||
|
||||
1742
tests/test_server_providers.py
Normal file
1742
tests/test_server_providers.py
Normal file
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user