feat: add provider_category field to distinguish byok (#2038)
This commit is contained in:
@@ -42,7 +42,7 @@ from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus, ProviderCategory, ProviderType
|
||||
from letta.schemas.environment_variables import SandboxEnvironmentVariableCreate
|
||||
from letta.schemas.group import GroupCreate, ManagerType, SleeptimeManager, VoiceSleeptimeManager
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
@@ -734,17 +734,17 @@ class SyncServer(Server):
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
@trace_method
|
||||
def get_cached_llm_config(self, **kwargs):
|
||||
def get_cached_llm_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(**kwargs)
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def get_cached_embedding_config(self, **kwargs):
|
||||
def get_cached_embedding_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(**kwargs)
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
@@ -766,7 +766,7 @@ class SyncServer(Server):
|
||||
"enable_reasoner": request.enable_reasoner,
|
||||
}
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = self.get_cached_llm_config(**config_params)
|
||||
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
|
||||
if request.embedding_config is None:
|
||||
@@ -777,7 +777,7 @@ class SyncServer(Server):
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
||||
request.embedding_config = self.get_cached_embedding_config(**embedding_config_params)
|
||||
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
||||
|
||||
log_event(name="start create_agent db")
|
||||
@@ -802,10 +802,10 @@ class SyncServer(Server):
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
if request.model is not None:
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.model)
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
|
||||
|
||||
if request.embedding is not None:
|
||||
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding)
|
||||
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
|
||||
|
||||
if request.enable_sleeptime:
|
||||
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -1201,10 +1201,21 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Organization with id {org_id} not found")
|
||||
|
||||
def list_llm_models(self, byok_only: bool = False, default_only: bool = False) -> List[LLMConfig]:
|
||||
def list_llm_models(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[LLMConfig]:
|
||||
"""List available models"""
|
||||
llm_models = []
|
||||
for provider in self.get_enabled_providers(byok_only=byok_only, default_only=default_only):
|
||||
for provider in self.get_enabled_providers(
|
||||
provider_category=provider_category,
|
||||
provider_name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
):
|
||||
try:
|
||||
llm_models.extend(provider.list_llm_models())
|
||||
except Exception as e:
|
||||
@@ -1214,32 +1225,49 @@ class SyncServer(Server):
|
||||
|
||||
return llm_models
|
||||
|
||||
def list_embedding_models(self) -> List[EmbeddingConfig]:
|
||||
def list_embedding_models(self, actor: User) -> List[EmbeddingConfig]:
|
||||
"""List available embedding models"""
|
||||
embedding_models = []
|
||||
for provider in self.get_enabled_providers():
|
||||
for provider in self.get_enabled_providers(actor):
|
||||
try:
|
||||
embedding_models.extend(provider.list_embedding_models())
|
||||
except Exception as e:
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_enabled_providers(self, byok_only: bool = False, default_only: bool = False):
|
||||
providers_from_env = {p.name: p for p in self._enabled_providers}
|
||||
def get_enabled_providers(
|
||||
self,
|
||||
actor: User,
|
||||
provider_category: Optional[List[ProviderCategory]] = None,
|
||||
provider_name: Optional[str] = None,
|
||||
provider_type: Optional[ProviderType] = None,
|
||||
) -> List[Provider]:
|
||||
providers = []
|
||||
if not provider_category or ProviderCategory.base in provider_category:
|
||||
providers_from_env = [p for p in self._enabled_providers]
|
||||
providers.extend(providers_from_env)
|
||||
|
||||
if default_only:
|
||||
return list(providers_from_env.values())
|
||||
if not provider_category or ProviderCategory.byok in provider_category:
|
||||
providers_from_db = self.provider_manager.list_providers(
|
||||
name=provider_name,
|
||||
provider_type=provider_type,
|
||||
actor=actor,
|
||||
)
|
||||
providers_from_db = [p.cast_to_subtype() for p in providers_from_db]
|
||||
providers.extend(providers_from_db)
|
||||
|
||||
providers_from_db = {p.name: p.cast_to_subtype() for p in self.provider_manager.list_providers()}
|
||||
if provider_name is not None:
|
||||
providers = [p for p in providers if p.name == provider_name]
|
||||
|
||||
if byok_only:
|
||||
return list(providers_from_db.values())
|
||||
if provider_type is not None:
|
||||
providers = [p for p in providers if p.provider_type == provider_type]
|
||||
|
||||
return list(providers_from_env.values()) + list(providers_from_db.values())
|
||||
return providers
|
||||
|
||||
@trace_method
|
||||
def get_llm_config_from_handle(
|
||||
self,
|
||||
actor: User,
|
||||
handle: str,
|
||||
context_window_limit: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
@@ -1248,7 +1276,7 @@ class SyncServer(Server):
|
||||
) -> LLMConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
provider = self.get_provider_from_name(provider_name, actor)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
@@ -1292,11 +1320,11 @@ class SyncServer(Server):
|
||||
|
||||
@trace_method
|
||||
def get_embedding_config_from_handle(
|
||||
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
provider = self.get_provider_from_name(provider_name, actor)
|
||||
|
||||
embedding_configs = [config for config in provider.list_embedding_models() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
@@ -1319,8 +1347,8 @@ class SyncServer(Server):
|
||||
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self.get_enabled_providers() if provider.name == provider_name]
|
||||
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
|
||||
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
|
||||
Reference in New Issue
Block a user