diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 718bd2bf..184615fa 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -246,65 +246,6 @@ def create( return response - elif llm_config.model_endpoint_type == "xai": - api_key = model_settings.xai_api_key - - if function_call is None and functions is not None and len(functions) > 0: - # force function calling for reliability, see https://platform.openai.com/docs/api-reference/chat/create#chat-create-tool_choice - function_call = "required" - - data = build_openai_chat_completions_request( - llm_config, - messages, - user_id, - functions, - function_call, - use_tool_naming, - put_inner_thoughts_first=put_inner_thoughts_first, - use_structured_output=False, # NOTE: not supported atm for xAI - ) - - # Specific bug for the mini models (as of Apr 14, 2025) - # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'} - # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'} - if "grok-3-mini-" in llm_config.model: - data.presence_penalty = None - data.frequency_penalty = None - - if stream: # Client requested token streaming - data.stream = True - assert isinstance(stream_interface, AgentChunkStreamingInterface) or isinstance( - stream_interface, AgentRefreshStreamingInterface - ), type(stream_interface) - response = openai_chat_completions_process_stream( - url=llm_config.model_endpoint, - api_key=api_key, - chat_completion_request=data, - stream_interface=stream_interface, - name=name, - # TODO turn on to support reasoning content from xAI reasoners: - # https://docs.x.ai/docs/guides/reasoning#reasoning - expect_reasoning_content=False, - ) - else: # Client did not request token streaming (expect a blocking backend response) - data.stream = False - if isinstance(stream_interface, AgentChunkStreamingInterface): - stream_interface.stream_start() - try: - response = openai_chat_completions_request( - url=llm_config.model_endpoint, - api_key=api_key, - chat_completion_request=data, - ) - finally: - if isinstance(stream_interface, AgentChunkStreamingInterface): - stream_interface.stream_end() - - if llm_config.put_inner_thoughts_in_kwargs: - response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) - - return response - elif llm_config.model_endpoint_type == "groq": if stream: raise NotImplementedError("Streaming not yet implemented for Groq.") diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index d5686eed..4fcd082a 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -79,5 +79,12 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.xai: + from letta.llm_api.xai_client import XAIClient + + return XAIClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case _: return None diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index de132002..267898bf 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -146,6 +146,9 @@ class OpenAIClient(LLMClientBase): def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: return requires_auto_tool_choice(llm_config) + def supports_structured_output(self, llm_config: LLMConfig) -> bool: + return supports_structured_output(llm_config) + @trace_method def build_request_data( self, diff --git a/letta/llm_api/xai_client.py b/letta/llm_api/xai_client.py new file mode 100644 index 00000000..059073e4 --- /dev/null +++ b/letta/llm_api/xai_client.py @@ -0,0 +1,85 @@ +import os +from typing import List, Optional + +from openai import AsyncOpenAI, AsyncStream, OpenAI +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + +from letta.llm_api.openai_client import OpenAIClient +from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage +from letta.settings import model_settings + + +class XAIClient(OpenAIClient): + + def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: + return False + + def supports_structured_output(self, llm_config: LLMConfig) -> bool: + return False + + @trace_method + def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + ) -> dict: + data = super().build_request_data(messages, llm_config, tools, force_tool_call) + + # Specific bug for the mini models (as of Apr 14, 2025) + # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'} + # 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'} + if "grok-3-mini-" in llm_config.model: + data.pop("presence_penalty", None) + data.pop("frequency_penalty", None) + + return data + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying synchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY") + client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying asynchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: + """ + Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator. + """ + api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create( + **request_data, stream=True, stream_options={"include_usage": True} + ) + return response_stream + + @trace_method + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: + """Request embeddings given texts and embedding config""" + api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) + + # TODO: add total usage + return [r.embedding for r in response.data] diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 1f27cdfa..ef6719e6 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -999,7 +999,7 @@ async def send_message( "bedrock", "ollama", "azure", - "together", + "xai", ] # Create a new run for execution tracking @@ -1143,7 +1143,7 @@ async def send_message_streaming( "bedrock", "ollama", "azure", - "together", + "xai", ] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] @@ -1538,7 +1538,7 @@ async def preview_raw_payload( "bedrock", "ollama", "azure", - "together", + "xai", ] if agent_eligible and model_compatible: @@ -1608,7 +1608,7 @@ async def summarize_agent_conversation( "bedrock", "ollama", "azure", - "together", + "xai", ] if agent_eligible and model_compatible: