Files
letta-server/letta/llm_api/groq_client.py
cthomas 3f8f2e622a fix: filter our reasoning for groq client [LET-7135] (#8982)
fix: filter our reasoning for groq client
2026-01-29 12:43:53 -08:00

91 lines
3.9 KiB
Python

import os
from typing import List, Optional
from openai import AsyncOpenAI, AsyncStream, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import AgentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
from letta.settings import model_settings
class GroqClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return False
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
return True
@trace_method
def build_request_data(
self,
agent_type: AgentType,
messages: List[PydanticMessage],
llm_config: LLMConfig,
tools: Optional[List[dict]] = None,
force_tool_call: Optional[str] = None,
requires_subsequent_tool_call: bool = False,
tool_return_truncation_chars: Optional[int] = None,
) -> dict:
data = super().build_request_data(agent_type, messages, llm_config, tools, force_tool_call, requires_subsequent_tool_call)
# Groq validation - these fields are not supported and will cause 400 errors
# https://console.groq.com/docs/openai
if "top_logprobs" in data:
del data["top_logprobs"]
if "logit_bias" in data:
del data["logit_bias"]
data["logprobs"] = False
data["n"] = 1
# for openai.BadRequestError: Error code: 400 - {'error': {'message': "'messages.2' : for 'role:assistant' the following must be satisfied[('messages.2' : property 'reasoning_content' is unsupported)]", 'type': 'invalid_request_error'}}
if "messages" in data:
for message in data["messages"]:
if "reasoning_content" in message:
del message["reasoning_content"]
if "reasoning_content_signature" in message:
del message["reasoning_content_signature"]
return data
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to Groq API and returns raw response dict.
"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to Groq API and returns raw response dict.
"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]
@trace_method
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
raise NotImplementedError("Streaming not supported for Groq.")