feat: lmstudio support via /v1/chat/completions proxy route (#724)

This commit is contained in:
Charles Packer
2025-01-24 15:08:23 -08:00
committed by GitHub
parent 5ac7154c86
commit d54e4fd4ac
4 changed files with 137 additions and 1 deletions

View File

@@ -30,7 +30,7 @@ OPENAI_SSE_DONE = "[DONE]"
def openai_get_model_list(
url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
) -> dict:
"""https://platform.openai.com/docs/api-reference/models/list"""
from letta.utils import printd

View File

@@ -1,3 +1,4 @@
import warnings
from datetime import datetime
from typing import List, Optional
@@ -210,6 +211,130 @@ class OpenAIProvider(Provider):
return None
class LMStudioOpenAIProvider(OpenAIProvider):
name: str = "lmstudio-openai"
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
def list_llm_models(self) -> List[LLMConfig]:
from letta.llm_api.openai import openai_get_model_list
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
response = openai_get_model_list(MODEL_ENDPOINT_URL)
"""
Example response:
{
"object": "list",
"data": [
{
"id": "qwen2-vl-7b-instruct",
"object": "model",
"type": "vlm",
"publisher": "mlx-community",
"arch": "qwen2_vl",
"compatibility_type": "mlx",
"quantization": "4bit",
"state": "not-loaded",
"max_context_length": 32768
},
...
"""
if "data" not in response:
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
return []
configs = []
for model in response["data"]:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
if "type" not in model:
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
continue
elif model["type"] not in ["vlm", "llm"]:
continue
if "max_context_length" in model:
context_window_size = model["max_context_length"]
else:
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
continue
configs.append(
LLMConfig(
model=model_name,
model_endpoint_type="openai",
model_endpoint=self.base_url,
context_window=context_window_size,
handle=self.get_handle(model_name),
)
)
return configs
def list_embedding_models(self) -> List[EmbeddingConfig]:
from letta.llm_api.openai import openai_get_model_list
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
MODEL_ENDPOINT_URL = f"{self.base_url}/api/v0"
response = openai_get_model_list(MODEL_ENDPOINT_URL)
"""
Example response:
{
"object": "list",
"data": [
{
"id": "text-embedding-nomic-embed-text-v1.5",
"object": "model",
"type": "embeddings",
"publisher": "nomic-ai",
"arch": "nomic-bert",
"compatibility_type": "gguf",
"quantization": "Q4_0",
"state": "not-loaded",
"max_context_length": 2048
}
...
"""
if "data" not in response:
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
return []
configs = []
for model in response["data"]:
assert "id" in model, f"Model missing 'id' field: {model}"
model_name = model["id"]
if "type" not in model:
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
continue
elif model["type"] not in ["embeddings"]:
continue
if "max_context_length" in model:
context_window_size = model["max_context_length"]
else:
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
continue
configs.append(
EmbeddingConfig(
embedding_model=model_name,
embedding_endpoint_type="openai",
embedding_endpoint=self.base_url,
embedding_dim=context_window_size,
embedding_chunk_size=300,
handle=self.get_handle(model_name),
),
)
return configs
class AnthropicProvider(Provider):
name: str = "anthropic"
api_key: str = Field(..., description="API key for the Anthropic API.")

View File

@@ -49,6 +49,7 @@ from letta.schemas.providers import (
GoogleAIProvider,
GroqProvider,
LettaProvider,
LMStudioOpenAIProvider,
OllamaProvider,
OpenAIProvider,
Provider,
@@ -391,6 +392,13 @@ class SyncServer(Server):
aws_region=model_settings.aws_region,
)
)
# Attempt to enable LM Studio by default
if model_settings.lmstudio_base_url:
self._enabled_providers.append(
LMStudioOpenAIProvider(
base_url=model_settings.lmstudio_base_url,
)
)
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
"""Updated method to load agents from persisted storage"""

View File

@@ -92,6 +92,9 @@ class ModelSettings(BaseSettings):
# vLLM
vllm_api_base: Optional[str] = None
# lmstudio
lmstudio_base_url: Optional[str] = None
# openllm
openllm_auth_type: Optional[str] = None
openllm_api_key: Optional[str] = None