feat: Add optional llm and embedding handle args to CreateAgent request (#2260)

This commit is contained in:
cthomas
2024-12-17 15:31:19 -08:00
committed by GitHub
parent 6ec36303e5
commit bb06ab0fcb
6 changed files with 114 additions and 11 deletions

View File

@@ -776,6 +776,18 @@ class SyncServer(Server):
# interface
interface: Union[AgentInterface, None] = None,
) -> AgentState:
if request.llm_config is None:
if request.llm is None:
raise ValueError("Must specify either llm or llm_config in request")
request.llm_config = self.get_llm_config_from_handle(handle=request.llm, context_window_limit=request.context_window_limit)
if request.embedding_config is None:
if request.embedding is None:
raise ValueError("Must specify either embedding or embedding_config in request")
request.embedding_config = self.get_embedding_config_from_handle(
handle=request.embedding, embedding_chunk_size=request.embedding_chunk_size or constants.DEFAULT_EMBEDDING_CHUNK_SIZE
)
"""Create a new agent using a config"""
# Invoke manager
agent_state = self.agent_manager.create_agent(
@@ -1283,6 +1295,57 @@ class SyncServer(Server):
warnings.warn(f"An error occurred while listing embedding models for provider {provider}: {e}")
return embedding_models
def get_llm_config_from_handle(self, handle: str, context_window_limit: Optional[int] = None) -> LLMConfig:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
llm_configs = [config for config in provider.list_llm_models() if config.model == model_name]
if not llm_configs:
raise ValueError(f"LLM model {model_name} is not supported by {provider_name}")
elif len(llm_configs) > 1:
raise ValueError(f"Multiple LLM models with name {model_name} supported by {provider_name}")
else:
llm_config = llm_configs[0]
if context_window_limit:
if context_window_limit > llm_config.context_window:
raise ValueError(
f"Context window limit ({context_window_limit}) is greater than maximum of ({llm_config.context_window})"
)
llm_config.context_window = context_window_limit
return llm_config
def get_embedding_config_from_handle(
self, handle: str, embedding_chunk_size: int = constants.DEFAULT_EMBEDDING_CHUNK_SIZE
) -> EmbeddingConfig:
provider_name, model_name = handle.split("/", 1)
provider = self.get_provider_from_name(provider_name)
embedding_configs = [config for config in provider.list_embedding_models() if config.embedding_model == model_name]
if not embedding_configs:
raise ValueError(f"Embedding model {model_name} is not supported by {provider_name}")
elif len(embedding_configs) > 1:
raise ValueError(f"Multiple embedding models with name {model_name} supported by {provider_name}")
else:
embedding_config = embedding_configs[0]
if embedding_chunk_size:
embedding_config.embedding_chunk_size = embedding_chunk_size
return embedding_config
def get_provider_from_name(self, provider_name: str) -> Provider:
providers = [provider for provider in self._enabled_providers if provider.name == provider_name]
if not providers:
raise ValueError(f"Provider {provider_name} is not supported")
elif len(providers) > 1:
raise ValueError(f"Multiple providers with name {provider_name} supported")
else:
provider = providers[0]
return provider
def add_llm_model(self, request: LLMConfig) -> LLMConfig:
"""Add a new LLM model"""