feat: support togetherAI via /completions (#2045)
This commit is contained in:
@@ -2,7 +2,7 @@ from typing import List, Optional
|
||||
|
||||
from pydantic import BaseModel, Field, model_validator
|
||||
|
||||
from letta.constants import LLM_MAX_TOKENS
|
||||
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
|
||||
from letta.llm_api.azure_openai import (
|
||||
get_azure_chat_completions_endpoint,
|
||||
get_azure_embeddings_endpoint,
|
||||
@@ -67,10 +67,15 @@ class OpenAIProvider(Provider):
|
||||
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
|
||||
|
||||
assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
data = response
|
||||
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
for model in data:
|
||||
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
@@ -82,6 +87,32 @@ class OpenAIProvider(Provider):
|
||||
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
||||
# TogetherAI includes the type, which we can use to filter out embedding models
|
||||
if self.base_url == "https://api.together.ai/v1":
|
||||
if "type" in model and model["type"] != "chat":
|
||||
continue
|
||||
|
||||
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
|
||||
# requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
|
||||
# "error": {
|
||||
# "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
|
||||
# "type": "invalid_request_error",
|
||||
# "param": null,
|
||||
# "code": "constraints_model"
|
||||
# }
|
||||
# }
|
||||
if "config" not in model:
|
||||
continue
|
||||
if "chat_template" not in model["config"]:
|
||||
continue
|
||||
if model["config"]["chat_template"] is None:
|
||||
continue
|
||||
if "tools" not in model["config"]["chat_template"]:
|
||||
continue
|
||||
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
|
||||
# continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
|
||||
)
|
||||
@@ -325,6 +356,113 @@ class GroqProvider(OpenAIProvider):
|
||||
raise NotImplementedError
|
||||
|
||||
|
||||
class TogetherProvider(OpenAIProvider):
|
||||
"""TogetherAI provider that uses the /completions API
|
||||
|
||||
TogetherAI can also be used via the /chat/completions API
|
||||
by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
|
||||
and API URL, however /completions is preferred because their /chat/completions
|
||||
function calling support is limited.
|
||||
"""
|
||||
|
||||
name: str = "together"
|
||||
base_url: str = "https://api.together.ai/v1"
|
||||
api_key: str = Field(..., description="API key for the TogetherAI API.")
|
||||
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
|
||||
# TogetherAI's response is missing the 'data' field
|
||||
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
if "data" in response:
|
||||
data = response["data"]
|
||||
else:
|
||||
data = response
|
||||
|
||||
configs = []
|
||||
for model in data:
|
||||
assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "context_length" in model:
|
||||
# Context length is returned in OpenRouter as "context_length"
|
||||
context_window_size = model["context_length"]
|
||||
else:
|
||||
context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
# We need the context length for embeddings too
|
||||
if not context_window_size:
|
||||
continue
|
||||
|
||||
# Skip models that are too small for Letta
|
||||
if context_window_size <= MIN_CONTEXT_WINDOW:
|
||||
continue
|
||||
|
||||
# TogetherAI includes the type, which we can use to filter for embedding models
|
||||
if "type" in model and model["type"] not in ["chat", "language"]:
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="together",
|
||||
model_endpoint=self.base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=context_window_size,
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
# TODO renable once we figure out how to pass API keys through properly
|
||||
return []
|
||||
|
||||
# from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
# response = openai_get_model_list(self.base_url, api_key=self.api_key)
|
||||
|
||||
# # TogetherAI's response is missing the 'data' field
|
||||
# # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
|
||||
# if "data" in response:
|
||||
# data = response["data"]
|
||||
# else:
|
||||
# data = response
|
||||
|
||||
# configs = []
|
||||
# for model in data:
|
||||
# assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
|
||||
# model_name = model["id"]
|
||||
|
||||
# if "context_length" in model:
|
||||
# # Context length is returned in OpenRouter as "context_length"
|
||||
# context_window_size = model["context_length"]
|
||||
# else:
|
||||
# context_window_size = self.get_model_context_window_size(model_name)
|
||||
|
||||
# if not context_window_size:
|
||||
# continue
|
||||
|
||||
# # TogetherAI includes the type, which we can use to filter out embedding models
|
||||
# if "type" in model and model["type"] not in ["embedding"]:
|
||||
# continue
|
||||
|
||||
# configs.append(
|
||||
# EmbeddingConfig(
|
||||
# embedding_model=model_name,
|
||||
# embedding_endpoint_type="openai",
|
||||
# embedding_endpoint=self.base_url,
|
||||
# embedding_dim=context_window_size,
|
||||
# embedding_chunk_size=300, # TODO: change?
|
||||
# )
|
||||
# )
|
||||
|
||||
# return configs
|
||||
|
||||
|
||||
class GoogleAIProvider(Provider):
|
||||
# gemini
|
||||
api_key: str = Field(..., description="API key for the Google AI API.")
|
||||
|
||||
Reference in New Issue
Block a user