Bugfixes for get_all function and code cleanup to match main

This commit is contained in:
Sarah Wooders
2023-12-26 17:50:49 +04:00
parent 11096b20a4
commit 0c2bf05406
8 changed files with 660 additions and 647 deletions

View File

@@ -93,11 +93,14 @@ class ChromaStorageConnector(StorageConnector):
for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
]
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
ids, filters = self.get_filters(filters)
if self.collection.count() == 0:
return []
results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit)
if limit:
results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit)
else:
results = self.collection.get(ids=ids, include=self.include, where=filters)
return self.results_to_records(results)
def get(self, id: str) -> Optional[Record]:

View File

@@ -191,7 +191,6 @@ class SQLStorageConnector(StorageConnector):
filter_conditions = {**self.filters, **filters}
else:
filter_conditions = self.filters
print("SQL FILTERS", filter_conditions)
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
@@ -232,7 +231,6 @@ class SQLStorageConnector(StorageConnector):
# return size of table
session = self.Session()
filters = self.get_filters(filters)
print("ALL FILTERS", filters)
return session.query(self.db_model).filter(*filters).count()
def insert(self, record: Record):
@@ -345,33 +343,6 @@ class PostgresStorageConnector(SQLStorageConnector):
records = [result.to_record() for result in results]
return records
def delete(self, filters: Optional[Dict] = {}):
session = self.Session()
filters = self.get_filters(filters)
session.query(self.db_model).filter(*filters).delete()
session.commit()
class PostgresStorageConnector(SQLStorageConnector):
"""Storage via Postgres"""
# TODO: this should probably eventually be moved into a parent DB class
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
super().__init__(table_type=table_type, agent_config=agent_config)
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
session = self.Session()
filters = self.get_filters(filters)
results = session.scalars(
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
).all()
# Convert the results into Passage objects
records = [result.to_record() for result in results]
return records
class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):

View File

@@ -73,8 +73,6 @@ class StorageConnector:
else:
self.filters = {}
print("FILTERS", self.filters)
def get_filters(self, filters: Optional[Dict] = {}):
# get all filters for query
if filters is not None:
@@ -87,19 +85,12 @@ class StorageConnector:
if agent_config is not None:
# Table names for agent-specific tables
if agent_config.memgpt_version < "0.2.6":
# if agent is prior version, use old table name
if table_type == TableType.ARCHIVAL_MEMORY:
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
else:
raise ValueError(f"Table type {table_type} not implemented")
if table_type == TableType.ARCHIVAL_MEMORY:
return ARCHIVAL_TABLE_NAME
elif table_type == TableType.RECALL_MEMORY:
return RECALL_TABLE_NAME
else:
if table_type == TableType.ARCHIVAL_MEMORY:
return ARCHIVAL_TABLE_NAME
elif table_type == TableType.RECALL_MEMORY:
return RECALL_TABLE_NAME
else:
raise ValueError(f"Table type {table_type} not implemented")
raise ValueError(f"Table type {table_type} not implemented")
else:
# table names for non-agent specific tables
if table_type == TableType.PASSAGES:
@@ -132,10 +123,12 @@ class StorageConnector:
from memgpt.connectors.chroma import ChromaStorageConnector
return ChromaStorageConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "lancedb":
from memgpt.connectors.db import LanceDBConnector
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
# TODO: add back
# elif storage_type == "lancedb":
# from memgpt.connectors.db import LanceDBConnector
# return LanceDBConnector(agent_config=agent_config, table_type=table_type)
elif storage_type == "local":
from memgpt.connectors.local import InMemoryStorageConnector

View File

@@ -167,10 +167,12 @@ class ArchivalMemory(ABC):
class RecallMemory(ABC):
@abstractmethod
def text_search(self, query_string, count=None, start=None):
"""Search messages that match query_string in recall memory"""
pass
@abstractmethod
def date_search(self, query_string, count=None, start=None):
def date_search(self, start_date, end_date, count=None, start=None):
"""Search messages between start_date and end_date in recall memory"""
pass
@abstractmethod
@@ -179,6 +181,7 @@ class RecallMemory(ABC):
@abstractmethod
def insert(self, message: Message):
"""Insert message into recall memory"""
pass
@@ -396,6 +399,16 @@ class EmbeddingArchivalMemory(ArchivalMemory):
# breakup string into passages
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
embedding = self.embed_model.get_text_embedding(node.text)
# fixing weird bug where type returned isn't a list, but instead is an object
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
if isinstance(embedding, dict):
try:
embedding = embedding["data"][0]["embedding"]
except (KeyError, IndexError):
# TODO as a fallback, see if we can find any lists in the payload
raise TypeError(
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
)
passages.append(self.create_passage(node.text, embedding))
# insert passages

View File

@@ -1,5 +1,13 @@
from datetime import datetime
import re
import json
import os
import pickle
import platform
import subprocess
import sys
import io
from contextlib import contextmanager
import difflib
import demjson3 as demjson
import pytz
@@ -135,8 +143,7 @@ def get_local_time(timezone=None):
local_time = datetime.now().astimezone()
# You may format it as you desire, including AM/PM
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
print("formatted_time", formatted_time)
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
return time_str.strip()

1205
poetry.lock generated

File diff suppressed because it is too large Load Diff

View File

@@ -49,8 +49,15 @@ tiktoken = "^0.5.1"
python-box = "^7.1.1"
pypdf = "^3.17.1"
pyyaml = "^6.0.1"
chromadb = {version = "^0.4.18", optional = true}
chromadb = "^0.4.18"
sqlalchemy-json = "^0.7.0"
fastapi = {version = "^0.104.1", optional = true}
uvicorn = {version = "^0.24.0.post1", optional = true}
pytest-asyncio = {version = "^0.23.2", optional = true}
pydantic = "^2.5.2"
pyautogen = {version = "0.2.0", optional = true}
html2text = "^2020.1.16"
docx2txt = "^0.8"
[tool.poetry.extras]
local = ["torch", "huggingface-hub", "transformers"]

View File

@@ -9,10 +9,7 @@ import pytest
# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
#
# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
import pgvector # Try to import again after installing
from memgpt.connectors.storage import StorageConnector, TableType
from memgpt.connectors.chroma import ChromaStorageConnector
from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector
from memgpt.embeddings import embedding_model
from memgpt.data_types import Message, Passage
from memgpt.config import MemGPTConfig, AgentConfig
@@ -72,6 +69,7 @@ def test_storage(storage_connector, table_type):
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
if storage_connector == "lancedb":
# TODO: complete lancedb implementation
if not os.getenv("LANCEDB_TEST_URL"):
print("Skipping test, missing LanceDB URI")
return