feat: byok provider models in db also (#8317)

* feat: byok provider models in db also

* make tests and sync api

* fix inconsistent state with recreating provider of same name

* fix sync on byok creation

* update revision

* move stripe code for testing purposes

* revert

* add refresh byok models endpoint

* just stage publish api

* add tests

* reorder revision

* add test for name clashes
This commit is contained in:
Ari Webb
2026-01-20 17:27:37 -08:00
committed by Caren Thomas
parent fa92f711fe
commit 4ec6649caf
8 changed files with 846 additions and 82 deletions

View File

@@ -0,0 +1,31 @@
"""last_synced column for providers
Revision ID: 308a180244fc
Revises: 82feb220a9b8
Create Date: 2026-01-05 18:54:15.996786
"""
from typing import Sequence, Union
import sqlalchemy as sa
from alembic import op
# revision identifiers, used by Alembic.
revision: str = "308a180244fc"
down_revision: Union[str, None] = "82feb220a9b8"
branch_labels: Union[str, Sequence[str], None] = None
depends_on: Union[str, Sequence[str], None] = None
def upgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.add_column("providers", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True))
# ### end Alembic commands ###
def downgrade() -> None:
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column("providers", "last_synced")
# ### end Alembic commands ###

View File

@@ -14826,6 +14826,53 @@
}
}
},
"/v1/providers/{provider_id}/refresh": {
"patch": {
"tags": ["providers"],
"summary": "Refresh Provider Models",
"description": "Refresh models for a BYOK provider by querying the provider's API.\nAdds new models and removes ones no longer available.",
"operationId": "refresh_provider_models",
"parameters": [
{
"name": "provider_id",
"in": "path",
"required": true,
"schema": {
"type": "string",
"minLength": 45,
"maxLength": 45,
"pattern": "^provider-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$",
"description": "The ID of the provider in the format 'provider-<uuid4>'",
"examples": ["provider-123e4567-e89b-42d3-8456-426614174000"],
"title": "Provider Id"
},
"description": "The ID of the provider in the format 'provider-<uuid4>'"
}
],
"responses": {
"200": {
"description": "Successful Response",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/Provider"
}
}
}
},
"422": {
"description": "Validation Error",
"content": {
"application/json": {
"schema": {
"$ref": "#/components/schemas/HTTPValidationError"
}
}
}
}
}
}
},
"/v1/runs/": {
"get": {
"tags": ["runs"],
@@ -38973,6 +39020,19 @@
"title": "Updated At",
"description": "The last update timestamp of the provider."
},
"last_synced": {
"anyOf": [
{
"type": "string",
"format": "date-time"
},
{
"type": "null"
}
],
"title": "Last Synced",
"description": "The last time models were synced for this provider."
},
"api_key_enc": {
"anyOf": [
{

View File

@@ -1,6 +1,7 @@
from datetime import datetime
from typing import TYPE_CHECKING, Optional
from sqlalchemy import ForeignKey, String, Text, UniqueConstraint
from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint
from sqlalchemy.orm import Mapped, mapped_column, relationship
from letta.orm.mixins import OrganizationMixin
@@ -41,6 +42,11 @@ class Provider(SqlalchemyBase, OrganizationMixin):
api_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted API key or secret key for the provider.")
access_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access key for the provider.")
# sync tracking
last_synced: Mapped[Optional[datetime]] = mapped_column(
DateTime(timezone=True), nullable=True, doc="Last time models were synced for this provider."
)
# relationships
organization: Mapped["Organization"] = relationship("Organization", back_populates="providers")
models: Mapped[list["ProviderModel"]] = relationship("ProviderModel", back_populates="provider", cascade="all, delete-orphan")

View File

@@ -32,6 +32,7 @@ class Provider(ProviderBase):
api_version: str | None = Field(None, description="API version used for requests to the provider.")
organization_id: str | None = Field(None, description="The organization id of the user")
updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.")
last_synced: datetime | None = Field(None, description="The last time models were synced for this provider.")
# Encrypted fields (stored as Secret objects, serialized to strings for DB)
# Secret class handles validation and serialization automatically via __get_pydantic_core_schema__

View File

@@ -1,6 +1,6 @@
from typing import TYPE_CHECKING, List, Literal, Optional
from fastapi import APIRouter, Body, Depends, Query, status
from fastapi import APIRouter, Body, Depends, HTTPException, Query, status
from fastapi.responses import JSONResponse
from letta.schemas.enums import ProviderCategory, ProviderType
@@ -144,6 +144,27 @@ async def check_existing_provider(
)
@router.patch("/{provider_id}/refresh", response_model=Provider, operation_id="refresh_provider_models")
async def refresh_provider_models(
provider_id: ProviderId,
headers: HeaderParams = Depends(get_headers),
server: "SyncServer" = Depends(get_letta_server),
):
"""
Refresh models for a BYOK provider by querying the provider's API.
Adds new models and removes ones no longer available.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id)
provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
# Only allow refresh for BYOK providers
if provider.provider_category != ProviderCategory.byok:
raise HTTPException(status_code=400, detail="Refresh is only supported for BYOK providers")
await server.provider_manager._sync_default_models_for_provider(provider, actor)
return await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor)
@router.delete("/{provider_id}", response_model=None, operation_id="delete_provider")
async def delete_provider(
provider_id: ProviderId,

View File

@@ -466,6 +466,8 @@ class SyncServer(object):
embedding_models=embedding_models,
organization_id=None, # Global models
)
# Update last_synced timestamp
await self.provider_manager.update_provider_last_synced_async(persisted_provider.id)
logger.info(
f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}"
)
@@ -1177,7 +1179,7 @@ class SyncServer(object):
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url or model.model_endpoint_type,
model_endpoint=provider.base_url,
context_window=model.max_context_window or 16384,
handle=model.handle,
provider_name=provider.name,
@@ -1185,7 +1187,7 @@ class SyncServer(object):
)
llm_models.append(llm_config)
# Get BYOK provider models by hitting provider endpoints directly
# Get BYOK provider models - sync if not synced yet, then read from DB
if include_byok:
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
@@ -1196,9 +1198,37 @@ class SyncServer(object):
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
llm_models.extend(models)
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
provider=provider,
llm_models=models,
embedding_models=embedding_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id)
# Read from database
provider_llm_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="llm",
provider_id=provider.id,
enabled=True,
)
for model in provider_llm_models:
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url,
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
handle=model.handle,
provider_name=provider.name,
provider_category=ProviderCategory.byok,
)
llm_models.append(llm_config)
except Exception as e:
logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}")
@@ -1240,7 +1270,7 @@ class SyncServer(object):
)
embedding_models.append(embedding_config)
# Get BYOK provider models by hitting provider endpoints directly
# Get BYOK provider models - sync if not synced yet, then read from DB
byok_providers = await self.provider_manager.list_providers_async(
actor=actor,
provider_category=[ProviderCategory.byok],
@@ -1248,9 +1278,36 @@ class SyncServer(object):
for provider in byok_providers:
try:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_embedding_models_async()
embedding_models.extend(models)
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
emb_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
provider=provider,
llm_models=llm_models,
embedding_models=emb_models,
organization_id=provider.organization_id,
)
await self.provider_manager.update_provider_last_synced_async(provider.id)
# Read from database
provider_embedding_models = await self.provider_manager.list_models_async(
actor=actor,
model_type="embedding",
provider_id=provider.id,
enabled=True,
)
for model in provider_embedding_models:
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,
)
embedding_models.append(embedding_config)
except Exception as e:
logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}")

View File

@@ -98,9 +98,30 @@ class ProviderManager:
deleted_provider.access_key_enc = access_key_secret.get_encrypted()
await deleted_provider.update_async(session, actor=actor)
# Also restore any soft-deleted models associated with this provider
# This is needed because the unique constraint on provider_models doesn't include is_deleted,
# so soft-deleted models would block creation of new models with the same handle
from sqlalchemy import update
restore_models_stmt = (
update(ProviderModelORM)
.where(
and_(
ProviderModelORM.provider_id == deleted_provider.id,
ProviderModelORM.is_deleted == True,
)
)
.values(is_deleted=False)
)
result = await session.execute(restore_models_stmt)
if result.rowcount > 0:
logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'")
provider_pydantic = deleted_provider.to_pydantic()
# For BYOK providers, automatically sync available models
# This will add any new models and remove any that are no longer available
if is_byok:
await self._sync_default_models_for_provider(provider_pydantic, actor)
@@ -201,6 +222,17 @@ class ProviderManager:
await existing_provider.update_async(session, actor=actor)
return existing_provider.to_pydantic()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
async def update_provider_last_synced_async(self, provider_id: str) -> None:
"""Update the last_synced timestamp for a provider."""
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None)
provider.last_synced = datetime.now(timezone.utc)
await session.commit()
@enforce_types
@raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER)
@trace_method
@@ -476,81 +508,19 @@ class ProviderManager:
async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None:
"""Sync models for a newly created BYOK provider by querying the provider's API."""
from letta.log import get_logger
logger = get_logger(__name__)
try:
# Get the provider class and create an instance
from letta.schemas.enums import ProviderType
from letta.schemas.providers.anthropic import AnthropicProvider
from letta.schemas.providers.azure import AzureProvider
from letta.schemas.providers.bedrock import BedrockProvider
from letta.schemas.providers.google_gemini import GoogleAIProvider
from letta.schemas.providers.groq import GroqProvider
from letta.schemas.providers.ollama import OllamaProvider
from letta.schemas.providers.openai import OpenAIProvider
from letta.schemas.providers.zai import ZAIProvider
# Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
# ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id
# (needed for OAuth token refresh and database persistence)
if provider.provider_type == ProviderType.chatgpt_oauth:
provider_instance = provider.cast_to_subtype()
else:
provider_type_to_class = {
"openai": OpenAIProvider,
"anthropic": AnthropicProvider,
"groq": GroqProvider,
"google": GoogleAIProvider,
"ollama": OllamaProvider,
"bedrock": BedrockProvider,
"azure": AzureProvider,
"zai": ZAIProvider,
}
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
provider_class = provider_type_to_class.get(provider_type)
if not provider_class:
logger.warning(f"No provider class found for type '{provider_type}'")
return
# Create provider instance with necessary parameters
api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None
access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None
kwargs = {
"name": provider.name,
"api_key": api_key,
"provider_category": provider.provider_category,
}
if provider.base_url:
kwargs["base_url"] = provider.base_url
if access_key:
kwargs["access_key"] = access_key
if provider.region:
kwargs["region"] = provider.region
if provider.api_version:
kwargs["api_version"] = provider.api_version
provider_instance = provider_class(**kwargs)
# Query the provider's API for available models
llm_models = await provider_instance.list_llm_models_async()
embedding_models = await provider_instance.list_embedding_models_async()
# Update handles and provider_name for BYOK providers
for model in llm_models:
model.provider_name = provider.name
model.handle = f"{provider.name}/{model.model}"
model.provider_category = provider.provider_category
for model in embedding_models:
model.handle = f"{provider.name}/{model.embedding_model}"
# Use existing sync_provider_models_async to save to database
await self.sync_provider_models_async(
provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id
provider=provider,
llm_models=llm_models,
embedding_models=embedding_models,
organization_id=actor.organization_id,
)
await self.update_provider_last_synced_async(provider.id)
except Exception as e:
logger.error(f"Failed to sync models for provider '{provider.name}': {e}")

View File

@@ -2292,6 +2292,7 @@ async def test_server_list_llm_models_byok_from_provider_api(default_user, provi
# Create a mock typed provider that returns our test models
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
# Patch cast_to_subtype on the Provider class to return our mock
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
@@ -2525,3 +2526,620 @@ async def test_create_agent_with_byok_handle_dynamic_fetch(default_user, provide
# Cleanup
await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=default_user)
@pytest.mark.asyncio
async def test_byok_provider_last_synced_triggers_sync_when_null(default_user, provider_manager):
"""Test that BYOK providers with last_synced=null trigger a sync on first model listing."""
from letta.schemas.providers import Provider
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a BYOK provider (last_synced will be null by default)
byok_provider_create = ProviderCreate(
name=f"test-byok-sync-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Verify last_synced is null initially
assert byok_provider.last_synced is None
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# Mock the BYOK provider's list_llm_models_async to return test models
mock_byok_models = [
LLMConfig(
model=f"byok-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=64000,
handle=f"test-byok-sync-{test_id}/gpt-4o",
provider_name=byok_provider.name,
provider_category=ProviderCategory.byok,
)
]
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models)
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
# List BYOK models - should trigger sync because last_synced is null
byok_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
# Verify sync was triggered (cast_to_subtype was called to fetch from API)
# Note: may be called multiple times if other BYOK providers exist in DB
mock_typed_provider.list_llm_models_async.assert_called()
# Verify last_synced was updated for our provider
updated_providers = await provider_manager.list_providers_async(name=byok_provider.name, actor=default_user)
assert len(updated_providers) == 1
assert updated_providers[0].last_synced is not None
# Verify models were synced to database
synced_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(synced_models) == 1
assert synced_models[0].name == f"byok-gpt-4o-{test_id}"
@pytest.mark.asyncio
async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provider_manager):
"""Test that BYOK providers with last_synced set skip sync and read from DB."""
from datetime import datetime, timezone
from letta.schemas.providers import Provider
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a BYOK provider
byok_provider_create = ProviderCreate(
name=f"test-byok-cached-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-byok-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Manually sync models to DB
cached_model = LLMConfig(
model=f"cached-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=64000,
handle=f"test-byok-cached-{test_id}/gpt-4o",
provider_name=byok_provider.name,
provider_category=ProviderCategory.byok,
)
await provider_manager.sync_provider_models_async(
provider=byok_provider,
llm_models=[cached_model],
embedding_models=[],
organization_id=default_user.organization_id,
)
# Set last_synced to indicate models are already synced
await provider_manager.update_provider_last_synced_async(byok_provider.id)
# Create server
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
server._enabled_providers = []
# Mock cast_to_subtype - should NOT be called since last_synced is set
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=[])
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[])
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
# List BYOK models - should read from DB, not trigger sync
byok_models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.byok],
)
# Verify sync was NOT triggered (cast_to_subtype should not be called)
mock_typed_provider.list_llm_models_async.assert_not_called()
# Verify we got the cached model from DB
byok_handles = [m.handle for m in byok_models]
assert f"test-byok-cached-{test_id}/gpt-4o" in byok_handles
@pytest.mark.asyncio
async def test_base_provider_updates_last_synced_on_sync(default_user, provider_manager):
"""Test that base provider sync updates the last_synced timestamp."""
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a base provider
base_provider_create = ProviderCreate(
name=f"test-base-sync-{test_id}",
provider_type=ProviderType.openai,
api_key="", # Base providers don't store API keys
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
# Verify last_synced is null initially
assert base_provider.last_synced is None
# Sync models for the base provider
base_model = LLMConfig(
model=f"base-gpt-4o-{test_id}",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=64000,
handle=f"test-base-sync-{test_id}/gpt-4o",
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[base_model],
embedding_models=[],
organization_id=None,
)
await provider_manager.update_provider_last_synced_async(base_provider.id)
# Verify last_synced was updated
updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user)
assert len(updated_providers) == 1
assert updated_providers[0].last_synced is not None
@pytest.mark.asyncio
async def test_byok_provider_models_synced_on_creation(default_user, provider_manager):
"""Test that models are automatically synced when a BYOK provider is created.
When create_provider_async is called with is_byok=True, it should:
1. Create the provider in the database
2. Call _sync_default_models_for_provider to fetch and persist models from the provider API
3. Update last_synced timestamp
"""
from letta.schemas.providers import Provider
test_id = generate_test_id()
# Mock models that the provider API would return
mock_llm_models = [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-byok-creation-{test_id}/gpt-4o",
provider_name=f"test-byok-creation-{test_id}",
provider_category=ProviderCategory.byok,
),
LLMConfig(
model="gpt-4o-mini",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-byok-creation-{test_id}/gpt-4o-mini",
provider_name=f"test-byok-creation-{test_id}",
provider_category=ProviderCategory.byok,
),
]
mock_embedding_models = [
EmbeddingConfig(
embedding_model="text-embedding-3-small",
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
embedding_chunk_size=300,
handle=f"test-byok-creation-{test_id}/text-embedding-3-small",
),
]
# Create a mock typed provider that returns our test models
mock_typed_provider = MagicMock()
mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_llm_models)
mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=mock_embedding_models)
# Patch cast_to_subtype to return our mock when _sync_default_models_for_provider is called
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider):
# Create the BYOK provider - this should automatically sync models
byok_provider_create = ProviderCreate(
name=f"test-byok-creation-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Verify the provider API was called during creation
mock_typed_provider.list_llm_models_async.assert_called_once()
mock_typed_provider.list_embedding_models_async.assert_called_once()
# Re-fetch the provider to get the updated last_synced value
# (the returned object from create_provider_async is stale since last_synced is set after)
byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
# Verify last_synced was set (indicating sync completed)
assert byok_provider.last_synced is not None
# Verify LLM models were persisted to the database
synced_llm_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(synced_llm_models) == 2
synced_llm_names = {m.name for m in synced_llm_models}
assert "gpt-4o" in synced_llm_names
assert "gpt-4o-mini" in synced_llm_names
# Verify embedding models were persisted to the database
synced_embedding_models = await provider_manager.list_models_async(
actor=default_user,
model_type="embedding",
provider_id=byok_provider.id,
)
assert len(synced_embedding_models) == 1
assert synced_embedding_models[0].name == "text-embedding-3-small"
@pytest.mark.asyncio
async def test_refresh_byok_provider_adds_new_models(default_user, provider_manager):
"""Test that refreshing a BYOK provider adds new models from the provider API.
When _sync_default_models_for_provider is called (via refresh endpoint):
1. It should fetch current models from the provider API
2. Add any new models that weren't previously synced
3. Update the last_synced timestamp
"""
from letta.schemas.providers import Provider
test_id = generate_test_id()
# Initial models when provider is created
initial_models = [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-refresh-add-{test_id}/gpt-4o",
provider_name=f"test-refresh-add-{test_id}",
provider_category=ProviderCategory.byok,
),
]
# Updated models after refresh (includes a new model)
updated_models = [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-refresh-add-{test_id}/gpt-4o",
provider_name=f"test-refresh-add-{test_id}",
provider_category=ProviderCategory.byok,
),
LLMConfig(
model="gpt-4.1", # New model added by provider
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=256000,
handle=f"test-refresh-add-{test_id}/gpt-4.1",
provider_name=f"test-refresh-add-{test_id}",
provider_category=ProviderCategory.byok,
),
]
# Create mock for initial sync during provider creation
mock_typed_provider_initial = MagicMock()
mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models)
mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[])
# Create the provider with initial models
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial):
byok_provider_create = ProviderCreate(
name=f"test-refresh-add-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Re-fetch the provider to get the updated last_synced value
byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
# Verify initial sync - should have 1 model
initial_synced_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(initial_synced_models) == 1
assert initial_synced_models[0].name == "gpt-4o"
initial_last_synced = byok_provider.last_synced
assert initial_last_synced is not None # Verify sync happened during creation
# Create mock for refresh with updated models
mock_typed_provider_refresh = MagicMock()
mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models)
mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[])
# Refresh the provider (simulating what the endpoint does)
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh):
await provider_manager._sync_default_models_for_provider(byok_provider, default_user)
# Verify the API was called during refresh
mock_typed_provider_refresh.list_llm_models_async.assert_called_once()
# Verify new model was added
refreshed_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(refreshed_models) == 2
refreshed_names = {m.name for m in refreshed_models}
assert "gpt-4o" in refreshed_names
assert "gpt-4.1" in refreshed_names
# Verify last_synced was updated
updated_provider = await provider_manager.get_provider_async(byok_provider.id, default_user)
assert updated_provider.last_synced is not None
assert updated_provider.last_synced >= initial_last_synced
@pytest.mark.asyncio
async def test_refresh_byok_provider_removes_old_models(default_user, provider_manager):
"""Test that refreshing a BYOK provider removes models no longer available from the provider API.
When _sync_default_models_for_provider is called (via refresh endpoint):
1. It should fetch current models from the provider API
2. Remove any models that are no longer available (soft delete)
3. Keep models that are still available
"""
from letta.schemas.providers import Provider
test_id = generate_test_id()
# Initial models when provider is created (includes a model that will be removed)
initial_models = [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-refresh-remove-{test_id}/gpt-4o",
provider_name=f"test-refresh-remove-{test_id}",
provider_category=ProviderCategory.byok,
),
LLMConfig(
model="gpt-4-turbo", # This model will be deprecated/removed
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-refresh-remove-{test_id}/gpt-4-turbo",
provider_name=f"test-refresh-remove-{test_id}",
provider_category=ProviderCategory.byok,
),
]
# Updated models after refresh (gpt-4-turbo is no longer available)
updated_models = [
LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=f"test-refresh-remove-{test_id}/gpt-4o",
provider_name=f"test-refresh-remove-{test_id}",
provider_category=ProviderCategory.byok,
),
]
# Create mock for initial sync during provider creation
mock_typed_provider_initial = MagicMock()
mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models)
mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[])
# Create the provider with initial models
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial):
byok_provider_create = ProviderCreate(
name=f"test-refresh-remove-{test_id}",
provider_type=ProviderType.openai,
api_key="sk-test-key",
)
byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True)
# Verify initial sync - should have 2 models
initial_synced_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(initial_synced_models) == 2
initial_names = {m.name for m in initial_synced_models}
assert "gpt-4o" in initial_names
assert "gpt-4-turbo" in initial_names
# Create mock for refresh with fewer models
mock_typed_provider_refresh = MagicMock()
mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models)
mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[])
# Refresh the provider (simulating what the endpoint does)
with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh):
await provider_manager._sync_default_models_for_provider(byok_provider, default_user)
# Verify the removed model is no longer in the list
refreshed_models = await provider_manager.list_models_async(
actor=default_user,
model_type="llm",
provider_id=byok_provider.id,
)
assert len(refreshed_models) == 1
assert refreshed_models[0].name == "gpt-4o"
# Verify gpt-4-turbo was removed (soft deleted)
refreshed_names = {m.name for m in refreshed_models}
assert "gpt-4-turbo" not in refreshed_names
@pytest.mark.asyncio
async def test_refresh_base_provider_fails(default_user, provider_manager):
"""Test that attempting to refresh a base provider returns an error.
The refresh endpoint should only work for BYOK providers, not base providers.
Base providers are managed by environment variables and shouldn't be refreshed.
"""
from fastapi import HTTPException
from letta.server.rest_api.routers.v1.providers import refresh_provider_models
from letta.server.server import SyncServer
test_id = generate_test_id()
# Create a base provider
base_provider_create = ProviderCreate(
name=f"test-base-refresh-{test_id}",
provider_type=ProviderType.openai,
api_key="", # Base providers don't store API keys
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
# Verify it's a base provider
assert base_provider.provider_category == ProviderCategory.base
# Create a mock server
server = SyncServer(init_with_default_org_and_user=False)
server.provider_manager = provider_manager
# Create mock headers
mock_headers = MagicMock()
mock_headers.actor_id = default_user.id
# Mock get_actor_or_default_async to return our test user
server.user_manager = MagicMock()
server.user_manager.get_actor_or_default_async = AsyncMock(return_value=default_user)
# Attempt to refresh the base provider - should raise HTTPException
with pytest.raises(HTTPException) as exc_info:
await refresh_provider_models(
provider_id=base_provider.id,
headers=mock_headers,
server=server,
)
assert exc_info.value.status_code == 400
assert "BYOK" in exc_info.value.detail
@pytest.mark.asyncio
async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, provider_manager):
"""Test that get_model_by_handle_async returns the BYOK model when both BYOK and base providers have the same handle.
This tests the legacy scenario where a user has both a BYOK provider and a base provider
with the same name (and thus models with the same handle). The BYOK model should be
returned because it's organization-specific, while base models are global.
"""
test_id = generate_test_id()
provider_name = f"test-duplicate-{test_id}"
model_handle = f"{provider_name}/gpt-4o"
# Step 1: Create a base provider and sync a model for it (global, organization_id=None)
base_provider_create = ProviderCreate(
name=provider_name,
provider_type=ProviderType.openai,
api_key="", # Base providers don't store API keys
)
base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False)
assert base_provider.provider_category == ProviderCategory.base
# Sync a model for the base provider (global model with organization_id=None)
base_llm_model = LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=model_handle,
provider_name=provider_name,
)
await provider_manager.sync_provider_models_async(
provider=base_provider,
llm_models=[base_llm_model],
embedding_models=[],
organization_id=None, # Global model
)
# Verify base model was created
base_model = await provider_manager.get_model_by_handle_async(
handle=model_handle,
actor=default_user,
model_type="llm",
)
assert base_model is not None
assert base_model.handle == model_handle
assert base_model.organization_id is None # Global model
# Step 2: Create a BYOK provider with the same name (simulating legacy duplicate)
# Note: In production, this is now prevented, but legacy data could have this
# We need to bypass the name conflict check for this test (simulating legacy data)
# Create the BYOK provider directly by manipulating the database
from letta.orm.provider import Provider as ProviderORM
from letta.schemas.providers import Provider as PydanticProvider
from letta.server.db import db_registry
# Create a pydantic provider first to generate an ID
byok_pydantic_provider = PydanticProvider(
name=provider_name, # Same name as base provider
provider_type=ProviderType.openai,
provider_category=ProviderCategory.byok,
organization_id=default_user.organization_id,
)
byok_pydantic_provider.resolve_identifier()
async with db_registry.async_session() as session:
byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True))
await byok_provider_orm.create_async(session, actor=default_user)
byok_provider = byok_provider_orm.to_pydantic()
assert byok_provider.provider_category == ProviderCategory.byok
# Sync a model for the BYOK provider (org-specific model)
byok_llm_model = LLMConfig(
model="gpt-4o",
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
context_window=128000,
handle=model_handle, # Same handle as base model
provider_name=provider_name,
provider_category=ProviderCategory.byok,
)
await provider_manager.sync_provider_models_async(
provider=byok_provider,
llm_models=[byok_llm_model],
embedding_models=[],
organization_id=default_user.organization_id, # Org-specific model
)
# Step 3: Verify that get_model_by_handle_async returns the BYOK model (org-specific)
retrieved_model = await provider_manager.get_model_by_handle_async(
handle=model_handle,
actor=default_user,
model_type="llm",
)
assert retrieved_model is not None
assert retrieved_model.handle == model_handle
# The key assertion: org-specific (BYOK) model should be returned, not the global (base) model
assert retrieved_model.organization_id == default_user.organization_id
assert retrieved_model.provider_id == byok_provider.id