159 lines
5.5 KiB
Python
159 lines
5.5 KiB
Python
import copy
|
|
import hashlib
|
|
import json
|
|
import os
|
|
import time
|
|
import uuid
|
|
from concurrent.futures import ThreadPoolExecutor, as_completed
|
|
|
|
from absl import app, flags
|
|
from icml_experiments.utils import get_experiment_config
|
|
from tqdm import tqdm
|
|
|
|
from letta.agent_store.storage import StorageConnector, TableType
|
|
from letta.cli.cli_config import delete
|
|
from letta.data_types import Passage
|
|
|
|
# Create an empty list to store the JSON objects
|
|
source_name = "wikipedia"
|
|
config = get_experiment_config(os.environ.get("PGVECTOR_TEST_DB_URL"), endpoint_type="openai")
|
|
config.save() # save config to file
|
|
user_id = uuid.UUID(config.anon_clientid)
|
|
|
|
FLAGS = flags.FLAGS
|
|
flags.DEFINE_boolean("drop_db", default=False, required=False, help="Drop existing source DB")
|
|
flags.DEFINE_string("file", default=None, required=True, help="File to parse")
|
|
|
|
|
|
def create_uuid_from_string(val: str):
|
|
"""
|
|
Generate consistent UUID from a string
|
|
from: https://samos-it.com/posts/python-create-uuid-from-random-string-of-words.html
|
|
"""
|
|
hex_string = hashlib.md5(val.encode("UTF-8")).hexdigest()
|
|
return uuid.UUID(hex=hex_string)
|
|
|
|
|
|
def insert_lines(lines, conn, show_progress=False):
|
|
"""Parse and insert list of lines into source database"""
|
|
passages = []
|
|
iterator = tqdm(lines) if show_progress else lines
|
|
added = set()
|
|
for line in iterator:
|
|
d = json.loads(line)
|
|
# pprint(d)
|
|
assert len(d) == 2, f"Line is empty: {len(d)}"
|
|
text = d[0]["input"]
|
|
model = d[0]["model"]
|
|
embedding = d[1]["data"][0]["embedding"]
|
|
embedding_dim = len(embedding)
|
|
assert embedding_dim == 1536, f"Wrong embedding dim: {len(embedding_dim)}"
|
|
assert len(d[1]["data"]) == 1, f"More than one embedding: {len(d[1]['data'])}"
|
|
d[1]["usage"]
|
|
# print(text)
|
|
|
|
passage_id = create_uuid_from_string(text) # consistent hash for text (prevent duplicates)
|
|
if passage_id in added:
|
|
continue
|
|
else:
|
|
added.add(passage_id)
|
|
# if conn.get(passage_id):
|
|
# continue
|
|
|
|
passage = Passage(
|
|
id=passage_id,
|
|
user_id=user_id,
|
|
text=text,
|
|
embedding_model=model,
|
|
embedding_dim=embedding_dim,
|
|
embedding=embedding,
|
|
# metadata=None,
|
|
data_source=source_name,
|
|
)
|
|
# print(passage.id)
|
|
passages.append(passage)
|
|
st = time.time()
|
|
# insert_passages_into_source(passages, source_name=source_name, user_id=user_id, config=config)
|
|
# conn.insert_many(passages)
|
|
conn.upsert_many(passages)
|
|
return time.time() - st
|
|
|
|
|
|
def main(argv):
|
|
# clear out existing source
|
|
if FLAGS.drop_db:
|
|
delete("source", source_name)
|
|
try:
|
|
passages_table = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
|
passages_table.delete_table()
|
|
|
|
except Exception as e:
|
|
print("Failed to delete source")
|
|
print(e)
|
|
|
|
# Open the file and read line by line
|
|
count = 0
|
|
# files = [
|
|
# #'data/wikipedia_passages_shard_1-00.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-01.jsonl',
|
|
# 'data/wikipedia_passages_shard_1-02.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-03.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-04.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-05.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-06.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-07.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-08.jsonl',
|
|
# #'data/wikipedia_passages_shard_1-09.jsonl',
|
|
# ]
|
|
files = [FLAGS.file]
|
|
chunk_size = 1000
|
|
conn = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
|
for file_path in files:
|
|
print(file_path)
|
|
futures = []
|
|
with ThreadPoolExecutor(max_workers=64) as p:
|
|
with open(file_path, "r") as file:
|
|
lines = []
|
|
|
|
# insert lines in 1k chunks
|
|
for line in tqdm(file):
|
|
lines.append(line)
|
|
if len(lines) >= chunk_size:
|
|
if count == 0:
|
|
# future = p.submit(insert_lines, copy.deepcopy(lines), conn, True)
|
|
print("Await first result (hack to avoid concurrency issues)")
|
|
t = insert_lines(lines, conn, True)
|
|
# t = future.result()
|
|
print("Finished first result", t)
|
|
else:
|
|
future = p.submit(insert_lines, copy.deepcopy(lines), conn)
|
|
futures.append(future)
|
|
count += len(lines)
|
|
lines = []
|
|
|
|
# insert remaining lines
|
|
if len(lines) > 0:
|
|
future = p.submit(insert_lines, copy.deepcopy(lines), conn)
|
|
futures.append(future)
|
|
count += len(lines)
|
|
lines = []
|
|
|
|
## breaking point
|
|
# if count >= 3000:
|
|
# break
|
|
|
|
print(f"Waiting for {len(futures)} futures")
|
|
# wait for futures
|
|
for future in tqdm(as_completed(futures)):
|
|
future.result()
|
|
|
|
# check metadata
|
|
# storage = StorageConnector.get_storage_connector(TableType.PASSAGES, config, user_id)
|
|
# size = storage.size()
|
|
size = conn.size()
|
|
print("Number of passages", size)
|
|
|
|
|
|
if __name__ == "__main__":
|
|
app.run(main)
|