feat: Add MistralProvider (#1883)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-14 13:34:12 -07:00
committed by GitHub
parent f408436669
commit 6e6aad7bb0
3 changed files with 105 additions and 3 deletions

47
letta/llm_api/mistral.py Normal file
View File

@@ -0,0 +1,47 @@
import requests
from letta.utils import printd, smart_urljoin
def mistral_get_model_list(url: str, api_key: str) -> dict:
url = smart_urljoin(url, "models")
headers = {"Content-Type": "application/json"}
if api_key is not None:
headers["Authorization"] = f"Bearer {api_key}"
printd(f"Sending request to {url}")
response = None
try:
# TODO add query param "tool" to be true
response = requests.get(url, headers=headers)
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
response_json = response.json() # convert to dict from string
return response_json
except requests.exceptions.HTTPError as http_err:
# Handle HTTP errors (e.g., response 4XX, 5XX)
try:
if response:
response = response.json()
except:
pass
printd(f"Got HTTPError, exception={http_err}, response={response}")
raise http_err
except requests.exceptions.RequestException as req_err:
# Handle other requests-related errors (e.g., connection error)
try:
if response:
response = response.json()
except:
pass
printd(f"Got RequestException, exception={req_err}, response={response}")
raise req_err
except Exception as e:
# Handle other potential errors
try:
if response:
response = response.json()
except:
pass
printd(f"Got unknown Exception, exception={e}, response={response}")
raise e

View File

@@ -139,6 +139,50 @@ class AnthropicProvider(Provider):
return []
class MistralProvider(Provider):
name: str = "mistral"
api_key: str = Field(..., description="API key for the Mistral API.")
base_url: str = "https://api.mistral.ai/v1"
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.mistral import mistral_get_model_list
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
# See: https://openrouter.ai/docs/requests
response = mistral_get_model_list(self.base_url, api_key=self.api_key)
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
configs = []
for model in response["data"]:
# If model has chat completions and function calling enabled
if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
configs.append(
LLMConfig(
model=model["id"],
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=model["max_context_length"],
)
)
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
# Not supported for mistral
return []
def get_model_context_window(self, model_name: str) -> Optional[int]:
# Redoing this is fine because it's a pretty lightweight call
models = self.list_llm_models()
for m in models:
if model_name in m["id"]:
return int(m["max_context_length"])
return None
class OllamaProvider(OpenAIProvider):
"""Ollama provider that uses the native /api/generate endpoint

View File

@@ -2,7 +2,9 @@ import os
from letta.providers import (
AnthropicProvider,
AzureProvider,
GoogleAIProvider,
MistralProvider,
OllamaProvider,
OpenAIProvider,
)
@@ -33,10 +35,13 @@ def test_anthropic():
#
# TODO: Add this test
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
def test_azure():
pass
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
models = provider.list_llm_models()
print([m.model for m in models])
embed_models = provider.list_embedding_models()
print([m.embedding_model for m in embed_models])
def test_ollama():
@@ -60,6 +65,12 @@ def test_googleai():
provider.list_embedding_models()
def test_mistral():
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
models = provider.list_llm_models()
print([m.model for m in models])
# def test_vllm():
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
# models = provider.list_llm_models()