fix: patch grok-3 and grok-3-fast (skip reasoners for now) (#1703)
This commit is contained in:
@@ -247,6 +247,13 @@ def create(
|
||||
use_structured_output=False, # NOTE: not supported atm for xAI
|
||||
)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
|
||||
if "grok-3-mini-" in llm_config.model:
|
||||
data.presence_penalty = None
|
||||
data.frequency_penalty = None
|
||||
|
||||
if stream: # Client requested token streaming
|
||||
data.stream = True
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
|
||||
@@ -228,63 +228,6 @@ class OpenAIProvider(Provider):
|
||||
return LLM_MAX_TOKENS["DEFAULT"]
|
||||
|
||||
|
||||
class xAIProvider(OpenAIProvider):
|
||||
"""https://docs.x.ai/docs/api-reference"""
|
||||
|
||||
name: str = "xai"
|
||||
api_key: str = Field(..., description="API key for the xAI/Grok API.")
|
||||
base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.")
|
||||
|
||||
def get_model_context_window_size(self, model_name: str) -> Optional[int]:
|
||||
# xAI doesn't return context window in the model listing,
|
||||
# so these are hardcoded from their website
|
||||
if model_name == "grok-2-1212":
|
||||
return 131072
|
||||
else:
|
||||
return None
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
data = response
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"xAI/Grok model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
# In case xAI starts supporting it in the future:
|
||||
if "context_length" in model:
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
if not context_window_size:
|
||||
warnings.warn(f"Couldn't find context window size for model {model_name}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="xai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
# No embeddings supported
|
||||
return []
|
||||
|
||||
|
||||
class DeepSeekProvider(OpenAIProvider):
|
||||
"""
|
||||
DeepSeek ChatCompletions API is similar to OpenAI's reasoning API,
|
||||
@@ -478,7 +421,7 @@ class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
return configs
|
||||
|
||||
|
||||
class xAIProvider(OpenAIProvider):
|
||||
class XAIProvider(OpenAIProvider):
|
||||
"""https://docs.x.ai/docs/api-reference"""
|
||||
|
||||
name: str = "xai"
|
||||
@@ -490,6 +433,15 @@ class xAIProvider(OpenAIProvider):
|
||||
# so these are hardcoded from their website
|
||||
if model_name == "grok-2-1212":
|
||||
return 131072
|
||||
# NOTE: disabling the minis for now since they return weird MM parts
|
||||
# elif model_name == "grok-3-mini-fast-beta":
|
||||
# return 131072
|
||||
# elif model_name == "grok-3-mini-beta":
|
||||
# return 131072
|
||||
elif model_name == "grok-3-fast-beta":
|
||||
return 131072
|
||||
elif model_name == "grok-3-beta":
|
||||
return 131072
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
@@ -67,7 +67,7 @@ from letta.schemas.providers import (
|
||||
TogetherProvider,
|
||||
VLLMChatCompletionsProvider,
|
||||
VLLMCompletionsProvider,
|
||||
xAIProvider,
|
||||
XAIProvider,
|
||||
)
|
||||
from letta.schemas.sandbox_config import SandboxType
|
||||
from letta.schemas.source import Source
|
||||
@@ -321,7 +321,7 @@ class SyncServer(Server):
|
||||
if model_settings.deepseek_api_key:
|
||||
self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key))
|
||||
if model_settings.xai_api_key:
|
||||
self._enabled_providers.append(xAIProvider(api_key=model_settings.xai_api_key))
|
||||
self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key))
|
||||
|
||||
# For MCP
|
||||
"""Initialize the MCP clients (there may be multiple)"""
|
||||
|
||||
Reference in New Issue
Block a user