Files
letta-server/letta/llm_api/xai_client.py
Charles Packer a4041879a4 feat: add new agent loop (squash rebase of OSS PR) (#4815)
* feat: squash rebase of OSS PR

* fix: revert changes that weren't on manual rebase

* fix: caught another one

* fix: disable force

* chore: drop print

* fix: just stage-api && just publish-api

* fix: make agent_type consistently an arg in the client

* fix: patch multi-modal support

* chore: put in todo stub

* fix: disable hardcoding for tests

* fix: patch validate agent sync (#4882)

patch validate agent sync

* fix: strip bad merge diff

* fix: revert unrelated diff

* fix: react_v2 naming -> letta_v1 naming

* fix: strip bad merge

---------

Co-authored-by: Kevin Lin <klin5061@gmail.com>
2025-10-07 17:50:45 -07:00

87 lines
3.8 KiB
Python

import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class XAIClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return False
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call)
# Specific bug for the mini models (as of Apr 14, 2025)
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: presencePenalty'}
# 400 - {'code': 'Client specified an invalid argument', 'error': 'Argument not supported on this model: frequencyPenalty'}
if "grok-3-mini-" in llm_config.model:
data.pop("presence_penalty", None)
data.pop("frequency_penalty", None)
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
"""
Performs underlying asynchronous streaming request to OpenAI and returns the async stream iterator.
"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response_stream: AsyncStream[ChatCompletionChunk] = await client.chat.completions.create(
**request_data, stream=True, stream_options={"include_usage": True}
)
return response_stream
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.xai_api_key or os.environ.get("XAI_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]