From 269c536530e314120d2399b70bc43e5cb57c5bdb Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 21 Aug 2025 13:13:25 -0700 Subject: [PATCH] feat: add new groq llm client LET-3943 (#3937) Co-authored-by: Sarah Wooders --- letta/llm_api/groq_client.py | 79 ++++++++++++++++++++++ letta/llm_api/llm_api_tools.py | 56 +-------------- letta/llm_api/llm_client.py | 7 ++ letta/server/rest_api/routers/v1/agents.py | 5 ++ tests/configs/llm_model_configs/groq.json | 12 ++-- tests/integration_test_send_message.py | 3 +- 6 files changed, 101 insertions(+), 61 deletions(-) create mode 100644 letta/llm_api/groq_client.py diff --git a/letta/llm_api/groq_client.py b/letta/llm_api/groq_client.py new file mode 100644 index 00000000..02ddb98a --- /dev/null +++ b/letta/llm_api/groq_client.py @@ -0,0 +1,79 @@ +import os +from typing import List, Optional + +from openai import AsyncOpenAI, AsyncStream, OpenAI +from openai.types.chat.chat_completion import ChatCompletion +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + +from letta.llm_api.openai_client import OpenAIClient +from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.schemas.message import Message as PydanticMessage +from letta.settings import model_settings + + +class GroqClient(OpenAIClient): + + def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: + return False + + def supports_structured_output(self, llm_config: LLMConfig) -> bool: + return True + + @trace_method + def build_request_data( + self, + messages: List[PydanticMessage], + llm_config: LLMConfig, + tools: Optional[List[dict]] = None, + force_tool_call: Optional[str] = None, + ) -> dict: + data = super().build_request_data(messages, llm_config, tools, force_tool_call) + + # Groq validation - these fields are not supported and will cause 400 errors + # https://console.groq.com/docs/openai + if "top_logprobs" in data: + del data["top_logprobs"] + if "logit_bias" in data: + del data["logit_bias"] + data["logprobs"] = False + data["n"] = 1 + + return data + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying synchronous request to Groq API and returns raw response dict. + """ + api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") + client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying asynchronous request to Groq API and returns raw response dict. + """ + api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: + """Request embeddings given texts and embedding config""" + api_key = model_settings.groq_api_key or os.environ.get("GROQ_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) + + # TODO: add total usage + return [r.embedding for r in response.data] + + @trace_method + async def stream_async(self, request_data: dict, llm_config: LLMConfig) -> AsyncStream[ChatCompletionChunk]: + raise NotImplementedError("Streaming not supported for Groq.") diff --git a/letta/llm_api/llm_api_tools.py b/letta/llm_api/llm_api_tools.py index 184615fa..3dd8cd74 100644 --- a/letta/llm_api/llm_api_tools.py +++ b/letta/llm_api/llm_api_tools.py @@ -8,7 +8,7 @@ import requests from letta.constants import CLI_WARNING_PREFIX from letta.errors import LettaConfigurationError, RateLimitExceededError from letta.llm_api.deepseek import build_deepseek_chat_completions_request, convert_deepseek_response_to_chatcompletion -from letta.llm_api.helpers import add_inner_thoughts_to_functions, unpack_all_inner_thoughts_from_kwargs +from letta.llm_api.helpers import unpack_all_inner_thoughts_from_kwargs from letta.llm_api.openai import ( build_openai_chat_completions_request, openai_chat_completions_process_stream, @@ -16,14 +16,13 @@ from letta.llm_api.openai import ( prepare_openai_payload, ) from letta.local_llm.chat_completion_proxy import get_chat_completion -from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG_DESCRIPTION +from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages from letta.orm.user import User from letta.otel.tracing import log_event, trace_method from letta.schemas.enums import ProviderCategory from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message -from letta.schemas.openai.chat_completion_request import ChatCompletionRequest from letta.schemas.openai.chat_completion_response import ChatCompletionResponse from letta.schemas.provider_trace import ProviderTraceCreate from letta.services.telemetry_manager import TelemetryManager @@ -246,57 +245,6 @@ def create( return response - elif llm_config.model_endpoint_type == "groq": - if stream: - raise NotImplementedError("Streaming not yet implemented for Groq.") - - if model_settings.groq_api_key is None and llm_config.model_endpoint == "https://api.groq.com/openai/v1/chat/completions": - raise LettaConfigurationError(message="Groq key is missing from letta config file", missing_fields=["groq_api_key"]) - - # force to true for groq, since they don't support 'content' is non-null - if llm_config.put_inner_thoughts_in_kwargs: - functions = add_inner_thoughts_to_functions( - functions=functions, - inner_thoughts_key=INNER_THOUGHTS_KWARG, - inner_thoughts_description=INNER_THOUGHTS_KWARG_DESCRIPTION, - ) - - tools = [{"type": "function", "function": f} for f in functions] if functions is not None else None - data = ChatCompletionRequest( - model=llm_config.model, - messages=[m.to_openai_dict(put_inner_thoughts_in_kwargs=llm_config.put_inner_thoughts_in_kwargs) for m in messages], - tools=tools, - tool_choice=function_call, - user=str(user_id), - ) - - # https://console.groq.com/docs/openai - # "The following fields are currently not supported and will result in a 400 error (yikes) if they are supplied:" - assert data.top_logprobs is None - assert data.logit_bias is None - assert data.logprobs == False - assert data.n == 1 - # They mention that none of the messages can have names, but it seems to not error out (for now) - - data.stream = False - if isinstance(stream_interface, AgentChunkStreamingInterface): - stream_interface.stream_start() - try: - # groq uses the openai chat completions API, so this component should be reusable - response = openai_chat_completions_request( - url=llm_config.model_endpoint, - api_key=model_settings.groq_api_key, - chat_completion_request=data, - ) - finally: - if isinstance(stream_interface, AgentChunkStreamingInterface): - stream_interface.stream_end() - - if llm_config.put_inner_thoughts_in_kwargs: - response = unpack_all_inner_thoughts_from_kwargs(response=response, inner_thoughts_key=INNER_THOUGHTS_KWARG) - - return response - elif llm_config.model_endpoint_type == "deepseek": if model_settings.deepseek_api_key is None and llm_config.model_endpoint == "": # only is a problem if we are *not* using an openai proxy diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 4fcd082a..927c5951 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -86,5 +86,12 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.groq: + from letta.llm_api.groq_client import GroqClient + + return GroqClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case _: return None diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index ef6719e6..009c2682 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -1000,6 +1000,7 @@ async def send_message( "ollama", "azure", "xai", + "groq", ] # Create a new run for execution tracking @@ -1144,6 +1145,7 @@ async def send_message_streaming( "ollama", "azure", "xai", + "groq", ] model_compatible_token_streaming = agent.llm_config.model_endpoint_type in ["anthropic", "openai", "bedrock"] @@ -1350,6 +1352,7 @@ async def _process_message_background( "google_vertex", "bedrock", "ollama", + "groq", ] if agent_eligible and model_compatible: if agent.enable_sleeptime and agent.agent_type != AgentType.voice_convo_agent: @@ -1539,6 +1542,7 @@ async def preview_raw_payload( "ollama", "azure", "xai", + "groq", ] if agent_eligible and model_compatible: @@ -1609,6 +1613,7 @@ async def summarize_agent_conversation( "ollama", "azure", "xai", + "groq", ] if agent_eligible and model_compatible: diff --git a/tests/configs/llm_model_configs/groq.json b/tests/configs/llm_model_configs/groq.json index 5f5c92f9..87e0b50d 100644 --- a/tests/configs/llm_model_configs/groq.json +++ b/tests/configs/llm_model_configs/groq.json @@ -1,8 +1,8 @@ { - "context_window": 8192, - "model": "llama-3.1-70b-versatile", - "model_endpoint_type": "groq", - "model_endpoint": "https://api.groq.com/openai/v1", - "model_wrapper": null, - "put_inner_thoughts_in_kwargs": true + "context_window": 8192, + "model": "qwen/qwen3-32b", + "model_endpoint_type": "groq", + "model_endpoint": "https://api.groq.com/openai/v1", + "model_wrapper": null, + "put_inner_thoughts_in_kwargs": true } diff --git a/tests/integration_test_send_message.py b/tests/integration_test_send_message.py index f2a0d68c..acf2b733 100644 --- a/tests/integration_test_send_message.py +++ b/tests/integration_test_send_message.py @@ -144,7 +144,7 @@ USER_MESSAGE_BASE64_IMAGE: List[MessageCreate] = [ ] # configs for models that are to dumb to do much other than messaging -limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json"] +limited_configs = ["ollama.json", "together-qwen-2.5-72b-instruct.json", "vllm.json", "lmstudio.json", "groq.json"] all_configs = [ "openai-gpt-4o-mini.json", @@ -161,6 +161,7 @@ all_configs = [ "gemini-2.5-pro-vertex.json", "ollama.json", "together-qwen-2.5-72b-instruct.json", + "groq.json", ] reasoning_configs = [