feat: add description to source (#1175)
This commit is contained in:
@@ -775,7 +775,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
"""List all data sources"""
|
||||
|
||||
# create table
|
||||
table.field_names = ["Name", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
|
||||
table.field_names = ["Name", "Description", "Embedding Model", "Embedding Dim", "Created At", "Agents"]
|
||||
# TODO: eventually look accross all storage connections
|
||||
# TODO: add data source stats
|
||||
# TODO: connect to agents
|
||||
@@ -788,7 +788,14 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
agent_names = [agent_state.name for agent_state in agent_states if agent_state is not None]
|
||||
|
||||
table.add_row(
|
||||
[source.name, source.embedding_model, source.embedding_dim, utils.format_datetime(source.created_at), ",".join(agent_names)]
|
||||
[
|
||||
source.name,
|
||||
source.description,
|
||||
source.embedding_model,
|
||||
source.embedding_dim,
|
||||
utils.format_datetime(source.created_at),
|
||||
",".join(agent_names),
|
||||
]
|
||||
)
|
||||
|
||||
print(table)
|
||||
|
||||
@@ -88,6 +88,7 @@ def load_directory(
|
||||
recursive: Annotated[bool, typer.Option(help="Recursively search for files in directory.")] = False,
|
||||
extensions: Annotated[str, typer.Option(help="Comma separated list of file extensions to load")] = default_extensions,
|
||||
user_id: Annotated[Optional[uuid.UUID], typer.Option(help="User ID to associate with dataset.")] = None, # TODO: remove
|
||||
description: Annotated[Optional[str], typer.Option(help="Description of the source.")] = None,
|
||||
):
|
||||
try:
|
||||
connector = DirectoryConnector(input_files=input_files, input_directory=input_dir, recursive=recursive, extensions=extensions)
|
||||
@@ -101,6 +102,7 @@ def load_directory(
|
||||
user_id=user_id,
|
||||
embedding_model=config.default_embedding_config.embedding_model,
|
||||
embedding_dim=config.default_embedding_config.embedding_dim,
|
||||
description=description,
|
||||
)
|
||||
ms.create_source(source)
|
||||
passage_storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
||||
|
||||
@@ -506,6 +506,7 @@ class Source:
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
name: str,
|
||||
description: Optional[str] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
id: Optional[uuid.UUID] = None,
|
||||
# embedding info
|
||||
@@ -521,6 +522,7 @@ class Source:
|
||||
|
||||
self.name = name
|
||||
self.user_id = user_id
|
||||
self.description = description
|
||||
self.created_at = created_at if created_at is not None else get_utc_time()
|
||||
|
||||
# embedding info (optional)
|
||||
|
||||
@@ -202,6 +202,7 @@ class SourceModel(Base):
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
embedding_dim = Column(BIGINT)
|
||||
embedding_model = Column(String)
|
||||
description = Column(String)
|
||||
|
||||
# TODO: add num passages
|
||||
|
||||
@@ -216,6 +217,7 @@ class SourceModel(Base):
|
||||
created_at=self.created_at,
|
||||
embedding_dim=self.embedding_dim,
|
||||
embedding_model=self.embedding_model,
|
||||
description=self.description,
|
||||
)
|
||||
|
||||
|
||||
|
||||
@@ -102,6 +102,7 @@ class SourceModel(SQLModel, table=True):
|
||||
user_id: uuid.UUID = Field(..., description="The unique identifier of the user associated with the source.")
|
||||
created_at: datetime = Field(default_factory=get_utc_time, description="The unix timestamp of when the source was created.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the source.", primary_key=True)
|
||||
description: Optional[str] = Field(None, description="The description of the source.")
|
||||
# embedding info
|
||||
# embedding_config: EmbeddingConfigModel = Field(..., description="The embedding configuration used by the source.")
|
||||
embedding_config: Optional[EmbeddingConfigModel] = Field(
|
||||
|
||||
Reference in New Issue
Block a user