feat: Add optional llm and embedding handle args to CreateAgent request (#2260)
This commit is contained in:
@@ -776,6 +776,18 @@ class SyncServer(Server):
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
if request.llm_config is None:
|
||||
if request.llm is None:
|
||||
raise ValueError("Must specify either llm or llm_config in request")
|
||||
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)
|
||||
|
||||
if request.embedding_config is None:
|
||||
if request.embedding is None:
|
||||
raise ValueError("Must specify either embedding or embedding_config in request")
|
||||
request.embedding_config = self.get_embedding_config_from_handle(
|
||||
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
)
|
||||
|
||||
"""Create a new agent using a config"""
|
||||
# Invoke manager
|
||||
agent_state = self.agent_manager.create_agent(
|
||||
@@ -1283,6 +1295,57 @@ class SyncServer(Server):
|
||||
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
|
||||
return embedding_models
|
||||
|
||||
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
|
||||
if not llm_configs:
|
||||
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
|
||||
elif len(llm_configs) > 1:
|
||||
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
llm_config = llm_configs[0]
|
||||
|
||||
if context_window_limit:
|
||||
if context_window_limit > llm_config.context_window:
|
||||
raise ValueError(
|
||||
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})"
|
||||
)
|
||||
llm_config.context_window = context_window_limit
|
||||
|
||||
return llm_config
|
||||
|
||||
def get_embedding_config_from_handle(
|
||||
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
|
||||
) -> EmbeddingConfig:
|
||||
provider_name, model_name = handle.split("/", 1)
|
||||
provider = self.get_provider_from_name(provider_name)
|
||||
|
||||
embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name]
|
||||
if not embedding_configs:
|
||||
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
|
||||
elif len(embedding_configs) > 1:
|
||||
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
|
||||
else:
|
||||
embedding_config = embedding_configs[0]
|
||||
|
||||
if embedding_chunk_size:
|
||||
embedding_config.embedding_chunk_size = embedding_chunk_size
|
||||
|
||||
return embedding_config
|
||||
|
||||
def get_provider_from_name(self, provider_name: str) -> Provider:
|
||||
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
|
||||
if not providers:
|
||||
raise ValueError(f"Provider {provider_name} is not supported")
|
||||
elif len(providers) > 1:
|
||||
raise ValueError(f"Multiple providers with name {provider_name} supported")
|
||||
else:
|
||||
provider = providers[0]
|
||||
|
||||
return provider
|
||||
|
||||
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
|
||||
"""Add a new LLM model"""
|
||||
|
||||
|
||||
Reference in New Issue
Block a user