feat: add new together llm client (#3875)

This commit is contained in:
cthomas
2025-08-12 13:37:20 -07:00
committed by GitHub
parent 17bd5ff2b0
commit 227e472282
3 changed files with 63 additions and 15 deletions

View File

@@ -58,12 +58,19 @@ class LLMClient:
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.openai | ProviderType.together | ProviderType.ollama:
case ProviderType.openai | ProviderType.ollama:
from letta.llm_api.openai_client import OpenAIClient
return OpenAIClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case ProviderType.together:
from letta.llm_api.together_client import TogetherClient
return TogetherClient(
put_inner_thoughts_first=put_inner_thoughts_first,
actor=actor,
)
case _:
return None

View File

@@ -26,7 +26,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG
from letta.log import get_logger
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import ProviderCategory, ProviderType
from letta.schemas.enums import ProviderCategory
from letta.schemas.letta_message_content import MessageContentType
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message as PydanticMessage
@@ -102,8 +102,6 @@ def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
"""Certain providers require the tool choice to be set to 'auto'."""
if "nebius.com" in llm_config.model_endpoint:
return True
if "together.ai" in llm_config.model_endpoint or "together.xyz" in llm_config.model_endpoint:
return True
if llm_config.handle and "vllm" in llm_config.handle:
return True
if llm_config.compatibility_type == "mlx":
@@ -118,8 +116,6 @@ class OpenAIClient(LLMClientBase):
from letta.services.provider_manager import ProviderManager
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
if llm_config.model_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@@ -130,12 +126,7 @@ class OpenAIClient(LLMClientBase):
return kwargs
def _prepare_client_kwargs_embedding(self, embedding_config: EmbeddingConfig) -> dict:
api_key = None
if embedding_config.embedding_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
# supposedly the openai python client requires a dummy API key
api_key = api_key or "DUMMY_API_KEY"
kwargs = {"api_key": api_key, "base_url": embedding_config.embedding_endpoint}
@@ -147,8 +138,6 @@ class OpenAIClient(LLMClientBase):
from letta.services.provider_manager import ProviderManager
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
if llm_config.model_endpoint_type == ProviderType.together:
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
if not api_key:
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
@@ -158,6 +147,9 @@ class OpenAIClient(LLMClientBase):
return kwargs
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return requires_auto_tool_choice(llm_config)
@trace_method
def build_request_data(
self,
@@ -204,7 +196,7 @@ class OpenAIClient(LLMClientBase):
# TODO(matt) move into LLMConfig
# TODO: This vllm checking is very brittle and is a patch at most
tool_choice = None
if requires_auto_tool_choice(llm_config):
if self.requires_auto_tool_choice(llm_config):
tool_choice = "auto"
elif tools:
# only set if tools is non-Null

View File

@@ -0,0 +1,49 @@
import os
from typing import List
from openai import AsyncOpenAI, OpenAI
from openai.types.chat.chat_completion import ChatCompletion
from letta.llm_api.openai_client import OpenAIClient
from letta.otel.tracing import trace_method
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.llm_config import LLMConfig
from letta.settings import model_settings
class TogetherClient(OpenAIClient):
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
return True
@trace_method
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying synchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
"""
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
"""
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
response: ChatCompletion = await client.chat.completions.create(**request_data)
return response.model_dump()
@trace_method
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
"""Request embeddings given texts and embedding config"""
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
# TODO: add total usage
return [r.embedding for r in response.data]