feat: allow customizing the handle base for openrouter and for vllm [LET-4609] (#5114)
* feat: allow setting VLLM_HANDLE_BASE * feat: same thing for openrouter
This commit is contained in:
committed by
Caren Thomas
parent
4f52aab652
commit
811b3e6cb6
@@ -21,6 +21,7 @@ class OpenRouterProvider(OpenAIProvider):
|
||||
provider_category: ProviderCategory = Field(ProviderCategory.base, description="The category of the provider (base or byok)")
|
||||
api_key: str = Field(..., description="API key for the OpenRouter API.")
|
||||
base_url: str = Field("https://openrouter.ai/api/v1", description="Base URL for the OpenRouter API.")
|
||||
handle_base: str | None = Field(None, description="Custom handle base name for model handles (e.g., 'custom' instead of 'openrouter').")
|
||||
|
||||
def _list_llm_models(self, data: list[dict]) -> list[LLMConfig]:
|
||||
"""
|
||||
@@ -33,7 +34,7 @@ class OpenRouterProvider(OpenAIProvider):
|
||||
continue
|
||||
model_name, context_window_size = check
|
||||
|
||||
handle = self.get_handle(model_name)
|
||||
handle = self.get_handle(model_name, base_name=self.handle_base) if self.handle_base else self.get_handle(model_name)
|
||||
|
||||
config = LLMConfig(
|
||||
model=model_name,
|
||||
|
||||
@@ -23,6 +23,7 @@ class VLLMProvider(Provider):
|
||||
default_prompt_formatter: str | None = Field(
|
||||
default=None, description="Default prompt formatter (aka model wrapper) to use on a /completions style API."
|
||||
)
|
||||
handle_base: str | None = Field(None, description="Custom handle base name for model handles (e.g., 'custom' instead of 'vllm').")
|
||||
|
||||
async def list_llm_models_async(self) -> list[LLMConfig]:
|
||||
from letta.llm_api.openai import openai_get_model_list_async
|
||||
@@ -43,7 +44,7 @@ class VLLMProvider(Provider):
|
||||
model_endpoint=base_url,
|
||||
model_wrapper=self.default_prompt_formatter,
|
||||
context_window=model["max_model_len"],
|
||||
handle=self.get_handle(model_name),
|
||||
handle=self.get_handle(model_name, base_name=self.handle_base) if self.handle_base else self.get_handle(model_name),
|
||||
provider_name=self.name,
|
||||
provider_category=self.provider_category,
|
||||
)
|
||||
|
||||
@@ -278,6 +278,7 @@ class SyncServer(object):
|
||||
name="vllm",
|
||||
base_url=model_settings.vllm_api_base,
|
||||
default_prompt_formatter=model_settings.default_prompt_formatter,
|
||||
handle_base=model_settings.vllm_handle_base,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -306,6 +307,7 @@ class SyncServer(object):
|
||||
OpenRouterProvider(
|
||||
name="openrouter",
|
||||
api_key=model_settings.openrouter_api_key,
|
||||
handle_base=model_settings.openrouter_handle_base,
|
||||
)
|
||||
)
|
||||
|
||||
|
||||
@@ -117,6 +117,7 @@ class ModelSettings(BaseSettings):
|
||||
# See https://openrouter.ai/docs/quick-start for details
|
||||
openrouter_referer: Optional[str] = None # e.g., your site URL
|
||||
openrouter_title: Optional[str] = None # e.g., your app name
|
||||
openrouter_handle_base: Optional[str] = None
|
||||
|
||||
# deepseek
|
||||
deepseek_api_key: Optional[str] = None
|
||||
@@ -163,6 +164,7 @@ class ModelSettings(BaseSettings):
|
||||
|
||||
# vLLM
|
||||
vllm_api_base: Optional[str] = None
|
||||
vllm_handle_base: Optional[str] = None
|
||||
|
||||
# lmstudio
|
||||
lmstudio_base_url: Optional[str] = None
|
||||
|
||||
Reference in New Issue
Block a user