From cd45212acbdb8c2ca106cfafa47a24671fbe566f Mon Sep 17 00:00:00 2001 From: Ari Webb Date: Fri, 19 Dec 2025 16:05:41 -0800 Subject: [PATCH] feat: add zai provider support (#7626) * feat: add zai provider support * add zai_api_key secret to deploy-core * add to justfile * add testing, provider integration skill * enable zai key * fix zai test * clean up skill a little * small changes --- fern/openapi.json | 113 +++++++++++++++++++-- letta/llm_api/llm_client.py | 7 ++ letta/llm_api/zai_client.py | 81 +++++++++++++++ letta/schemas/enums.py | 1 + letta/schemas/llm_config.py | 7 ++ letta/schemas/model.py | 19 ++++ letta/schemas/providers/__init__.py | 2 + letta/schemas/providers/base.py | 3 + letta/schemas/providers/zai.py | 75 ++++++++++++++ letta/server/rest_api/routers/v1/agents.py | 4 + letta/server/server.py | 9 ++ letta/services/provider_manager.py | 2 + letta/services/streaming_service.py | 1 + letta/settings.py | 4 + tests/integration_test_send_message_v2.py | 13 ++- tests/model_settings/zai-glm-4.6.json | 9 ++ tests/test_providers.py | 14 +++ 17 files changed, 351 insertions(+), 13 deletions(-) create mode 100644 letta/llm_api/zai_client.py create mode 100644 letta/schemas/providers/zai.py create mode 100644 tests/model_settings/zai-glm-4.6.json diff --git a/fern/openapi.json b/fern/openapi.json index 26a2ecd6..95c49b16 100644 --- a/fern/openapi.json +++ b/fern/openapi.json @@ -21021,6 +21021,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -21046,7 +21049,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -24789,6 +24793,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -24814,7 +24821,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -24897,6 +24905,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -24922,7 +24933,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -25739,6 +25751,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -25764,7 +25779,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -29972,6 +29988,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -29997,7 +30016,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -30893,7 +30913,8 @@ "together", "bedrock", "deepseek", - "xai" + "xai", + "zai" ], "title": "Model Endpoint Type", "description": "The endpoint type for the model." @@ -33149,7 +33170,8 @@ "together", "bedrock", "deepseek", - "xai" + "xai", + "zai" ], "title": "Model Endpoint Type", "description": "Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", @@ -34596,7 +34618,8 @@ "openai", "together", "vllm", - "xai" + "xai", + "zai" ], "title": "ProviderType" }, @@ -39311,6 +39334,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -39336,7 +39362,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, @@ -40262,6 +40289,68 @@ "title": "XAIModelSettings", "description": "xAI model configuration (OpenAI-compatible)." }, + "ZAIModelSettings": { + "properties": { + "max_output_tokens": { + "type": "integer", + "title": "Max Output Tokens", + "description": "The maximum number of tokens the model can generate.", + "default": 4096 + }, + "parallel_tool_calls": { + "type": "boolean", + "title": "Parallel Tool Calls", + "description": "Whether to enable parallel tool calling.", + "default": false + }, + "provider_type": { + "type": "string", + "const": "zai", + "title": "Provider Type", + "description": "The type of the provider.", + "default": "zai" + }, + "temperature": { + "type": "number", + "title": "Temperature", + "description": "The temperature of the model.", + "default": 0.7 + }, + "response_format": { + "anyOf": [ + { + "oneOf": [ + { + "$ref": "#/components/schemas/TextResponseFormat" + }, + { + "$ref": "#/components/schemas/JsonSchemaResponseFormat" + }, + { + "$ref": "#/components/schemas/JsonObjectResponseFormat" + } + ], + "discriminator": { + "propertyName": "type", + "mapping": { + "json_object": "#/components/schemas/JsonObjectResponseFormat", + "json_schema": "#/components/schemas/JsonSchemaResponseFormat", + "text": "#/components/schemas/TextResponseFormat" + } + } + }, + { + "type": "null" + } + ], + "title": "Response Format", + "description": "The response format for the model." + } + }, + "type": "object", + "title": "ZAIModelSettings", + "description": "Z.ai (ZhipuAI) model configuration (OpenAI-compatible)." + }, "letta__schemas__agent_file__AgentSchema": { "properties": { "name": { @@ -40589,6 +40678,9 @@ { "$ref": "#/components/schemas/XAIModelSettings" }, + { + "$ref": "#/components/schemas/ZAIModelSettings" + }, { "$ref": "#/components/schemas/GroqModelSettings" }, @@ -40614,7 +40706,8 @@ "groq": "#/components/schemas/GroqModelSettings", "openai": "#/components/schemas/OpenAIModelSettings", "together": "#/components/schemas/TogetherModelSettings", - "xai": "#/components/schemas/XAIModelSettings" + "xai": "#/components/schemas/XAIModelSettings", + "zai": "#/components/schemas/ZAIModelSettings" } } }, diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index d778b319..264d7e2f 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -79,6 +79,13 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.zai: + from letta.llm_api.zai_client import ZAIClient + + return ZAIClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case ProviderType.groq: from letta.llm_api.groq_client import GroqClient diff --git a/letta/llm_api/zai_client.py b/letta/llm_api/zai_client.py new file mode 100644 index 00000000..9eec79c2 --- /dev/null +++ b/letta/llm_api/zai_client.py @@ -0,0 +1,81 @@ +import os +from typing import List, Optional + +from openai import AsyncOpenAI, AsyncStream, OpenAI +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + +from letta.llm_api.openai_client import OpenAIClient +from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.enums import AgentType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage +from letta.settings import model_settings + + +class ZAIClient(OpenAIClient): + """Z.ai (ZhipuAI) client - uses OpenAI-compatible API.""" + + def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: + return False + + def supports_structured_output(self, llm_config: LLMConfig) -> bool: + return False + + @trace_method + def build_request_data( + self, + agent_type: AgentType, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + requires_subsequent_tool_call: bool = False, + tool_return_truncation_chars: Optional[int] = None, + ) -> dict: + data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call) + return data + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying synchronous request to Z.ai API and returns raw response dict. + """ + api_key = model_settings.zai_api_key + client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying asynchronous request to Z.ai API and returns raw response dict. + """ + api_key = model_settings.zai_api_key + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: + """ + Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator. + """ + api_key = model_settings.zai_api_key + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + **request_data, stream=True, stream_options={"include_usage": True} + ) + return response_stream + + @trace_method + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: + """Request embeddings given texts and embedding config""" + api_key = model_settings.zai_api_key + client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) + + return [r.embedding for r in response.data] diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index da4dc27f..e0a697b9 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -67,6 +67,7 @@ class ProviderType(str, Enum): together = "together" vllm = "vllm" xai = "xai" + zai = "zai" class AgentType(str, Enum): diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 26db755f..3f1c484d 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -48,6 +48,7 @@ class LLMConfig(BaseModel): "bedrock", "deepseek", "xai", + "zai", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") provider_name: Optional[str] = Field(None, description="The provider name for the model.") @@ -317,6 +318,7 @@ class LLMConfig(BaseModel): OpenAIReasoning, TogetherModelSettings, XAIModelSettings, + ZAIModelSettings, ) if self.model_endpoint_type == "openai": @@ -359,6 +361,11 @@ class LLMConfig(BaseModel): max_output_tokens=self.max_tokens or 4096, temperature=self.temperature, ) + elif self.model_endpoint_type == "zai": + return ZAIModelSettings( + max_output_tokens=self.max_tokens or 4096, + temperature=self.temperature, + ) elif self.model_endpoint_type == "groq": return GroqModelSettings( max_output_tokens=self.max_tokens or 4096, diff --git a/letta/schemas/model.py b/letta/schemas/model.py index d2aa6a7f..daf3291e 100644 --- a/letta/schemas/model.py +++ b/letta/schemas/model.py @@ -47,6 +47,7 @@ class Model(LLMConfig, ModelBase): "bedrock", "deepseek", "xai", + "zai", ] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True) context_window: int = Field( ..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True @@ -131,6 +132,7 @@ class Model(LLMConfig, ModelBase): ProviderType.google_vertex: GoogleVertexModelSettings, ProviderType.azure: AzureModelSettings, ProviderType.xai: XAIModelSettings, + ProviderType.zai: ZAIModelSettings, ProviderType.groq: GroqModelSettings, ProviderType.deepseek: DeepseekModelSettings, ProviderType.together: TogetherModelSettings, @@ -352,6 +354,22 @@ class XAIModelSettings(ModelSettings): } +class ZAIModelSettings(ModelSettings): + """Z.ai (ZhipuAI) model configuration (OpenAI-compatible).""" + + provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.") + temperature: float = Field(0.7, description="The temperature of the model.") + response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.") + + def _to_legacy_config_params(self) -> dict: + return { + "temperature": self.temperature, + "max_tokens": self.max_output_tokens, + "response_format": self.response_format, + "parallel_tool_calls": self.parallel_tool_calls, + } + + class GroqModelSettings(ModelSettings): """Groq model configuration (OpenAI-compatible).""" @@ -424,6 +442,7 @@ ModelSettingsUnion = Annotated[ GoogleVertexModelSettings, AzureModelSettings, XAIModelSettings, + ZAIModelSettings, GroqModelSettings, DeepseekModelSettings, TogetherModelSettings, diff --git a/letta/schemas/providers/__init__.py b/letta/schemas/providers/__init__.py index 2d6fd735..8486ee51 100644 --- a/letta/schemas/providers/__init__.py +++ b/letta/schemas/providers/__init__.py @@ -18,6 +18,7 @@ from .openrouter import OpenRouterProvider from .together import TogetherProvider from .vllm import VLLMProvider from .xai import XAIProvider +from .zai import ZAIProvider __all__ = [ # Base classes @@ -43,5 +44,6 @@ __all__ = [ "TogetherProvider", "VLLMProvider", # Replaces ChatCompletions and Completions "XAIProvider", + "ZAIProvider", "OpenRouterProvider", ] diff --git a/letta/schemas/providers/base.py b/letta/schemas/providers/base.py index d6ca1df2..1f11e8cd 100644 --- a/letta/schemas/providers/base.py +++ b/letta/schemas/providers/base.py @@ -196,6 +196,7 @@ class Provider(ProviderBase): TogetherProvider, VLLMProvider, XAIProvider, + ZAIProvider, ) if self.base_url == "": @@ -230,6 +231,8 @@ class Provider(ProviderBase): return CerebrasProvider(**self.model_dump(exclude_none=True)) case ProviderType.xai: return XAIProvider(**self.model_dump(exclude_none=True)) + case ProviderType.zai: + return ZAIProvider(**self.model_dump(exclude_none=True)) case ProviderType.lmstudio_openai: return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True)) case ProviderType.bedrock: diff --git a/letta/schemas/providers/zai.py b/letta/schemas/providers/zai.py new file mode 100644 index 00000000..28866f6a --- /dev/null +++ b/letta/schemas/providers/zai.py @@ -0,0 +1,75 @@ +from typing import Literal + +from letta.log import get_logger + +logger = get_logger(__name__) + +from pydantic import Field + +from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.llm_config import LLMConfig +from letta.schemas.providers.openai import OpenAIProvider + +# Z.ai model context windows +# Reference: https://docs.z.ai/ +MODEL_CONTEXT_WINDOWS = { + "glm-4": 128_000, + "glm-4.5": 128_000, + "glm-4.5-air": 128_000, + "glm-4.6": 128_000, + "glm-4.6v": 128_000, + "glm-4-assistant": 128_000, + "charglm-3": 8_000, +} + + +class ZAIProvider(OpenAIProvider): + """Z.ai (ZhipuAI) provider - https://docs.z.ai/""" + + provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.") + provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)") + api_key: str | None = Field(None, description="API key for the Z.ai API.", deprecated=True) + base_url: str = Field("https://api.z.ai/api/paas/v4/", description="Base URL for the Z.ai API.") + + def get_model_context_window_size(self, model_name: str) -> int | None: + # Z.ai doesn't return context window in the model listing, + # this is hardcoded from documentation + return MODEL_CONTEXT_WINDOWS.get(model_name) + + async def list_llm_models_async(self) -> list[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list_async + + api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None + response = await openai_get_model_list_async(self.base_url, api_key=api_key) + + data = response.get("data", response) + + configs = [] + for model in data: + assert "id" in model, f"Z.ai model missing 'id' field: {model}" + model_name = model["id"] + + # In case Z.ai starts supporting it in the future: + if "context_length" in model: + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + logger.warning(f"Couldn't find context window size for model {model_name}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type=self.provider_type.value, + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + max_tokens=self.get_default_max_output_tokens(model_name), + provider_name=self.name, + provider_category=self.provider_category, + ) + ) + + return configs diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 25b784dd..8e2b43a3 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1522,6 +1522,7 @@ async def send_message( "ollama", "azure", "xai", + "zai", "groq", "deepseek", ] @@ -1772,6 +1773,7 @@ async def _process_message_background( "ollama", "azure", "xai", + "zai", "groq", "deepseek", ] @@ -2076,6 +2078,7 @@ async def preview_model_request( "ollama", "azure", "xai", + "zai", "groq", "deepseek", ] @@ -2129,6 +2132,7 @@ async def summarize_messages( "ollama", "azure", "xai", + "zai", "groq", "deepseek", ] diff --git a/letta/server/server.py b/letta/server/server.py index 63e40785..df31a139 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -69,6 +69,7 @@ from letta.schemas.providers import ( TogetherProvider, VLLMProvider, XAIProvider, + ZAIProvider, ) from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate from letta.schemas.secret import Secret @@ -316,6 +317,14 @@ class SyncServer(object): api_key_enc=Secret.from_plaintext(model_settings.xai_api_key), ) ) + if model_settings.zai_api_key: + self._enabled_providers.append( + ZAIProvider( + name="zai", + api_key_enc=Secret.from_plaintext(model_settings.zai_api_key), + base_url=model_settings.zai_base_url, + ) + ) if model_settings.openrouter_api_key: self._enabled_providers.append( OpenRouterProvider( diff --git a/letta/services/provider_manager.py b/letta/services/provider_manager.py index 7b1d685c..26384c4b 100644 --- a/letta/services/provider_manager.py +++ b/letta/services/provider_manager.py @@ -421,6 +421,7 @@ class ProviderManager: from letta.schemas.providers.groq import GroqProvider from letta.schemas.providers.ollama import OllamaProvider from letta.schemas.providers.openai import OpenAIProvider + from letta.schemas.providers.zai import ZAIProvider provider_type_to_class = { "openai": OpenAIProvider, @@ -430,6 +431,7 @@ class ProviderManager: "ollama": OllamaProvider, "bedrock": BedrockProvider, "azure": AzureProvider, + "zai": ZAIProvider, } provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type) diff --git a/letta/services/streaming_service.py b/letta/services/streaming_service.py index ec8bf2d6..b589ca46 100644 --- a/letta/services/streaming_service.py +++ b/letta/services/streaming_service.py @@ -460,6 +460,7 @@ class StreamingService: "ollama", "azure", "xai", + "zai", "groq", "deepseek", ] diff --git a/letta/settings.py b/letta/settings.py index be2b6185..b769f1e2 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -134,6 +134,10 @@ class ModelSettings(BaseSettings): # xAI / Grok xai_api_key: Optional[str] = None + # Z.ai (ZhipuAI) + zai_api_key: Optional[str] = None + zai_base_url: str = "https://api.z.ai/api/paas/v4/" + # groq groq_api_key: Optional[str] = None diff --git a/tests/integration_test_send_message_v2.py b/tests/integration_test_send_message_v2.py index 2087e818..d5e47ece 100644 --- a/tests/integration_test_send_message_v2.py +++ b/tests/integration_test_send_message_v2.py @@ -37,6 +37,7 @@ all_configs = [ "openai-gpt-5.json", "claude-4-5-sonnet.json", "gemini-2.5-pro.json", + "zai-glm-4.6.json", ] @@ -206,9 +207,9 @@ def assert_tool_call_response( # Reasoning is non-deterministic, so don't throw if missing pass - # Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call + # Special case for claude-sonnet-4-5-20250929, opus-4.1, and zai which can generate an extra AssistantMessage before tool call if ( - ("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle) + ("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle or model_settings.get("provider_type") == "zai") and index < len(messages) and isinstance(messages[index], AssistantMessage) ): @@ -436,6 +437,10 @@ def get_expected_message_count_range( if "claude-opus-4-1" in model_handle: expected_range += 1 + # Z.ai models output an AssistantMessage with each ReasoningMessage (not just the final one) + if model_settings.get("provider_type") == "zai": + expected_range += 1 + if tool_call: # tool call and tool return messages expected_message_count += 2 @@ -477,8 +482,10 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool: is_google_ai_reasoning = ( model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True ) + # Z.ai models output reasoning by default + is_zai_reasoning = model_settings.get("provider_type") == "zai" - return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning + return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning # ------------------------------ diff --git a/tests/model_settings/zai-glm-4.6.json b/tests/model_settings/zai-glm-4.6.json new file mode 100644 index 00000000..7e03c966 --- /dev/null +++ b/tests/model_settings/zai-glm-4.6.json @@ -0,0 +1,9 @@ +{ + "handle": "zai/glm-4.6", + "model_settings": { + "provider_type": "zai", + "temperature": 0.7, + "max_output_tokens": 4096, + "parallel_tool_calls": false + } +} diff --git a/tests/test_providers.py b/tests/test_providers.py index a23b49cc..aebd2b93 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -15,6 +15,7 @@ from letta.schemas.providers import ( OpenAIProvider, TogetherProvider, VLLMProvider, + ZAIProvider, ) from letta.schemas.secret import Secret from letta.settings import model_settings @@ -104,6 +105,19 @@ async def test_deepseek(): assert models[0].handle == f"{provider.name}/{models[0].model}" +@pytest.mark.skipif(model_settings.zai_api_key is None, reason="Only run if ZAI_API_KEY is set.") +@pytest.mark.asyncio +async def test_zai(): + provider = ZAIProvider( + name="zai", + api_key_enc=Secret.from_plaintext(model_settings.zai_api_key), + base_url=model_settings.zai_base_url, + ) + models = await provider.list_llm_models_async() + assert len(models) > 0 + assert models[0].handle == f"{provider.name}/{models[0].model}" + + @pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.") @pytest.mark.asyncio async def test_groq():