fix: fix zai and others byok (#8991)
* fix: fix zai and other byok providers * fix test * get endpoint from typed provider and add test * also add base_url on provider create
This commit is contained in:
@@ -3143,3 +3143,100 @@ async def test_get_model_by_handle_prioritizes_byok_over_base(default_user, prov
|
||||
# The key assertion: org-specific (BYOK) model should be returned, not the global (base) model
|
||||
assert retrieved_model.organization_id == default_user.organization_id
|
||||
assert retrieved_model.provider_id == byok_provider.id
|
||||
|
||||
|
||||
@pytest.mark.asyncio
|
||||
async def test_byok_provider_uses_schema_default_base_url(default_user, provider_manager):
|
||||
"""Test that BYOK providers with schema-default base_url get correct model_endpoint.
|
||||
|
||||
This tests a bug where providers like ZAI have a schema-default base_url
|
||||
(e.g., "https://api.z.ai/api/paas/v4/") that isn't stored in the database.
|
||||
When list_llm_models_async reads from DB, the base_url is NULL, and if the code
|
||||
uses provider.base_url directly instead of typed_provider.base_url, the
|
||||
model_endpoint would be None/wrong, causing requests to go to the wrong endpoint.
|
||||
|
||||
The fix uses cast_to_subtype() to get the typed provider with schema defaults.
|
||||
"""
|
||||
from letta.orm.provider import Provider as ProviderORM
|
||||
from letta.schemas.providers import Provider as PydanticProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
from letta.server.db import db_registry
|
||||
|
||||
test_id = generate_test_id()
|
||||
provider_name = f"test-zai-{test_id}"
|
||||
|
||||
# Create a ZAI BYOK provider WITHOUT explicitly setting base_url
|
||||
# This simulates what happens when a user creates a ZAI provider via the API
|
||||
# The schema default "https://api.z.ai/api/paas/v4/" applies in memory but
|
||||
# may not be stored in the database (base_url column is NULL)
|
||||
byok_pydantic_provider = PydanticProvider(
|
||||
name=provider_name,
|
||||
provider_type=ProviderType.zai,
|
||||
provider_category=ProviderCategory.byok,
|
||||
organization_id=default_user.organization_id,
|
||||
# NOTE: base_url is intentionally NOT set - this is the bug scenario
|
||||
# The DB will have base_url=NULL
|
||||
)
|
||||
byok_pydantic_provider.resolve_identifier()
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
byok_provider_orm = ProviderORM(**byok_pydantic_provider.model_dump(to_orm=True))
|
||||
await byok_provider_orm.create_async(session, actor=default_user)
|
||||
byok_provider = byok_provider_orm.to_pydantic()
|
||||
|
||||
# Verify base_url is None in the provider loaded from DB
|
||||
assert byok_provider.base_url is None, "base_url should be NULL in DB for this test"
|
||||
assert byok_provider.provider_type == ProviderType.zai
|
||||
|
||||
# Sync a model for the provider (simulating what happens after provider creation)
|
||||
# Set last_synced so the server reads from DB instead of calling provider API
|
||||
from datetime import datetime, timezone
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
provider_orm = await ProviderORM.read_async(session, identifier=byok_provider.id, actor=None)
|
||||
provider_orm.last_synced = datetime.now(timezone.utc)
|
||||
await session.commit()
|
||||
|
||||
model_handle = f"{provider_name}/glm-4-flash"
|
||||
byok_llm_model = LLMConfig(
|
||||
model="glm-4-flash",
|
||||
model_endpoint_type="zai",
|
||||
model_endpoint="https://api.z.ai/api/paas/v4/", # The correct endpoint
|
||||
context_window=128000,
|
||||
handle=model_handle,
|
||||
provider_name=provider_name,
|
||||
provider_category=ProviderCategory.byok,
|
||||
)
|
||||
await provider_manager.sync_provider_models_async(
|
||||
provider=byok_provider,
|
||||
llm_models=[byok_llm_model],
|
||||
embedding_models=[],
|
||||
organization_id=default_user.organization_id,
|
||||
)
|
||||
|
||||
# Create server and list LLM models
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
server.default_user = default_user
|
||||
server.provider_manager = provider_manager
|
||||
|
||||
# List LLM models - this should use typed_provider.base_url (schema default)
|
||||
# NOT provider.base_url (which is NULL in DB)
|
||||
models = await server.list_llm_models_async(
|
||||
actor=default_user,
|
||||
provider_category=[ProviderCategory.byok], # Only BYOK providers
|
||||
)
|
||||
|
||||
# Find our ZAI model
|
||||
zai_models = [m for m in models if m.handle == model_handle]
|
||||
assert len(zai_models) == 1, f"Expected 1 ZAI model, got {len(zai_models)}"
|
||||
|
||||
zai_model = zai_models[0]
|
||||
|
||||
# THE KEY ASSERTION: model_endpoint should be the ZAI schema default,
|
||||
# NOT None (which would cause requests to go to OpenAI's endpoint)
|
||||
expected_endpoint = "https://api.z.ai/api/paas/v4/"
|
||||
assert zai_model.model_endpoint == expected_endpoint, (
|
||||
f"model_endpoint should be '{expected_endpoint}' from ZAI schema default, "
|
||||
f"but got '{zai_model.model_endpoint}'. This indicates the bug where "
|
||||
f"provider.base_url (NULL from DB) was used instead of typed_provider.base_url."
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user