feat: Enable adding files (#1864)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -1,11 +1,15 @@
|
||||
from typing import Dict, Iterator, List, Optional, Tuple
|
||||
from typing import Dict, Iterator, List, Tuple
|
||||
|
||||
import typer
|
||||
from llama_index.core import Document as LlamaIndexDocument
|
||||
|
||||
from letta.agent_store.storage import StorageConnector
|
||||
from letta.data_sources.connectors_helper import (
|
||||
assert_all_files_exist_locally,
|
||||
extract_metadata_from_files,
|
||||
get_filenames_in_dir,
|
||||
)
|
||||
from letta.embeddings import embedding_model
|
||||
from letta.schemas.document import Document
|
||||
from letta.schemas.file import FileMetadata
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source
|
||||
from letta.utils import create_uuid_from_string
|
||||
@@ -13,23 +17,23 @@ from letta.utils import create_uuid_from_string
|
||||
|
||||
class DataConnector:
|
||||
"""
|
||||
Base class for data connectors that can be extended to generate documents and passages from a custom data source.
|
||||
Base class for data connectors that can be extended to generate files and passages from a custom data source.
|
||||
"""
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
||||
"""
|
||||
Generate document text and metadata from a data source.
|
||||
Generate file metadata from a data source.
|
||||
|
||||
Returns:
|
||||
documents (Iterator[Tuple[str, Dict]]): Generate a tuple of string text and metadata dictionary for each document.
|
||||
files (Iterator[FileMetadata]): Generate file metadata for each file found.
|
||||
"""
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
"""
|
||||
Generate passage text and metadata from a list of documents.
|
||||
Generate passage text and metadata from a list of files.
|
||||
|
||||
Args:
|
||||
documents (List[Document]): List of documents to generate passages from.
|
||||
file (FileMetadata): The document to generate passages from.
|
||||
chunk_size (int, optional): Chunk size for splitting passages. Defaults to 1024.
|
||||
|
||||
Returns:
|
||||
@@ -41,33 +45,25 @@ def load_data(
|
||||
connector: DataConnector,
|
||||
source: Source,
|
||||
passage_store: StorageConnector,
|
||||
document_store: Optional[StorageConnector] = None,
|
||||
file_metadata_store: StorageConnector,
|
||||
):
|
||||
"""Load data from a connector (generates documents and passages) into a specified source_id, associatedw with a user_id."""
|
||||
"""Load data from a connector (generates file and passages) into a specified source_id, associatedw with a user_id."""
|
||||
embedding_config = source.embedding_config
|
||||
|
||||
# embedding model
|
||||
embed_model = embedding_model(embedding_config)
|
||||
|
||||
# insert passages/documents
|
||||
# insert passages/file
|
||||
passages = []
|
||||
embedding_to_document_name = {}
|
||||
passage_count = 0
|
||||
document_count = 0
|
||||
for document_text, document_metadata in connector.generate_documents():
|
||||
# insert document into storage
|
||||
document = Document(
|
||||
text=document_text,
|
||||
metadata_=document_metadata,
|
||||
source_id=source.id,
|
||||
user_id=source.user_id,
|
||||
)
|
||||
document_count += 1
|
||||
if document_store:
|
||||
document_store.insert(document)
|
||||
file_count = 0
|
||||
for file_metadata in connector.find_files(source):
|
||||
file_count += 1
|
||||
file_metadata_store.insert(file_metadata)
|
||||
|
||||
# generate passages
|
||||
for passage_text, passage_metadata in connector.generate_passages([document], chunk_size=embedding_config.embedding_chunk_size):
|
||||
for passage_text, passage_metadata in connector.generate_passages(file_metadata, chunk_size=embedding_config.embedding_chunk_size):
|
||||
# for some reason, llama index parsers sometimes return empty strings
|
||||
if len(passage_text) == 0:
|
||||
typer.secho(
|
||||
@@ -89,7 +85,7 @@ def load_data(
|
||||
passage = Passage(
|
||||
id=create_uuid_from_string(f"{str(source.id)}_{passage_text}"),
|
||||
text=passage_text,
|
||||
doc_id=document.id,
|
||||
file_id=file_metadata.id,
|
||||
source_id=source.id,
|
||||
metadata_=passage_metadata,
|
||||
user_id=source.user_id,
|
||||
@@ -98,16 +94,16 @@ def load_data(
|
||||
)
|
||||
|
||||
hashable_embedding = tuple(passage.embedding)
|
||||
document_name = document.metadata_.get("file_path", document.id)
|
||||
file_name = file_metadata.file_name
|
||||
if hashable_embedding in embedding_to_document_name:
|
||||
typer.secho(
|
||||
f"Warning: Duplicate embedding found for passage in {document_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
||||
f"Warning: Duplicate embedding found for passage in {file_name} (already exists in {embedding_to_document_name[hashable_embedding]}), skipping insert into VectorDB.",
|
||||
fg=typer.colors.YELLOW,
|
||||
)
|
||||
continue
|
||||
|
||||
passages.append(passage)
|
||||
embedding_to_document_name[hashable_embedding] = document_name
|
||||
embedding_to_document_name[hashable_embedding] = file_name
|
||||
if len(passages) >= 100:
|
||||
# insert passages into passage store
|
||||
passage_store.insert_many(passages)
|
||||
@@ -120,7 +116,7 @@ def load_data(
|
||||
passage_store.insert_many(passages)
|
||||
passage_count += len(passages)
|
||||
|
||||
return passage_count, document_count
|
||||
return passage_count, file_count
|
||||
|
||||
|
||||
class DirectoryConnector(DataConnector):
|
||||
@@ -143,105 +139,109 @@ class DirectoryConnector(DataConnector):
|
||||
if self.recursive == True:
|
||||
assert self.input_directory is not None, "Must provide input directory if recursive is True."
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
|
||||
def find_files(self, source: Source) -> Iterator[FileMetadata]:
|
||||
if self.input_directory is not None:
|
||||
reader = SimpleDirectoryReader(
|
||||
files = get_filenames_in_dir(
|
||||
input_dir=self.input_directory,
|
||||
recursive=self.recursive,
|
||||
required_exts=[ext.strip() for ext in str(self.extensions).split(",")],
|
||||
exclude=["*png", "*jpg", "*jpeg"],
|
||||
)
|
||||
else:
|
||||
assert self.input_files is not None, "Must provide input files if input_dir is None"
|
||||
reader = SimpleDirectoryReader(input_files=[str(f) for f in self.input_files])
|
||||
files = self.input_files
|
||||
|
||||
llama_index_docs = reader.load_data(show_progress=True)
|
||||
for llama_index_doc in llama_index_docs:
|
||||
# TODO: add additional metadata?
|
||||
# doc = Document(text=llama_index_doc.text, metadata=llama_index_doc.metadata)
|
||||
# docs.append(doc)
|
||||
yield llama_index_doc.text, llama_index_doc.metadata
|
||||
# Check that file paths are valid
|
||||
assert_all_files_exist_locally(files)
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
# use llama index to run embeddings code
|
||||
# from llama_index.core.node_parser import SentenceSplitter
|
||||
for metadata in extract_metadata_from_files(files):
|
||||
yield FileMetadata(
|
||||
user_id=source.user_id,
|
||||
source_id=source.id,
|
||||
file_name=metadata.get("file_name"),
|
||||
file_path=metadata.get("file_path"),
|
||||
file_type=metadata.get("file_type"),
|
||||
file_size=metadata.get("file_size"),
|
||||
file_creation_date=metadata.get("file_creation_date"),
|
||||
file_last_modified_date=metadata.get("file_last_modified_date"),
|
||||
)
|
||||
|
||||
def generate_passages(self, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]:
|
||||
from llama_index.core import SimpleDirectoryReader
|
||||
from llama_index.core.node_parser import TokenTextSplitter
|
||||
|
||||
parser = TokenTextSplitter(chunk_size=chunk_size)
|
||||
for document in documents:
|
||||
llama_index_docs = [LlamaIndexDocument(text=document.text, metadata=document.metadata_)]
|
||||
nodes = parser.get_nodes_from_documents(llama_index_docs)
|
||||
for node in nodes:
|
||||
# passage = Passage(
|
||||
# text=node.text,
|
||||
# doc_id=document.id,
|
||||
# )
|
||||
yield node.text, None
|
||||
documents = SimpleDirectoryReader(input_files=[file.file_path]).load_data()
|
||||
nodes = parser.get_nodes_from_documents(documents)
|
||||
for node in nodes:
|
||||
yield node.text, None
|
||||
|
||||
|
||||
class WebConnector(DirectoryConnector):
|
||||
def __init__(self, urls: List[str] = None, html_to_text: bool = True):
|
||||
self.urls = urls
|
||||
self.html_to_text = html_to_text
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
from llama_index.readers.web import SimpleWebPageReader
|
||||
|
||||
documents = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
|
||||
for document in documents:
|
||||
yield document.text, {"url": document.id_}
|
||||
|
||||
|
||||
class VectorDBConnector(DataConnector):
|
||||
# NOTE: this class has not been properly tested, so is unlikely to work
|
||||
# TODO: allow loading multiple tables (1:1 mapping between Document and Table)
|
||||
|
||||
def __init__(
|
||||
self,
|
||||
name: str,
|
||||
uri: str,
|
||||
table_name: str,
|
||||
text_column: str,
|
||||
embedding_column: str,
|
||||
embedding_dim: int,
|
||||
):
|
||||
self.name = name
|
||||
self.uri = uri
|
||||
self.table_name = table_name
|
||||
self.text_column = text_column
|
||||
self.embedding_column = embedding_column
|
||||
self.embedding_dim = embedding_dim
|
||||
|
||||
# connect to db table
|
||||
from sqlalchemy import create_engine
|
||||
|
||||
self.engine = create_engine(uri)
|
||||
|
||||
def generate_documents(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
yield self.table_name, None
|
||||
|
||||
def generate_passages(self, documents: List[Document], chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
from pgvector.sqlalchemy import Vector
|
||||
from sqlalchemy import Inspector, MetaData, Table, select
|
||||
|
||||
metadata = MetaData()
|
||||
# Create an inspector to inspect the database
|
||||
inspector = Inspector.from_engine(self.engine)
|
||||
table_names = inspector.get_table_names()
|
||||
assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
||||
|
||||
table = Table(self.table_name, metadata, autoload_with=self.engine)
|
||||
|
||||
# Prepare a select statement
|
||||
select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
||||
|
||||
# Execute the query and fetch the results
|
||||
# TODO: paginate results
|
||||
with self.engine.connect() as connection:
|
||||
result = connection.execute(select_statement).fetchall()
|
||||
|
||||
for text, embedding in result:
|
||||
# assume that embeddings are the same model as in config
|
||||
# TODO: don't re-compute embedding
|
||||
yield text, {"embedding": embedding}
|
||||
"""
|
||||
The below isn't used anywhere, it isn't tested, and pretty much should be deleted.
|
||||
- Matt
|
||||
"""
|
||||
# class WebConnector(DirectoryConnector):
|
||||
# def __init__(self, urls: List[str] = None, html_to_text: bool = True):
|
||||
# self.urls = urls
|
||||
# self.html_to_text = html_to_text
|
||||
#
|
||||
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
# from llama_index.readers.web import SimpleWebPageReader
|
||||
#
|
||||
# files = SimpleWebPageReader(html_to_text=self.html_to_text).load_data(self.urls)
|
||||
# for document in files:
|
||||
# yield document.text, {"url": document.id_}
|
||||
#
|
||||
#
|
||||
# class VectorDBConnector(DataConnector):
|
||||
# # NOTE: this class has not been properly tested, so is unlikely to work
|
||||
# # TODO: allow loading multiple tables (1:1 mapping between FileMetadata and Table)
|
||||
#
|
||||
# def __init__(
|
||||
# self,
|
||||
# name: str,
|
||||
# uri: str,
|
||||
# table_name: str,
|
||||
# text_column: str,
|
||||
# embedding_column: str,
|
||||
# embedding_dim: int,
|
||||
# ):
|
||||
# self.name = name
|
||||
# self.uri = uri
|
||||
# self.table_name = table_name
|
||||
# self.text_column = text_column
|
||||
# self.embedding_column = embedding_column
|
||||
# self.embedding_dim = embedding_dim
|
||||
#
|
||||
# # connect to db table
|
||||
# from sqlalchemy import create_engine
|
||||
#
|
||||
# self.engine = create_engine(uri)
|
||||
#
|
||||
# def generate_files(self) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Document]:
|
||||
# yield self.table_name, None
|
||||
#
|
||||
# def generate_passages(self, file_text: str, file: FileMetadata, chunk_size: int = 1024) -> Iterator[Tuple[str, Dict]]: # -> Iterator[Passage]:
|
||||
# from pgvector.sqlalchemy import Vector
|
||||
# from sqlalchemy import Inspector, MetaData, Table, select
|
||||
#
|
||||
# metadata = MetaData()
|
||||
# # Create an inspector to inspect the database
|
||||
# inspector = Inspector.from_engine(self.engine)
|
||||
# table_names = inspector.get_table_names()
|
||||
# assert self.table_name in table_names, f"Table {self.table_name} not found in database: tables that exist {table_names}."
|
||||
#
|
||||
# table = Table(self.table_name, metadata, autoload_with=self.engine)
|
||||
#
|
||||
# # Prepare a select statement
|
||||
# select_statement = select(table.c[self.text_column], table.c[self.embedding_column].cast(Vector(self.embedding_dim)))
|
||||
#
|
||||
# # Execute the query and fetch the results
|
||||
# # TODO: paginate results
|
||||
# with self.engine.connect() as connection:
|
||||
# result = connection.execute(select_statement).fetchall()
|
||||
#
|
||||
# for text, embedding in result:
|
||||
# # assume that embeddings are the same model as in config
|
||||
# # TODO: don't re-compute embedding
|
||||
# yield text, {"embedding": embedding}
|
||||
|
||||
Reference in New Issue
Block a user