From 72dac99e92854b79f133d33112c1d40e898682fe Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Wed, 26 Feb 2025 11:02:42 -0800 Subject: [PATCH] feat: add xAI / Grok support (#1122) Co-authored-by: Shubham Naik --- letta/llm_api/llm_api_tools.py | 59 ++++++++++++++++++- letta/llm_api/openai.py | 26 +++++--- letta/schemas/llm_config.py | 1 + letta/schemas/providers.py | 57 ++++++++++++++++++ letta/server/server.py | 6 +- letta/settings.py | 3 + .../configs/llm_model_configs/xai-grok-2.json | 6 ++ tests/test_model_letta_performance.py | 11 ++++ 8 files changed, 160 insertions(+), 9 deletions(-) create mode 100644 tests/configs/llm_model_configs/xai-grok-2.json diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index c45e8e4c..32a07136 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -187,8 +187,65 @@ def create( function_call = "required" data = build_openai_chat_completions_request( - llm_config, messages, user_id, functions, function_call, use_tool_naming, put_inner_thoughts_first=put_inner_thoughts_first + llm_config, + messages, + user_id, + functions, + function_call, + use_tool_naming, + put_inner_thoughts_first=put_inner_thoughts_first, + use_structured_output=True, # NOTE: turn on all the time for OpenAI API ) + + if stream: # Client requested token streaming + data.stream = True + assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( + stream_interface, AgentRefreshStreamingInterface + ), type(stream_interface) + response = openai_chat_completions_process_stream( + url=llm_config.model_endpoint, + api_key=api_key, + chat_completion_request=data, + stream_interface=stream_interface, + ) + else: # Client did not request token streaming (expect a blocking backend response) + data.stream = False + if isinstance(stream_interface, AgentChunkStreamingInterface): + stream_interface.stream_start() + try: + response = openai_chat_completions_request( + url=llm_config.model_endpoint, + api_key=api_key, + chat_completion_request=data, + ) + finally: + if isinstance(stream_interface, AgentChunkStreamingInterface): + stream_interface.stream_end() + + if llm_config.put_inner_thoughts_in_kwargs: + response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) + + return response + + elif llm_config.model_endpoint_type == "xai": + + api_key = model_settings.xai_api_key + + if function_call is None and functions is not None and len(functions) > 0: + # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice + function_call = "required" + + data = build_openai_chat_completions_request( + llm_config, + messages, + user_id, + functions, + function_call, + use_tool_naming, + put_inner_thoughts_first=put_inner_thoughts_first, + use_structured_output=False, # NOTE: not supported atm for xAI + ) + if stream: # Client requested token streaming data.stream = True assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index 25189d5c..56710682 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -13,7 +13,7 @@ from letta.schemas.message import Message as _Message from letta.schemas.message import MessageRole as _MessageRole from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_request import FunctionCall as ToolFunctionChoiceFunctionCall -from letta.schemas.openai.chat_completion_request import Tool, ToolFunctionChoice, cast_message_to_subtype +from letta.schemas.openai.chat_completion_request import FunctionSchema, Tool, ToolFunctionChoice, cast_message_to_subtype from letta.schemas.openai.chat_completion_response import ( ChatCompletionChunkResponse, ChatCompletionResponse, @@ -95,6 +95,7 @@ def build_openai_chat_completions_request( function_call: Optional[str], use_tool_naming: bool, put_inner_thoughts_first: bool = True, + use_structured_output: bool = True, ) -> ChatCompletionRequest: if functions and llm_config.put_inner_thoughts_in_kwargs: # Special case for LM Studio backend since it needs extra guidance to force out the thoughts first @@ -157,6 +158,16 @@ def build_openai_chat_completions_request( data.user = str(uuid.UUID(int=0)) data.model = "memgpt-openai" + if use_structured_output and data.tools is not None and len(data.tools) > 0: + # Convert to structured output style (which has 'strict' and no optionals) + for tool in data.tools: + try: + # tool["function"] = convert_to_structured_output(tool["function"]) + structured_output_version = convert_to_structured_output(tool.function.model_dump()) + tool.function = FunctionSchema(**structured_output_version) + except ValueError as e: + warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") + return data @@ -455,11 +466,12 @@ def prepare_openai_payload(chat_completion_request: ChatCompletionRequest): data.pop("tools") data.pop("tool_choice", None) # extra safe, should exist always (default="auto") - if "tools" in data: - for tool in data["tools"]: - try: - tool["function"] = convert_to_structured_output(tool["function"]) - except ValueError as e: - warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") + # # NOTE: move this out to wherever the ChatCompletionRequest is created + # if "tools" in data: + # for tool in data["tools"]: + # try: + # tool["function"] = convert_to_structured_output(tool["function"]) + # except ValueError as e: + # warnings.warn(f"Failed to convert tool function to structured output, tool={tool}, error={e}") return data diff --git a/letta/schemas/llm_config.py b/letta/schemas/llm_config.py index 7a941940..1cdadd0f 100644 --- a/letta/schemas/llm_config.py +++ b/letta/schemas/llm_config.py @@ -42,6 +42,7 @@ class LLMConfig(BaseModel): "together", # completions endpoint "bedrock", "deepseek", + "xai", ] = Field(..., description="The endpoint type for the model.") model_endpoint: Optional[str] = Field(None, description="The endpoint for the model.") model_wrapper: Optional[str] = Field(None, description="The wrapper for the model.") diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index 1de05b0e..ec532b8f 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -211,6 +211,63 @@ class OpenAIProvider(Provider): return None +class xAIProvider(OpenAIProvider): + """https://docs.x.ai/docs/api-reference""" + + name: str = "xai" + api_key: str = Field(..., description="API key for the xAI/Grok API.") + base_url: str = Field("https://api.x.ai/v1", description="Base URL for the xAI/Grok API.") + + def get_model_context_window_size(self, model_name: str) -> Optional[int]: + # xAI doesn't return context window in the model listing, + # so these are hardcoded from their website + if model_name == "grok-2-1212": + return 131072 + else: + return None + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + response = openai_get_model_list(self.base_url, api_key=self.api_key) + + if "data" in response: + data = response["data"] + else: + data = response + + configs = [] + for model in data: + assert "id" in model, f"xAI/Grok model missing 'id' field: {model}" + model_name = model["id"] + + # In case xAI starts supporting it in the future: + if "context_length" in model: + context_window_size = model["context_length"] + else: + context_window_size = self.get_model_context_window_size(model_name) + + if not context_window_size: + warnings.warn(f"Couldn't find context window size for model {model_name}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="xai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # No embeddings supported + return [] + + class DeepSeekProvider(OpenAIProvider): """ DeepSeek ChatCompletions API is similar to OpenAI's reasoning API, diff --git a/letta/server/server.py b/letta/server/server.py index 5797111f..4383e8b2 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -60,6 +60,7 @@ from letta.schemas.providers import ( TogetherProvider, VLLMChatCompletionsProvider, VLLMCompletionsProvider, + xAIProvider, ) from letta.schemas.sandbox_config import SandboxType from letta.schemas.source import Source @@ -311,6 +312,8 @@ class SyncServer(Server): self._enabled_providers.append(LMStudioOpenAIProvider(base_url=lmstudio_url)) if model_settings.deepseek_api_key: self._enabled_providers.append(DeepSeekProvider(api_key=model_settings.deepseek_api_key)) + if model_settings.xai_api_key: + self._enabled_providers.append(xAIProvider(api_key=model_settings.xai_api_key)) def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" @@ -1197,7 +1200,8 @@ class SyncServer(Server): # Disable token streaming if not OpenAI or Anthropic # TODO: cleanup this logic llm_config = letta_agent.agent_state.llm_config - supports_token_streaming = ["openai", "anthropic", "deepseek"] + # supports_token_streaming = ["openai", "anthropic", "xai", "deepseek"] + supports_token_streaming = ["openai", "anthropic", "deepseek"] # TODO re-enable xAI once streaming is patched if stream_tokens and ( llm_config.model_endpoint_type not in supports_token_streaming or "inference.memgpt.ai" in llm_config.model_endpoint ): diff --git a/letta/settings.py b/letta/settings.py index 7dd756e7..d69c0777 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -63,6 +63,9 @@ class ModelSettings(BaseSettings): # deepseek deepseek_api_key: Optional[str] = None + # xAI / Grok + xai_api_key: Optional[str] = None + # groq groq_api_key: Optional[str] = None diff --git a/tests/configs/llm_model_configs/xai-grok-2.json b/tests/configs/llm_model_configs/xai-grok-2.json new file mode 100644 index 00000000..c3b93abd --- /dev/null +++ b/tests/configs/llm_model_configs/xai-grok-2.json @@ -0,0 +1,6 @@ +{ + "context_window": 131072, + "model": "grok-2-1212", + "model_endpoint_type": "xai", + "model_endpoint": "https://api.x.ai/v1" +} diff --git a/tests/test_model_letta_performance.py b/tests/test_model_letta_performance.py index 3d425f49..ea9c30ea 100644 --- a/tests/test_model_letta_performance.py +++ b/tests/test_model_letta_performance.py @@ -328,6 +328,17 @@ def test_deepseek_reasoner_returns_valid_first_message(): print(f"Got successful response from client: \n\n{response}") +# ====================================================================================================================== +# xAI TESTS +# ====================================================================================================================== +@pytest.mark.xai_basic +def test_xai_grok2_returns_valid_first_message(): + filename = os.path.join(llm_config_dir, "xai-grok-2.json") + response = check_first_response_is_valid_for_llm_endpoint(filename) + # Log out successful response + print(f"Got successful response from client: \n\n{response}") + + # ====================================================================================================================== # TOGETHER TESTS # ======================================================================================================================