fix: various breaking bugs with local LLM implementation and postgres docker. (#1355)

This commit is contained in:
madgrizzle
2024-05-12 14:53:46 -04:00
committed by GitHub
parent 9c457bdc77
commit e9c9513f84
7 changed files with 55 additions and 11 deletions

0
db/run_postgres.sh Normal file → Executable file
View File

View File

@@ -1,6 +1,7 @@
import base64
import os
import uuid
from datetime import datetime
from typing import Dict, Iterator, List, Optional
import numpy as np
@@ -379,7 +380,7 @@ class SQLStorageConnector(StorageConnector):
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
return unique_data_sources
def query_date(self, start_date, end_date, offset=0, limit=None):
def query_date(self, start_date, end_date, limit=None, offset=0):
filters = self.get_filters({})
with self.session_maker() as session:
query = (
@@ -387,6 +388,8 @@ class SQLStorageConnector(StorageConnector):
.filter(*filters)
.filter(self.db_model.created_at >= start_date)
.filter(self.db_model.created_at <= end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
@@ -394,7 +397,7 @@ class SQLStorageConnector(StorageConnector):
results = query.all()
return [result.to_record() for result in results]
def query_text(self, query, offset=0, limit=None):
def query_text(self, query, limit=None, offset=0):
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
filters = self.get_filters({})
with self.session_maker() as session:
@@ -402,6 +405,8 @@ class SQLStorageConnector(StorageConnector):
session.query(self.db_model)
.filter(*filters)
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
@@ -527,6 +532,30 @@ class PostgresStorageConnector(SQLStorageConnector):
# Commit the changes to the database
session.commit()
def str_to_datetime(self, str_date):
val = str_date.split("-")
_datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
return _datetime
def query_date(self, start_date, end_date, limit=None, offset=0):
filters = self.get_filters({})
_start_date = self.str_to_datetime(start_date)
_end_date = self.str_to_datetime(end_date)
with self.session_maker() as session:
query = (
session.query(self.db_model)
.filter(*filters)
.filter(self.db_model.created_at >= _start_date)
.filter(self.db_model.created_at <= _end_date)
.filter(self.db_model.role != "system")
.filter(self.db_model.role != "tool")
.offset(offset)
)
if limit:
query = query.limit(limit)
results = query.all()
return [result.to_record() for result in results]
class SQLLiteStorageConnector(SQLStorageConnector):
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):

View File

@@ -171,7 +171,6 @@ class MemGPTConfig:
"config_path": config_path,
"memgpt_version": get_field(config, "version", "memgpt_version"),
}
# Don't include null values
config_dict = {k: v for k, v in config_dict.items() if v is not None}

View File

@@ -244,6 +244,11 @@ class Message(Record):
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
)
def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
result_json = self.to_openai_dict()
search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}}
return search_result_json
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
"""Go from Message class to ChatCompletion message object"""

View File

@@ -327,12 +327,12 @@ class BaseRecallMemory(RecallMemory):
def text_search(self, query_string, count=None, start=None):
results = self.storage.query_text(query_string, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def date_search(self, start_date, end_date, count=None, start=None):
results = self.storage.query_date(start_date, end_date, count, start)
results_json = [message.to_openai_dict() for message in results]
results_json = [message.to_openai_dict_search_results() for message in results]
return results_json, len(results)
def __repr__(self) -> str:

View File

@@ -216,7 +216,7 @@ class SyncServer(LockingServer):
# Update storage URI to match passed in settings
# TODO: very hack, fix in the future
for memory_type in ("archival", "recall", "metadata"):
if settings.memgpt_pg_uri:
if settings.memgpt_pg_uri_no_default:
# override with env
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
self.config.save()

View File

@@ -7,16 +7,27 @@ class Settings(BaseSettings):
model_config = SettingsConfigDict(env_prefix="memgpt_")
server_pass: Optional[str] = None
pg_db: Optional[str] = "memgpt"
pg_user: Optional[str] = "memgpt"
pg_password: Optional[str] = "memgpt"
pg_host: Optional[str] = "localhost"
pg_port: Optional[int] = 5432
pg_db: Optional[str] = None
pg_user: Optional[str] = None
pg_password: Optional[str] = None
pg_host: Optional[str] = None
pg_port: Optional[int] = None
pg_uri: Optional[str] = None # option to specifiy full uri
cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"]
@property
def memgpt_pg_uri(self) -> str:
if self.pg_uri:
return self.pg_uri
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port:
return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}"
else:
return f"postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt"
# add this property to avoid being returned the default
# reference: https://github.com/cpacker/MemGPT/issues/1362
@property
def memgpt_pt_uri_no_default(self) -> str:
if self.pg_uri:
return self.pg_uri
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port: