diff --git a/memgpt/utils.py b/memgpt/utils.py index 37746a15..77658228 100644 --- a/memgpt/utils.py +++ b/memgpt/utils.py @@ -1,4 +1,6 @@ from datetime import datetime + +import csv import difflib import demjson3 as demjson import numpy as np @@ -96,6 +98,16 @@ def read_in_chunks(file_object, chunk_size): break yield data +def read_in_rows_csv(file_object, chunk_size): + csvreader = csv.reader(file_object) + header = next(csvreader) + for row in csvreader: + next_row_terms = [] + for h, v in zip(header, row): + next_row_terms.append(f"{h}={v}") + next_row_str = ', '.join(next_row_terms) + yield next_row_str + def prepare_archival_index_from_files(glob_pattern, tkns_per_chunk=300, model='gpt-4'): encoding = tiktoken.encoding_for_model(model) files = glob.glob(glob_pattern) @@ -111,12 +123,16 @@ def total_bytes(pattern): def chunk_file(file, tkns_per_chunk=300, model='gpt-4'): encoding = tiktoken.encoding_for_model(model) with open(file, 'r') as f: - lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)] + if file.endswith('.csv'): + lines = [l for l in read_in_rows_csv(f, tkns_per_chunk*8)] + else: + lines = [l for l in read_in_chunks(f, tkns_per_chunk*4)] curr_chunk = [] curr_token_ct = 0 for i, line in enumerate(lines): line = line.rstrip() line = line.lstrip() + line += '\n' try: line_token_ct = len(encoding.encode(line)) except Exception as e: