feat: add zai provider support (#7626)

* feat: add zai provider support

* add zai_api_key secret to deploy-core

* add to justfile

* add testing, provider integration skill

* enable zai key

* fix zai test

* clean up skill a little

* small changes
This commit is contained in:
Ari Webb
2025-12-19 16:05:41 -08:00
committed by Caren Thomas
parent cb283373b7
commit cd45212acb
17 changed files with 351 additions and 13 deletions

View File

@@ -21021,6 +21021,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -21046,7 +21049,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -24789,6 +24793,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -24814,7 +24821,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -24897,6 +24905,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -24922,7 +24933,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -25739,6 +25751,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -25764,7 +25779,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -29972,6 +29988,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -29997,7 +30016,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -30893,7 +30913,8 @@
"together",
"bedrock",
"deepseek",
"xai"
"xai",
"zai"
],
"title": "Model Endpoint Type",
"description": "The endpoint type for the model."
@@ -33149,7 +33170,8 @@
"together",
"bedrock",
"deepseek",
"xai"
"xai",
"zai"
],
"title": "Model Endpoint Type",
"description": "Deprecated: Use 'provider_type' field instead. The endpoint type for the model.",
@@ -34596,7 +34618,8 @@
"openai",
"together",
"vllm",
"xai"
"xai",
"zai"
],
"title": "ProviderType"
},
@@ -39311,6 +39334,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -39336,7 +39362,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},
@@ -40262,6 +40289,68 @@
"title": "XAIModelSettings",
"description": "xAI model configuration (OpenAI-compatible)."
},
"ZAIModelSettings": {
"properties": {
"max_output_tokens": {
"type": "integer",
"title": "Max Output Tokens",
"description": "The maximum number of tokens the model can generate.",
"default": 4096
},
"parallel_tool_calls": {
"type": "boolean",
"title": "Parallel Tool Calls",
"description": "Whether to enable parallel tool calling.",
"default": false
},
"provider_type": {
"type": "string",
"const": "zai",
"title": "Provider Type",
"description": "The type of the provider.",
"default": "zai"
},
"temperature": {
"type": "number",
"title": "Temperature",
"description": "The temperature of the model.",
"default": 0.7
},
"response_format": {
"anyOf": [
{
"oneOf": [
{
"$ref": "#/components/schemas/TextResponseFormat"
},
{
"$ref": "#/components/schemas/JsonSchemaResponseFormat"
},
{
"$ref": "#/components/schemas/JsonObjectResponseFormat"
}
],
"discriminator": {
"propertyName": "type",
"mapping": {
"json_object": "#/components/schemas/JsonObjectResponseFormat",
"json_schema": "#/components/schemas/JsonSchemaResponseFormat",
"text": "#/components/schemas/TextResponseFormat"
}
}
},
{
"type": "null"
}
],
"title": "Response Format",
"description": "The response format for the model."
}
},
"type": "object",
"title": "ZAIModelSettings",
"description": "Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."
},
"letta__schemas__agent_file__AgentSchema": {
"properties": {
"name": {
@@ -40589,6 +40678,9 @@
{
"$ref": "#/components/schemas/XAIModelSettings"
},
{
"$ref": "#/components/schemas/ZAIModelSettings"
},
{
"$ref": "#/components/schemas/GroqModelSettings"
},
@@ -40614,7 +40706,8 @@
"groq": "#/components/schemas/GroqModelSettings",
"openai": "#/components/schemas/OpenAIModelSettings",
"together": "#/components/schemas/TogetherModelSettings",
"xai": "#/components/schemas/XAIModelSettings"
"xai": "#/components/schemas/XAIModelSettings",
"zai": "#/components/schemas/ZAIModelSettings"
}
}
},

View File

@@ -79,6 +79,13 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.zai:
from letta.llm_api.zai_client import ZAIClient
return ZAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.groq:
from letta.llm_api.groq_client import GroqClient

View File

