feat: support byoc on cloud (#2005)

Co-authored-by: Shubham Naik <shub@memgpt.ai>
This commit is contained in:
Shubham Naik
2025-05-06 11:37:30 -07:00
committed by GitHub
parent acda68c0a8
commit 230eb944d1
3 changed files with 15 additions and 6 deletions

View File

@@ -1201,10 +1201,10 @@ class SyncServer(Server):
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self, byok_only: bool = False) -> List[LLMConfig]:
def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]:
"""List available models"""
llm_models = []
for provider in self.get_enabled_providers(byok_only=byok_only):
for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@@ -1224,11 +1224,17 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self, byok_only: bool = False):
def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False):
providers_from_env = {p.name: p for p in self._enabled_providers}
if default_only:
return list(providers_from_env.values())
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if byok_only:
return list(providers_from_db.values())
providers_from_env = {p.name: p for p in self._enabled_providers}
return list(providers_from_env.values()) + list(providers_from_db.values())
@trace_method