diff --git a/memgpt/memory.py b/memgpt/memory.py index ca8be9dc..8de814b4 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -382,7 +382,7 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory): if embedding is None: # Get the embedding embedding = async_get_embedding_with_backoff(memory_string, model=self.embedding_model) - return await self._insert(memory_string, embedding) + return self._insert(memory_string, embedding) def _search(self, query_embedding, query_string, count=None, start=None): """Simple embedding-based search (inefficient, no caching)""" @@ -588,7 +588,7 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): self.embedding_model = "text-embedding-ada-002" self.only_use_preloaded_embeddings = False - def _text_search(self, embedding_getter_func, query_string, count, start): + def text_search(self, query_string, count, start): # in the dummy version, run an (inefficient) case-insensitive match search message_pool = [d for d in self._message_logs if d["message"]["role"] not in ["system", "function"]] @@ -603,11 +603,11 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): message_pool_filtered.append(d) elif message_str not in self.embeddings: printd(f"recall_memory.text_search -- '{message_str}' was not in embedding dict, computing now") - self.embeddings[message_str] = embedding_getter_func(message_str, model=self.embedding_model) + self.embeddings[message_str] = get_embedding_with_backoff(message_str, model=self.embedding_model) message_pool_filtered.append(d) # our wrapped version supports backoff/rate-limits - query_embedding = embedding_getter_func(query_string, model=self.embedding_model) + query_embedding = get_embedding_with_backoff(query_string, model=self.embedding_model) similarity_scores = [cosine_similarity(self.embeddings[d["message"]["content"]], query_embedding) for d in message_pool_filtered] # Sort the archive based on similarity scores @@ -633,11 +633,8 @@ class DummyRecallMemoryWithEmbeddings(DummyRecallMemory): else: return matches, len(matches) - def text_search(self, query_string, count=None, start=None): - return self._text_search(get_embedding_with_backoff, query_string, count, start) - async def a_text_search(self, query_string, count=None, start=None): - return await self._text_search(async_get_embedding_with_backoff, query_string, count, start) + return self.text_search(query_string, count, start) class LocalArchivalMemory(ArchivalMemory):