feat(asyncify): migrate list models (#2369)
This commit is contained in:
@@ -54,7 +54,7 @@ class Provider(ProviderBase):
|
||||
return []
|
||||
|
||||
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
|
||||
return []
|
||||
return self.list_embedding_models()
|
||||
|
||||
def get_model_context_window(self, model_name: str) -> Optional[int]:
|
||||
raise NotImplementedError
|
||||
|
||||
@@ -741,6 +741,13 @@ class SyncServer(Server):
|
||||
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def get_cached_llm_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._llm_config_cache:
|
||||
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
|
||||
return self._llm_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def get_cached_embedding_config(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
@@ -748,6 +755,13 @@ class SyncServer(Server):
|
||||
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
|
||||
key = make_key(**kwargs)
|
||||
if key not in self._embedding_config_cache:
|
||||
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
|
||||
return self._embedding_config_cache[key]
|
||||
|
||||
@trace_method
|
||||
def create_agent(
|
||||
self,
|
||||
@@ -815,7 +829,7 @@ class SyncServer(Server):
|
||||
"enable_reasoner": request.enable_reasoner,
|
||||
}
|
||||
log_event(name="start get_cached_llm_config", attributes=config_params)
|
||||
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
|
||||
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
|
||||
log_event(name="end get_cached_llm_config", attributes=config_params)
|
||||
|
||||
if request.embedding_config is None:
|
||||
@@ -826,7 +840,7 @@ class SyncServer(Server):
|
||||
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
|
||||
}
|
||||
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
|
||||
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
|
||||
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
|
||||
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
|
||||
|
||||
log_event(name="start create_agent db")
|
||||
@@ -877,10 +891,10 @@ class SyncServer(Server):
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
if request.model is not None:
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
|
||||
request.llm_config = await self.get_llm_config_from_handle_async(handle=request.model, actor=actor)
|
||||
|
||||
if request.embedding is not None:
|
||||
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
|
||||
request.embedding_config = await self.get_embedding_config_from_handle_async(handle=request.embedding, actor=actor)
|
||||
|
||||
if request.enable_sleeptime:
|
||||
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
@@ -1568,6 +1582,61 @@ class SyncServer(Server):
|
||||
|
||||
return llm_config
|
||||
|
||||
@trace_method
|
||||
async def get_llm_config_from_handle_async(
|
||||
self,
|
||||
actor: User,
|
||||
handle: str,
|
||||
context_window_limit: Optional[int] = None,
|
||||
max_tokens: Optional[int] = None,
|
||||
max_reasoning_tokens: Optional[int] = None,
|
||||
enable_reasoner: Optional[bool] = None,
|
||||
) -> LLMConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_llm_configs = await provider.list_llm_models_async()
|
||||
llm_configs = [config for config in all_llm_configs if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in all_llm_configs if config.model == model_name]
|
||||
if not llm_configs:
|
||||
available_handles = [config.handle for config in all_llm_configs]
|
||||
raise HandleNotFoundError(handle, available_handles)
|
||||
except ValueError as e:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
|
||||
if not llm_configs:
|
||||
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise e
|
||||
|
||||
if len(llm_configs) == 1:
|
||||
llm_config = llm_configs[0]
|
||||
elif len(llm_configs) > 1:
|
||||
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit is not None:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
|
||||
llm_config.context_window = context_window_limit
|
||||
else:
|
||||
llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit)
|
||||
|
||||
if max_tokens is not None:
|
||||
llm_config.max_tokens = max_tokens
|
||||
if max_reasoning_tokens is not None:
|
||||
if not max_tokens or max_reasoning_tokens > max_tokens:
|
||||
raise ValueError(f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})")
|
||||
llm_config.max_reasoning_tokens = max_reasoning_tokens
|
||||
if enable_reasoner is not None:
|
||||
llm_config.enable_reasoner = enable_reasoner
|
||||
if enable_reasoner and llm_config.model_endpoint_type == "anthropic":
|
||||
llm_config.put_inner_thoughts_in_kwargs = False
|
||||
|
||||
return llm_config
|
||||
|
||||
@trace_method
|
||||
def get_embedding_config_from_handle(
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
@@ -1597,6 +1666,36 @@ class SyncServer(Server):
|
||||
|
||||
return embedding_config
|
||||
|
||||
@trace_method
|
||||
async def get_embedding_config_from_handle_async(
|
||||
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
try:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = await self.get_provider_from_name_async(provider_name, actor)
|
||||
|
||||
all_embedding_configs = await provider.list_embedding_models_async()
|
||||
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
|
||||
except ValueError as e:
|
||||
# search local configs
|
||||
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
|
||||
if not embedding_configs:
|
||||
raise e
|
||||
|
||||
if len(embedding_configs) == 1:
|
||||
embedding_config = embedding_configs[0]
|
||||
elif len(embedding_configs) > 1:
|
||||
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
if embedding_chunk_size:
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
|
||||
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
|
||||
if not providers:
|
||||
@@ -1608,6 +1707,18 @@ class SyncServer(Server):
|
||||
|
||||
return provider
|
||||
|
||||
async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider:
|
||||
all_providers = await self.get_enabled_providers_async(actor)
|
||||
providers = [provider for provider in all_providers if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
raise ValueError(f"Multiple providers with name {provider_name} supported")
|
||||
else:
|
||||
provider = providers[0]
|
||||
|
||||
return provider
|
||||
|
||||
def get_local_llm_configs(self):
|
||||
llm_models = []
|
||||
try:
|
||||
|
||||
Reference in New Issue
Block a user