feat: add new xai llm client (#3936)

This commit is contained in:
cthomas
2025-08-20 15:49:09 -07:00
committed by GitHub
parent 5b70d554bb
commit c7b71ad8a9
5 changed files with 99 additions and 63 deletions

View File

@@ -246,65 +246,6 @@ def create(
return response
elif llm_config.model_endpoint_type == "xai":
api_key = model_settings.xai_api_key
if function_call is None and functions is not None and len(functions) > 0:
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
function_call = "required"
data = build_openai_chat_completions_request(
llm_config,
messages,
user_id,
functions,
function_call,
use_tool_naming,
put_inner_thoughts_first=put_inner_thoughts_first,
use_structured_output=False, # NOTE: not supported atm for xAI
)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
if "grok-3-mini-" in llm_config.model:
data.presence_penalty = None
data.frequency_penalty = None
if stream: # Client requested token streaming
data.stream = True
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
stream_interface, AgentRefreshStreamingInterface
), type(stream_interface)
response = openai_chat_completions_process_stream(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
stream_interface=stream_interface,
name=name,
# TODO turn on to support reasoning content from xAI reasoners:
# https://docs.x.ai/docs/guides/reasoning#reasoning
expect_reasoning_content=False,
)
else: # Client did not request token streaming (expect a blocking backend response)
data.stream = False
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_start()
try:
response = openai_chat_completions_request(
url=llm_config.model_endpoint,
api_key=api_key,
chat_completion_request=data,
)
finally:
if isinstance(stream_interface, AgentChunkStreamingInterface):
stream_interface.stream_end()
if llm_config.put_inner_thoughts_in_kwargs:
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
return response
elif llm_config.model_endpoint_type == "groq":
if stream:
raise NotImplementedError("Streaming not yet implemented for Groq.")

View File

@@ -79,5 +79,12 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.xai:
from letta.llm_api.xai_client import XAIClient
return XAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case _:
return None

View File

@@ -146,6 +146,9 @@ class OpenAIClient(LLMClientBase):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return requires_auto_tool_choice(llm_config)
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return supports_structured_output(llm_config)
@trace_method
def build_request_data(
self,

View File

@@ -0,0 +1,85 @@
import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class XAIClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
if "grok-3-mini-" in llm_config.model:
data.pop("presence_penalty", None)
data.pop("frequency_penalty", None)
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]