diff --git a/letta/server/server.py b/letta/server/server.py index f3225ecf..eea548e9 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -286,9 +286,7 @@ class SyncServer(object): # e.g. "... --enable-auto-tool-choice --tool-call-parser hermes" # Auto-append /v1 to the base URL vllm_url = ( - model_settings.vllm_api_base - if model_settings.vllm_api_base.endswith("/v1") - else model_settings.vllm_api_base + "/v1" + model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1" ) self._enabled_providers.append( VLLMProvider( @@ -302,9 +300,7 @@ class SyncServer(object): if model_settings.sglang_api_base: # Auto-append /v1 to the base URL sglang_url = ( - model_settings.sglang_api_base - if model_settings.sglang_api_base.endswith("/v1") - else model_settings.sglang_api_base + "/v1" + model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1" ) self._enabled_providers.append( SGLangProvider( @@ -1198,9 +1194,11 @@ class SyncServer(object): for provider in byok_providers: try: + # Get typed provider to access schema defaults (e.g., base_url) + typed_provider = provider.cast_to_subtype() + # Sync models if not synced yet if provider.last_synced is None: - typed_provider = provider.cast_to_subtype() models = await typed_provider.list_llm_models_async() embedding_models = await typed_provider.list_embedding_models_async() await self.provider_manager.sync_provider_models_async( @@ -1222,7 +1220,7 @@ class SyncServer(object): llm_config = LLMConfig( model=model.name, model_endpoint_type=model.model_endpoint_type, - model_endpoint=provider.base_url, + model_endpoint=typed_provider.base_url, context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW, handle=model.handle, provider_name=provider.name, @@ -1278,9 +1276,11 @@ class SyncServer(object): for provider in byok_providers: try: + # Get typed provider to access schema defaults (e.g., base_url) + typed_provider = provider.cast_to_subtype() + # Sync models if not synced yet if provider.last_synced is None: - typed_provider = provider.cast_to_subtype() llm_models = await typed_provider.list_llm_models_async() emb_models = await typed_provider.list_embedding_models_async() await self.provider_manager.sync_provider_models_async( @@ -1302,7 +1302,7 @@ class SyncServer(object): embedding_config = EmbeddingConfig( embedding_model=model.name, embedding_endpoint_type=model.model_endpoint_type, - embedding_endpoint=provider.base_url or model.model_endpoint_type, + embedding_endpoint=typed_provider.base_url, embedding_dim=model.embedding_dim or 1536, embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE, handle=model.handle, diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 1e846170..55af875c 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -140,6 +140,14 @@ class ProviderManager: # if provider.name == provider.provider_type.value: # raise ValueError("Provider name must be unique and different from provider type") + # Fill in schema-default base_url if not provided + # This ensures providers like ZAI get their default endpoint persisted to DB + # rather than relying on cast_to_subtype() at read time + if provider.base_url is None: + typed_provider = provider.cast_to_subtype() + if typed_provider.base_url is not None: + provider.base_url = typed_provider.base_url + # Only assign organization id for non-base providers # Base providers should be globally accessible (org_id = None) if is_byok: diff --git a/tests/test_server_providers.py b/tests/test_server_providers.py index 1e07ceea..10579fc7 100644 --- a/tests/test_server_providers.py +++ b/tests/test_server_providers.py @@ -3143,3 +3143,100 @@ async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, prov # The key assertion: org-specific (BYOK) model should be returned, not the global (base) model assert retrieved_model.organization_id == default_user.organization_id assert retrieved_model.provider_id == byok_provider.id + + +@pytest.mark.asyncio +async def test_byok_provider_uses_schema_default_base_url(default_user, provider_manager): + """Test that BYOK providers with schema-default base_url get correct model_endpoint. + + This tests a bug where providers like ZAI have a schema-default base_url + (e.g., "https://api.z.ai/api/paas/v4/") that isn't stored in the database. + When list_llm_models_async reads from DB, the base_url is NULL, and if the code + uses provider.base_url directly instead of typed_provider.base_url, the + model_endpoint would be None/wrong, causing requests to go to the wrong endpoint. + + The fix uses cast_to_subtype() to get the typed provider with schema defaults. + """ + from letta.orm.provider import Provider as ProviderORM + from letta.schemas.providers import Provider as PydanticProvider + from letta.schemas.providers.zai import ZAIProvider + from letta.server.db import db_registry + + test_id = generate_test_id() + provider_name = f"test-zai-{test_id}" + + # Create a ZAI BYOK provider WITHOUT explicitly setting base_url + # This simulates what happens when a user creates a ZAI provider via the API + # The schema default "https://api.z.ai/api/paas/v4/" applies in memory but + # may not be stored in the database (base_url column is NULL) + byok_pydantic_provider = PydanticProvider( + name=provider_name, + provider_type=ProviderType.zai, + provider_category=ProviderCategory.byok, + organization_id=default_user.organization_id, + # NOTE: base_url is intentionally NOT set - this is the bug scenario + # The DB will have base_url=NULL + ) + byok_pydantic_provider.resolve_identifier() + + async with db_registry.async_session() as session: + byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True)) + await byok_provider_orm.create_async(session, actor=default_user) + byok_provider = byok_provider_orm.to_pydantic() + + # Verify base_url is None in the provider loaded from DB + assert byok_provider.base_url is None, "base_url should be NULL in DB for this test" + assert byok_provider.provider_type == ProviderType.zai + + # Sync a model for the provider (simulating what happens after provider creation) + # Set last_synced so the server reads from DB instead of calling provider API + from datetime import datetime, timezone + + async with db_registry.async_session() as session: + provider_orm = await ProviderORM.read_async(session, identifier=byok_provider.id, actor=None) + provider_orm.last_synced = datetime.now(timezone.utc) + await session.commit() + + model_handle = f"{provider_name}/glm-4-flash" + byok_llm_model = LLMConfig( + model="glm-4-flash", + model_endpoint_type="zai", + model_endpoint="https://api.z.ai/api/paas/v4/", # The correct endpoint + context_window=128000, + handle=model_handle, + provider_name=provider_name, + provider_category=ProviderCategory.byok, + ) + await provider_manager.sync_provider_models_async( + provider=byok_provider, + llm_models=[byok_llm_model], + embedding_models=[], + organization_id=default_user.organization_id, + ) + + # Create server and list LLM models + server = SyncServer(init_with_default_org_and_user=False) + server.default_user = default_user + server.provider_manager = provider_manager + + # List LLM models - this should use typed_provider.base_url (schema default) + # NOT provider.base_url (which is NULL in DB) + models = await server.list_llm_models_async( + actor=default_user, + provider_category=[ProviderCategory.byok], # Only BYOK providers + ) + + # Find our ZAI model + zai_models = [m for m in models if m.handle == model_handle] + assert len(zai_models) == 1, f"Expected 1 ZAI model, got {len(zai_models)}" + + zai_model = zai_models[0] + + # THE KEY ASSERTION: model_endpoint should be the ZAI schema default, + # NOT None (which would cause requests to go to OpenAI's endpoint) + expected_endpoint = "https://api.z.ai/api/paas/v4/" + assert zai_model.model_endpoint == expected_endpoint, ( + f"model_endpoint should be '{expected_endpoint}' from ZAI schema default, " + f"but got '{zai_model.model_endpoint}'. This indicates the bug where " + f"provider.base_url (NULL from DB) was used instead of typed_provider.base_url." + )