From 6e6aad7bb08915938ecbc0961055678240dd45e5 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Mon, 14 Oct 2024 13:34:12 -0700 Subject: [PATCH] feat: Add MistralProvider (#1883) Co-authored-by: Matt Zhou --- letta/llm_api/mistral.py | 47 ++++++++++++++++++++++++++++++++++++++++ letta/providers.py | 44 +++++++++++++++++++++++++++++++++++++ tests/test_providers.py | 17 ++++++++++++--- 3 files changed, 105 insertions(+), 3 deletions(-) create mode 100644 letta/llm_api/mistral.py diff --git a/letta/llm_api/mistral.py b/letta/llm_api/mistral.py new file mode 100644 index 00000000..932cf874 --- /dev/null +++ b/letta/llm_api/mistral.py @@ -0,0 +1,47 @@ +import requests + +from letta.utils import printd, smart_urljoin + + +def mistral_get_model_list(url: str, api_key: str) -> dict: + url = smart_urljoin(url, "models") + + headers = {"Content-Type": "application/json"} + if api_key is not None: + headers["Authorization"] = f"Bearer {api_key}" + + printd(f"Sending request to {url}") + response = None + try: + # TODO add query param "tool" to be true + response = requests.get(url, headers=headers) + response.raise_for_status() # Raises HTTPError for 4XX/5XX status + response_json = response.json() # convert to dict from string + return response_json + except requests.exceptions.HTTPError as http_err: + # Handle HTTP errors (e.g., response 4XX, 5XX) + try: + if response: + response = response.json() + except: + pass + printd(f"Got HTTPError, exception={http_err}, response={response}") + raise http_err + except requests.exceptions.RequestException as req_err: + # Handle other requests-related errors (e.g., connection error) + try: + if response: + response = response.json() + except: + pass + printd(f"Got RequestException, exception={req_err}, response={response}") + raise req_err + except Exception as e: + # Handle other potential errors + try: + if response: + response = response.json() + except: + pass + printd(f"Got unknown Exception, exception={e}, response={response}") + raise e diff --git a/letta/providers.py b/letta/providers.py index 9ea298c2..c4b878b6 100644 --- a/letta/providers.py +++ b/letta/providers.py @@ -139,6 +139,50 @@ class AnthropicProvider(Provider): return [] +class MistralProvider(Provider): + name: str = "mistral" + api_key: str = Field(..., description="API key for the Mistral API.") + base_url: str = "https://api.mistral.ai/v1" + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.mistral import mistral_get_model_list + + # Some hardcoded support for OpenRouter (so that we only get models with tool calling support)... + # See: https://openrouter.ai/docs/requests + response = mistral_get_model_list(self.base_url, api_key=self.api_key) + + assert "data" in response, f"Mistral model query response missing 'data' field: {response}" + + configs = [] + for model in response["data"]: + # If model has chat completions and function calling enabled + if model["capabilities"]["completion_chat"] and model["capabilities"]["function_calling"]: + configs.append( + LLMConfig( + model=model["id"], + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=model["max_context_length"], + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + # Not supported for mistral + return [] + + def get_model_context_window(self, model_name: str) -> Optional[int]: + # Redoing this is fine because it's a pretty lightweight call + models = self.list_llm_models() + + for m in models: + if model_name in m["id"]: + return int(m["max_context_length"]) + + return None + + class OllamaProvider(OpenAIProvider): """Ollama provider that uses the native /api/generate endpoint diff --git a/tests/test_providers.py b/tests/test_providers.py index f2d4f95b..e1e15be2 100644 --- a/tests/test_providers.py +++ b/tests/test_providers.py @@ -2,7 +2,9 @@ import os from letta.providers import ( AnthropicProvider, + AzureProvider, GoogleAIProvider, + MistralProvider, OllamaProvider, OpenAIProvider, ) @@ -33,10 +35,13 @@ def test_anthropic(): # -# TODO: Add this test -# https://linear.app/letta/issue/LET-159/add-tests-for-azure-openai-in-test-providerspy-and-test-endpointspy def test_azure(): - pass + provider = AzureProvider(api_key=os.getenv("AZURE_API_KEY"), base_url=os.getenv("AZURE_BASE_URL")) + models = provider.list_llm_models() + print([m.model for m in models]) + + embed_models = provider.list_embedding_models() + print([m.embedding_model for m in embed_models]) def test_ollama(): @@ -60,6 +65,12 @@ def test_googleai(): provider.list_embedding_models() +def test_mistral(): + provider = MistralProvider(api_key=os.getenv("MISTRAL_API_KEY")) + models = provider.list_llm_models() + print([m.model for m in models]) + + # def test_vllm(): # provider = VLLMProvider(base_url=os.getenv("VLLM_API_BASE")) # models = provider.list_llm_models()