fix: fix zai and others byok (#8991)
* fix: fix zai and other byok providers * fix test * get endpoint from typed provider and add test * also add base_url on provider create
This commit is contained in:
@@ -286,9 +286,7 @@ class SyncServer(object):
|
||||
# e.g. "... --enable-auto-tool-choice --tool-call-parser hermes"
|
||||
# Auto-append /v1 to the base URL
|
||||
vllm_url = (
|
||||
model_settings.vllm_api_base
|
||||
if model_settings.vllm_api_base.endswith("/v1")
|
||||
else model_settings.vllm_api_base + "/v1"
|
||||
model_settings.vllm_api_base if model_settings.vllm_api_base.endswith("/v1") else model_settings.vllm_api_base + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(
|
||||
VLLMProvider(
|
||||
@@ -302,9 +300,7 @@ class SyncServer(object):
|
||||
if model_settings.sglang_api_base:
|
||||
# Auto-append /v1 to the base URL
|
||||
sglang_url = (
|
||||
model_settings.sglang_api_base
|
||||
if model_settings.sglang_api_base.endswith("/v1")
|
||||
else model_settings.sglang_api_base + "/v1"
|
||||
model_settings.sglang_api_base if model_settings.sglang_api_base.endswith("/v1") else model_settings.sglang_api_base + "/v1"
|
||||
)
|
||||
self._enabled_providers.append(
|
||||
SGLangProvider(
|
||||
@@ -1198,9 +1194,11 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
# Get typed provider to access schema defaults (e.g., base_url)
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
models = await typed_provider.list_llm_models_async()
|
||||
embedding_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
@@ -1222,7 +1220,7 @@ class SyncServer(object):
|
||||
llm_config = LLMConfig(
|
||||
model=model.name,
|
||||
model_endpoint_type=model.model_endpoint_type,
|
||||
model_endpoint=provider.base_url,
|
||||
model_endpoint=typed_provider.base_url,
|
||||
context_window=model.max_context_window or constants.DEFAULT_CONTEXT_WINDOW,
|
||||
handle=model.handle,
|
||||
provider_name=provider.name,
|
||||
@@ -1278,9 +1276,11 @@ class SyncServer(object):
|
||||
|
||||
for provider in byok_providers:
|
||||
try:
|
||||
# Get typed provider to access schema defaults (e.g., base_url)
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
|
||||
# Sync models if not synced yet
|
||||
if provider.last_synced is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
llm_models = await typed_provider.list_llm_models_async()
|
||||
emb_models = await typed_provider.list_embedding_models_async()
|
||||
await self.provider_manager.sync_provider_models_async(
|
||||
@@ -1302,7 +1302,7 @@ class SyncServer(object):
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_model=model.name,
|
||||
embedding_endpoint_type=model.model_endpoint_type,
|
||||
embedding_endpoint=provider.base_url or model.model_endpoint_type,
|
||||
embedding_endpoint=typed_provider.base_url,
|
||||
embedding_dim=model.embedding_dim or 1536,
|
||||
embedding_chunk_size=constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
handle=model.handle,
|
||||
|
||||
@@ -140,6 +140,14 @@ class ProviderManager:
|
||||
# if provider.name == provider.provider_type.value:
|
||||
# raise ValueError("Provider name must be unique and different from provider type")
|
||||
|
||||
# Fill in schema-default base_url if not provided
|
||||
# This ensures providers like ZAI get their default endpoint persisted to DB
|
||||
# rather than relying on cast_to_subtype() at read time
|
||||
if provider.base_url is None:
|
||||
typed_provider = provider.cast_to_subtype()
|
||||
if typed_provider.base_url is not None:
|
||||
provider.base_url = typed_provider.base_url
|
||||
|
||||
# Only assign organization id for non-base providers
|
||||
# Base providers should be globally accessible (org_id = None)
|
||||
if is_byok:
|
||||
|
||||
@@ -3143,3 +3143,100 @@ async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, prov
|
||||
# The key assertion: org-specific (BYOK) model should be returned, not the global (base) model
|
||||
assert retrieved_model.organization_id == default_user.organization_id
|
||||
assert retrieved_model.provider_id == byok_provider.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_uses_schema_default_base_url(default_user, provider_manager):
|
||||
"""Test that BYOK providers with schema-default base_url get correct model_endpoint.
|
||||
|
||||
This tests a bug where providers like ZAI have a schema-default base_url
|
||||
(e.g., "https://api.z.ai/api/paas/v4/") that isn't stored in the database.
|
||||
When list_llm_models_async reads from DB, the base_url is NULL, and if the code
|
||||
uses provider.base_url directly instead of typed_provider.base_url, the
|
||||
model_endpoint would be None/wrong, causing requests to go to the wrong endpoint.
|
||||
|
||||
The fix uses cast_to_subtype() to get the typed provider with schema defaults.
|
||||
"""
|
||||
from letta.orm.provider import Provider as ProviderORM
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
from letta.server.db import db_registry
|
||||
|
||||
test_id = generate_test_id()
|
||||
provider_name = f"test-zai-{test_id}"
|
||||
|
||||
# Create a ZAI BYOK provider WITHOUT explicitly setting base_url
|
||||
# This simulates what happens when a user creates a ZAI provider via the API
|
||||
# The schema default "https://api.z.ai/api/paas/v4/" applies in memory but
|
||||
# may not be stored in the database (base_url column is NULL)
|
||||
byok_pydantic_provider = PydanticProvider(
|
||||
name=provider_name,
|
||||
provider_type=ProviderType.zai,
|
||||
provider_category=ProviderCategory.byok,
|
||||
organization_id=default_user.organization_id,
|
||||
# NOTE: base_url is intentionally NOT set - this is the bug scenario
|
||||
# The DB will have base_url=NULL
|
||||
)
|
||||
byok_pydantic_provider.resolve_identifier()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True))
|
||||
await byok_provider_orm.create_async(session, actor=default_user)
|
||||
byok_provider = byok_provider_orm.to_pydantic()
|
||||
|
||||
# Verify base_url is None in the provider loaded from DB
|
||||
assert byok_provider.base_url is None, "base_url should be NULL in DB for this test"
|
||||
assert byok_provider.provider_type == ProviderType.zai
|
||||
|
||||
# Sync a model for the provider (simulating what happens after provider creation)
|
||||
# Set last_synced so the server reads from DB instead of calling provider API
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
provider_orm = await ProviderORM.read_async(session, identifier=byok_provider.id, actor=None)
|
||||
provider_orm.last_synced = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
model_handle = f"{provider_name}/glm-4-flash"
|
||||
byok_llm_model = LLMConfig(
|
||||
model="glm-4-flash",
|
||||
model_endpoint_type="zai",
|
||||
model_endpoint="https://api.z.ai/api/paas/v4/", # The correct endpoint
|
||||
context_window=128000,
|
||||
handle=model_handle,
|
||||
provider_name=provider_name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=byok_provider,
|
||||
llm_models=[byok_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=default_user.organization_id,
|
||||
)
|
||||
|
||||
# Create server and list LLM models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
|
||||
# List LLM models - this should use typed_provider.base_url (schema default)
|
||||
# NOT provider.base_url (which is NULL in DB)
|
||||
models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok], # Only BYOK providers
|
||||
)
|
||||
|
||||
# Find our ZAI model
|
||||
zai_models = [m for m in models if m.handle == model_handle]
|
||||
assert len(zai_models) == 1, f"Expected 1 ZAI model, got {len(zai_models)}"
|
||||
|
||||
zai_model = zai_models[0]
|
||||
|
||||
# THE KEY ASSERTION: model_endpoint should be the ZAI schema default,
|
||||
# NOT None (which would cause requests to go to OpenAI's endpoint)
|
||||
expected_endpoint = "https://api.z.ai/api/paas/v4/"
|
||||
assert zai_model.model_endpoint == expected_endpoint, (
|
||||
f"model_endpoint should be '{expected_endpoint}' from ZAI schema default, "
|
||||
f"but got '{zai_model.model_endpoint}'. This indicates the bug where "
|
||||
f"provider.base_url (NULL from DB) was used instead of typed_provider.base_url."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user