Hotfix bug from async refactor (#203)
This commit is contained in:
@@ -382,7 +382,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
|
||||
if embedding is None:
|
||||
# Get the embedding
|
||||
embedding = async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
|
||||
return await self._insert(memory_string, embedding)
|
||||
return self._insert(memory_string, embedding)
|
||||
|
||||
def _search(self, query_embedding, query_string, count=None, start=None):
|
||||
"""Simple embedding-based search (inefficient, no caching)"""
|
||||
@@ -588,7 +588,7 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
self.embedding_model = "text-embedding-ada-002"
|
||||
self.only_use_preloaded_embeddings = False
|
||||
|
||||
def _text_search(self, embedding_getter_func, query_string, count, start):
|
||||
def text_search(self, query_string, count, start):
|
||||
# in the dummy version, run an (inefficient) case-insensitive match search
|
||||
message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]]
|
||||
|
||||
@@ -603,11 +603,11 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
message_pool_filtered.append(d)
|
||||
elif message_str not in self.embeddings:
|
||||
printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, computing now")
|
||||
self.embeddings[message_str] = embedding_getter_func(message_str, model=self.embedding_model)
|
||||
self.embeddings[message_str] = get_embedding_with_backoff(message_str, model=self.embedding_model)
|
||||
message_pool_filtered.append(d)
|
||||
|
||||
# our wrapped version supports backoff/rate-limits
|
||||
query_embedding = embedding_getter_func(query_string, model=self.embedding_model)
|
||||
query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model)
|
||||
similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered]
|
||||
|
||||
# Sort the archive based on similarity scores
|
||||
@@ -633,11 +633,8 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory):
|
||||
else:
|
||||
return matches, len(matches)
|
||||
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
return self._text_search(get_embedding_with_backoff, query_string, count, start)
|
||||
|
||||
async def a_text_search(self, query_string, count=None, start=None):
|
||||
return await self._text_search(async_get_embedding_with_backoff, query_string, count, start)
|
||||
return self.text_search(query_string, count, start)
|
||||
|
||||
|
||||
class LocalArchivalMemory(ArchivalMemory):
|
||||
|
||||
Reference in New Issue
Block a user