chore: fix branch (#1865)

This commit is contained in:
Sarah Wooders
2024-10-10 14:07:45 -07:00
committed by GitHub
parent fb8ba76e42
commit e15dea623d
5 changed files with 50 additions and 12 deletions

View File

@@ -14,7 +14,9 @@ from letta.constants import CLI_WARNING_PREFIX, LETTA_DIR
from letta.local_llm.constants import ASSISTANT_MESSAGE_CLI_SYMBOL
from letta.log import get_logger
from letta.metadata import MetadataStore
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import OptionState
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import ChatMemory, Memory
from letta.server.server import logger as server_logger
@@ -233,25 +235,46 @@ def run(
# choose from list of llm_configs
llm_configs = client.list_llm_configs()
llm_options = [llm_config.model for llm_config in llm_configs]
# TODO move into LLMConfig as a class method?
def prettify_llm_config(llm_config: LLMConfig) -> str:
return f"{llm_config.model}" + f" ({llm_config.model_endpoint})" if llm_config.model_endpoint else ""
llm_choices = [questionary.Choice(title=prettify_llm_config(llm_config), value=llm_config) for llm_config in llm_configs]
# select model
if len(llm_options) == 0:
raise ValueError("No LLM models found. Please enable a provider.")
elif len(llm_options) == 1:
llm_model_name = llm_options[0]
else:
llm_model_name = questionary.select("Select LLM model:", choices=llm_options).ask()
llm_model_name = questionary.select("Select LLM model:", choices=llm_choices).ask().model
llm_config = [llm_config for llm_config in llm_configs if llm_config.model == llm_model_name][0]
# choose form list of embedding configs
embedding_configs = client.list_embedding_configs()
embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs]
# TODO move into EmbeddingConfig as a class method?
def prettify_embed_config(embedding_config: EmbeddingConfig) -> str:
return (
f"{embedding_config.embedding_model}" + f" ({embedding_config.embedding_endpoint})"
if embedding_config.embedding_endpoint
else ""
)
embedding_choices = [
questionary.Choice(title=prettify_embed_config(embedding_config), value=embedding_config)
for embedding_config in embedding_configs
]
# select model
if len(embedding_options) == 0:
raise ValueError("No embedding models found. Please enable a provider.")
elif len(embedding_options) == 1:
embedding_model_name = embedding_options[0]
else:
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_options).ask()
embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model
embedding_config = [
embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name
][0]