diff --git a/memgpt/cli/cli_config.py b/memgpt/cli/cli_config.py index 783e9a42..0412c638 100644 --- a/memgpt/cli/cli_config.py +++ b/memgpt/cli/cli_config.py @@ -775,7 +775,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]): """List all data sources""" # create table - table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"] + table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"] # TODO: eventually look accross all storage connections # TODO: add data source stats # TODO: connect to agents @@ -788,7 +788,14 @@ def list(arg: Annotated[ListChoice, typer.Argument]): agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None] table.add_row( - [source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)] + [ + source.name, + source.description, + source.embedding_model, + source.embedding_dim, + utils.format_datetime(source.created_at), + ",".join(agent_names), + ] ) print(table) diff --git a/memgpt/cli/cli_load.py b/memgpt/cli/cli_load.py index 83c70633..8525899a 100644 --- a/memgpt/cli/cli_load.py +++ b/memgpt/cli/cli_load.py @@ -88,6 +88,7 @@ def load_directory( recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False, extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions, user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove + description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None, ): try: connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions) @@ -101,6 +102,7 @@ def load_directory( user_id=user_id, embedding_model=config.default_embedding_config.embedding_model, embedding_dim=config.default_embedding_config.embedding_dim, + description=description, ) ms.create_source(source) passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id) diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 593b23eb..82027429 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -506,6 +506,7 @@ class Source: self, user_id: uuid.UUID, name: str, + description: Optional[str] = None, created_at: Optional[datetime] = None, id: Optional[uuid.UUID] = None, # embedding info @@ -521,6 +522,7 @@ class Source: self.name = name self.user_id = user_id + self.description = description self.created_at = created_at if created_at is not None else get_utc_time() # embedding info (optional) diff --git a/memgpt/metadata.py b/memgpt/metadata.py index cbe2db79..6a99b806 100644 --- a/memgpt/metadata.py +++ b/memgpt/metadata.py @@ -202,6 +202,7 @@ class SourceModel(Base): created_at = Column(DateTime(timezone=True), server_default=func.now()) embedding_dim = Column(BIGINT) embedding_model = Column(String) + description = Column(String) # TODO: add num passages @@ -216,6 +217,7 @@ class SourceModel(Base): created_at=self.created_at, embedding_dim=self.embedding_dim, embedding_model=self.embedding_model, + description=self.description, ) diff --git a/memgpt/models/pydantic_models.py b/memgpt/models/pydantic_models.py index 4b3956bc..02b1df72 100644 --- a/memgpt/models/pydantic_models.py +++ b/memgpt/models/pydantic_models.py @@ -102,6 +102,7 @@ class SourceModel(SQLModel, table=True): user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.") created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.") id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True) + description: Optional[str] = Field(None, description="The description of the source.") # embedding info # embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.") embedding_config: Optional[EmbeddingConfigModel] = Field(