@@ -0,0 +1,81 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class ZAIClient(OpenAIClient):
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to Z.ai API and returns raw response dict.
"""
api_key = model_settings.zai_api_key
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.zai_api_key
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
return [r.embedding for r in response.data]

View File

@@ -67,6 +67,7 @@ class ProviderType(str, Enum):
together = "together"
vllm = "vllm"
xai = "xai"
zai = "zai"
class AgentType(str, Enum):

View File

@@ -48,6 +48,7 @@ class LLMConfig(BaseModel):
"bedrock",
"deepseek",
"xai",
"zai",
] = Field(..., description="The endpoint type for the model.")
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
@@ -317,6 +318,7 @@ class LLMConfig(BaseModel):
OpenAIReasoning,
TogetherModelSettings,
XAIModelSettings,
ZAIModelSettings,
)
if self.model_endpoint_type == "openai":
@@ -359,6 +361,11 @@ class LLMConfig(BaseModel):
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "zai":
return ZAIModelSettings(
max_output_tokens=self.max_tokens or 4096,
temperature=self.temperature,
)
elif self.model_endpoint_type == "groq":
return GroqModelSettings(
max_output_tokens=self.max_tokens or 4096,

View File

@@ -47,6 +47,7 @@ class Model(LLMConfig, ModelBase):
"bedrock",
"deepseek",
"xai",
"zai",
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
context_window: int = Field(
..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True
@@ -131,6 +132,7 @@ class Model(LLMConfig, ModelBase):
ProviderType.google_vertex: GoogleVertexModelSettings,
ProviderType.azure: AzureModelSettings,
ProviderType.xai: XAIModelSettings,
ProviderType.zai: ZAIModelSettings,
ProviderType.groq: GroqModelSettings,
ProviderType.deepseek: DeepseekModelSettings,
ProviderType.together: TogetherModelSettings,
@@ -352,6 +354,22 @@ class XAIModelSettings(ModelSettings):
}
class ZAIModelSettings(ModelSettings):
"""Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."""
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
temperature: float = Field(0.7, description="The temperature of the model.")
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
def _to_legacy_config_params(self) -> dict:
return {
"temperature": self.temperature,
"max_tokens": self.max_output_tokens,
"response_format": self.response_format,
"parallel_tool_calls": self.parallel_tool_calls,
}
class GroqModelSettings(ModelSettings):
"""Groq model configuration (OpenAI-compatible)."""
@@ -424,6 +442,7 @@ ModelSettingsUnion = Annotated[
GoogleVertexModelSettings,
AzureModelSettings,
XAIModelSettings,
ZAIModelSettings,
GroqModelSettings,
DeepseekModelSettings,
TogetherModelSettings,

View File

@@ -18,6 +18,7 @@ from .openrouter import OpenRouterProvider
from .together import TogetherProvider
from .vllm import VLLMProvider
from .xai import XAIProvider
from .zai import ZAIProvider
__all__ = [
# Base classes
@@ -43,5 +44,6 @@ __all__ = [
"TogetherProvider",
"VLLMProvider", # Replaces ChatCompletions and Completions
"XAIProvider",
"ZAIProvider",
"OpenRouterProvider",
]

View File

@@ -196,6 +196,7 @@ class Provider(ProviderBase):
TogetherProvider,
VLLMProvider,
XAIProvider,
ZAIProvider,
)
if self.base_url == "":
@@ -230,6 +231,8 @@ class Provider(ProviderBase):
return CerebrasProvider(**self.model_dump(exclude_none=True))
case ProviderType.xai:
return XAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.zai:
return ZAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.lmstudio_openai:
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
case ProviderType.bedrock:

View File

@@ -0,0 +1,75 @@
from typing import Literal
from letta.log import get_logger
logger = get_logger(__name__)
from pydantic import Field
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.providers.openai import OpenAIProvider
# Z.ai model context windows
# Reference: https://docs.z.ai/
MODEL_CONTEXT_WINDOWS = {
"glm-4": 128_000,
"glm-4.5": 128_000,
"glm-4.5-air": 128_000,
"glm-4.6": 128_000,
"glm-4.6v": 128_000,
"glm-4-assistant": 128_000,
"charglm-3": 8_000,
}
class ZAIProvider(OpenAIProvider):
"""Z.ai (ZhipuAI) provider - https://docs.z.ai/"""
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
api_key: str | None = Field(None, description="API key for the Z.ai API.", deprecated=True)
base_url: str = Field("https://api.z.ai/api/paas/v4/", description="Base URL for the Z.ai API.")
def get_model_context_window_size(self, model_name: str) -> int | None:
# Z.ai doesn't return context window in the model listing,
# this is hardcoded from documentation
return MODEL_CONTEXT_WINDOWS.get(model_name)
async def list_llm_models_async(self) -> list[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list_async
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
data = response.get("data", response)
configs = []
for model in data:
assert "id" in model, f"Z.ai model missing 'id' field: {model}"
model_name = model["id"]
# In case Z.ai starts supporting it in the future:
if "context_length" in model:
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
if not context_window_size:
logger.warning(f"Couldn't find context window size for model {model_name}")
continue
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type=self.provider_type.value,
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
max_tokens=self.get_default_max_output_tokens(model_name),
provider_name=self.name,
provider_category=self.provider_category,
)
)
return configs

