feat: add zai provider support (#7626)
* feat: add zai provider support * add zai_api_key secret to deploy-core * add to justfile * add testing, provider integration skill * enable zai key * fix zai test * clean up skill a little * small changes
This commit is contained in:
@@ -37,6 +37,7 @@ all_configs = [
|
||||
"openai-gpt-5.json",
|
||||
"claude-4-5-sonnet.json",
|
||||
"gemini-2.5-pro.json",
|
||||
"zai-glm-4.6.json",
|
||||
]
|
||||
|
||||
|
||||
@@ -206,9 +207,9 @@ def assert_tool_call_response(
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
# Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call
|
||||
# Special case for claude-sonnet-4-5-20250929, opus-4.1, and zai which can generate an extra AssistantMessage before tool call
|
||||
if (
|
||||
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle)
|
||||
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle or model_settings.get("provider_type") == "zai")
|
||||
and index < len(messages)
|
||||
and isinstance(messages[index], AssistantMessage)
|
||||
):
|
||||
@@ -436,6 +437,10 @@ def get_expected_message_count_range(
|
||||
if "claude-opus-4-1" in model_handle:
|
||||
expected_range += 1
|
||||
|
||||
# Z.ai models output an AssistantMessage with each ReasoningMessage (not just the final one)
|
||||
if model_settings.get("provider_type") == "zai":
|
||||
expected_range += 1
|
||||
|
||||
if tool_call:
|
||||
# tool call and tool return messages
|
||||
expected_message_count += 2
|
||||
@@ -477,8 +482,10 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
|
||||
is_google_ai_reasoning = (
|
||||
model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
|
||||
)
|
||||
# Z.ai models output reasoning by default
|
||||
is_zai_reasoning = model_settings.get("provider_type") == "zai"
|
||||
|
||||
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning
|
||||
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
9
tests/model_settings/zai-glm-4.6.json
Normal file
9
tests/model_settings/zai-glm-4.6.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"handle": "zai/glm-4.6",
|
||||
"model_settings": {
|
||||
"provider_type": "zai",
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 4096,
|
||||
"parallel_tool_calls": false
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.providers import (
|
||||
OpenAIProvider,
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.settings import model_settings
|
||||
@@ -104,6 +105,19 @@ async def test_deepseek():
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(model_settings.zai_api_key is None, reason="Only run if ZAI_API_KEY is set.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_zai():
|
||||
provider = ZAIProvider(
|
||||
name="zai",
|
||||
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
|
||||
base_url=model_settings.zai_base_url,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_groq():
|
||||
|
||||
Reference in New Issue
Block a user