From 4ec6649caf133fb17f953bddbe625bca13a191b9 Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Tue, 20 Jan 2026 17:27:37 -0800 Subject: [PATCH] feat: byok provider models in db also (#8317) * feat: byok provider models in db also * make tests and sync api * fix inconsistent state with recreating provider of same name * fix sync on byok creation * update revision * move stripe code for testing purposes * revert * add refresh byok models endpoint * just stage publish api * add tests * reorder revision * add test for name clashes --- ...0244fc_last_synced_column_for_providers.py | 31 + fern/openapi.json | 60 ++ letta/orm/provider.py | 8 +- letta/schemas/providers/base.py | 1 + letta/server/rest_api/routers/v1/providers.py | 23 +- letta/server/server.py | 75 ++- letta/services/provider_manager.py | 112 ++-- tests/test_server_providers.py | 618 ++++++++++++++++++ 8 files changed, 846 insertions(+), 82 deletions(-) create mode 100644 alembic/versions/308a180244fc_last_synced_column_for_providers.py diff --git a/alembic/versions/308a180244fc_last_synced_column_for_providers.py b/alembic/versions/308a180244fc_last_synced_column_for_providers.py new file mode 100644 index 00000000..03aa169c --- /dev/null +++ b/alembic/versions/308a180244fc_last_synced_column_for_providers.py @@ -0,0 +1,31 @@ +"""last_synced column for providers + +Revision ID: 308a180244fc +Revises: 82feb220a9b8 +Create Date: 2026-01-05 18:54:15.996786 + +""" + +from typing import Sequence, Union + +import sqlalchemy as sa + +from alembic import op + +# revision identifiers, used by Alembic. +revision: str = "308a180244fc" +down_revision: Union[str, None] = "82feb220a9b8" +branch_labels: Union[str, Sequence[str], None] = None +depends_on: Union[str, Sequence[str], None] = None + + +def upgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.add_column("providers", sa.Column("last_synced", sa.DateTime(timezone=True), nullable=True)) + # ### end Alembic commands ### + + +def downgrade() -> None: + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column("providers", "last_synced") + # ### end Alembic commands ### diff --git a/fern/openapi.json b/fern/openapi.json index 18b61446..3ae41301 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -14826,6 +14826,53 @@ } } }, + "/v1/providers/{provider_id}/refresh": { + "patch": { + "tags": ["providers"], + "summary": "Refresh Provider Models", + "description": "Refresh models for a BYOK provider by querying the provider's API.\nAdds new models and removes ones no longer available.", + "operationId": "refresh_provider_models", + "parameters": [ + { + "name": "provider_id", + "in": "path", + "required": true, + "schema": { + "type": "string", + "minLength": 45, + "maxLength": 45, + "pattern": "^provider-[0-9a-f]{8}-[0-9a-f]{4}-4[0-9a-f]{3}-[89ab][0-9a-f]{3}-[0-9a-f]{12}$", + "description": "The ID of the provider in the format 'provider-'", + "examples": ["provider-123e4567-e89b-42d3-8456-426614174000"], + "title": "Provider Id" + }, + "description": "The ID of the provider in the format 'provider-'" + } + ], + "responses": { + "200": { + "description": "Successful Response", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/Provider" + } + } + } + }, + "422": { + "description": "Validation Error", + "content": { + "application/json": { + "schema": { + "$ref": "#/components/schemas/HTTPValidationError" + } + } + } + } + } + } + }, "/v1/runs/": { "get": { "tags": ["runs"], @@ -38973,6 +39020,19 @@ "title": "Updated At", "description": "The last update timestamp of the provider." }, + "last_synced": { + "anyOf": [ + { + "type": "string", + "format": "date-time" + }, + { + "type": "null" + } + ], + "title": "Last Synced", + "description": "The last time models were synced for this provider." + }, "api_key_enc": { "anyOf": [ { diff --git a/letta/orm/provider.py b/letta/orm/provider.py index bd42a1be..f784caa1 100644 --- a/letta/orm/provider.py +++ b/letta/orm/provider.py @@ -1,6 +1,7 @@ +from datetime import datetime from typing import TYPE_CHECKING, Optional -from sqlalchemy import ForeignKey, String, Text, UniqueConstraint +from sqlalchemy import DateTime, ForeignKey, String, Text, UniqueConstraint from sqlalchemy.orm import Mapped, mapped_column, relationship from letta.orm.mixins import OrganizationMixin @@ -41,6 +42,11 @@ class Provider(SqlalchemyBase, OrganizationMixin): api_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted API key or secret key for the provider.") access_key_enc: Mapped[Optional[str]] = mapped_column(Text, nullable=True, doc="Encrypted access key for the provider.") + # sync tracking + last_synced: Mapped[Optional[datetime]] = mapped_column( + DateTime(timezone=True), nullable=True, doc="Last time models were synced for this provider." + ) + # relationships organization: Mapped["Organization"] = relationship("Organization", back_populates="providers") models: Mapped[list["ProviderModel"]] = relationship("ProviderModel", back_populates="provider", cascade="all, delete-orphan") diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index 1893ce94..6b5b722d 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -32,6 +32,7 @@ class Provider(ProviderBase): api_version: str | None = Field(None, description="API version used for requests to the provider.") organization_id: str | None = Field(None, description="The organization id of the user") updated_at: datetime | None = Field(None, description="The last update timestamp of the provider.") + last_synced: datetime | None = Field(None, description="The last time models were synced for this provider.") # Encrypted fields (stored as Secret objects, serialized to strings for DB) # Secret class handles validation and serialization automatically via __get_pydantic_core_schema__ diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 5d5135f9..5c0ae926 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -1,6 +1,6 @@ from typing import TYPE_CHECKING, List, Literal, Optional -from fastapi import APIRouter, Body, Depends, Query, status +from fastapi import APIRouter, Body, Depends, HTTPException, Query, status from fastapi.responses import JSONResponse from letta.schemas.enums import ProviderCategory, ProviderType @@ -144,6 +144,27 @@ async def check_existing_provider( ) +@router.patch("/{provider_id}/refresh", response_model=Provider, operation_id="refresh_provider_models") +async def refresh_provider_models( + provider_id: ProviderId, + headers: HeaderParams = Depends(get_headers), + server: "SyncServer" = Depends(get_letta_server), +): + """ + Refresh models for a BYOK provider by querying the provider's API. + Adds new models and removes ones no longer available. + """ + actor = await server.user_manager.get_actor_or_default_async(actor_id=headers.actor_id) + provider = await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor) + + # Only allow refresh for BYOK providers + if provider.provider_category != ProviderCategory.byok: + raise HTTPException(status_code=400, detail="Refresh is only supported for BYOK providers") + + await server.provider_manager._sync_default_models_for_provider(provider, actor) + return await server.provider_manager.get_provider_async(provider_id=provider_id, actor=actor) + + @router.delete("/{provider_id}", response_model=None, operation_id="delete_provider") async def delete_provider( provider_id: ProviderId, diff --git a/letta/server/server.py b/letta/server/server.py index a45c31a8..f3225ecf 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -466,6 +466,8 @@ class SyncServer(object): embedding_models=embedding_models, organization_id=None, # Global models ) + # Update last_synced timestamp + await self.provider_manager.update_provider_last_synced_async(persisted_provider.id) logger.info( f"Synced {len(llm_models)} LLM models and {len(embedding_models)} embedding models for provider {persisted_provider.name}" ) @@ -1177,7 +1179,7 @@ class SyncServer(object): llm_config = LLMConfig( model=model.name, model_endpoint_type=model.model_endpoint_type, - model_endpoint=provider.base_url or model.model_endpoint_type, + model_endpoint=provider.base_url, context_window=model.max_context_window or 16384, handle=model.handle, provider_name=provider.name, @@ -1185,7 +1187,7 @@ class SyncServer(object): ) llm_models.append(llm_config) - # Get BYOK provider models by hitting provider endpoints directly + # Get BYOK provider models - sync if not synced yet, then read from DB if include_byok: byok_providers = await self.provider_manager.list_providers_async( actor=actor, @@ -1196,9 +1198,37 @@ class SyncServer(object): for provider in byok_providers: try: - typed_provider = provider.cast_to_subtype() - models = await typed_provider.list_llm_models_async() - llm_models.extend(models) + # Sync models if not synced yet + if provider.last_synced is None: + typed_provider = provider.cast_to_subtype() + models = await typed_provider.list_llm_models_async() + embedding_models = await typed_provider.list_embedding_models_async() + await self.provider_manager.sync_provider_models_async( + provider=provider, + llm_models=models, + embedding_models=embedding_models, + organization_id=provider.organization_id, + ) + await self.provider_manager.update_provider_last_synced_async(provider.id) + + # Read from database + provider_llm_models = await self.provider_manager.list_models_async( + actor=actor, + model_type="llm", + provider_id=provider.id, + enabled=True, + ) + for model in provider_llm_models: + llm_config = LLMConfig( + model=model.name, + model_endpoint_type=model.model_endpoint_type, + model_endpoint=provider.base_url, + context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW, + handle=model.handle, + provider_name=provider.name, + provider_category=ProviderCategory.byok, + ) + llm_models.append(llm_config) except Exception as e: logger.warning(f"Failed to fetch models from BYOK provider {provider.name}: {e}") @@ -1240,7 +1270,7 @@ class SyncServer(object): ) embedding_models.append(embedding_config) - # Get BYOK provider models by hitting provider endpoints directly + # Get BYOK provider models - sync if not synced yet, then read from DB byok_providers = await self.provider_manager.list_providers_async( actor=actor, provider_category=[ProviderCategory.byok], @@ -1248,9 +1278,36 @@ class SyncServer(object): for provider in byok_providers: try: - typed_provider = provider.cast_to_subtype() - models = await typed_provider.list_embedding_models_async() - embedding_models.extend(models) + # Sync models if not synced yet + if provider.last_synced is None: + typed_provider = provider.cast_to_subtype() + llm_models = await typed_provider.list_llm_models_async() + emb_models = await typed_provider.list_embedding_models_async() + await self.provider_manager.sync_provider_models_async( + provider=provider, + llm_models=llm_models, + embedding_models=emb_models, + organization_id=provider.organization_id, + ) + await self.provider_manager.update_provider_last_synced_async(provider.id) + + # Read from database + provider_embedding_models = await self.provider_manager.list_models_async( + actor=actor, + model_type="embedding", + provider_id=provider.id, + enabled=True, + ) + for model in provider_embedding_models: + embedding_config = EmbeddingConfig( + embedding_model=model.name, + embedding_endpoint_type=model.model_endpoint_type, + embedding_endpoint=provider.base_url or model.model_endpoint_type, + embedding_dim=model.embedding_dim or 1536, + embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, + handle=model.handle, + ) + embedding_models.append(embedding_config) except Exception as e: logger.warning(f"Failed to fetch embedding models from BYOK provider {provider.name}: {e}") diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 3757e5fc..1e846170 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -98,9 +98,30 @@ class ProviderManager: deleted_provider.access_key_enc = access_key_secret.get_encrypted() await deleted_provider.update_async(session, actor=actor) + + # Also restore any soft-deleted models associated with this provider + # This is needed because the unique constraint on provider_models doesn't include is_deleted, + # so soft-deleted models would block creation of new models with the same handle + from sqlalchemy import update + + restore_models_stmt = ( + update(ProviderModelORM) + .where( + and_( + ProviderModelORM.provider_id == deleted_provider.id, + ProviderModelORM.is_deleted == True, + ) + ) + .values(is_deleted=False) + ) + result = await session.execute(restore_models_stmt) + if result.rowcount > 0: + logger.info(f"Restored {result.rowcount} soft-deleted model(s) for provider '{request.name}'") + provider_pydantic = deleted_provider.to_pydantic() # For BYOK providers, automatically sync available models + # This will add any new models and remove any that are no longer available if is_byok: await self._sync_default_models_for_provider(provider_pydantic, actor) @@ -201,6 +222,17 @@ class ProviderManager: await existing_provider.update_async(session, actor=actor) return existing_provider.to_pydantic() + @enforce_types + @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) + async def update_provider_last_synced_async(self, provider_id: str) -> None: + """Update the last_synced timestamp for a provider.""" + from datetime import datetime, timezone + + async with db_registry.async_session() as session: + provider = await ProviderModel.read_async(db_session=session, identifier=provider_id, actor=None) + provider.last_synced = datetime.now(timezone.utc) + await session.commit() + @enforce_types @raise_on_invalid_id(param_name="provider_id", expected_prefix=PrimitiveType.PROVIDER) @trace_method @@ -476,81 +508,19 @@ class ProviderManager: async def _sync_default_models_for_provider(self, provider: PydanticProvider, actor: PydanticUser) -> None: """Sync models for a newly created BYOK provider by querying the provider's API.""" - from letta.log import get_logger - - logger = get_logger(__name__) - try: - # Get the provider class and create an instance - from letta.schemas.enums import ProviderType - from letta.schemas.providers.anthropic import AnthropicProvider - from letta.schemas.providers.azure import AzureProvider - from letta.schemas.providers.bedrock import BedrockProvider - from letta.schemas.providers.google_gemini import GoogleAIProvider - from letta.schemas.providers.groq import GroqProvider - from letta.schemas.providers.ollama import OllamaProvider - from letta.schemas.providers.openai import OpenAIProvider - from letta.schemas.providers.zai import ZAIProvider + # Use cast_to_subtype() which properly handles all provider types and preserves api_key_enc + typed_provider = provider.cast_to_subtype() + llm_models = await typed_provider.list_llm_models_async() + embedding_models = await typed_provider.list_embedding_models_async() - # ChatGPT OAuth requires cast_to_subtype to preserve api_key_enc and id - # (needed for OAuth token refresh and database persistence) - if provider.provider_type == ProviderType.chatgpt_oauth: - provider_instance = provider.cast_to_subtype() - else: - provider_type_to_class = { - "openai": OpenAIProvider, - "anthropic": AnthropicProvider, - "groq": GroqProvider, - "google": GoogleAIProvider, - "ollama": OllamaProvider, - "bedrock": BedrockProvider, - "azure": AzureProvider, - "zai": ZAIProvider, - } - - provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type) - provider_class = provider_type_to_class.get(provider_type) - - if not provider_class: - logger.warning(f"No provider class found for type '{provider_type}'") - return - - # Create provider instance with necessary parameters - api_key = await provider.api_key_enc.get_plaintext_async() if provider.api_key_enc else None - access_key = await provider.access_key_enc.get_plaintext_async() if provider.access_key_enc else None - kwargs = { - "name": provider.name, - "api_key": api_key, - "provider_category": provider.provider_category, - } - if provider.base_url: - kwargs["base_url"] = provider.base_url - if access_key: - kwargs["access_key"] = access_key - if provider.region: - kwargs["region"] = provider.region - if provider.api_version: - kwargs["api_version"] = provider.api_version - - provider_instance = provider_class(**kwargs) - - # Query the provider's API for available models - llm_models = await provider_instance.list_llm_models_async() - embedding_models = await provider_instance.list_embedding_models_async() - - # Update handles and provider_name for BYOK providers - for model in llm_models: - model.provider_name = provider.name - model.handle = f"{provider.name}/{model.model}" - model.provider_category = provider.provider_category - - for model in embedding_models: - model.handle = f"{provider.name}/{model.embedding_model}" - - # Use existing sync_provider_models_async to save to database await self.sync_provider_models_async( - provider=provider, llm_models=llm_models, embedding_models=embedding_models, organization_id=actor.organization_id + provider=provider, + llm_models=llm_models, + embedding_models=embedding_models, + organization_id=actor.organization_id, ) + await self.update_provider_last_synced_async(provider.id) except Exception as e: logger.error(f"Failed to sync models for provider '{provider.name}': {e}") diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index c3ca87da..1e07ceea 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -2292,6 +2292,7 @@ async def test_server_list_llm_models_byok_from_provider_api(default_user, provi # Create a mock typed provider that returns our test models mock_typed_provider = MagicMock() mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models) + mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[]) # Patch cast_to_subtype on the Provider class to return our mock with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider): @@ -2525,3 +2526,620 @@ async def test_create_agent_with_byok_handle_dynamic_fetch(default_user, provide # Cleanup await server.agent_manager.delete_agent_async(agent_id=agent.id, actor=default_user) + + +@pytest.mark.asyncio +async def test_byok_provider_last_synced_triggers_sync_when_null(default_user, provider_manager): + """Test that BYOK providers with last_synced=null trigger a sync on first model listing.""" + from letta.schemas.providers import Provider + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a BYOK provider (last_synced will be null by default) + byok_provider_create = ProviderCreate( + name=f"test-byok-sync-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-byok-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Verify last_synced is null initially + assert byok_provider.last_synced is None + + # Create server + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] + + # Mock the BYOK provider's list_llm_models_async to return test models + mock_byok_models = [ + LLMConfig( + model=f"byok-gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=64000, + handle=f"test-byok-sync-{test_id}/gpt-4o", + provider_name=byok_provider.name, + provider_category=ProviderCategory.byok, + ) + ] + + mock_typed_provider = MagicMock() + mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_byok_models) + mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[]) + + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider): + # List BYOK models - should trigger sync because last_synced is null + byok_models = await server.list_llm_models_async( + actor=default_user, + provider_category=[ProviderCategory.byok], + ) + + # Verify sync was triggered (cast_to_subtype was called to fetch from API) + # Note: may be called multiple times if other BYOK providers exist in DB + mock_typed_provider.list_llm_models_async.assert_called() + + # Verify last_synced was updated for our provider + updated_providers = await provider_manager.list_providers_async(name=byok_provider.name, actor=default_user) + assert len(updated_providers) == 1 + assert updated_providers[0].last_synced is not None + + # Verify models were synced to database + synced_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(synced_models) == 1 + assert synced_models[0].name == f"byok-gpt-4o-{test_id}" + + +@pytest.mark.asyncio +async def test_byok_provider_last_synced_skips_sync_when_set(default_user, provider_manager): + """Test that BYOK providers with last_synced set skip sync and read from DB.""" + from datetime import datetime, timezone + + from letta.schemas.providers import Provider + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a BYOK provider + byok_provider_create = ProviderCreate( + name=f"test-byok-cached-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-byok-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Manually sync models to DB + cached_model = LLMConfig( + model=f"cached-gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=64000, + handle=f"test-byok-cached-{test_id}/gpt-4o", + provider_name=byok_provider.name, + provider_category=ProviderCategory.byok, + ) + await provider_manager.sync_provider_models_async( + provider=byok_provider, + llm_models=[cached_model], + embedding_models=[], + organization_id=default_user.organization_id, + ) + + # Set last_synced to indicate models are already synced + await provider_manager.update_provider_last_synced_async(byok_provider.id) + + # Create server + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + server._enabled_providers = [] + + # Mock cast_to_subtype - should NOT be called since last_synced is set + mock_typed_provider = MagicMock() + mock_typed_provider.list_llm_models_async = AsyncMock(return_value=[]) + mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=[]) + + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider): + # List BYOK models - should read from DB, not trigger sync + byok_models = await server.list_llm_models_async( + actor=default_user, + provider_category=[ProviderCategory.byok], + ) + + # Verify sync was NOT triggered (cast_to_subtype should not be called) + mock_typed_provider.list_llm_models_async.assert_not_called() + + # Verify we got the cached model from DB + byok_handles = [m.handle for m in byok_models] + assert f"test-byok-cached-{test_id}/gpt-4o" in byok_handles + + +@pytest.mark.asyncio +async def test_base_provider_updates_last_synced_on_sync(default_user, provider_manager): + """Test that base provider sync updates the last_synced timestamp.""" + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a base provider + base_provider_create = ProviderCreate( + name=f"test-base-sync-{test_id}", + provider_type=ProviderType.openai, + api_key="", # Base providers don't store API keys + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + # Verify last_synced is null initially + assert base_provider.last_synced is None + + # Sync models for the base provider + base_model = LLMConfig( + model=f"base-gpt-4o-{test_id}", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=64000, + handle=f"test-base-sync-{test_id}/gpt-4o", + ) + await provider_manager.sync_provider_models_async( + provider=base_provider, + llm_models=[base_model], + embedding_models=[], + organization_id=None, + ) + await provider_manager.update_provider_last_synced_async(base_provider.id) + + # Verify last_synced was updated + updated_providers = await provider_manager.list_providers_async(name=base_provider.name, actor=default_user) + assert len(updated_providers) == 1 + assert updated_providers[0].last_synced is not None + + +@pytest.mark.asyncio +async def test_byok_provider_models_synced_on_creation(default_user, provider_manager): + """Test that models are automatically synced when a BYOK provider is created. + + When create_provider_async is called with is_byok=True, it should: + 1. Create the provider in the database + 2. Call _sync_default_models_for_provider to fetch and persist models from the provider API + 3. Update last_synced timestamp + """ + from letta.schemas.providers import Provider + + test_id = generate_test_id() + + # Mock models that the provider API would return + mock_llm_models = [ + LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-byok-creation-{test_id}/gpt-4o", + provider_name=f"test-byok-creation-{test_id}", + provider_category=ProviderCategory.byok, + ), + LLMConfig( + model="gpt-4o-mini", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-byok-creation-{test_id}/gpt-4o-mini", + provider_name=f"test-byok-creation-{test_id}", + provider_category=ProviderCategory.byok, + ), + ] + mock_embedding_models = [ + EmbeddingConfig( + embedding_model="text-embedding-3-small", + embedding_endpoint_type="openai", + embedding_endpoint="https://api.openai.com/v1", + embedding_dim=1536, + embedding_chunk_size=300, + handle=f"test-byok-creation-{test_id}/text-embedding-3-small", + ), + ] + + # Create a mock typed provider that returns our test models + mock_typed_provider = MagicMock() + mock_typed_provider.list_llm_models_async = AsyncMock(return_value=mock_llm_models) + mock_typed_provider.list_embedding_models_async = AsyncMock(return_value=mock_embedding_models) + + # Patch cast_to_subtype to return our mock when _sync_default_models_for_provider is called + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider): + # Create the BYOK provider - this should automatically sync models + byok_provider_create = ProviderCreate( + name=f"test-byok-creation-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Verify the provider API was called during creation + mock_typed_provider.list_llm_models_async.assert_called_once() + mock_typed_provider.list_embedding_models_async.assert_called_once() + + # Re-fetch the provider to get the updated last_synced value + # (the returned object from create_provider_async is stale since last_synced is set after) + byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user) + + # Verify last_synced was set (indicating sync completed) + assert byok_provider.last_synced is not None + + # Verify LLM models were persisted to the database + synced_llm_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(synced_llm_models) == 2 + synced_llm_names = {m.name for m in synced_llm_models} + assert "gpt-4o" in synced_llm_names + assert "gpt-4o-mini" in synced_llm_names + + # Verify embedding models were persisted to the database + synced_embedding_models = await provider_manager.list_models_async( + actor=default_user, + model_type="embedding", + provider_id=byok_provider.id, + ) + assert len(synced_embedding_models) == 1 + assert synced_embedding_models[0].name == "text-embedding-3-small" + + +@pytest.mark.asyncio +async def test_refresh_byok_provider_adds_new_models(default_user, provider_manager): + """Test that refreshing a BYOK provider adds new models from the provider API. + + When _sync_default_models_for_provider is called (via refresh endpoint): + 1. It should fetch current models from the provider API + 2. Add any new models that weren't previously synced + 3. Update the last_synced timestamp + """ + from letta.schemas.providers import Provider + + test_id = generate_test_id() + + # Initial models when provider is created + initial_models = [ + LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-refresh-add-{test_id}/gpt-4o", + provider_name=f"test-refresh-add-{test_id}", + provider_category=ProviderCategory.byok, + ), + ] + + # Updated models after refresh (includes a new model) + updated_models = [ + LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-refresh-add-{test_id}/gpt-4o", + provider_name=f"test-refresh-add-{test_id}", + provider_category=ProviderCategory.byok, + ), + LLMConfig( + model="gpt-4.1", # New model added by provider + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=256000, + handle=f"test-refresh-add-{test_id}/gpt-4.1", + provider_name=f"test-refresh-add-{test_id}", + provider_category=ProviderCategory.byok, + ), + ] + + # Create mock for initial sync during provider creation + mock_typed_provider_initial = MagicMock() + mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models) + mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[]) + + # Create the provider with initial models + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial): + byok_provider_create = ProviderCreate( + name=f"test-refresh-add-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Re-fetch the provider to get the updated last_synced value + byok_provider = await provider_manager.get_provider_async(byok_provider.id, default_user) + + # Verify initial sync - should have 1 model + initial_synced_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(initial_synced_models) == 1 + assert initial_synced_models[0].name == "gpt-4o" + + initial_last_synced = byok_provider.last_synced + assert initial_last_synced is not None # Verify sync happened during creation + + # Create mock for refresh with updated models + mock_typed_provider_refresh = MagicMock() + mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models) + mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[]) + + # Refresh the provider (simulating what the endpoint does) + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh): + await provider_manager._sync_default_models_for_provider(byok_provider, default_user) + + # Verify the API was called during refresh + mock_typed_provider_refresh.list_llm_models_async.assert_called_once() + + # Verify new model was added + refreshed_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(refreshed_models) == 2 + refreshed_names = {m.name for m in refreshed_models} + assert "gpt-4o" in refreshed_names + assert "gpt-4.1" in refreshed_names + + # Verify last_synced was updated + updated_provider = await provider_manager.get_provider_async(byok_provider.id, default_user) + assert updated_provider.last_synced is not None + assert updated_provider.last_synced >= initial_last_synced + + +@pytest.mark.asyncio +async def test_refresh_byok_provider_removes_old_models(default_user, provider_manager): + """Test that refreshing a BYOK provider removes models no longer available from the provider API. + + When _sync_default_models_for_provider is called (via refresh endpoint): + 1. It should fetch current models from the provider API + 2. Remove any models that are no longer available (soft delete) + 3. Keep models that are still available + """ + from letta.schemas.providers import Provider + + test_id = generate_test_id() + + # Initial models when provider is created (includes a model that will be removed) + initial_models = [ + LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-refresh-remove-{test_id}/gpt-4o", + provider_name=f"test-refresh-remove-{test_id}", + provider_category=ProviderCategory.byok, + ), + LLMConfig( + model="gpt-4-turbo", # This model will be deprecated/removed + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-refresh-remove-{test_id}/gpt-4-turbo", + provider_name=f"test-refresh-remove-{test_id}", + provider_category=ProviderCategory.byok, + ), + ] + + # Updated models after refresh (gpt-4-turbo is no longer available) + updated_models = [ + LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=f"test-refresh-remove-{test_id}/gpt-4o", + provider_name=f"test-refresh-remove-{test_id}", + provider_category=ProviderCategory.byok, + ), + ] + + # Create mock for initial sync during provider creation + mock_typed_provider_initial = MagicMock() + mock_typed_provider_initial.list_llm_models_async = AsyncMock(return_value=initial_models) + mock_typed_provider_initial.list_embedding_models_async = AsyncMock(return_value=[]) + + # Create the provider with initial models + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_initial): + byok_provider_create = ProviderCreate( + name=f"test-refresh-remove-{test_id}", + provider_type=ProviderType.openai, + api_key="sk-test-key", + ) + byok_provider = await provider_manager.create_provider_async(byok_provider_create, actor=default_user, is_byok=True) + + # Verify initial sync - should have 2 models + initial_synced_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(initial_synced_models) == 2 + initial_names = {m.name for m in initial_synced_models} + assert "gpt-4o" in initial_names + assert "gpt-4-turbo" in initial_names + + # Create mock for refresh with fewer models + mock_typed_provider_refresh = MagicMock() + mock_typed_provider_refresh.list_llm_models_async = AsyncMock(return_value=updated_models) + mock_typed_provider_refresh.list_embedding_models_async = AsyncMock(return_value=[]) + + # Refresh the provider (simulating what the endpoint does) + with patch.object(Provider, "cast_to_subtype", return_value=mock_typed_provider_refresh): + await provider_manager._sync_default_models_for_provider(byok_provider, default_user) + + # Verify the removed model is no longer in the list + refreshed_models = await provider_manager.list_models_async( + actor=default_user, + model_type="llm", + provider_id=byok_provider.id, + ) + assert len(refreshed_models) == 1 + assert refreshed_models[0].name == "gpt-4o" + + # Verify gpt-4-turbo was removed (soft deleted) + refreshed_names = {m.name for m in refreshed_models} + assert "gpt-4-turbo" not in refreshed_names + + +@pytest.mark.asyncio +async def test_refresh_base_provider_fails(default_user, provider_manager): + """Test that attempting to refresh a base provider returns an error. + + The refresh endpoint should only work for BYOK providers, not base providers. + Base providers are managed by environment variables and shouldn't be refreshed. + """ + from fastapi import HTTPException + + from letta.server.rest_api.routers.v1.providers import refresh_provider_models + from letta.server.server import SyncServer + + test_id = generate_test_id() + + # Create a base provider + base_provider_create = ProviderCreate( + name=f"test-base-refresh-{test_id}", + provider_type=ProviderType.openai, + api_key="", # Base providers don't store API keys + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + + # Verify it's a base provider + assert base_provider.provider_category == ProviderCategory.base + + # Create a mock server + server = SyncServer(init_with_default_org_and_user=False) + server.provider_manager = provider_manager + + # Create mock headers + mock_headers = MagicMock() + mock_headers.actor_id = default_user.id + + # Mock get_actor_or_default_async to return our test user + server.user_manager = MagicMock() + server.user_manager.get_actor_or_default_async = AsyncMock(return_value=default_user) + + # Attempt to refresh the base provider - should raise HTTPException + with pytest.raises(HTTPException) as exc_info: + await refresh_provider_models( + provider_id=base_provider.id, + headers=mock_headers, + server=server, + ) + + assert exc_info.value.status_code == 400 + assert "BYOK" in exc_info.value.detail + + +@pytest.mark.asyncio +async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, provider_manager): + """Test that get_model_by_handle_async returns the BYOK model when both BYOK and base providers have the same handle. + + This tests the legacy scenario where a user has both a BYOK provider and a base provider + with the same name (and thus models with the same handle). The BYOK model should be + returned because it's organization-specific, while base models are global. + """ + test_id = generate_test_id() + provider_name = f"test-duplicate-{test_id}" + model_handle = f"{provider_name}/gpt-4o" + + # Step 1: Create a base provider and sync a model for it (global, organization_id=None) + base_provider_create = ProviderCreate( + name=provider_name, + provider_type=ProviderType.openai, + api_key="", # Base providers don't store API keys + ) + base_provider = await provider_manager.create_provider_async(base_provider_create, actor=default_user, is_byok=False) + assert base_provider.provider_category == ProviderCategory.base + + # Sync a model for the base provider (global model with organization_id=None) + base_llm_model = LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=model_handle, + provider_name=provider_name, + ) + await provider_manager.sync_provider_models_async( + provider=base_provider, + llm_models=[base_llm_model], + embedding_models=[], + organization_id=None, # Global model + ) + + # Verify base model was created + base_model = await provider_manager.get_model_by_handle_async( + handle=model_handle, + actor=default_user, + model_type="llm", + ) + assert base_model is not None + assert base_model.handle == model_handle + assert base_model.organization_id is None # Global model + + # Step 2: Create a BYOK provider with the same name (simulating legacy duplicate) + # Note: In production, this is now prevented, but legacy data could have this + # We need to bypass the name conflict check for this test (simulating legacy data) + # Create the BYOK provider directly by manipulating the database + from letta.orm.provider import Provider as ProviderORM + from letta.schemas.providers import Provider as PydanticProvider + from letta.server.db import db_registry + + # Create a pydantic provider first to generate an ID + byok_pydantic_provider = PydanticProvider( + name=provider_name, # Same name as base provider + provider_type=ProviderType.openai, + provider_category=ProviderCategory.byok, + organization_id=default_user.organization_id, + ) + byok_pydantic_provider.resolve_identifier() + + async with db_registry.async_session() as session: + byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True)) + await byok_provider_orm.create_async(session, actor=default_user) + byok_provider = byok_provider_orm.to_pydantic() + + assert byok_provider.provider_category == ProviderCategory.byok + + # Sync a model for the BYOK provider (org-specific model) + byok_llm_model = LLMConfig( + model="gpt-4o", + model_endpoint_type="openai", + model_endpoint="https://api.openai.com/v1", + context_window=128000, + handle=model_handle, # Same handle as base model + provider_name=provider_name, + provider_category=ProviderCategory.byok, + ) + await provider_manager.sync_provider_models_async( + provider=byok_provider, + llm_models=[byok_llm_model], + embedding_models=[], + organization_id=default_user.organization_id, # Org-specific model + ) + + # Step 3: Verify that get_model_by_handle_async returns the BYOK model (org-specific) + retrieved_model = await provider_manager.get_model_by_handle_async( + handle=model_handle, + actor=default_user, + model_type="llm", + ) + + assert retrieved_model is not None + assert retrieved_model.handle == model_handle + # The key assertion: org-specific (BYOK) model should be returned, not the global (base) model + assert retrieved_model.organization_id == default_user.organization_id + assert retrieved_model.provider_id == byok_provider.id