Bugfixes for get_all function and code cleanup to match main
This commit is contained in:
@@ -93,11 +93,14 @@ class ChromaStorageConnector(StorageConnector):
|
||||
for (text, id, metadatas) in zip(results["documents"], results["ids"], results["metadatas"])
|
||||
]
|
||||
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=10) -> List[Record]:
|
||||
def get_all(self, filters: Optional[Dict] = {}, limit=None) -> List[Record]:
|
||||
ids, filters = self.get_filters(filters)
|
||||
if self.collection.count() == 0:
|
||||
return []
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit)
|
||||
if limit:
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters, limit=limit)
|
||||
else:
|
||||
results = self.collection.get(ids=ids, include=self.include, where=filters)
|
||||
return self.results_to_records(results)
|
||||
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
|
||||
@@ -191,7 +191,6 @@ class SQLStorageConnector(StorageConnector):
|
||||
filter_conditions = {**self.filters, **filters}
|
||||
else:
|
||||
filter_conditions = self.filters
|
||||
print("SQL FILTERS", filter_conditions)
|
||||
return [getattr(self.db_model, key) == value for key, value in filter_conditions.items()]
|
||||
|
||||
def get_all_paginated(self, filters: Optional[Dict] = {}, page_size: Optional[int] = 1000) -> Iterator[List[Record]]:
|
||||
@@ -232,7 +231,6 @@ class SQLStorageConnector(StorageConnector):
|
||||
# return size of table
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
print("ALL FILTERS", filters)
|
||||
return session.query(self.db_model).filter(*filters).count()
|
||||
|
||||
def insert(self, record: Record):
|
||||
@@ -345,33 +343,6 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
def delete(self, filters: Optional[Dict] = {}):
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
session.query(self.db_model).filter(*filters).delete()
|
||||
session.commit()
|
||||
|
||||
|
||||
class PostgresStorageConnector(SQLStorageConnector):
|
||||
"""Storage via Postgres"""
|
||||
|
||||
# TODO: this should probably eventually be moved into a parent DB class
|
||||
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
super().__init__(table_type=table_type, agent_config=agent_config)
|
||||
self.Session().execute(text("CREATE EXTENSION IF NOT EXISTS vector")) # Enables the vector extension
|
||||
|
||||
def query(self, query: str, query_vec: List[float], top_k: int = 10, filters: Optional[Dict] = {}) -> List[Record]:
|
||||
session = self.Session()
|
||||
filters = self.get_filters(filters)
|
||||
results = session.scalars(
|
||||
select(self.db_model).filter(*filters).order_by(self.db_model.embedding.l2_distance(query_vec)).limit(top_k)
|
||||
).all()
|
||||
|
||||
# Convert the results into Passage objects
|
||||
records = [result.to_record() for result in results]
|
||||
return records
|
||||
|
||||
|
||||
class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
def __init__(self, table_type: str, agent_config: Optional[AgentConfig] = None):
|
||||
|
||||
@@ -73,8 +73,6 @@ class StorageConnector:
|
||||
else:
|
||||
self.filters = {}
|
||||
|
||||
print("FILTERS", self.filters)
|
||||
|
||||
def get_filters(self, filters: Optional[Dict] = {}):
|
||||
# get all filters for query
|
||||
if filters is not None:
|
||||
@@ -87,19 +85,12 @@ class StorageConnector:
|
||||
|
||||
if agent_config is not None:
|
||||
# Table names for agent-specific tables
|
||||
if agent_config.memgpt_version < "0.2.6":
|
||||
# if agent is prior version, use old table name
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
return f"memgpt_agent_{self.sanitize_table_name(agent_config.name)}"
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
return ARCHIVAL_TABLE_NAME
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
return RECALL_TABLE_NAME
|
||||
else:
|
||||
if table_type == TableType.ARCHIVAL_MEMORY:
|
||||
return ARCHIVAL_TABLE_NAME
|
||||
elif table_type == TableType.RECALL_MEMORY:
|
||||
return RECALL_TABLE_NAME
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
else:
|
||||
# table names for non-agent specific tables
|
||||
if table_type == TableType.PASSAGES:
|
||||
@@ -132,10 +123,12 @@ class StorageConnector:
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
|
||||
return ChromaStorageConnector(agent_config=agent_config, table_type=table_type)
|
||||
elif storage_type == "lancedb":
|
||||
from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
return LanceDBConnector(agent_config=agent_config, table_type=table_type)
|
||||
# TODO: add back
|
||||
# elif storage_type == "lancedb":
|
||||
# from memgpt.connectors.db import LanceDBConnector
|
||||
|
||||
# return LanceDBConnector(agent_config=agent_config, table_type=table_type)
|
||||
|
||||
elif storage_type == "local":
|
||||
from memgpt.connectors.local import InMemoryStorageConnector
|
||||
|
||||
@@ -167,10 +167,12 @@ class ArchivalMemory(ABC):
|
||||
class RecallMemory(ABC):
|
||||
@abstractmethod
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
"""Search messages that match query_string in recall memory"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def date_search(self, query_string, count=None, start=None):
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
"""Search messages between start_date and end_date in recall memory"""
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
@@ -179,6 +181,7 @@ class RecallMemory(ABC):
|
||||
|
||||
@abstractmethod
|
||||
def insert(self, message: Message):
|
||||
"""Insert message into recall memory"""
|
||||
pass
|
||||
|
||||
|
||||
@@ -396,6 +399,16 @@ class EmbeddingArchivalMemory(ArchivalMemory):
|
||||
# breakup string into passages
|
||||
for node in parser.get_nodes_from_documents([Document(text=memory_string)]):
|
||||
embedding = self.embed_model.get_text_embedding(node.text)
|
||||
# fixing weird bug where type returned isn't a list, but instead is an object
|
||||
# eg: embedding={'object': 'list', 'data': [{'object': 'embedding', 'embedding': [-0.0071973633, -0.07893023,
|
||||
if isinstance(embedding, dict):
|
||||
try:
|
||||
embedding = embedding["data"][0]["embedding"]
|
||||
except (KeyError, IndexError):
|
||||
# TODO as a fallback, see if we can find any lists in the payload
|
||||
raise TypeError(
|
||||
f"Got back an unexpected payload from text embedding function, type={type(embedding)}, value={embedding}"
|
||||
)
|
||||
passages.append(self.create_passage(node.text, embedding))
|
||||
|
||||
# insert passages
|
||||
|
||||
@@ -1,5 +1,13 @@
|
||||
from datetime import datetime
|
||||
import re
|
||||
import json
|
||||
import os
|
||||
import pickle
|
||||
import platform
|
||||
import subprocess
|
||||
import sys
|
||||
import io
|
||||
from contextlib import contextmanager
|
||||
import difflib
|
||||
import demjson3 as demjson
|
||||
import pytz
|
||||
@@ -135,8 +143,7 @@ def get_local_time(timezone=None):
|
||||
local_time = datetime.now().astimezone()
|
||||
|
||||
# You may format it as you desire, including AM/PM
|
||||
formatted_time = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
print("formatted_time", formatted_time)
|
||||
time_str = local_time.strftime("%Y-%m-%d %I:%M:%S %p %Z%z")
|
||||
|
||||
return time_str.strip()
|
||||
|
||||
|
||||
1205
poetry.lock
generated
1205
poetry.lock
generated
File diff suppressed because it is too large
Load Diff
@@ -49,8 +49,15 @@ tiktoken = "^0.5.1"
|
||||
python-box = "^7.1.1"
|
||||
pypdf = "^3.17.1"
|
||||
pyyaml = "^6.0.1"
|
||||
chromadb = {version = "^0.4.18", optional = true}
|
||||
chromadb = "^0.4.18"
|
||||
sqlalchemy-json = "^0.7.0"
|
||||
fastapi = {version = "^0.104.1", optional = true}
|
||||
uvicorn = {version = "^0.24.0.post1", optional = true}
|
||||
pytest-asyncio = {version = "^0.23.2", optional = true}
|
||||
pydantic = "^2.5.2"
|
||||
pyautogen = {version = "0.2.0", optional = true}
|
||||
html2text = "^2020.1.16"
|
||||
docx2txt = "^0.8"
|
||||
|
||||
[tool.poetry.extras]
|
||||
local = ["torch", "huggingface-hub", "transformers"]
|
||||
|
||||
@@ -9,10 +9,7 @@ import pytest
|
||||
# ) # , "psycopg_binary"]) # "psycopg", "libpq-dev"])
|
||||
#
|
||||
# subprocess.check_call([sys.executable, "-m", "pip", "install", "lancedb"])
|
||||
import pgvector # Try to import again after installing
|
||||
from memgpt.connectors.storage import StorageConnector, TableType
|
||||
from memgpt.connectors.chroma import ChromaStorageConnector
|
||||
from memgpt.connectors.db import SQLStorageConnector, LanceDBConnector
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.data_types import Message, Passage
|
||||
from memgpt.config import MemGPTConfig, AgentConfig
|
||||
@@ -72,6 +69,7 @@ def test_storage(storage_connector, table_type):
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
if storage_connector == "lancedb":
|
||||
# TODO: complete lancedb implementation
|
||||
if not os.getenv("LANCEDB_TEST_URL"):
|
||||
print("Skipping test, missing LanceDB URI")
|
||||
return
|
||||
|
||||
Reference in New Issue
Block a user