80 lines
3.1 KiB
Python
80 lines
3.1 KiB
Python
import os
|
|
from typing import List, Optional
|
|
|
|
from openai import AsyncOpenAI, AsyncStream, OpenAI
|
|
from openai.types.chat.chat_completion import ChatCompletion
|
|
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
|
|
|
from letta.llm_api.openai_client import OpenAIClient
|
|
from letta.otel.tracing import trace_method
|
|
from letta.schemas.embedding_config import EmbeddingConfig
|
|
from letta.schemas.llm_config import LLMConfig
|
|
from letta.schemas.message import Message as PydanticMessage
|
|
from letta.settings import model_settings
|
|
|
|
|
|
class GroqClient(OpenAIClient):
|
|
|
|
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
|
return False
|
|
|
|
def supports_structured_output(self, llm_config: LLMConfig) -> bool:
|
|
return True
|
|
|
|
@trace_method
|
|
def build_request_data(
|
|
self,
|
|
messages: List[PydanticMessage],
|
|
llm_config: LLMConfig,
|
|
tools: Optional[List[dict]] = None,
|
|
force_tool_call: Optional[str] = None,
|
|
) -> dict:
|
|
data = super().build_request_data(messages, llm_config, tools, force_tool_call)
|
|
|
|
# Groq validation - these fields are not supported and will cause 400 errors
|
|
# https://console.groq.com/docs/openai
|
|
if "top_logprobs" in data:
|
|
del data["top_logprobs"]
|
|
if "logit_bias" in data:
|
|
del data["logit_bias"]
|
|
data["logprobs"] = False
|
|
data["n"] = 1
|
|
|
|
return data
|
|
|
|
@trace_method
|
|
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying synchronous request to Groq API and returns raw response dict.
|
|
"""
|
|
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
|
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
|
|
|
response: ChatCompletion = client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
|
|
@trace_method
|
|
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
|
"""
|
|
Performs underlying asynchronous request to Groq API and returns raw response dict.
|
|
"""
|
|
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
|
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
|
|
|
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
|
return response.model_dump()
|
|
|
|
@trace_method
|
|
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
|
"""Request embeddings given texts and embedding config"""
|
|
api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY")
|
|
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
|
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
|
|
|
# TODO: add total usage
|
|
return [r.embedding for r in response.data]
|
|
|
|
@trace_method
|
|
async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]:
|
|
raise NotImplementedError("Streaming not supported for Groq.")
|