From d54e4fd4acf73ed508a77f71dfccc4fcd7b6574d Mon Sep 17 00:00:00 2001 From: Charles Packer Date: Fri, 24 Jan 2025 15:08:23 -0800 Subject: [PATCH] feat: lmstudio support via /v1/chat/completions proxy route (#724) --- letta/llm_api/openai.py | 2 +- letta/schemas/providers.py | 125 +++++++++++++++++++++++++++++++++++++ letta/server/server.py | 8 +++ letta/settings.py | 3 + 4 files changed, 137 insertions(+), 1 deletion(-) diff --git a/letta/llm_api/openai.py b/letta/llm_api/openai.py index ee084947..ee4e7954 100644 --- a/letta/llm_api/openai.py +++ b/letta/llm_api/openai.py @@ -30,7 +30,7 @@ OPENAI_SSE_DONE = "[DONE]" def openai_get_model_list( - url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None + url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None ) -> dict: """https://platform.openai.com/docs/api-reference/models/list""" from letta.utils import printd diff --git a/letta/schemas/providers.py b/letta/schemas/providers.py index b3e40a7d..8d38ad4c 100644 --- a/letta/schemas/providers.py +++ b/letta/schemas/providers.py @@ -1,3 +1,4 @@ +import warnings from datetime import datetime from typing import List, Optional @@ -210,6 +211,130 @@ class OpenAIProvider(Provider): return None +class LMStudioOpenAIProvider(OpenAIProvider): + name: str = "lmstudio-openai" + base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.") + api_key: Optional[str] = Field(None, description="API key for the LMStudio API.") + + def list_llm_models(self) -> List[LLMConfig]: + from letta.llm_api.openai import openai_get_model_list + + # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' + MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0" + response = openai_get_model_list(MODEL_ENDPOINT_URL) + + """ + Example response: + + { + "object": "list", + "data": [ + { + "id": "qwen2-vl-7b-instruct", + "object": "model", + "type": "vlm", + "publisher": "mlx-community", + "arch": "qwen2_vl", + "compatibility_type": "mlx", + "quantization": "4bit", + "state": "not-loaded", + "max_context_length": 32768 + }, + ... + """ + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] + + if "type" not in model: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + elif model["type"] not in ["vlm", "llm"]: + continue + + if "max_context_length" in model: + context_window_size = model["max_context_length"] + else: + warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") + continue + + configs.append( + LLMConfig( + model=model_name, + model_endpoint_type="openai", + model_endpoint=self.base_url, + context_window=context_window_size, + handle=self.get_handle(model_name), + ) + ) + + return configs + + def list_embedding_models(self) -> List[EmbeddingConfig]: + from letta.llm_api.openai import openai_get_model_list + + # For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models' + MODEL_ENDPOINT_URL = f"{self.base_url}/api/v0" + response = openai_get_model_list(MODEL_ENDPOINT_URL) + + """ + Example response: + { + "object": "list", + "data": [ + { + "id": "text-embedding-nomic-embed-text-v1.5", + "object": "model", + "type": "embeddings", + "publisher": "nomic-ai", + "arch": "nomic-bert", + "compatibility_type": "gguf", + "quantization": "Q4_0", + "state": "not-loaded", + "max_context_length": 2048 + } + ... + """ + if "data" not in response: + warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}") + return [] + + configs = [] + for model in response["data"]: + assert "id" in model, f"Model missing 'id' field: {model}" + model_name = model["id"] + + if "type" not in model: + warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}") + continue + elif model["type"] not in ["embeddings"]: + continue + + if "max_context_length" in model: + context_window_size = model["max_context_length"] + else: + warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}") + continue + + configs.append( + EmbeddingConfig( + embedding_model=model_name, + embedding_endpoint_type="openai", + embedding_endpoint=self.base_url, + embedding_dim=context_window_size, + embedding_chunk_size=300, + handle=self.get_handle(model_name), + ), + ) + + return configs + + class AnthropicProvider(Provider): name: str = "anthropic" api_key: str = Field(..., description="API key for the Anthropic API.") diff --git a/letta/server/server.py b/letta/server/server.py index 48e18d86..ef7a8ec6 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -49,6 +49,7 @@ from letta.schemas.providers import ( GoogleAIProvider, GroqProvider, LettaProvider, + LMStudioOpenAIProvider, OllamaProvider, OpenAIProvider, Provider, @@ -391,6 +392,13 @@ class SyncServer(Server): aws_region=model_settings.aws_region, ) ) + # Attempt to enable LM Studio by default + if model_settings.lmstudio_base_url: + self._enabled_providers.append( + LMStudioOpenAIProvider( + base_url=model_settings.lmstudio_base_url, + ) + ) def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent: """Updated method to load agents from persisted storage""" diff --git a/letta/settings.py b/letta/settings.py index 1c5f5bfe..4ef28021 100644 --- a/letta/settings.py +++ b/letta/settings.py @@ -92,6 +92,9 @@ class ModelSettings(BaseSettings): # vLLM vllm_api_base: Optional[str] = None + # lmstudio + lmstudio_base_url: Optional[str] = None + # openllm openllm_auth_type: Optional[str] = None openllm_api_key: Optional[str] = None