diff --git a/letta/server/rest_api/routers/v1/llms.py b/letta/server/rest_api/routers/v1/llms.py index 02c369f6..f050cf7d 100644 --- a/letta/server/rest_api/routers/v1/llms.py +++ b/letta/server/rest_api/routers/v1/llms.py @@ -15,10 +15,11 @@ router = APIRouter(prefix="/models", tags=["models", "llms"]) @router.get("/", response_model=List[LLMConfig], operation_id="list_models") def list_llm_models( byok_only: Optional[bool] = Query(None), + default_only: Optional[bool] = Query(None), server: "SyncServer" = Depends(get_letta_server), ): - models = server.list_llm_models(byok_only=byok_only) + models = server.list_llm_models(byok_only=byok_only, default_only=default_only) # print(models) return models diff --git a/letta/server/rest_api/routers/v1/providers.py b/letta/server/rest_api/routers/v1/providers.py index 02615f63..a8f01c1b 100644 --- a/letta/server/rest_api/routers/v1/providers.py +++ b/letta/server/rest_api/routers/v1/providers.py @@ -1,7 +1,9 @@ from typing import TYPE_CHECKING, List, Optional -from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query +from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query, status +from fastapi.responses import JSONResponse +from letta.orm.errors import NoResultFound from letta.schemas.enums import ProviderType from letta.schemas.providers import Provider, ProviderCreate, ProviderUpdate from letta.server.rest_api.utils import get_letta_server diff --git a/letta/server/server.py b/letta/server/server.py index e8ef2d07..d99fd74d 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -1201,10 +1201,10 @@ class SyncServer(Server): except NoResultFound: raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found") - def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]: + def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]: """List available models""" llm_models = [] - for provider in self.get_enabled_providers(byok_only=byok_only): + for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only): try: llm_models.extend(provider.list_llm_models()) except Exception as e: @@ -1224,11 +1224,17 @@ class SyncServer(Server): warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}") return embedding_models - def get_enabled_providers(self, byok_only: bool = False): + def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False): + providers_from_env = {p.name: p for p in self._enabled_providers} + + if default_only: + return list(providers_from_env.values()) + providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()} + if byok_only: return list(providers_from_db.values()) - providers_from_env = {p.name: p for p in self._enabled_providers} + return list(providers_from_env.values()) + list(providers_from_db.values()) @trace_method