feat: add description to source (#1175)

This commit is contained in:
Sarah Wooders
2024-03-21 13:02:50 -07:00
committed by GitHub
parent 542d78c102
commit 08fd722f13
5 changed files with 16 additions and 2 deletions

View File

@@ -775,7 +775,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
"""List all data sources"""
# create table
table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
# TODO: eventually look accross all storage connections
# TODO: add data source stats
# TODO: connect to agents
@@ -788,7 +788,14 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None]
table.add_row(
[source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)]
[
source.name,
source.description,
source.embedding_model,
source.embedding_dim,
utils.format_datetime(source.created_at),
",".join(agent_names),
]
)
print(table)

View File

@@ -88,6 +88,7 @@ def load_directory(
recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
):
try:
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
@@ -101,6 +102,7 @@ def load_directory(
user_id=user_id,
embedding_model=config.default_embedding_config.embedding_model,
embedding_dim=config.default_embedding_config.embedding_dim,
description=description,
)
ms.create_source(source)
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)

View File

@@ -506,6 +506,7 @@ class Source:
self,
user_id: uuid.UUID,
name: str,
description: Optional[str] = None,
created_at: Optional[datetime] = None,
id: Optional[uuid.UUID] = None,
# embedding info
@@ -521,6 +522,7 @@ class Source:
self.name = name
self.user_id = user_id
self.description = description
self.created_at = created_at if created_at is not None else get_utc_time()
# embedding info (optional)

View File

@@ -202,6 +202,7 @@ class SourceModel(Base):
created_at = Column(DateTime(timezone=True), server_default=func.now())
embedding_dim = Column(BIGINT)
embedding_model = Column(String)
description = Column(String)
# TODO: add num passages
@@ -216,6 +217,7 @@ class SourceModel(Base):
created_at=self.created_at,
embedding_dim=self.embedding_dim,
embedding_model=self.embedding_model,
description=self.description,
)

View File

@@ -102,6 +102,7 @@ class SourceModel(SQLModel, table=True):
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
description: Optional[str] = Field(None, description="The description of the source.")
# embedding info
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
embedding_config: Optional[EmbeddingConfigModel] = Field(