Hotfix bug from async refactor (#203)

This commit is contained in:
Vivian Fang
2023-10-30 15:38:25 -07:00
committed by GitHub
parent cc84d46d8b
commit 5a6b0ef1e8

View File

@@ -382,7 +382,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
if embedding is None:
# Get the embedding
embedding = async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
return await self._insert(memory_string, embedding)
return self._insert(memory_string, embedding)
def _search(self, query_embedding, query_string, count=None, start=None):
"""Simple embedding-based search (inefficient, no caching)"""
@@ -588,7 +588,7 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
self.embedding_model = "text-embedding-ada-002"
self.only_use_preloaded_embeddings = False
def _text_search(self, embedding_getter_func, query_string, count, start):
def text_search(self, query_string, count, start):
# in the dummy version, run an (inefficient) case-insensitive match search
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
@@ -603,11 +603,11 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
message_pool_filtered.append(d)
elif message_str not in self.embeddings:
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, computing now")
self.embeddings[message_str] = embedding_getter_func(message_str, model=self.embedding_model)
self.embeddings[message_str] = get_embedding_with_backoff(message_str, model=self.embedding_model)
message_pool_filtered.append(d)
# our wrapped version supports backoff/rate-limits
query_embedding = embedding_getter_func(query_string, model=self.embedding_model)
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
# Sort the archive based on similarity scores
@@ -633,11 +633,8 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
else:
return matches, len(matches)
def text_search(self, query_string, count=None, start=None):
return self._text_search(get_embedding_with_backoff, query_string, count, start)
async def a_text_search(self, query_string, count=None, start=None):
return await self._text_search(async_get_embedding_with_backoff, query_string, count, start)
return self.text_search(query_string, count, start)
class LocalArchivalMemory(ArchivalMemory):