From 227e4722820aa37b59e815ca8caa2f95504ba68e Mon Sep 17 00:00:00 2001 From: cthomas Date: Tue, 12 Aug 2025 13:37:20 -0700 Subject: [PATCH] feat: add new together llm client (#3875) --- letta/llm_api/llm_client.py | 9 +++++- letta/llm_api/openai_client.py | 20 ++++--------- letta/llm_api/together_client.py | 49 ++++++++++++++++++++++++++++++++ 3 files changed, 63 insertions(+), 15 deletions(-) create mode 100644 letta/llm_api/together_client.py diff --git a/letta/llm_api/llm_client.py b/letta/llm_api/llm_client.py index 1e03c5f7..b047ec85 100644 --- a/letta/llm_api/llm_client.py +++ b/letta/llm_api/llm_client.py @@ -58,12 +58,19 @@ class LLMClient: put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) - case ProviderType.openai | ProviderType.together | ProviderType.ollama: + case ProviderType.openai | ProviderType.ollama: from letta.llm_api.openai_client import OpenAIClient return OpenAIClient( put_inner_thoughts_first=put_inner_thoughts_first, actor=actor, ) + case ProviderType.together: + from letta.llm_api.together_client import TogetherClient + + return TogetherClient( + put_inner_thoughts_first=put_inner_thoughts_first, + actor=actor, + ) case _: return None diff --git a/letta/llm_api/openai_client.py b/letta/llm_api/openai_client.py index 159372d2..035a7734 100644 --- a/letta/llm_api/openai_client.py +++ b/letta/llm_api/openai_client.py @@ -26,7 +26,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.otel.tracing import trace_method from letta.schemas.embedding_config import EmbeddingConfig -from letta.schemas.enums import ProviderCategory, ProviderType +from letta.schemas.enums import ProviderCategory from letta.schemas.letta_message_content import MessageContentType from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage @@ -102,8 +102,6 @@ def requires_auto_tool_choice(llm_config: LLMConfig) -> bool: """Certain providers require the tool choice to be set to 'auto'.""" if "nebius.com" in llm_config.model_endpoint: return True - if "together.ai" in llm_config.model_endpoint or "together.xyz" in llm_config.model_endpoint: - return True if llm_config.handle and "vllm" in llm_config.handle: return True if llm_config.compatibility_type == "mlx": @@ -118,8 +116,6 @@ class OpenAIClient(LLMClientBase): from letta.services.provider_manager import ProviderManager api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor) - if llm_config.model_endpoint_type == ProviderType.together: - api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") if not api_key: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") @@ -130,12 +126,7 @@ class OpenAIClient(LLMClientBase): return kwargs def _prepare_client_kwargs_embedding(self, embedding_config: EmbeddingConfig) -> dict: - api_key = None - if embedding_config.embedding_endpoint_type == ProviderType.together: - api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") - - if not api_key: - api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") + api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") # supposedly the openai python client requires a dummy API key api_key = api_key or "DUMMY_API_KEY" kwargs = {"api_key": api_key, "base_url": embedding_config.embedding_endpoint} @@ -147,8 +138,6 @@ class OpenAIClient(LLMClientBase): from letta.services.provider_manager import ProviderManager api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor) - if llm_config.model_endpoint_type == ProviderType.together: - api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") if not api_key: api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY") @@ -158,6 +147,9 @@ class OpenAIClient(LLMClientBase): return kwargs + def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: + return requires_auto_tool_choice(llm_config) + @trace_method def build_request_data( self, @@ -204,7 +196,7 @@ class OpenAIClient(LLMClientBase): # TODO(matt) move into LLMConfig # TODO: This vllm checking is very brittle and is a patch at most tool_choice = None - if requires_auto_tool_choice(llm_config): + if self.requires_auto_tool_choice(llm_config): tool_choice = "auto" elif tools: # only set if tools is non-Null diff --git a/letta/llm_api/together_client.py b/letta/llm_api/together_client.py new file mode 100644 index 00000000..83ccd435 --- /dev/null +++ b/letta/llm_api/together_client.py @@ -0,0 +1,49 @@ +import os +from typing import List + +from openai import AsyncOpenAI, OpenAI +from openai.types.chat.chat_completion import ChatCompletion + +from letta.llm_api.openai_client import OpenAIClient +from letta.otel.tracing import trace_method +from letta.schemas.embedding_config import EmbeddingConfig +from letta.schemas.llm_config import LLMConfig +from letta.settings import model_settings + + +class TogetherClient(OpenAIClient): + + def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool: + return True + + @trace_method + def request(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying synchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") + client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict: + """ + Performs underlying asynchronous request to OpenAI API and returns raw response dict. + """ + api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint) + + response: ChatCompletion = await client.chat.completions.create(**request_data) + return response.model_dump() + + @trace_method + async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]: + """Request embeddings given texts and embedding config""" + api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY") + client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint) + response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs) + + # TODO: add total usage + return [r.embedding for r in response.data]