feat: add new xai llm client (#3936)
This commit is contained in:
@@ -246,65 +246,6 @@ def create(
|
||||
|
||||
return response
|
||||
|
||||
elif llm_config.model_endpoint_type == "xai":
|
||||
api_key = model_settings.xai_api_key
|
||||
|
||||
if function_call is None and functions is not None and len(functions) > 0:
|
||||
# force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice
|
||||
function_call = "required"
|
||||
|
||||
data = build_openai_chat_completions_request(
|
||||
llm_config,
|
||||
messages,
|
||||
user_id,
|
||||
functions,
|
||||
function_call,
|
||||
use_tool_naming,
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
use_structured_output=False, # NOTE: not supported atm for xAI
|
||||
)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
|
||||
if "grok-3-mini-" in llm_config.model:
|
||||
data.presence_penalty = None
|
||||
data.frequency_penalty = None
|
||||
|
||||
if stream: # Client requested token streaming
|
||||
data.stream = True
|
||||
assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance(
|
||||
stream_interface, AgentRefreshStreamingInterface
|
||||
), type(stream_interface)
|
||||
response = openai_chat_completions_process_stream(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
stream_interface=stream_interface,
|
||||
name=name,
|
||||
# TODO turn on to support reasoning content from xAI reasoners:
|
||||
# https://docs.x.ai/docs/guides/reasoning#reasoning
|
||||
expect_reasoning_content=False,
|
||||
)
|
||||
else: # Client did not request token streaming (expect a blocking backend response)
|
||||
data.stream = False
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_start()
|
||||
try:
|
||||
response = openai_chat_completions_request(
|
||||
url=llm_config.model_endpoint,
|
||||
api_key=api_key,
|
||||
chat_completion_request=data,
|
||||
)
|
||||
finally:
|
||||
if isinstance(stream_interface, AgentChunkStreamingInterface):
|
||||
stream_interface.stream_end()
|
||||
|
||||
if llm_config.put_inner_thoughts_in_kwargs:
|
||||
response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG)
|
||||
|
||||
return response
|
||||
|
||||
elif llm_config.model_endpoint_type == "groq":
|
||||
if stream:
|
||||
raise NotImplementedError("Streaming not yet implemented for Groq.")
|
||||
|
||||
@@ -79,5 +79,12 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.xai:
|
||||
from letta.llm_api.xai_client import XAIClient
|
||||
|
||||
return XAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -146,6 +146,9 @@ class OpenAIClient(LLMClientBase):
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return requires_auto_tool_choice(llm_config)
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return supports_structured_output(llm_config)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
|
||||
85
letta/llm_api/xai_client.py
Normal file
85
letta/llm_api/xai_client.py
Normal file
@@ -0,0 +1,85 @@
|
||||
import os
|
||||
from typing import List, Optional
|
||||
|
||||
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class XAIClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
||||
return False
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
llm_config: LLMConfig,
|
||||
tools: Optional[List[dict]] = None,
|
||||
force_tool_call: Optional[str] = None,
|
||||
) -> dict:
|
||||
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
||||
|
||||
# Specific bug for the mini models (as of Apr 14, 2025)
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
|
||||
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
|
||||
if "grok-3-mini-" in llm_config.model:
|
||||
data.pop("presence_penalty", None)
|
||||
data.pop("frequency_penalty", None)
|
||||
|
||||
return data
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
||||
"""
|
||||
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
|
||||
"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
|
||||
**request_data, stream=True, stream_options={"include_usage": True}
|
||||
)
|
||||
return response_stream
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
return [r.embedding for r in response.data]
|
||||
@@ -999,7 +999,7 @@ async def send_message(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
]
|
||||
|
||||
# Create a new run for execution tracking
|
||||
@@ -1143,7 +1143,7 @@ async def send_message_streaming(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
]
|
||||
model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"]
|
||||
|
||||
@@ -1538,7 +1538,7 @@ async def preview_raw_payload(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
@@ -1608,7 +1608,7 @@ async def summarize_agent_conversation(
|
||||
"bedrock",
|
||||
"ollama",
|
||||
"azure",
|
||||
"together",
|
||||
"xai",
|
||||
]
|
||||
|
||||
if agent_eligible and model_compatible:
|
||||
|
||||
Reference in New Issue
Block a user