feat: support togetherAI via /completions (#2045)

This commit is contained in:
Charles Packer
2024-11-18 15:15:05 -08:00
committed by GitHub
parent cada5976da
commit f57dc28552
14 changed files with 364 additions and 6 deletions

View File

@@ -2,7 +2,7 @@ from typing import List, Optional
from pydantic import BaseModel, Field, model_validator
from letta.constants import LLM_MAX_TOKENS
from letta.constants import LLM_MAX_TOKENS, MIN_CONTEXT_WINDOW
from letta.llm_api.azure_openai import (
get_azure_chat_completions_endpoint,
get_azure_embeddings_endpoint,
@@ -67,10 +67,15 @@ class OpenAIProvider(Provider):
extra_params = {"supported_parameters": "tools"} if "openrouter.ai" in self.base_url else None
response = openai_get_model_list(self.base_url, api_key=self.api_key, extra_params=extra_params)
assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
if "data" in response:
data = response["data"]
else:
data = response
configs = []
for model in response["data"]:
for model in data:
assert "id" in model, f"OpenAI model missing 'id' field: {model}"
model_name = model["id"]
@@ -82,6 +87,32 @@ class OpenAIProvider(Provider):
if not context_window_size:
continue
# TogetherAI includes the type, which we can use to filter out embedding models
if self.base_url == "https://api.together.ai/v1":
if "type" in model and model["type"] != "chat":
continue
# for TogetherAI, we need to skip the models that don't support JSON mode / function calling
# requests.exceptions.HTTPError: HTTP error occurred: 400 Client Error: Bad Request for url: https://api.together.ai/v1/chat/completions | Status code: 400, Message: {
# "error": {
# "message": "mistralai/Mixtral-8x7B-v0.1 is not supported for JSON mode/function calling",
# "type": "invalid_request_error",
# "param": null,
# "code": "constraints_model"
# }
# }
if "config" not in model:
continue
if "chat_template" not in model["config"]:
continue
if model["config"]["chat_template"] is None:
continue
if "tools" not in model["config"]["chat_template"]:
continue
# if "config" in data and "chat_template" in data["config"] and "tools" not in data["config"]["chat_template"]:
# continue
configs.append(
LLMConfig(model=model_name, model_endpoint_type="openai", model_endpoint=self.base_url, context_window=context_window_size)
)
@@ -325,6 +356,113 @@ class GroqProvider(OpenAIProvider):
raise NotImplementedError
class TogetherProvider(OpenAIProvider):
"""TogetherAI provider that uses the /completions API
TogetherAI can also be used via the /chat/completions API
by settings OPENAI_API_KEY and OPENAI_API_BASE to the TogetherAI API key
and API URL, however /completions is preferred because their /chat/completions
function calling support is limited.
"""
name: str = "together"
base_url: str = "https://api.together.ai/v1"
api_key: str = Field(..., description="API key for the TogetherAI API.")
default_prompt_formatter: str = Field(..., description="Default prompt formatter (aka model wrapper) to use on vLLM /completions API.")
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list
response = openai_get_model_list(self.base_url, api_key=self.api_key)
# TogetherAI's response is missing the 'data' field
# assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
if "data" in response:
data = response["data"]
else:
data = response
configs = []
for model in data:
assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
model_name = model["id"]
if "context_length" in model:
# Context length is returned in OpenRouter as "context_length"
context_window_size = model["context_length"]
else:
context_window_size = self.get_model_context_window_size(model_name)
# We need the context length for embeddings too
if not context_window_size:
continue
# Skip models that are too small for Letta
if context_window_size <= MIN_CONTEXT_WINDOW:
continue
# TogetherAI includes the type, which we can use to filter for embedding models
if "type" in model and model["type"] not in ["chat", "language"]:
continue
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="together",
model_endpoint=self.base_url,
model_wrapper=self.default_prompt_formatter,
context_window=context_window_size,
)
)
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
# TODO renable once we figure out how to pass API keys through properly
return []
# from letta.llm_api.openai import openai_get_model_list
# response = openai_get_model_list(self.base_url, api_key=self.api_key)
# # TogetherAI's response is missing the 'data' field
# # assert "data" in response, f"OpenAI model query response missing 'data' field: {response}"
# if "data" in response:
# data = response["data"]
# else:
# data = response
# configs = []
# for model in data:
# assert "id" in model, f"TogetherAI model missing 'id' field: {model}"
# model_name = model["id"]
# if "context_length" in model:
# # Context length is returned in OpenRouter as "context_length"
# context_window_size = model["context_length"]
# else:
# context_window_size = self.get_model_context_window_size(model_name)
# if not context_window_size:
# continue
# # TogetherAI includes the type, which we can use to filter out embedding models
# if "type" in model and model["type"] not in ["embedding"]:
# continue
# configs.append(
# EmbeddingConfig(
# embedding_model=model_name,
# embedding_endpoint_type="openai",
# embedding_endpoint=self.base_url,
# embedding_dim=context_window_size,
# embedding_chunk_size=300, # TODO: change?
# )
# )
# return configs
class GoogleAIProvider(Provider):
# gemini
api_key: str = Field(..., description="API key for the Google AI API.")