diff --git a/letta/cli/cli_load.py b/letta/cli/cli_load.py index 61518bc0..b27da4d8 100644 --- a/letta/cli/cli_load.py +++ b/letta/cli/cli_load.py @@ -11,6 +11,7 @@ letta load --name [ADDITIONAL ARGS] import uuid from typing import Annotated, List, Optional +import questionary import typer from letta import create_client @@ -37,8 +38,27 @@ def load_directory( # create connector connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions) + # choose form list of embedding configs + embedding_configs = client.list_embedding_configs() + embedding_options = [embedding_config.embedding_model for embedding_config in embedding_configs] + + embedding_choices = [ + questionary.Choice(title=embedding_config.pretty_print(), value=embedding_config) for embedding_config in embedding_configs + ] + + # select model + if len(embedding_options) == 0: + raise ValueError("No embedding models found. Please enable a provider.") + elif len(embedding_options) == 1: + embedding_model_name = embedding_options[0] + else: + embedding_model_name = questionary.select("Select embedding model:", choices=embedding_choices).ask().embedding_model + embedding_config = [ + embedding_config for embedding_config in embedding_configs if embedding_config.embedding_model == embedding_model_name + ][0] + # create source - source = client.create_source(name=name) + source = client.create_source(name=name, embedding_config=embedding_config) # load data try: @@ -46,71 +66,3 @@ def load_directory( except Exception as e: typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) client.delete_source(source.id) - - -# @app.command("webpage") -# def load_webpage( -# name: Annotated[str, typer.Option(help="Name of dataset to load.")], -# urls: Annotated[List[str], typer.Option(help="List of urls to load.")], -# ): -# try: -# from llama_index.readers.web import SimpleWebPageReader -# -# docs = SimpleWebPageReader(html_to_text=True).load_data(urls) -# store_docs(name, docs) -# -# except ValueError as e: -# typer.secho(f"Failed to load webpage from provided information.\n{e}", fg=typer.colors.RED) - - -@app.command("vector-database") -def load_vector_database( - name: Annotated[str, typer.Option(help="Name of dataset to load.")], - uri: Annotated[str, typer.Option(help="Database URI.")], - table_name: Annotated[str, typer.Option(help="Name of table containing data.")], - text_column: Annotated[str, typer.Option(help="Name of column containing text.")], - embedding_column: Annotated[str, typer.Option(help="Name of column containing embedding.")], - user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, -): - """Load pre-computed embeddings into Letta from a database.""" - raise NotImplementedError - # try: - # config = LettaConfig.load() - # connector = VectorDBConnector( - # uri=uri, - # table_name=table_name, - # text_column=text_column, - # embedding_column=embedding_column, - # embedding_dim=config.default_embedding_config.embedding_dim, - # ) - # if not user_id: - # user_id = uuid.UUID(config.anon_clientid) - - # ms = MetadataStore(config) - # source = Source( - # name=name, - # user_id=user_id, - # embedding_model=config.default_embedding_config.embedding_model, - # embedding_dim=config.default_embedding_config.embedding_dim, - # ) - # ms.create_source(source) - # passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) - # # TODO: also get document store - - # # ingest data into passage/document store - # try: - # num_passages, num_documents = load_data( - # connector=connector, - # source=source, - # embedding_config=config.default_embedding_config, - # document_store=None, - # passage_store=passage_storage, - # ) - # print(f"Loaded {num_passages} passages and {num_documents} files from {name}") - # except Exception as e: - # typer.secho(f"Failed to load data from provided information.\n{e}", fg=typer.colors.RED) - # ms.delete_source(source_id=source.id) - - # except ValueError as e: - # typer.secho(f"Failed to load VectorDB from provided information.\n{e}", fg=typer.colors.RED) - # raise diff --git a/letta/client/client.py b/letta/client/client.py index d4cedb03..7e8ec304 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1319,6 +1319,7 @@ class RESTClient(AbstractClient): Returns: source (Source): Created source """ + assert embedding_config or self._default_embedding_config, f"Must specify embedding_config for source" source_create = SourceCreate(name=name, embedding_config=embedding_config or self._default_embedding_config) payload = source_create.model_dump() response = requests.post(f"{self.base_url}/{self.api_prefix}/sources", json=payload, headers=self.headers) @@ -2896,6 +2897,7 @@ class LocalClient(AbstractClient): Returns: source (Source): Created source """ + assert embedding_config or self._default_embedding_config, f"Must specify embedding_config for source" source = Source( name=name, embedding_config=embedding_config or self._default_embedding_config, organization_id=self.user.organization_id ) diff --git a/letta/main.py b/letta/main.py index 1f8e19ef..6a394fcf 100644 --- a/letta/main.py +++ b/letta/main.py @@ -191,7 +191,6 @@ def run_agent_loop( print(f"\nDumping memory contents:\n") print(f"{letta_agent.agent_state.memory.compile()}") print(f"{letta_agent.archival_memory.compile()}") - print(f"{letta_agent.recall_memory.compile()}") continue elif user_input.lower() == "/model":