feat(asyncify): migrate list models (#2369)

This commit is contained in:
cthomas
2025-05-23 09:09:47 -07:00
committed by GitHub
parent f9d2793caf
commit 878511b460
2 changed files with 116 additions and 5 deletions

View File

@@ -54,7 +54,7 @@ class Provider(ProviderBase):
return []
async def list_embedding_models_async(self) -> List[EmbeddingConfig]:
return []
return self.list_embedding_models()
def get_model_context_window(self, model_name: str) -> Optional[int]:
raise NotImplementedError

View File

@@ -741,6 +741,13 @@ class SyncServer(Server):
self._llm_config_cache[key] = self.get_llm_config_from_handle(actor=actor, **kwargs)
return self._llm_config_cache[key]
@trace_method
async def get_cached_llm_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._llm_config_cache:
self._llm_config_cache[key] = await self.get_llm_config_from_handle_async(actor=actor, **kwargs)
return self._llm_config_cache[key]
@trace_method
def get_cached_embedding_config(self, actor: User, **kwargs):
key = make_key(**kwargs)
@@ -748,6 +755,13 @@ class SyncServer(Server):
self._embedding_config_cache[key] = self.get_embedding_config_from_handle(actor=actor, **kwargs)
return self._embedding_config_cache[key]
@trace_method
async def get_cached_embedding_config_async(self, actor: User, **kwargs):
key = make_key(**kwargs)
if key not in self._embedding_config_cache:
self._embedding_config_cache[key] = await self.get_embedding_config_from_handle_async(actor=actor, **kwargs)
return self._embedding_config_cache[key]
@trace_method
def create_agent(
self,
@@ -815,7 +829,7 @@ class SyncServer(Server):
"enable_reasoner": request.enable_reasoner,
}
log_event(name="start get_cached_llm_config", attributes=config_params)
request.llm_config = self.get_cached_llm_config(actor=actor, **config_params)
request.llm_config = await self.get_cached_llm_config_async(actor=actor, **config_params)
log_event(name="end get_cached_llm_config", attributes=config_params)
if request.embedding_config is None:
@@ -826,7 +840,7 @@ class SyncServer(Server):
"embedding_chunk_size": request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE,
}
log_event(name="start get_cached_embedding_config", attributes=embedding_config_params)
request.embedding_config = self.get_cached_embedding_config(actor=actor, **embedding_config_params)
request.embedding_config = await self.get_cached_embedding_config_async(actor=actor, **embedding_config_params)
log_event(name="end get_cached_embedding_config", attributes=embedding_config_params)
log_event(name="start create_agent db")
@@ -877,10 +891,10 @@ class SyncServer(Server):
actor: User,
) -> AgentState:
if request.model is not None:
request.llm_config = self.get_llm_config_from_handle(handle=request.model, actor=actor)
request.llm_config = await self.get_llm_config_from_handle_async(handle=request.model, actor=actor)
if request.embedding is not None:
request.embedding_config = self.get_embedding_config_from_handle(handle=request.embedding, actor=actor)
request.embedding_config = await self.get_embedding_config_from_handle_async(handle=request.embedding, actor=actor)
if request.enable_sleeptime:
agent = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
@@ -1568,6 +1582,61 @@ class SyncServer(Server):
return llm_config
@trace_method
async def get_llm_config_from_handle_async(
self,
actor: User,
handle: str,
context_window_limit: Optional[int] = None,
max_tokens: Optional[int] = None,
max_reasoning_tokens: Optional[int] = None,
enable_reasoner: Optional[bool] = None,
) -> LLMConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_llm_configs = await provider.list_llm_models_async()
llm_configs = [config for config in all_llm_configs if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in all_llm_configs if config.model == model_name]
if not llm_configs:
available_handles = [config.handle for config in all_llm_configs]
raise HandleNotFoundError(handle, available_handles)
except ValueError as e:
llm_configs = [config for config in self.get_local_llm_configs() if config.handle == handle]
if not llm_configs:
llm_configs = [config for config in self.get_local_llm_configs() if config.model == model_name]
if not llm_configs:
raise e
if len(llm_configs) == 1:
llm_config = llm_configs[0]
elif len(llm_configs) > 1:
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
else:
llm_config = llm_configs[0]
if context_window_limit is not None:
if context_window_limit > llm_config.context_window:
raise ValueError(f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})")
llm_config.context_window = context_window_limit
else:
llm_config.context_window = min(llm_config.context_window, model_settings.global_max_context_window_limit)
if max_tokens is not None:
llm_config.max_tokens = max_tokens
if max_reasoning_tokens is not None:
if not max_tokens or max_reasoning_tokens > max_tokens:
raise ValueError(f"Max reasoning tokens ({max_reasoning_tokens}) must be less than max tokens ({max_tokens})")
llm_config.max_reasoning_tokens = max_reasoning_tokens
if enable_reasoner is not None:
llm_config.enable_reasoner = enable_reasoner
if enable_reasoner and llm_config.model_endpoint_type == "anthropic":
llm_config.put_inner_thoughts_in_kwargs = False
return llm_config
@trace_method
def get_embedding_config_from_handle(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
@@ -1597,6 +1666,36 @@ class SyncServer(Server):
return embedding_config
@trace_method
async def get_embedding_config_from_handle_async(
self, actor: User, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
try:
provider_name, model_name = handle.split("/", 1)
provider = await self.get_provider_from_name_async(provider_name, actor)
all_embedding_configs = await provider.list_embedding_models_async()
embedding_configs = [config for config in all_embedding_configs if config.handle == handle]
if not embedding_configs:
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
except ValueError as e:
# search local configs
embedding_configs = [config for config in self.get_local_embedding_configs() if config.handle == handle]
if not embedding_configs:
raise e
if len(embedding_configs) == 1:
embedding_config = embedding_configs[0]
elif len(embedding_configs) > 1:
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
else:
embedding_config = embedding_configs[0]
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
def get_provider_from_name(self, provider_name: str, actor: User) -> Provider:
providers = [provider for provider in self.get_enabled_providers(actor) if provider.name == provider_name]
if not providers:
@@ -1608,6 +1707,18 @@ class SyncServer(Server):
return provider
async def get_provider_from_name_async(self, provider_name: str, actor: User) -> Provider:
all_providers = await self.get_enabled_providers_async(actor)
providers = [provider for provider in all_providers if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1:
raise ValueError(f"Multiple providers with name {provider_name} supported")
else:
provider = providers[0]
return provider
def get_local_llm_configs(self):
llm_models = []
try: