feat: lmstudio support via /v1/chat/completions proxy route (#724)
This commit is contained in:
@@ -30,7 +30,7 @@ OPENAI_SSE_DONE = "[DONE]"
|
||||
|
||||
|
||||
def openai_get_model_list(
|
||||
url: str, api_key: Union[str, None], fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||
url: str, api_key: Optional[str] = None, fix_url: Optional[bool] = False, extra_params: Optional[dict] = None
|
||||
) -> dict:
|
||||
"""https://platform.openai.com/docs/api-reference/models/list"""
|
||||
from letta.utils import printd
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import warnings
|
||||
from datetime import datetime
|
||||
from typing import List, Optional
|
||||
|
||||
@@ -210,6 +211,130 @@ class OpenAIProvider(Provider):
|
||||
return None
|
||||
|
||||
|
||||
class LMStudioOpenAIProvider(OpenAIProvider):
|
||||
name: str = "lmstudio-openai"
|
||||
base_url: str = Field(..., description="Base URL for the LMStudio OpenAI API.")
|
||||
api_key: Optional[str] = Field(None, description="API key for the LMStudio API.")
|
||||
|
||||
def list_llm_models(self) -> List[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
|
||||
MODEL_ENDPOINT_URL = f"{self.base_url.strip('/v1')}/api/v0"
|
||||
response = openai_get_model_list(MODEL_ENDPOINT_URL)
|
||||
|
||||
"""
|
||||
Example response:
|
||||
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "qwen2-vl-7b-instruct",
|
||||
"object": "model",
|
||||
"type": "vlm",
|
||||
"publisher": "mlx-community",
|
||||
"arch": "qwen2_vl",
|
||||
"compatibility_type": "mlx",
|
||||
"quantization": "4bit",
|
||||
"state": "not-loaded",
|
||||
"max_context_length": 32768
|
||||
},
|
||||
...
|
||||
"""
|
||||
if "data" not in response:
|
||||
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
|
||||
return []
|
||||
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "type" not in model:
|
||||
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
|
||||
continue
|
||||
elif model["type"] not in ["vlm", "llm"]:
|
||||
continue
|
||||
|
||||
if "max_context_length" in model:
|
||||
context_window_size = model["max_context_length"]
|
||||
else:
|
||||
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
LLMConfig(
|
||||
model=model_name,
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint=self.base_url,
|
||||
context_window=context_window_size,
|
||||
handle=self.get_handle(model_name),
|
||||
)
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list
|
||||
|
||||
# For LMStudio, we want to hit 'GET /api/v0/models' instead of 'GET /v1/models'
|
||||
MODEL_ENDPOINT_URL = f"{self.base_url}/api/v0"
|
||||
response = openai_get_model_list(MODEL_ENDPOINT_URL)
|
||||
|
||||
"""
|
||||
Example response:
|
||||
{
|
||||
"object": "list",
|
||||
"data": [
|
||||
{
|
||||
"id": "text-embedding-nomic-embed-text-v1.5",
|
||||
"object": "model",
|
||||
"type": "embeddings",
|
||||
"publisher": "nomic-ai",
|
||||
"arch": "nomic-bert",
|
||||
"compatibility_type": "gguf",
|
||||
"quantization": "Q4_0",
|
||||
"state": "not-loaded",
|
||||
"max_context_length": 2048
|
||||
}
|
||||
...
|
||||
"""
|
||||
if "data" not in response:
|
||||
warnings.warn(f"LMStudio OpenAI model query response missing 'data' field: {response}")
|
||||
return []
|
||||
|
||||
configs = []
|
||||
for model in response["data"]:
|
||||
assert "id" in model, f"Model missing 'id' field: {model}"
|
||||
model_name = model["id"]
|
||||
|
||||
if "type" not in model:
|
||||
warnings.warn(f"LMStudio OpenAI model missing 'type' field: {model}")
|
||||
continue
|
||||
elif model["type"] not in ["embeddings"]:
|
||||
continue
|
||||
|
||||
if "max_context_length" in model:
|
||||
context_window_size = model["max_context_length"]
|
||||
else:
|
||||
warnings.warn(f"LMStudio OpenAI model missing 'max_context_length' field: {model}")
|
||||
continue
|
||||
|
||||
configs.append(
|
||||
EmbeddingConfig(
|
||||
embedding_model=model_name,
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint=self.base_url,
|
||||
embedding_dim=context_window_size,
|
||||
embedding_chunk_size=300,
|
||||
handle=self.get_handle(model_name),
|
||||
),
|
||||
)
|
||||
|
||||
return configs
|
||||
|
||||
|
||||
class AnthropicProvider(Provider):
|
||||
name: str = "anthropic"
|
||||
api_key: str = Field(..., description="API key for the Anthropic API.")
|
||||
|
||||
@@ -49,6 +49,7 @@ from letta.schemas.providers import (
|
||||
GoogleAIProvider,
|
||||
GroqProvider,
|
||||
LettaProvider,
|
||||
LMStudioOpenAIProvider,
|
||||
OllamaProvider,
|
||||
OpenAIProvider,
|
||||
Provider,
|
||||
@@ -391,6 +392,13 @@ class SyncServer(Server):
|
||||
aws_region=model_settings.aws_region,
|
||||
)
|
||||
)
|
||||
# Attempt to enable LM Studio by default
|
||||
if model_settings.lmstudio_base_url:
|
||||
self._enabled_providers.append(
|
||||
LMStudioOpenAIProvider(
|
||||
base_url=model_settings.lmstudio_base_url,
|
||||
)
|
||||
)
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
|
||||
@@ -92,6 +92,9 @@ class ModelSettings(BaseSettings):
|
||||
# vLLM
|
||||
vllm_api_base: Optional[str] = None
|
||||
|
||||
# lmstudio
|
||||
lmstudio_base_url: Optional[str] = None
|
||||
|
||||
# openllm
|
||||
openllm_auth_type: Optional[str] = None
|
||||
openllm_api_key: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user