fix: various breaking bugs with local LLM implementation and postgres docker. (#1355)
This commit is contained in:
0
db/run_postgres.sh
Normal file → Executable file
0
db/run_postgres.sh
Normal file → Executable file
@@ -1,6 +1,7 @@
|
||||
import base64
|
||||
import os
|
||||
import uuid
|
||||
from datetime import datetime
|
||||
from typing import Dict, Iterator, List, Optional
|
||||
|
||||
import numpy as np
|
||||
@@ -379,7 +380,7 @@ class SQLStorageConnector(StorageConnector):
|
||||
unique_data_sources = session.query(self.db_model.data_source).filter(*self.filters).distinct().all()
|
||||
return unique_data_sources
|
||||
|
||||
def query_date(self, start_date, end_date, offset=0, limit=None):
|
||||
def query_date(self, start_date, end_date, limit=None, offset=0):
|
||||
filters = self.get_filters({})
|
||||
with self.session_maker() as session:
|
||||
query = (
|
||||
@@ -387,6 +388,8 @@ class SQLStorageConnector(StorageConnector):
|
||||
.filter(*filters)
|
||||
.filter(self.db_model.created_at >= start_date)
|
||||
.filter(self.db_model.created_at <= end_date)
|
||||
.filter(self.db_model.role != "system")
|
||||
.filter(self.db_model.role != "tool")
|
||||
.offset(offset)
|
||||
)
|
||||
if limit:
|
||||
@@ -394,7 +397,7 @@ class SQLStorageConnector(StorageConnector):
|
||||
results = query.all()
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
def query_text(self, query, offset=0, limit=None):
|
||||
def query_text(self, query, limit=None, offset=0):
|
||||
# todo: make fuzz https://stackoverflow.com/questions/42388956/create-a-full-text-search-index-with-sqlalchemy-on-postgresql/42390204#42390204
|
||||
filters = self.get_filters({})
|
||||
with self.session_maker() as session:
|
||||
@@ -402,6 +405,8 @@ class SQLStorageConnector(StorageConnector):
|
||||
session.query(self.db_model)
|
||||
.filter(*filters)
|
||||
.filter(func.lower(self.db_model.text).contains(func.lower(query)))
|
||||
.filter(self.db_model.role != "system")
|
||||
.filter(self.db_model.role != "tool")
|
||||
.offset(offset)
|
||||
)
|
||||
if limit:
|
||||
@@ -527,6 +532,30 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
# Commit the changes to the database
|
||||
session.commit()
|
||||
|
||||
def str_to_datetime(self, str_date):
|
||||
val = str_date.split("-")
|
||||
_datetime = datetime(int(val[0]), int(val[1]), int(val[2]))
|
||||
return _datetime
|
||||
|
||||
def query_date(self, start_date, end_date, limit=None, offset=0):
|
||||
filters = self.get_filters({})
|
||||
_start_date = self.str_to_datetime(start_date)
|
||||
_end_date = self.str_to_datetime(end_date)
|
||||
with self.session_maker() as session:
|
||||
query = (
|
||||
session.query(self.db_model)
|
||||
.filter(*filters)
|
||||
.filter(self.db_model.created_at >= _start_date)
|
||||
.filter(self.db_model.created_at <= _end_date)
|
||||
.filter(self.db_model.role != "system")
|
||||
.filter(self.db_model.role != "tool")
|
||||
.offset(offset)
|
||||
)
|
||||
if limit:
|
||||
query = query.limit(limit)
|
||||
results = query.all()
|
||||
return [result.to_record() for result in results]
|
||||
|
||||
|
||||
class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
def __init__(self, table_type: str, config: MemGPTConfig, user_id, agent_id=None):
|
||||
|
||||
@@ -171,7 +171,6 @@ class MemGPTConfig:
|
||||
"config_path": config_path,
|
||||
"memgpt_version": get_field(config, "version", "memgpt_version"),
|
||||
}
|
||||
|
||||
# Don't include null values
|
||||
config_dict = {k: v for k, v in config_dict.items() if v is not None}
|
||||
|
||||
|
||||
@@ -244,6 +244,11 @@ class Message(Record):
|
||||
tool_call_id=openai_message_dict["tool_call_id"] if "tool_call_id" in openai_message_dict else None,
|
||||
)
|
||||
|
||||
def to_openai_dict_search_results(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
|
||||
result_json = self.to_openai_dict()
|
||||
search_result_json = {"timestamp": self.created_at, "message": {"content": result_json["content"], "role": result_json["role"]}}
|
||||
return search_result_json
|
||||
|
||||
def to_openai_dict(self, max_tool_id_length=TOOL_CALL_ID_MAX_LEN) -> dict:
|
||||
"""Go from Message class to ChatCompletion message object"""
|
||||
|
||||
|
||||
@@ -327,12 +327,12 @@ class BaseRecallMemory(RecallMemory):
|
||||
|
||||
def text_search(self, query_string, count=None, start=None):
|
||||
results = self.storage.query_text(query_string, count, start)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
results_json = [message.to_openai_dict_search_results() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def date_search(self, start_date, end_date, count=None, start=None):
|
||||
results = self.storage.query_date(start_date, end_date, count, start)
|
||||
results_json = [message.to_openai_dict() for message in results]
|
||||
results_json = [message.to_openai_dict_search_results() for message in results]
|
||||
return results_json, len(results)
|
||||
|
||||
def __repr__(self) -> str:
|
||||
|
||||
@@ -216,7 +216,7 @@ class SyncServer(LockingServer):
|
||||
# Update storage URI to match passed in settings
|
||||
# TODO: very hack, fix in the future
|
||||
for memory_type in ("archival", "recall", "metadata"):
|
||||
if settings.memgpt_pg_uri:
|
||||
if settings.memgpt_pg_uri_no_default:
|
||||
# override with env
|
||||
setattr(self.config, f"{memory_type}_storage_uri", settings.memgpt_pg_uri)
|
||||
self.config.save()
|
||||
|
||||
@@ -7,16 +7,27 @@ class Settings(BaseSettings):
|
||||
model_config = SettingsConfigDict(env_prefix="memgpt_")
|
||||
|
||||
server_pass: Optional[str] = None
|
||||
pg_db: Optional[str] = "memgpt"
|
||||
pg_user: Optional[str] = "memgpt"
|
||||
pg_password: Optional[str] = "memgpt"
|
||||
pg_host: Optional[str] = "localhost"
|
||||
pg_port: Optional[int] = 5432
|
||||
pg_db: Optional[str] = None
|
||||
pg_user: Optional[str] = None
|
||||
pg_password: Optional[str] = None
|
||||
pg_host: Optional[str] = None
|
||||
pg_port: Optional[int] = None
|
||||
pg_uri: Optional[str] = None # option to specifiy full uri
|
||||
cors_origins: Optional[list] = ["http://memgpt.localhost", "http://localhost:8283", "http://localhost:8083"]
|
||||
|
||||
@property
|
||||
def memgpt_pg_uri(self) -> str:
|
||||
if self.pg_uri:
|
||||
return self.pg_uri
|
||||
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port:
|
||||
return f"postgresql+pg8000://{self.pg_user}:{self.pg_password}@{self.pg_host}:{self.pg_port}/{self.pg_db}"
|
||||
else:
|
||||
return f"postgresql+pg8000://memgpt:memgpt@localhost:5432/memgpt"
|
||||
|
||||
# add this property to avoid being returned the default
|
||||
# reference: https://github.com/cpacker/MemGPT/issues/1362
|
||||
@property
|
||||
def memgpt_pt_uri_no_default(self) -> str:
|
||||
if self.pg_uri:
|
||||
return self.pg_uri
|
||||
elif self.pg_db and self.pg_user and self.pg_password and self.pg_host and self.pg_port:
|
||||
|
||||
Reference in New Issue
Block a user