feat: add provider_category field to distinguish byok (#2038)

This commit is contained in:
cthomas
2025-05-06 17:31:36 -07:00
committed by GitHub
parent 230eb944d1
commit db6982a4bc
23 changed files with 250 additions and 112 deletions

View File

@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
from letta.schemas.job import Job, JobUpdate
@@ -734,17 +734,17 @@ class SyncServer(Server):
return self._command(user_id=user_id, agent_id=agent_id, command=command)
@trace_method
def get_cached_llm_config(self, **kwargs):
def get_cached_llm_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, **kwargs):
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
return self._embedding_config_cache[key]
@trace_method
@@ -766,7 +766,7 @@ class SyncServer(Server):
"enable_reasoner": request.enable_reasoner,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = self.get_cached_llm_config(**config_params)
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
if request.embedding_config is None:
@@ -777,7 +777,7 @@ class SyncServer(Server):
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start create_agent db")
@@ -802,10 +802,10 @@ class SyncServer(Server):
actor: User,
) -> AgentState:
if request.model is not None:
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
if request.embedding is not None:
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
if request.enable_sleeptime:
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]:
def list_llm_models(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[LLMConfig]:
"""List available models"""
llm_models = []
for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only):
for provider in self.get_enabled_providers(
provider_category=provider_category,
provider_name=provider_name,
provider_type=provider_type,
actor=actor,
):
try:
llm_models.extend(provider.list_llm_models())
except Exception as e:
@@ -1214,32 +1225,49 @@ class SyncServer(Server):
return llm_models
def list_embedding_models(self) -> List[EmbeddingConfig]:
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
"""List available embedding models"""
embedding_models = []
for provider in self.get_enabled_providers():
for provider in self.get_enabled_providers(actor):
try:
embedding_models.extend(provider.list_embedding_models())
except Exception as e:
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False):
providers_from_env = {p.name: p for p in self._enabled_providers}
def get_enabled_providers(
self,
actor: User,
provider_category: Optional[List[ProviderCategory]] = None,
provider_name: Optional[str] = None,
provider_type: Optional[ProviderType] = None,
) -> List[Provider]:
providers = []
if not provider_category or ProviderCategory.base in provider_category:
providers_from_env = [p for p in self._enabled_providers]
providers.extend(providers_from_env)
if default_only:
return list(providers_from_env.values())
if not provider_category or ProviderCategory.byok in provider_category:
providers_from_db = self.provider_manager.list_providers(
name=provider_name,
provider_type=provider_type,
actor=actor,
)
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
providers.extend(providers_from_db)
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
if provider_name is not None:
providers = [p for p in providers if p.name == provider_name]
if byok_only:
return list(providers_from_db.values())
if provider_type is not None:
providers = [p for p in providers if p.provider_type == provider_type]
return list(providers_from_env.values()) + list(providers_from_db.values())
return providers
@trace_method
def get_llm_config_from_handle(
self,
actor: User,
handle: str,
context_window_limit: Optional[int] = None,
max_tokens: Optional[int] = None,
@@ -1248,7 +1276,7 @@ class SyncServer(Server):
) -> LLMConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
provider = self.get_provider_from_name(provider_name, actor)
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
if not llm_configs:
@@ -1292,11 +1320,11 @@ class SyncServer(Server):
@trace_method
def get_embedding_config_from_handle(
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
provider = self.get_provider_from_name(provider_name, actor)
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
if not embedding_configs:
@@ -1319,8 +1347,8 @@ class SyncServer(Server):
return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1: