import os from typing import List, Optional from openai import AsyncOpenAI, AsyncStream, OpenAI from openai.types.chat.chat_completion import ChatCompletion from openai.types.chat.chat_completion_chunk import ChatCompletionChunk from letta.helpers.json_helpers import sanitize_unicode_surrogates from letta.llm_api.openai_client import OpenAIClient from letta.otel.tracing import trace_method from letta.schemas.embedding_config import EmbeddingConfig from letta.schemas.enums import AgentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.settings import model_settings class GroqClient(OpenAIClient): def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: return False def supports_structured_output(self, llm_config: LLMConfig) -> bool: return True @trace_method def build_request_data( self, agent_type: AgentType, messages: List[PydanticMessage], llm_config: LLMConfig, tools: Optional[List[dict]] = None, force_tool_call: Optional[str] = None, requires_subsequent_tool_call: bool = False, tool_return_truncation_chars: Optional[int] = None, ) -> dict: data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call) # Groq only supports string values for tool_choice: "none", "auto", "required" # Convert object-format tool_choice (used for force_tool_call) to "required" if "tool_choice" in data and isinstance(data["tool_choice"], dict): data["tool_choice"] = "required" # Groq validation - these fields are not supported and will cause 400 errors # https://console.groq.com/docs/openai if "top_logprobs" in data: del data["top_logprobs"] if "logit_bias" in data: del data["logit_bias"] data["logprobs"] = False data["n"] = 1 # for openai.BadRequestError: Error code: 400 - {'error': {'message': "'messages.2' : for 'role:assistant' the following must be satisfied[('messages.2' : property 'reasoning_content' is unsupported)]", 'type': 'invalid_request_error'}} if "messages" in data: for message in data["messages"]: if "reasoning_content" in message: del message["reasoning_content"] if "reasoning_content_signature" in message: del message["reasoning_content_signature"] return data @trace_method def request(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying synchronous request to Groq API and returns raw response dict. """ api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint) response: ChatCompletion = client.chat.completions.create(**request_data) return response.model_dump() @trace_method async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: """ Performs underlying asynchronous request to Groq API and returns raw response dict. """ request_data = sanitize_unicode_surrogates(request_data) api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) response: ChatCompletion = await client.chat.completions.create(**request_data) return response.model_dump() @trace_method async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: """Request embeddings given texts and embedding config""" api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint) response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) # TODO: add total usage return [r.embedding for r in response.data] @trace_method async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: raise NotImplementedError("Streaming not supported for Groq.")