diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 3d069f1b..e1bd0435 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -247,6 +247,13 @@ def create( use_structured_output=False, # NOTE: not supported atm for xAI ) + # Specific bug for the mini models (as of Apr 14, 2025) + # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'} + # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'} + if "grok-3-mini-" in llm_config.model: + data.presence_penalty = None + data.frequency_penalty = None + if stream: # Client requested token streaming data.stream = True assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 90776f9d..e5129d15 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -228,63 +228,6 @@ class OpenAIProvider(Provider): return LLM_MAX_TOKENS["DEFAULT"] -class xAIProvider(OpenAIProvider): - """https://docs.x.ai/docs/api-reference""" - - name: str = "xai" - api_key: str = Field(..., description="API key for the xAI/Grok API.") - base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") - - def get_model_context_window_size(self, model_name: str) -> Optional[int]: - # xAI doesn't return context window in the model listing, - # so these are hardcoded from their website - if model_name == "grok-2-1212": - return 131072 - else: - return None - - def list_llm_models(self) -> List[LLMConfig]: - from letta.llm_api.openai import openai_get_model_list - - response = openai_get_model_list(self.base_url, api_key=self.api_key) - - if "data" in response: - data = response["data"] - else: - data = response - - configs = [] - for model in data: - assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" - model_name = model["id"] - - # In case xAI starts supporting it in the future: - if "context_length" in model: - context_window_size = model["context_length"] - else: - context_window_size = self.get_model_context_window_size(model_name) - - if not context_window_size: - warnings.warn(f"Couldn't find context window size for model {model_name}") - continue - - configs.append( - LLMConfig( - model=model_name, - model_endpoint_type="xai", - model_endpoint=self.base_url, - context_window=context_window_size, - handle=self.get_handle(model_name), - ) - ) - - return configs - - def list_embedding_models(self) -> List[EmbeddingConfig]: - # No embeddings supported - return [] - - class DeepSeekProvider(OpenAIProvider): """ DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, @@ -478,7 +421,7 @@ class LMStudioOpenAIProvider(OpenAIProvider): return configs -class xAIProvider(OpenAIProvider): +class XAIProvider(OpenAIProvider): """https://docs.x.ai/docs/api-reference""" name: str = "xai" @@ -490,6 +433,15 @@ class xAIProvider(OpenAIProvider): # so these are hardcoded from their website if model_name == "grok-2-1212": return 131072 + # NOTE: disabling the minis for now since they return weird MM parts + # elif model_name == "grok-3-mini-fast-beta": + # return 131072 + # elif model_name == "grok-3-mini-beta": + # return 131072 + elif model_name == "grok-3-fast-beta": + return 131072 + elif model_name == "grok-3-beta": + return 131072 else: return None diff --git a/letta/server/server.py b/letta/server/server.py index 428afa0d..7f12020c 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -67,7 +67,7 @@ from letta.schemas.providers import ( TogetherProvider, VLLMChatCompletionsProvider, VLLMCompletionsProvider, - xAIProvider, + XAIProvider, ) from letta.schemas.sandbox_config import SandboxType from letta.schemas.source import Source @@ -321,7 +321,7 @@ class SyncServer(Server): if model_settings.deepseek_api_key: self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) if model_settings.xai_api_key: - self._enabled_providers.append(xAIProvider(api_key=model_settings.xai_api_key)) + self._enabled_providers.append(XAIProvider(api_key=model_settings.xai_api_key)) # For MCP """Initialize the MCP clients (there may be multiple)"""