diff --git a/db/run_postgres.sh b/db/run_postgres.sh old mode 100644 new mode 100755 diff --git a/memgpt/agent_store/db.py b/memgpt/agent_store/db.py index 8584f46d..52146db4 100644 --- a/memgpt/agent_store/db.py +++ b/memgpt/agent_store/db.py @@ -1,6 +1,7 @@ import base64 import os import uuid +from datetime import datetime from typing import Dict, Iterator, List, Optional import numpy as np @@ -379,7 +380,7 @@ class SQLStorageConnector(StorageConnector): unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all() return unique_data_sources - def query_date(self, start_date, end_date, offset=0, limit=None): + def query_date(self, start_date, end_date, limit=None, offset=0): filters = self.get_filters({}) with self.session_maker() as session: query = ( @@ -387,6 +388,8 @@ class SQLStorageConnector(StorageConnector): .filter(*filters) .filter(self.db_model.created_at >= start_date) .filter(self.db_model.created_at <= end_date) + .filter(self.db_model.role != "system") + .filter(self.db_model.role != "tool") .offset(offset) ) if limit: @@ -394,7 +397,7 @@ class SQLStorageConnector(StorageConnector): results = query.all() return [result.to_record() for result in results] - def query_text(self, query, offset=0, limit=None): + def query_text(self, query, limit=None, offset=0): # todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204 filters = self.get_filters({}) with self.session_maker() as session: @@ -402,6 +405,8 @@ class SQLStorageConnector(StorageConnector): session.query(self.db_model) .filter(*filters) .filter(func.lower(self.db_model.text).contains(func.lower(query))) + .filter(self.db_model.role != "system") + .filter(self.db_model.role != "tool") .offset(offset) ) if limit: @@ -527,6 +532,30 @@ class PostgresStorageConnector(SQLStorageConnector): # Commit the changes to the database session.commit() + def str_to_datetime(self, str_date): + val = str_date.split("-") + _datetime = datetime(int(val[0]), int(val[1]), int(val[2])) + return _datetime + + def query_date(self, start_date, end_date, limit=None, offset=0): + filters = self.get_filters({}) + _start_date = self.str_to_datetime(start_date) + _end_date = self.str_to_datetime(end_date) + with self.session_maker() as session: + query = ( + session.query(self.db_model) + .filter(*filters) + .filter(self.db_model.created_at >= _start_date) + .filter(self.db_model.created_at <= _end_date) + .filter(self.db_model.role != "system") + .filter(self.db_model.role != "tool") + .offset(offset) + ) + if limit: + query = query.limit(limit) + results = query.all() + return [result.to_record() for result in results] + class SQLLiteStorageConnector(SQLStorageConnector): def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None): diff --git a/memgpt/config.py b/memgpt/config.py index 0db59ab2..30e4e8df 100644 --- a/memgpt/config.py +++ b/memgpt/config.py @@ -171,7 +171,6 @@ class MemGPTConfig: "config_path": config_path, "memgpt_version": get_field(config, "version", "memgpt_version"), } - # Don't include null values config_dict = {k: v for k, v in config_dict.items() if v is not None} diff --git a/memgpt/data_types.py b/memgpt/data_types.py index 9e40be63..682270e3 100644 --- a/memgpt/data_types.py +++ b/memgpt/data_types.py @@ -244,6 +244,11 @@ class Message(Record): tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None, ) + def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: + result_json = self.to_openai_dict() + search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}} + return search_result_json + def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict: """Go from Message class to ChatCompletion message object""" diff --git a/memgpt/memory.py b/memgpt/memory.py index b07b0263..a3f6b3f9 100644 --- a/memgpt/memory.py +++ b/memgpt/memory.py @@ -327,12 +327,12 @@ class BaseRecallMemory(RecallMemory): def text_search(self, query_string, count=None, start=None): results = self.storage.query_text(query_string, count, start) - results_json = [message.to_openai_dict() for message in results] + results_json = [message.to_openai_dict_search_results() for message in results] return results_json, len(results) def date_search(self, start_date, end_date, count=None, start=None): results = self.storage.query_date(start_date, end_date, count, start) - results_json = [message.to_openai_dict() for message in results] + results_json = [message.to_openai_dict_search_results() for message in results] return results_json, len(results) def __repr__(self) -> str: diff --git a/memgpt/server/server.py b/memgpt/server/server.py index cda83e03..00c776d4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -216,7 +216,7 @@ class SyncServer(LockingServer): # Update storage URI to match passed in settings # TODO: very hack, fix in the future for memory_type in ("archival", "recall", "metadata"): - if settings.memgpt_pg_uri: + if settings.memgpt_pg_uri_no_default: # override with env setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri) self.config.save() diff --git a/memgpt/settings.py b/memgpt/settings.py index 8b3da82a..6fceff37 100644 --- a/memgpt/settings.py +++ b/memgpt/settings.py @@ -7,16 +7,27 @@ class Settings(BaseSettings): model_config = SettingsConfigDict(env_prefix="memgpt_") server_pass: Optional[str] = None - pg_db: Optional[str] = "memgpt" - pg_user: Optional[str] = "memgpt" - pg_password: Optional[str] = "memgpt" - pg_host: Optional[str] = "localhost" - pg_port: Optional[int] = 5432 + pg_db: Optional[str] = None + pg_user: Optional[str] = None + pg_password: Optional[str] = None + pg_host: Optional[str] = None + pg_port: Optional[int] = None pg_uri: Optional[str] = None # option to specifiy full uri cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"] @property def memgpt_pg_uri(self) -> str: + if self.pg_uri: + return self.pg_uri + elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: + return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}" + else: + return f"postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt" + + # add this property to avoid being returned the default + # reference: https://github.com/cpacker/MemGPT/issues/1362 + @property + def memgpt_pt_uri_no_default(self) -> str: if self.pg_uri: return self.pg_uri elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: