feat: add zai provider support (#7626)
* feat: add zai provider support * add zai_api_key secret to deploy-core * add to justfile * add testing, provider integration skill * enable zai key * fix zai test * clean up skill a little * small changes
This commit is contained in:
@@ -21021,6 +21021,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -21046,7 +21049,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -24789,6 +24793,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -24814,7 +24821,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -24897,6 +24905,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -24922,7 +24933,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -25739,6 +25751,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -25764,7 +25779,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -29972,6 +29988,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -29997,7 +30016,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -30893,7 +30913,8 @@
|
||||
"together",
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai"
|
||||
"xai",
|
||||
"zai"
|
||||
],
|
||||
"title": "Model Endpoint Type",
|
||||
"description": "The endpoint type for the model."
|
||||
@@ -33149,7 +33170,8 @@
|
||||
"together",
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai"
|
||||
"xai",
|
||||
"zai"
|
||||
],
|
||||
"title": "Model Endpoint Type",
|
||||
"description": "Deprecated: Use 'provider_type' field instead. The endpoint type for the model.",
|
||||
@@ -34596,7 +34618,8 @@
|
||||
"openai",
|
||||
"together",
|
||||
"vllm",
|
||||
"xai"
|
||||
"xai",
|
||||
"zai"
|
||||
],
|
||||
"title": "ProviderType"
|
||||
},
|
||||
@@ -39311,6 +39334,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -39336,7 +39362,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
@@ -40262,6 +40289,68 @@
|
||||
"title": "XAIModelSettings",
|
||||
"description": "xAI model configuration (OpenAI-compatible)."
|
||||
},
|
||||
"ZAIModelSettings": {
|
||||
"properties": {
|
||||
"max_output_tokens": {
|
||||
"type": "integer",
|
||||
"title": "Max Output Tokens",
|
||||
"description": "The maximum number of tokens the model can generate.",
|
||||
"default": 4096
|
||||
},
|
||||
"parallel_tool_calls": {
|
||||
"type": "boolean",
|
||||
"title": "Parallel Tool Calls",
|
||||
"description": "Whether to enable parallel tool calling.",
|
||||
"default": false
|
||||
},
|
||||
"provider_type": {
|
||||
"type": "string",
|
||||
"const": "zai",
|
||||
"title": "Provider Type",
|
||||
"description": "The type of the provider.",
|
||||
"default": "zai"
|
||||
},
|
||||
"temperature": {
|
||||
"type": "number",
|
||||
"title": "Temperature",
|
||||
"description": "The temperature of the model.",
|
||||
"default": 0.7
|
||||
},
|
||||
"response_format": {
|
||||
"anyOf": [
|
||||
{
|
||||
"oneOf": [
|
||||
{
|
||||
"$ref": "#/components/schemas/TextResponseFormat"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/JsonSchemaResponseFormat"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/JsonObjectResponseFormat"
|
||||
}
|
||||
],
|
||||
"discriminator": {
|
||||
"propertyName": "type",
|
||||
"mapping": {
|
||||
"json_object": "#/components/schemas/JsonObjectResponseFormat",
|
||||
"json_schema": "#/components/schemas/JsonSchemaResponseFormat",
|
||||
"text": "#/components/schemas/TextResponseFormat"
|
||||
}
|
||||
}
|
||||
},
|
||||
{
|
||||
"type": "null"
|
||||
}
|
||||
],
|
||||
"title": "Response Format",
|
||||
"description": "The response format for the model."
|
||||
}
|
||||
},
|
||||
"type": "object",
|
||||
"title": "ZAIModelSettings",
|
||||
"description": "Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."
|
||||
},
|
||||
"letta__schemas__agent_file__AgentSchema": {
|
||||
"properties": {
|
||||
"name": {
|
||||
@@ -40589,6 +40678,9 @@
|
||||
{
|
||||
"$ref": "#/components/schemas/XAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/ZAIModelSettings"
|
||||
},
|
||||
{
|
||||
"$ref": "#/components/schemas/GroqModelSettings"
|
||||
},
|
||||
@@ -40614,7 +40706,8 @@
|
||||
"groq": "#/components/schemas/GroqModelSettings",
|
||||
"openai": "#/components/schemas/OpenAIModelSettings",
|
||||
"together": "#/components/schemas/TogetherModelSettings",
|
||||
"xai": "#/components/schemas/XAIModelSettings"
|
||||
"xai": "#/components/schemas/XAIModelSettings",
|
||||
"zai": "#/components/schemas/ZAIModelSettings"
|
||||
}
|
||||
}
|
||||
},
|
||||
|
||||
@@ -79,6 +79,13 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.zai:
|
||||
from letta.llm_api.zai_client import ZAIClient
|
||||
|
||||
return ZAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.groq:
|
||||
from letta.llm_api.groq_client import GroqClient
|
||||
|
||||
|
||||
81
letta/llm_api/zai_client.py
Normal file
81
letta/llm_api/zai_client.py
Normal file
@@ -0,0 +1,81 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import AgentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class ZAIClient(OpenAIClient):
|
||||
"""Z.ai (ZhipuAI) client - uses OpenAI-compatible API."""
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
agent_type: AgentType,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
requires_subsequent_tool_call: bool = False,
|
||||
tool_return_truncation_chars: Optional[int] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to Z.ai API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to Z.ai API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to Z.ai and returns the async stream iterator.
|
||||
"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.zai_api_key
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
return [r.embedding for r in response.data]
|
||||
@@ -67,6 +67,7 @@ class ProviderType(str, Enum):
|
||||
together = "together"
|
||||
vllm = "vllm"
|
||||
xai = "xai"
|
||||
zai = "zai"
|
||||
|
||||
|
||||
class AgentType(str, Enum):
|
||||
|
||||
@@ -48,6 +48,7 @@ class LLMConfig(BaseModel):
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
] = Field(..., description="The endpoint type for the model.")
|
||||
model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.")
|
||||
provider_name: Optional[str] = Field(None, description="The provider name for the model.")
|
||||
@@ -317,6 +318,7 @@ class LLMConfig(BaseModel):
|
||||
OpenAIReasoning,
|
||||
TogetherModelSettings,
|
||||
XAIModelSettings,
|
||||
ZAIModelSettings,
|
||||
)
|
||||
|
||||
if self.model_endpoint_type == "openai":
|
||||
@@ -359,6 +361,11 @@ class LLMConfig(BaseModel):
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "zai":
|
||||
return ZAIModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
temperature=self.temperature,
|
||||
)
|
||||
elif self.model_endpoint_type == "groq":
|
||||
return GroqModelSettings(
|
||||
max_output_tokens=self.max_tokens or 4096,
|
||||
|
||||
@@ -47,6 +47,7 @@ class Model(LLMConfig, ModelBase):
|
||||
"bedrock",
|
||||
"deepseek",
|
||||
"xai",
|
||||
"zai",
|
||||
] = Field(..., description="Deprecated: Use 'provider_type' field instead. The endpoint type for the model.", deprecated=True)
|
||||
context_window: int = Field(
|
||||
..., description="Deprecated: Use 'max_context_window' field instead. The context window size for the model.", deprecated=True
|
||||
@@ -131,6 +132,7 @@ class Model(LLMConfig, ModelBase):
|
||||
ProviderType.google_vertex: GoogleVertexModelSettings,
|
||||
ProviderType.azure: AzureModelSettings,
|
||||
ProviderType.xai: XAIModelSettings,
|
||||
ProviderType.zai: ZAIModelSettings,
|
||||
ProviderType.groq: GroqModelSettings,
|
||||
ProviderType.deepseek: DeepseekModelSettings,
|
||||
ProviderType.together: TogetherModelSettings,
|
||||
@@ -352,6 +354,22 @@ class XAIModelSettings(ModelSettings):
|
||||
}
|
||||
|
||||
|
||||
class ZAIModelSettings(ModelSettings):
|
||||
"""Z.ai (ZhipuAI) model configuration (OpenAI-compatible)."""
|
||||
|
||||
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
|
||||
temperature: float = Field(0.7, description="The temperature of the model.")
|
||||
response_format: Optional[ResponseFormatUnion] = Field(None, description="The response format for the model.")
|
||||
|
||||
def _to_legacy_config_params(self) -> dict:
|
||||
return {
|
||||
"temperature": self.temperature,
|
||||
"max_tokens": self.max_output_tokens,
|
||||
"response_format": self.response_format,
|
||||
"parallel_tool_calls": self.parallel_tool_calls,
|
||||
}
|
||||
|
||||
|
||||
class GroqModelSettings(ModelSettings):
|
||||
"""Groq model configuration (OpenAI-compatible)."""
|
||||
|
||||
@@ -424,6 +442,7 @@ ModelSettingsUnion = Annotated[
|
||||
GoogleVertexModelSettings,
|
||||
AzureModelSettings,
|
||||
XAIModelSettings,
|
||||
ZAIModelSettings,
|
||||
GroqModelSettings,
|
||||
DeepseekModelSettings,
|
||||
TogetherModelSettings,
|
||||
|
||||
@@ -18,6 +18,7 @@ from .openrouter import OpenRouterProvider
|
||||
from .together import TogetherProvider
|
||||
from .vllm import VLLMProvider
|
||||
from .xai import XAIProvider
|
||||
from .zai import ZAIProvider
|
||||
|
||||
__all__ = [
|
||||
# Base classes
|
||||
@@ -43,5 +44,6 @@ __all__ = [
|
||||
"TogetherProvider",
|
||||
"VLLMProvider", # Replaces ChatCompletions and Completions
|
||||
"XAIProvider",
|
||||
"ZAIProvider",
|
||||
"OpenRouterProvider",
|
||||
]
|
||||
|
||||
@@ -196,6 +196,7 @@ class Provider(ProviderBase):
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
|
||||
if self.base_url == "":
|
||||
@@ -230,6 +231,8 @@ class Provider(ProviderBase):
|
||||
return CerebrasProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.xai:
|
||||
return XAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.zai:
|
||||
return ZAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.lmstudio_openai:
|
||||
return LMStudioOpenAIProvider(**self.model_dump(exclude_none=True))
|
||||
case ProviderType.bedrock:
|
||||
|
||||
75
letta/schemas/providers/zai.py
Normal file
75
letta/schemas/providers/zai.py
Normal file
@@ -0,0 +1,75 @@
|
||||
from typing import Literal
|
||||
|
||||
from letta.log import get_logger
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
from pydantic import Field
|
||||
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
|
||||
# Z.ai model context windows
|
||||
# Reference: https://docs.z.ai/
|
||||
MODEL_CONTEXT_WINDOWS = {
|
||||
"glm-4": 128_000,
|
||||
"glm-4.5": 128_000,
|
||||
"glm-4.5-air": 128_000,
|
||||
"glm-4.6": 128_000,
|
||||
"glm-4.6v": 128_000,
|
||||
"glm-4-assistant": 128_000,
|
||||
"charglm-3": 8_000,
|
||||
}
|
||||
|
||||
|
||||
class ZAIProvider(OpenAIProvider):
|
||||
"""Z.ai (ZhipuAI) provider - https://docs.z.ai/"""
|
||||
|
||||
provider_type: Literal[ProviderType.zai] = Field(ProviderType.zai, description="The type of the provider.")
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str | None = Field(None, description="API key for the Z.ai API.", deprecated=True)
|
||||
base_url: str = Field("https://api.z.ai/api/paas/v4/", description="Base URL for the Z.ai API.")
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> int | None:
|
||||
# Z.ai doesn't return context window in the model listing,
|
||||
# this is hardcoded from documentation
|
||||
return MODEL_CONTEXT_WINDOWS.get(model_name)
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
|
||||
api_key = await self.api_key_enc.get_plaintext_async() if self.api_key_enc else None
|
||||
response = await openai_get_model_list_async(self.base_url, api_key=api_key)
|
||||
|
||||
data = response.get("data", response)
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"Z.ai model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
# In case Z.ai starts supporting it in the future:
|
||||
if "context_length" in model:
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
logger.warning(f"Couldn't find context window size for model {model_name}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type=self.provider_type.value,
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
max_tokens=self.get_default_max_output_tokens(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
@@ -1522,6 +1522,7 @@ async def send_message(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -1772,6 +1773,7 @@ async def _process_message_background(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -2076,6 +2078,7 @@ async def preview_model_request(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
@@ -2129,6 +2132,7 @@ async def summarize_messages(
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
|
||||
@@ -69,6 +69,7 @@ from letta.schemas.providers import (
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
XAIProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
from letta.schemas.sandbox_config import LocalSandboxConfig, SandboxConfigCreate
|
||||
from letta.schemas.secret import Secret
|
||||
@@ -316,6 +317,14 @@ class SyncServer(object):
|
||||
api_key_enc=Secret.from_plaintext(model_settings.xai_api_key),
|
||||
)
|
||||
)
|
||||
if model_settings.zai_api_key:
|
||||
self._enabled_providers.append(
|
||||
ZAIProvider(
|
||||
name="zai",
|
||||
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
|
||||
base_url=model_settings.zai_base_url,
|
||||
)
|
||||
)
|
||||
if model_settings.openrouter_api_key:
|
||||
self._enabled_providers.append(
|
||||
OpenRouterProvider(
|
||||
|
||||
@@ -421,6 +421,7 @@ class ProviderManager:
|
||||
from letta.schemas.providers.groq import GroqProvider
|
||||
from letta.schemas.providers.ollama import OllamaProvider
|
||||
from letta.schemas.providers.openai import OpenAIProvider
|
||||
from letta.schemas.providers.zai import ZAIProvider
|
||||
|
||||
provider_type_to_class = {
|
||||
"openai": OpenAIProvider,
|
||||
@@ -430,6 +431,7 @@ class ProviderManager:
|
||||
"ollama": OllamaProvider,
|
||||
"bedrock": BedrockProvider,
|
||||
"azure": AzureProvider,
|
||||
"zai": ZAIProvider,
|
||||
}
|
||||
|
||||
provider_type = provider.provider_type.value if hasattr(provider.provider_type, "value") else str(provider.provider_type)
|
||||
|
||||
@@ -460,6 +460,7 @@ class StreamingService:
|
||||
"ollama",
|
||||
"azure",
|
||||
"xai",
|
||||
"zai",
|
||||
"groq",
|
||||
"deepseek",
|
||||
]
|
||||
|
||||
@@ -134,6 +134,10 @@ class ModelSettings(BaseSettings):
|
||||
# xAI / Grok
|
||||
xai_api_key: Optional[str] = None
|
||||
|
||||
# Z.ai (ZhipuAI)
|
||||
zai_api_key: Optional[str] = None
|
||||
zai_base_url: str = "https://api.z.ai/api/paas/v4/"
|
||||
|
||||
# groq
|
||||
groq_api_key: Optional[str] = None
|
||||
|
||||
|
||||
@@ -37,6 +37,7 @@ all_configs = [
|
||||
"openai-gpt-5.json",
|
||||
"claude-4-5-sonnet.json",
|
||||
"gemini-2.5-pro.json",
|
||||
"zai-glm-4.6.json",
|
||||
]
|
||||
|
||||
|
||||
@@ -206,9 +207,9 @@ def assert_tool_call_response(
|
||||
# Reasoning is non-deterministic, so don't throw if missing
|
||||
pass
|
||||
|
||||
# Special case for claude-sonnet-4-5-20250929 and opus-4.1 which can generate an extra AssistantMessage before tool call
|
||||
# Special case for claude-sonnet-4-5-20250929, opus-4.1, and zai which can generate an extra AssistantMessage before tool call
|
||||
if (
|
||||
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle)
|
||||
("claude-sonnet-4-5-20250929" in model_handle or "claude-opus-4-1" in model_handle or model_settings.get("provider_type") == "zai")
|
||||
and index < len(messages)
|
||||
and isinstance(messages[index], AssistantMessage)
|
||||
):
|
||||
@@ -436,6 +437,10 @@ def get_expected_message_count_range(
|
||||
if "claude-opus-4-1" in model_handle:
|
||||
expected_range += 1
|
||||
|
||||
# Z.ai models output an AssistantMessage with each ReasoningMessage (not just the final one)
|
||||
if model_settings.get("provider_type") == "zai":
|
||||
expected_range += 1
|
||||
|
||||
if tool_call:
|
||||
# tool call and tool return messages
|
||||
expected_message_count += 2
|
||||
@@ -477,8 +482,10 @@ def is_reasoner_model(model_handle: str, model_settings: dict) -> bool:
|
||||
is_google_ai_reasoning = (
|
||||
model_settings.get("provider_type") == "google_ai" and model_settings.get("thinking_config", {}).get("include_thoughts") is True
|
||||
)
|
||||
# Z.ai models output reasoning by default
|
||||
is_zai_reasoning = model_settings.get("provider_type") == "zai"
|
||||
|
||||
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning
|
||||
return is_openai_reasoning or is_anthropic_reasoning or is_google_vertex_reasoning or is_google_ai_reasoning or is_zai_reasoning
|
||||
|
||||
|
||||
# ------------------------------
|
||||
|
||||
9
tests/model_settings/zai-glm-4.6.json
Normal file
9
tests/model_settings/zai-glm-4.6.json
Normal file
@@ -0,0 +1,9 @@
|
||||
{
|
||||
"handle": "zai/glm-4.6",
|
||||
"model_settings": {
|
||||
"provider_type": "zai",
|
||||
"temperature": 0.7,
|
||||
"max_output_tokens": 4096,
|
||||
"parallel_tool_calls": false
|
||||
}
|
||||
}
|
||||
@@ -15,6 +15,7 @@ from letta.schemas.providers import (
|
||||
OpenAIProvider,
|
||||
TogetherProvider,
|
||||
VLLMProvider,
|
||||
ZAIProvider,
|
||||
)
|
||||
from letta.schemas.secret import Secret
|
||||
from letta.settings import model_settings
|
||||
@@ -104,6 +105,19 @@ async def test_deepseek():
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(model_settings.zai_api_key is None, reason="Only run if ZAI_API_KEY is set.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_zai():
|
||||
provider = ZAIProvider(
|
||||
name="zai",
|
||||
api_key_enc=Secret.from_plaintext(model_settings.zai_api_key),
|
||||
base_url=model_settings.zai_base_url,
|
||||
)
|
||||
models = await provider.list_llm_models_async()
|
||||
assert len(models) > 0
|
||||
assert models[0].handle == f"{provider.name}/{models[0].model}"
|
||||
|
||||
|
||||
@pytest.mark.skipif(model_settings.groq_api_key is None, reason="Only run if GROQ_API_KEY is set.")
|
||||
@pytest.mark.asyncio
|
||||
async def test_groq():
|
||||
|
||||
Reference in New Issue
Block a user