fix: fix zai and others byok (#8991)

* fix: fix zai and other byok providers

* fix test

* get endpoint from typed provider and add test

* also add base_url on provider create
This commit is contained in:
Ari Webb
2026-01-20 19:05:51 -08:00
committed by Caren Thomas
parent 7133083b81
commit 2e826577d9
3 changed files with 115 additions and 10 deletions

View File

@@ -286,9 +286,7 @@ class SyncServer(object):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
# Auto-append /v1 to the base URL
vllm_url = (
model_settings.vllm_api_base
if model_settings.vllm_api_base.endswith("/v1")
else model_settings.vllm_api_base + "/v1"
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
)
self._enabled_providers.append(
VLLMProvider(
@@ -302,9 +300,7 @@ class SyncServer(object):
if model_settings.sglang_api_base:
# Auto-append /v1 to the base URL
sglang_url = (
model_settings.sglang_api_base
if model_settings.sglang_api_base.endswith("/v1")
else model_settings.sglang_api_base + "/v1"
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
)
self._enabled_providers.append(
SGLangProvider(
@@ -1198,9 +1194,11 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
@@ -1222,7 +1220,7 @@ class SyncServer(object):
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url,
model_endpoint=typed_provider.base_url,
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
handle=model.handle,
provider_name=provider.name,
@@ -1278,9 +1276,11 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
emb_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
@@ -1302,7 +1302,7 @@ class SyncServer(object):
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_endpoint=typed_provider.base_url,
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,

View File

@@ -140,6 +140,14 @@ class ProviderManager:
# if provider.name == provider.provider_type.value:
# raise ValueError("Provider name must be unique and different from provider type")
# Fill in schema-default base_url if not provided
# This ensures providers like ZAI get their default endpoint persisted to DB
# rather than relying on cast_to_subtype() at read time
if provider.base_url is None:
typed_provider = provider.cast_to_subtype()
if typed_provider.base_url is not None:
provider.base_url = typed_provider.base_url
# Only assign organization id for non-base providers
# Base providers should be globally accessible (org_id = None)
if is_byok:

View File

@@ -3143,3 +3143,100 @@ async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, prov
# The key assertion: org-specific (BYOK) model should be returned, not the global (base) model
assert retrieved_model.organization_id == default_user.organization_id
assert retrieved_model.provider_id == byok_provider.id
@pytest.mark.asyncio
async def test_byok_provider_uses_schema_default_base_url(default_user, provider_manager):
"""Test that BYOK providers with schema-default base_url get correct model_endpoint.
This tests a bug where providers like ZAI have a schema-default base_url
(e.g., "https://api.z.ai/api/paas/v4/") that isn't stored in the database.
When list_llm_models_async reads from DB, the base_url is NULL, and if the code
uses provider.base_url directly instead of typed_provider.base_url, the
model_endpoint would be None/wrong, causing requests to go to the wrong endpoint.
The fix uses cast_to_subtype() to get the typed provider with schema defaults.
"""
from letta.orm.provider import Provider as ProviderORM
from letta.schemas.providers import Provider as PydanticProvider
from letta.schemas.providers.zai import ZAIProvider
from letta.server.db import db_registry
test_id = generate_test_id()
provider_name = f"test-zai-{test_id}"
# Create a ZAI BYOK provider WITHOUT explicitly setting base_url
# This simulates what happens when a user creates a ZAI provider via the API
# The schema default "https://api.z.ai/api/paas/v4/" applies in memory but
# may not be stored in the database (base_url column is NULL)
byok_pydantic_provider = PydanticProvider(
name=provider_name,
provider_type=ProviderType.zai,
provider_category=ProviderCategory.byok,
organization_id=default_user.organization_id,
# NOTE: base_url is intentionally NOT set - this is the bug scenario
# The DB will have base_url=NULL
)
byok_pydantic_provider.resolve_identifier()
async with db_registry.async_session() as session:
byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True))
await byok_provider_orm.create_async(session, actor=default_user)
byok_provider = byok_provider_orm.to_pydantic()
# Verify base_url is None in the provider loaded from DB
assert byok_provider.base_url is None, "base_url should be NULL in DB for this test"
assert byok_provider.provider_type == ProviderType.zai
# Sync a model for the provider (simulating what happens after provider creation)
# Set last_synced so the server reads from DB instead of calling provider API
from datetime import datetime, timezone
async with db_registry.async_session() as session:
provider_orm = await ProviderORM.read_async(session, identifier=byok_provider.id, actor=None)
provider_orm.last_synced = datetime.now(timezone.utc)
await session.commit()
model_handle = f"{provider_name}/glm-4-flash"
byok_llm_model = LLMConfig(
model="glm-4-flash",
model_endpoint_type="zai",
model_endpoint="https://api.z.ai/api/paas/v4/", # The correct endpoint
context_window=128000,
handle=model_handle,
provider_name=provider_name,
provider_category=ProviderCategory.byok,
)
await provider_manager.sync_provider_models_async(
provider=byok_provider,
llm_models=[byok_llm_model],
embedding_models=[],
organization_id=default_user.organization_id,
)
# Create server and list LLM models
server = SyncServer(init_with_default_org_and_user=False)
server.default_user = default_user
server.provider_manager = provider_manager
# List LLM models - this should use typed_provider.base_url (schema default)
# NOT provider.base_url (which is NULL in DB)
models = await server.list_llm_models_async(
actor=default_user,
provider_category=[ProviderCategory.byok], # Only BYOK providers
)
# Find our ZAI model
zai_models = [m for m in models if m.handle == model_handle]
assert len(zai_models) == 1, f"Expected 1 ZAI model, got {len(zai_models)}"
zai_model = zai_models[0]
# THE KEY ASSERTION: model_endpoint should be the ZAI schema default,
# NOT None (which would cause requests to go to OpenAI's endpoint)
expected_endpoint = "https://api.z.ai/api/paas/v4/"
assert zai_model.model_endpoint == expected_endpoint, (
f"model_endpoint should be '{expected_endpoint}' from ZAI schema default, "
f"but got '{zai_model.model_endpoint}'. This indicates the bug where "
f"provider.base_url (NULL from DB) was used instead of typed_provider.base_url."
)