first stab at running embedding generation concurrently

This commit is contained in:
Vivian Fang
2023-10-21 10:58:49 -07:00
parent 3c38cbbdca
commit a557b8d58c

View File

@@ -1,5 +1,6 @@
from datetime import datetime
import asyncio
import csv
import difflib
import demjson3 as demjson
@@ -194,6 +195,31 @@ def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
ret.append(curr_file)
return ret
async def process_chunk(i, chunk, model):
try:
return i, await async_get_embedding_with_backoff(chunk['content'], model=model)
except Exception as e:
print(chunk)
raise e
async def process_concurrently(archival_database, model, concurrency=5):
# Create a semaphore to limit the number of concurrent tasks
semaphore = asyncio.Semaphore(concurrency)
async def bounded_process_chunk(i, chunk):
async with semaphore:
return await process_chunk(i, chunk, model)
# Create a list of tasks for chunks
embedding_data = [0 for _ in archival_database]
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
for future in tqdm(asyncio.as_completed(tasks), total=len(archival_database), desc="Processing file chunks"):
i, result = await future
embedding_data[i] = result
return embedding_data
async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
files = sorted(glob.glob(glob_pattern))
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
@@ -206,15 +232,7 @@ async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkn
# chunk the files, make embeddings
archival_database = chunk_files(files, tkns_per_chunk, model)
embedding_data = []
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
try:
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
except Exception as e:
print(chunk)
raise e
embedding_data.append(embedding)
embedding_data = await process_concurrently(archival_database, embeddings_model)
embeddings_file = os.path.join(save_dir, "embeddings.json")
with open(embeddings_file, 'w') as f:
print(f"Saving embeddings to {embeddings_file}")