Merge pull request #85 from cpacker/fast-embeddings
Parallelize embedding generation
This commit is contained in:
@@ -1,5 +1,6 @@
|
||||
from datetime import datetime
|
||||
|
||||
import asyncio
|
||||
import csv
|
||||
import difflib
|
||||
import demjson3 as demjson
|
||||
@@ -194,6 +195,31 @@ def chunk_files_for_jsonl(files, tkns_per_chunk=300, model='gpt-4'):
|
||||
ret.append(curr_file)
|
||||
return ret
|
||||
|
||||
async def process_chunk(i, chunk, model):
|
||||
try:
|
||||
return i, await async_get_embedding_with_backoff(chunk['content'], model=model)
|
||||
except Exception as e:
|
||||
print(chunk)
|
||||
raise e
|
||||
|
||||
async def process_concurrently(archival_database, model, concurrency=10):
|
||||
# Create a semaphore to limit the number of concurrent tasks
|
||||
semaphore = asyncio.Semaphore(concurrency)
|
||||
|
||||
async def bounded_process_chunk(i, chunk):
|
||||
async with semaphore:
|
||||
return await process_chunk(i, chunk, model)
|
||||
|
||||
# Create a list of tasks for chunks
|
||||
embedding_data = [0 for _ in archival_database]
|
||||
tasks = [bounded_process_chunk(i, chunk) for i, chunk in enumerate(archival_database)]
|
||||
|
||||
for future in tqdm(asyncio.as_completed(tasks), total=len(archival_database), desc="Processing file chunks"):
|
||||
i, result = await future
|
||||
embedding_data[i] = result
|
||||
|
||||
return embedding_data
|
||||
|
||||
async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkns_per_chunk=300, model='gpt-4', embeddings_model='text-embedding-ada-002'):
|
||||
files = sorted(glob.glob(glob_pattern))
|
||||
save_dir = "archival_index_from_files_" + get_local_time().replace(' ', '_').replace(':', '_')
|
||||
@@ -206,15 +232,7 @@ async def prepare_archival_index_from_files_compute_embeddings(glob_pattern, tkn
|
||||
|
||||
# chunk the files, make embeddings
|
||||
archival_database = chunk_files(files, tkns_per_chunk, model)
|
||||
embedding_data = []
|
||||
for chunk in tqdm(archival_database, desc="Processing file chunks", total=len(archival_database)):
|
||||
# for chunk in tqdm(f, desc=f"Embedding file {i+1}/{len(chunks_by_file)}", total=len(f), leave=False):
|
||||
try:
|
||||
embedding = await async_get_embedding_with_backoff(chunk['content'], model=embeddings_model)
|
||||
except Exception as e:
|
||||
print(chunk)
|
||||
raise e
|
||||
embedding_data.append(embedding)
|
||||
embedding_data = await process_concurrently(archival_database, embeddings_model)
|
||||
embeddings_file = os.path.join(save_dir, "embeddings.json")
|
||||
with open(embeddings_file, 'w') as f:
|
||||
print(f"Saving embeddings to {embeddings_file}")
|
||||
|
||||
Reference in New Issue
Block a user