fix: fix zai and others byok (#8991)

* fix: fix zai and other byok providers

* fix test

* get endpoint from typed provider and add test

* also add base_url on provider create
This commit is contained in:
Ari Webb
2026-01-20 19:05:51 -08:00
committed by Caren Thomas
parent 7133083b81
commit 2e826577d9
3 changed files with 115 additions and 10 deletions

View File

@@ -286,9 +286,7 @@ class SyncServer(object):
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
# Auto-append /v1 to the base URL
vllm_url = (
model_settings.vllm_api_base
if model_settings.vllm_api_base.endswith("/v1")
else model_settings.vllm_api_base + "/v1"
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
)
self._enabled_providers.append(
VLLMProvider(
@@ -302,9 +300,7 @@ class SyncServer(object):
if model_settings.sglang_api_base:
# Auto-append /v1 to the base URL
sglang_url = (
model_settings.sglang_api_base
if model_settings.sglang_api_base.endswith("/v1")
else model_settings.sglang_api_base + "/v1"
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
)
self._enabled_providers.append(
SGLangProvider(
@@ -1198,9 +1194,11 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
models = await typed_provider.list_llm_models_async()
embedding_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
@@ -1222,7 +1220,7 @@ class SyncServer(object):
llm_config = LLMConfig(
model=model.name,
model_endpoint_type=model.model_endpoint_type,
model_endpoint=provider.base_url,
model_endpoint=typed_provider.base_url,
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
handle=model.handle,
provider_name=provider.name,
@@ -1278,9 +1276,11 @@ class SyncServer(object):
for provider in byok_providers:
try:
# Get typed provider to access schema defaults (e.g., base_url)
typed_provider = provider.cast_to_subtype()
# Sync models if not synced yet
if provider.last_synced is None:
typed_provider = provider.cast_to_subtype()
llm_models = await typed_provider.list_llm_models_async()
emb_models = await typed_provider.list_embedding_models_async()
await self.provider_manager.sync_provider_models_async(
@@ -1302,7 +1302,7 @@ class SyncServer(object):
embedding_config = EmbeddingConfig(
embedding_model=model.name,
embedding_endpoint_type=model.model_endpoint_type,
embedding_endpoint=provider.base_url or model.model_endpoint_type,
embedding_endpoint=typed_provider.base_url,
embedding_dim=model.embedding_dim or 1536,
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
handle=model.handle,