feat: add new together llm client (#3875)
This commit is contained in:
@@ -58,12 +58,19 @@ class LLMClient:
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.openai | ProviderType.together | ProviderType.ollama:
|
||||
case ProviderType.openai | ProviderType.ollama:
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
|
||||
return OpenAIClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case ProviderType.together:
|
||||
from letta.llm_api.together_client import TogetherClient
|
||||
|
||||
return TogetherClient(
|
||||
put_inner_thoughts_first=put_inner_thoughts_first,
|
||||
actor=actor,
|
||||
)
|
||||
case _:
|
||||
return None
|
||||
|
||||
@@ -26,7 +26,7 @@ from letta.local_llm.constants import INNER_THOUGHTS_KWARG, INNER_THOUGHTS_KWARG
|
||||
from letta.log import get_logger
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import ProviderCategory, ProviderType
|
||||
from letta.schemas.enums import ProviderCategory
|
||||
from letta.schemas.letta_message_content import MessageContentType
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
@@ -102,8 +102,6 @@ def requires_auto_tool_choice(llm_config: LLMConfig) -> bool:
|
||||
"""Certain providers require the tool choice to be set to 'auto'."""
|
||||
if "nebius.com" in llm_config.model_endpoint:
|
||||
return True
|
||||
if "together.ai" in llm_config.model_endpoint or "together.xyz" in llm_config.model_endpoint:
|
||||
return True
|
||||
if llm_config.handle and "vllm" in llm_config.handle:
|
||||
return True
|
||||
if llm_config.compatibility_type == "mlx":
|
||||
@@ -118,8 +116,6 @@ class OpenAIClient(LLMClientBase):
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = ProviderManager().get_override_key(llm_config.provider_name, actor=self.actor)
|
||||
if llm_config.model_endpoint_type == ProviderType.together:
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@@ -130,12 +126,7 @@ class OpenAIClient(LLMClientBase):
|
||||
return kwargs
|
||||
|
||||
def _prepare_client_kwargs_embedding(self, embedding_config: EmbeddingConfig) -> dict:
|
||||
api_key = None
|
||||
if embedding_config.embedding_endpoint_type == ProviderType.together:
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
# supposedly the openai python client requires a dummy API key
|
||||
api_key = api_key or "DUMMY_API_KEY"
|
||||
kwargs = {"api_key": api_key, "base_url": embedding_config.embedding_endpoint}
|
||||
@@ -147,8 +138,6 @@ class OpenAIClient(LLMClientBase):
|
||||
from letta.services.provider_manager import ProviderManager
|
||||
|
||||
api_key = await ProviderManager().get_override_key_async(llm_config.provider_name, actor=self.actor)
|
||||
if llm_config.model_endpoint_type == ProviderType.together:
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
|
||||
if not api_key:
|
||||
api_key = model_settings.openai_api_key or os.environ.get("OPENAI_API_KEY")
|
||||
@@ -158,6 +147,9 @@ class OpenAIClient(LLMClientBase):
|
||||
|
||||
return kwargs
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return requires_auto_tool_choice(llm_config)
|
||||
|
||||
@trace_method
|
||||
def build_request_data(
|
||||
self,
|
||||
@@ -204,7 +196,7 @@ class OpenAIClient(LLMClientBase):
|
||||
# TODO(matt) move into LLMConfig
|
||||
# TODO: This vllm checking is very brittle and is a patch at most
|
||||
tool_choice = None
|
||||
if requires_auto_tool_choice(llm_config):
|
||||
if self.requires_auto_tool_choice(llm_config):
|
||||
tool_choice = "auto"
|
||||
elif tools:
|
||||
# only set if tools is non-Null
|
||||
|
||||
49
letta/llm_api/together_client.py
Normal file
49
letta/llm_api/together_client.py
Normal file
@@ -0,0 +1,49 @@
|
||||
import os
|
||||
from typing import List
|
||||
|
||||
from openai import AsyncOpenAI, OpenAI
|
||||
from openai.types.chat.chat_completion import ChatCompletion
|
||||
|
||||
from letta.llm_api.openai_client import OpenAIClient
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.settings import model_settings
|
||||
|
||||
|
||||
class TogetherClient(OpenAIClient):
|
||||
|
||||
def requires_auto_tool_choice(self, llm_config: LLMConfig) -> bool:
|
||||
return True
|
||||
|
||||
@trace_method
|
||||
def request(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying synchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
client = OpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_async(self, request_data: dict, llm_config: LLMConfig) -> dict:
|
||||
"""
|
||||
Performs underlying asynchronous request to OpenAI API and returns raw response dict.
|
||||
"""
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=llm_config.model_endpoint)
|
||||
|
||||
response: ChatCompletion = await client.chat.completions.create(**request_data)
|
||||
return response.model_dump()
|
||||
|
||||
@trace_method
|
||||
async def request_embeddings(self, inputs: List[str], embedding_config: EmbeddingConfig) -> List[List[float]]:
|
||||
"""Request embeddings given texts and embedding config"""
|
||||
api_key = model_settings.together_api_key or os.environ.get("TOGETHER_API_KEY")
|
||||
client = AsyncOpenAI(api_key=api_key, base_url=embedding_config.embedding_endpoint)
|
||||
response = await client.embeddings.create(model=embedding_config.embedding_model, input=inputs)
|
||||
|
||||
# TODO: add total usage
|
||||
return [r.embedding for r in response.data]
|
||||
Reference in New Issue
Block a user