View File

@@ -1522,6 +1522,7 @@ async def send_message(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -1772,6 +1773,7 @@ async def _process_message_background(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -2076,6 +2078,7 @@ async def preview_model_request(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]
@@ -2129,6 +2132,7 @@ async def summarize_messages(
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]

View File

@@ -69,6 +69,7 @@ from letta.schemas.providers import (
TogetherProvider,
VLLMProvider,
XAIProvider,
ZAIProvider,
)
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate
from letta.schemas.secret import Secret
@@ -316,6 +317,14 @@ class SyncServer(object):
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
)
)
if model_settings.zai_api_key:
self._enabled_providers.append(
ZAIProvider(
name="zai",
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
base_url=model_settings.zai_base_url,
)
)
if model_settings.openrouter_api_key:
self._enabled_providers.append(
OpenRouterProvider(

View File

@@ -421,6 +421,7 @@ class ProviderManager:
from letta.schemas.providers.groq import GroqProvider
from letta.schemas.providers.ollama import OllamaProvider
from letta.schemas.providers.openai import OpenAIProvider
from letta.schemas.providers.zai import ZAIProvider
provider_type_to_class = {
"openai": OpenAIProvider,
@@ -430,6 +431,7 @@ class ProviderManager:
"ollama": OllamaProvider,
"bedrock": BedrockProvider,
"azure": AzureProvider,
"zai": ZAIProvider,
}
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)

View File

@@ -460,6 +460,7 @@ class StreamingService:
"ollama",
"azure",
"xai",
"zai",
"groq",
"deepseek",
]

View File

@@ -134,6 +134,10 @@ class ModelSettings(BaseSettings):
# xAI / Grok
xai_api_key: Optional[str] = None
# Z.ai (ZhipuAI)
zai_api_key: Optional[str] = None
zai_base_url: str = "https://api.z.ai/api/paas/v4/"
# groq
groq_api_key: Optional[str] = None

View File

@@ -37,6 +37,7 @@ all_configs = [
"openai-gpt-5.json",
"claude-4-5-sonnet.json",
"gemini-2.5-pro.json",
"zai-glm-4.6.json",
]
@@ -206,9 +207,9 @@ def assert_tool_call_response(
# Reasoning is non-deterministic, so don't throw if missing
pass
# Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call
# Special case for claude-sonnet-4-5-20250929, opus-4.1, and zai which can generate an extra AssistantMessage before tool call
if (
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle)
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle or model_settings.get("provider_type") == "zai")
and index < len(messages)
and isinstance(messages[index], AssistantMessage)
):
@@ -436,6 +437,10 @@ def get_expected_message_count_range(
if "claude-opus-4-1" in model_handle:
expected_range += 1
# Z.ai models output an AssistantMessage with each ReasoningMessage (not just the final one)
if model_settings.get("provider_type") == "zai":
expected_range += 1
if tool_call:
# tool call and tool return messages
expected_message_count += 2
@@ -477,8 +482,10 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
is_google_ai_reasoning = (
model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
)
# Z.ai models output reasoning by default
is_zai_reasoning = model_settings.get("provider_type") == "zai"
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning
# ------------------------------

View File

@@ -0,0 +1,9 @@
{
"handle": "zai/glm-4.6",
"model_settings": {
"provider_type": "zai",
"temperature": 0.7,
"max_output_tokens": 4096,
"parallel_tool_calls": false
}
}

View File

@@ -15,6 +15,7 @@ from letta.schemas.providers import (
OpenAIProvider,
TogetherProvider,
VLLMProvider,
ZAIProvider,
)
from letta.schemas.secret import Secret
from letta.settings import model_settings
@@ -104,6 +105,19 @@ async def test_deepseek():
assert models[0].handle == f"{provider.name}/{models[0].model}"
@pytest.mark.skipif(model_settings.zai_api_key is None, reason="Only run if ZAI_API_KEY is set.")
@pytest.mark.asyncio
async def test_zai():
provider = ZAIProvider(
name="zai",
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
base_url=model_settings.zai_base_url,
)
models = await provider.list_llm_models_async()
assert len(models) > 0
assert models[0].handle == f"{provider.name}/{models[0].model}"
@pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.")
@pytest.mark.asyncio
async def test_groq():