Remove embeddings as argument in archival_memory.insert (#284)

This commit is contained in:
Charles Packer
2023-11-05 12:48:22 -08:00
committed by GitHub
parent d9b9ad4860
commit cc1ce0ce33
2 changed files with 18 additions and 24 deletions

View File

@@ -800,8 +800,8 @@ class Agent(object):
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
def archival_memory_insert(self, content, embedding=None):
self.persistence_manager.archival_memory.insert(content, embedding=None)
def archival_memory_insert(self, content):
self.persistence_manager.archival_memory.insert(content)
return None
def archival_memory_search(self, query, count=5, page=0):
@@ -1245,8 +1245,8 @@ class AgentAsync(Agent):
results_str = f"{results_pref} {json.dumps(results_formatted)}"
return results_str
async def archival_memory_insert(self, content, embedding=None):
await self.persistence_manager.archival_memory.a_insert(content, embedding=None)
async def archival_memory_insert(self, content):
await self.persistence_manager.archival_memory.a_insert(content)
return None
async def archival_memory_search(self, query, count=5, page=0):

View File

@@ -231,9 +231,7 @@ class DummyArchivalMemory(ArchivalMemory):
memory_str = "\n".join([d["content"] for d in self._archive])
return f"\n### ARCHIVAL MEMORY ###" + f"\n{memory_str}"
def insert(self, memory_string, embedding=None):
if embedding is not None:
raise ValueError("Basic text-based archival memory does not support embeddings")
def insert(self, memory_string):
self._archive.append(
{
# can eventually upgrade to adding semantic tags, etc
@@ -242,8 +240,8 @@ class DummyArchivalMemory(ArchivalMemory):
}
)
async def a_insert(self, memory_string, embedding=None):
return self.insert(memory_string, embedding)
async def a_insert(self, memory_string):
return self.insert(memory_string)
def search(self, query_string, count=None, start=None):
"""Simple text-based search"""
@@ -293,14 +291,12 @@ class DummyArchivalMemoryWithEmbeddings(DummyArchivalMemory):
}
)
def insert(self, memory_string, embedding=None):
if embedding is None:
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
def insert(self, memory_string):
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)
async def a_insert(self, memory_string, embedding=None):
if embedding is None:
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
async def a_insert(self, memory_string):
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)
def _search(self, query_embedding, query_string, count, start):
@@ -382,16 +378,14 @@ class DummyArchivalMemoryWithFaiss(DummyArchivalMemory):
embedding = np.array([embedding]).astype("float32")
self.index.add(embedding)
def insert(self, memory_string, embedding=None):
if embedding is None:
# Get the embedding
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
def insert(self, memory_string):
# Get the embedding
embedding = get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)
async def a_insert(self, memory_string, embedding=None):
if embedding is None:
# Get the embedding
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
async def a_insert(self, memory_string):
# Get the embedding
embedding = await async_get_embedding_with_backoff(memory_string, model=self.embedding_model)
return self._insert(memory_string, embedding)
def _search(self, query_embedding, query_string, count=None, start=None):
@@ -814,7 +808,7 @@ class EmbeddingArchivalMemory(ArchivalMemory):
async def a_search(self, query_string, count=None, start=None):
return self.search(query_string, count, start)
async def a_insert(self, memory_string, embedding=None):
async def a_insert(self, memory_string):
return self.insert(memory_string)
def __repr__(self) -> str: