feat: Add MistralProvider (#1883)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
47
letta/llm_api/mistral.py
Normal file
47
letta/llm_api/mistral.py
Normal file
@@ -0,0 +1,47 @@
|
||||
import requests
|
||||
|
||||
from letta.utils import printd, smart_urljoin
|
||||
|
||||
|
||||
def mistral_get_model_list(url: str, api_key: str) -> dict:
|
||||
url = smart_urljoin(url, "models")
|
||||
|
||||
headers = {"Content-Type": "application/json"}
|
||||
if api_key is not None:
|
||||
headers["Authorization"] = f"Bearer {api_key}"
|
||||
|
||||
printd(f"Sending request to {url}")
|
||||
response = None
|
||||
try:
|
||||
# TODO add query param "tool" to be true
|
||||
response = requests.get(url, headers=headers)
|
||||
response.raise_for_status() # Raises HTTPError for 4XX/5XX status
|
||||
response_json = response.json() # convert to dict from string
|
||||
return response_json
|
||||
except requests.exceptions.HTTPError as http_err:
|
||||
# Handle HTTP errors (e.g., response 4XX, 5XX)
|
||||
try:
|
||||
if response:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got HTTPError, exception={http_err}, response={response}")
|
||||
raise http_err
|
||||
except requests.exceptions.RequestException as req_err:
|
||||
# Handle other requests-related errors (e.g., connection error)
|
||||
try:
|
||||
if response:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got RequestException, exception={req_err}, response={response}")
|
||||
raise req_err
|
||||
except Exception as e:
|
||||
# Handle other potential errors
|
||||
try:
|
||||
if response:
|
||||
response = response.json()
|
||||
except:
|
||||
pass
|
||||
printd(f"Got unknown Exception, exception={e}, response={response}")
|
||||
raise e
|
||||
@@ -139,6 +139,50 @@ class AnthropicProvider(Provider):
|
||||
return []
|
||||
|
||||
|
||||
class MistralProvider(Provider):
|
||||
name: str = "mistral"
|
||||
api_key: str = Field(..., description="API key for the Mistral API.")
|
||||
base_url: str = "https://api.mistral.ai/v1"
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.mistral import mistral_get_model_list
|
||||
|
||||
# Some hardcoded support for OpenRouter (so that we only get models with tool calling support)...
|
||||
# See: https://openrouter.ai/docs/requests
|
||||
response = mistral_get_model_list(self.base_url, api_key=self.api_key)
|
||||
|
||||
assert "data" in response, f"Mistral model query response missing 'data' field: {response}"
|
||||
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
# If model has chat completions and function calling enabled
|
||||
if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]:
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model["id"],
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=model["max_context_length"],
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
# Not supported for mistral
|
||||
return []
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
# Redoing this is fine because it's a pretty lightweight call
|
||||
models = self.list_llm_models()
|
||||
|
||||
for m in models:
|
||||
if model_name in m["id"]:
|
||||
return int(m["max_context_length"])
|
||||
|
||||
return None
|
||||
|
||||
|
||||
class OllamaProvider(OpenAIProvider):
|
||||
"""Ollama provider that uses the native /api/generate endpoint
|
||||
|
||||
|
||||
@@ -2,7 +2,9 @@ import os
|
||||
|
||||
from letta.providers import (
|
||||
AnthropicProvider,
|
||||
AzureProvider,
|
||||
GoogleAIProvider,
|
||||
MistralProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
)
|
||||
@@ -33,10 +35,13 @@ def test_anthropic():
|
||||
#
|
||||
|
||||
|
||||
# TODO: Add this test
|
||||
# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy
|
||||
def test_azure():
|
||||
pass
|
||||
provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL"))
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
|
||||
embed_models = provider.list_embedding_models()
|
||||
print([m.embedding_model for m in embed_models])
|
||||
|
||||
|
||||
def test_ollama():
|
||||
@@ -60,6 +65,12 @@ def test_googleai():
|
||||
provider.list_embedding_models()
|
||||
|
||||
|
||||
def test_mistral():
|
||||
provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY"))
|
||||
models = provider.list_llm_models()
|
||||
print([m.model for m in models])
|
||||
|
||||
|
||||
# def test_vllm():
|
||||
# provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE"))
|
||||
# models = provider.list_llm_models()
|
||||
|
||||
Reference in New Issue
Block a user