feat: refactor CoreMemory to support generalized memory fields and memory editing functions (#1479)

Co-authored-by: cpacker <packercharles@gmail.com>
Co-authored-by: Maximilian-Winter <maximilian.winter.91@gmail.com>
This commit is contained in:
Sarah Wooders
2024-07-01 11:50:57 -07:00
committed by GitHub
parent 545bc5348e
commit 9b15cbef39
32 changed files with 805 additions and 1243 deletions

View File

@@ -56,7 +56,7 @@ jobs:
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
run: |
pipx install poetry==1.8.2
poetry install -E dev
poetry install -E dev -E postgres
poetry run pytest -s tests/test_client.py
poetry run pytest -s tests/test_concurrent_connections.py

View File

@@ -3,7 +3,6 @@ import inspect
import json
import traceback
import uuid
from pathlib import Path
from typing import List, Optional, Tuple, Union, cast
from tqdm import tqdm
@@ -11,8 +10,6 @@ from tqdm import tqdm
from memgpt.agent_store.storage import StorageConnector
from memgpt.constants import (
CLI_WARNING_PREFIX,
CORE_MEMORY_HUMAN_CHAR_LIMIT,
CORE_MEMORY_PERSONA_CHAR_LIMIT,
FIRST_MESSAGE_ATTEMPTS,
JSON_ENSURE_ASCII,
JSON_LOADS_STRICT,
@@ -24,9 +21,7 @@ from memgpt.constants import (
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
from memgpt.interface import AgentInterface
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
from memgpt.memory import ArchivalMemory
from memgpt.memory import CoreMemory as InContextMemory
from memgpt.memory import RecallMemory, summarize_messages
from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
from memgpt.metadata import MetadataStore
from memgpt.models import chat_completion_response
from memgpt.models.pydantic_models import ToolModel
@@ -41,7 +36,6 @@ from memgpt.utils import (
count_tokens,
create_uuid_from_string,
get_local_time,
get_schema_diff,
get_tool_call_id,
get_utc_time,
is_utc_datetime,
@@ -53,81 +47,17 @@ from memgpt.utils import (
)
from .errors import LLMError
from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets
def link_functions(function_schemas: list):
"""Link function definitions to list of function schemas"""
# need to dynamically link the functions
# the saved agent.functions will just have the schemas, but we need to
# go through the functions library and pull the respective python functions
# Available functions is a mapping from:
# function_name -> {
# json_schema: schema
# python_function: function
# }
# agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create)
# [{'name': ..., 'description': ...}, {...}]
available_functions = load_all_function_sets()
linked_function_set = {}
for f_schema in function_schemas:
# Attempt to find the function in the existing function library
f_name = f_schema.get("name")
if f_name is None:
raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}")
linked_function = available_functions.get(f_name)
if linked_function is None:
# raise ValueError(
# f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
# )
print(
f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
)
continue
# Once we find a matching function, make sure the schema is identical
if json.dumps(f_schema, ensure_ascii=JSON_ENSURE_ASCII) != json.dumps(
linked_function["json_schema"], ensure_ascii=JSON_ENSURE_ASCII
):
# error_message = (
# f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different."
# + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
# + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
# )
schema_diff = get_schema_diff(f_schema, linked_function["json_schema"])
error_message = (
f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n"
+ "".join(schema_diff)
)
# NOTE to handle old configs, instead of erroring here let's just warn
# raise ValueError(error_message)
printd(error_message)
linked_function_set[f_name] = linked_function
return linked_function_set
def initialize_memory(ai_notes: Union[str, None], human_notes: Union[str, None]):
if ai_notes is None:
raise ValueError(ai_notes)
if human_notes is None:
raise ValueError(human_notes)
memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
memory.edit_persona(ai_notes)
memory.edit_human(human_notes)
return memory
def construct_system_with_memory(
system: str,
memory: InContextMemory,
memory: BaseMemory,
memory_edit_timestamp: str,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
include_char_count: bool = True,
):
# TODO: modify this to be generalized
full_system_message = "\n".join(
[
system,
@@ -136,12 +66,13 @@ def construct_system_with_memory(
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
f'<persona characters="{len(memory.persona)}/{memory.persona_char_limit}">' if include_char_count else "<persona>",
memory.persona,
"</persona>",
f'<human characters="{len(memory.human)}/{memory.human_char_limit}">' if include_char_count else "<human>",
memory.human,
"</human>",
str(memory),
# f'<persona characters="{len(memory.persona)}/{memory.persona_char_limit}">' if include_char_count else "<persona>",
# memory.persona,
# "</persona>",
# f'<human characters="{len(memory.human)}/{memory.human_char_limit}">' if include_char_count else "<human>",
# memory.human,
# "</human>",
]
)
return full_system_message
@@ -150,7 +81,7 @@ def construct_system_with_memory(
def initialize_message_sequence(
model: str,
system: str,
memory: InContextMemory,
memory: BaseMemory,
archival_memory: Optional[ArchivalMemory] = None,
recall_memory: Optional[RecallMemory] = None,
memory_edit_timestamp: Optional[str] = None,
@@ -195,13 +126,14 @@ class Agent(object):
# agents can be created from providing agent_state
agent_state: AgentState,
tools: List[ToolModel],
# memory: BaseMemory,
# extras
messages_total: Optional[int] = None, # TODO remove?
first_message_verify_mono: bool = True, # TODO move to config?
):
# tools
for tool in tools:
assert tool, f"Tool is None - must be error in querying tool from DB"
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
for tool_name in agent_state.tools:
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
@@ -230,13 +162,8 @@ class Agent(object):
self.system = self.agent_state.system
# Initialize the memory object
# TODO: support more general memory types
if "persona" not in self.agent_state.state: # TODO: remove
raise ValueError(f"'persona' not found in provided AgentState")
if "human" not in self.agent_state.state: # TODO: remove
raise ValueError(f"'human' not found in provided AgentState")
self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"])
printd("INITIALIZED MEMORY", self.memory.persona, self.memory.human)
self.memory = BaseMemory.load(self.agent_state.state["memory"])
printd("Initialized memory object", self.memory)
# Interface must implement:
# - internal_monologue
@@ -285,7 +212,7 @@ class Agent(object):
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
else:
# print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
init_messages = initialize_message_sequence(
self.model,
self.system,
@@ -311,12 +238,10 @@ class Agent(object):
# Keep track of the total number of messages throughout all time
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
# self.messages_total_init = self.messages_total
self.messages_total_init = len(self._messages) - 1
printd(f"Agent initialized, self.messages_total={self.messages_total}")
# Create the agent in the DB
# self.save()
self.update_state()
@property
@@ -609,6 +534,10 @@ class Agent(object):
heartbeat_request = False
function_failed = False
# rebuild memory
# TODO: @charles please check this
self.rebuild_memory()
return messages, heartbeat_request, function_failed
def step(
@@ -770,6 +699,10 @@ class Agent(object):
self._append_to_messages(all_new_messages)
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
# update state after each step
self.update_state()
return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage
except Exception as e:
@@ -913,6 +846,14 @@ class Agent(object):
def rebuild_memory(self):
"""Rebuilds the system message with the latest memory object"""
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
# NOTE: This is a hacky way to check if the memory has changed
memory_repr = str(self.memory)
if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
printd(f"Memory has not changed, not rebuilding system")
return
# update memory (TODO: potentially update recall/archival stats seperately)
new_system_message = initialize_message_sequence(
self.model,
self.system,
@@ -922,136 +863,85 @@ class Agent(object):
)[0]
diff = united_diff(curr_system_message["content"], new_system_message["content"])
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
if len(diff) > 0: # there was a diff
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
# Swap the system message out
self._swap_system_message(
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
# Swap the system message out (only if there is a diff)
self._swap_system_message(
Message.dict_to_message(
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
)
)
assert self.messages[0]["content"] == new_system_message["content"], (
self.messages[0]["content"],
new_system_message["content"],
)
)
# def to_agent_state(self) -> AgentState:
# # The state may have change since the last time we wrote it
# updated_state = {
# "persona": self.memory.persona,
# "human": self.memory.human,
# "system": self.system,
# "functions": self.functions,
# "messages": [str(msg.id) for msg in self._messages],
# }
# agent_state = AgentState(
# name=self.agent_state.name,
# user_id=self.agent_state.user_id,
# persona=self.agent_state.persona,
# human=self.agent_state.human,
# llm_config=self.agent_state.llm_config,
# embedding_config=self.agent_state.embedding_config,
# preset=self.agent_state.preset,
# id=self.agent_state.id,
# created_at=self.agent_state.created_at,
# state=updated_state,
# )
# return agent_state
def add_function(self, function_name: str) -> str:
if function_name in self.functions_python.keys():
msg = f"Function {function_name} already loaded"
printd(msg)
return msg
# TODO: refactor
raise NotImplementedError
# if function_name in self.functions_python.keys():
# msg = f"Function {function_name} already loaded"
# printd(msg)
# return msg
available_functions = load_all_function_sets()
if function_name not in available_functions.keys():
raise ValueError(f"Function {function_name} not found in function library")
# available_functions = load_all_function_sets()
# if function_name not in available_functions.keys():
# raise ValueError(f"Function {function_name} not found in function library")
self.functions.append(available_functions[function_name]["json_schema"])
self.functions_python[function_name] = available_functions[function_name]["python_function"]
# self.functions.append(available_functions[function_name]["json_schema"])
# self.functions_python[function_name] = available_functions[function_name]["python_function"]
msg = f"Added function {function_name}"
# self.save()
self.update_state()
printd(msg)
return msg
# msg = f"Added function {function_name}"
## self.save()
# self.update_state()
# printd(msg)
# return msg
def remove_function(self, function_name: str) -> str:
if function_name not in self.functions_python.keys():
msg = f"Function {function_name} not loaded, ignoring"
printd(msg)
return msg
# TODO: refactor
raise NotImplementedError
# if function_name not in self.functions_python.keys():
# msg = f"Function {function_name} not loaded, ignoring"
# printd(msg)
# return msg
# only allow removal of user defined functions
user_func_path = Path(USER_FUNCTIONS_DIR)
func_path = Path(inspect.getfile(self.functions_python[function_name]))
is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
## only allow removal of user defined functions
# user_func_path = Path(USER_FUNCTIONS_DIR)
# func_path = Path(inspect.getfile(self.functions_python[function_name]))
# is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
if not is_subpath:
raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
# if not is_subpath:
# raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
self.functions_python.pop(function_name)
# self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
# self.functions_python.pop(function_name)
msg = f"Removed function {function_name}"
# self.save()
self.update_state()
printd(msg)
return msg
# def save(self):
# """Save agent state locally"""
# new_agent_state = self.to_agent_state()
# # without this, even after Agent.__init__, agent.config.state["messages"] will be None
# self.agent_state = new_agent_state
# # Check if we need to create the agent
# if not self.ms.get_agent(agent_id=new_agent_state.id, user_id=new_agent_state.user_id, agent_name=new_agent_state.name):
# # print(f"Agent.save {new_agent_state.id} :: agent does not exist, creating...")
# self.ms.create_agent(agent=new_agent_state)
# # Otherwise, we should update the agent
# else:
# # print(f"Agent.save {new_agent_state.id} :: agent already exists, updating...")
# print(f"Agent.save {new_agent_state.id} :: preupdate:\n\tmessages={new_agent_state.state['messages']}")
# self.ms.update_agent(agent=new_agent_state)
# msg = f"Removed function {function_name}"
## self.save()
# self.update_state()
# printd(msg)
# return msg
def update_state(self) -> AgentState:
# updated_state = {
# "persona": self.memory.persona,
# "human": self.memory.human,
# "system": self.system,
# "functions": self.functions,
# "messages": [str(msg.id) for msg in self._messages],
# }
memory = {
"system": self.system,
"persona": self.memory.persona,
"human": self.memory.human,
"memory": self.memory.to_dict(),
"messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
}
# TODO: add this field
metadata = { # TODO
"human_name": self.agent_state.persona,
"persona_name": self.agent_state.human,
}
self.agent_state = AgentState(
name=self.agent_state.name,
user_id=self.agent_state.user_id,
tools=self.agent_state.tools,
system=self.system,
persona=self.agent_state.persona, # TODO: remove (stores persona_name)
human=self.agent_state.human, # TODO: remove (stores human_name)
## "model_state"
llm_config=self.agent_state.llm_config,
embedding_config=self.agent_state.embedding_config,
preset=self.agent_state.preset,
id=self.agent_state.id,
created_at=self.agent_state.created_at,
## "agent_state"
state=memory,
_metadata=self.agent_state._metadata,
)
return self.agent_state

View File

@@ -493,9 +493,7 @@ class PostgresStorageConnector(SQLStorageConnector):
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
@@ -594,9 +592,7 @@ class SQLLiteStorageConnector(SQLStorageConnector):
if isinstance(records[0], Passage):
with self.engine.connect() as conn:
db_records = [vars(record) for record in records]
# print("records", db_records)
stmt = insert(self.db_model.__table__).values(db_records)
# print(stmt)
if exists_ok:
upsert_stmt = stmt.on_conflict_do_update(
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column

View File

@@ -13,16 +13,19 @@ import requests
import typer
import memgpt.utils as utils
from memgpt import create_client
from memgpt.agent import Agent, save_agent
from memgpt.cli.cli_config import configure
from memgpt.config import MemGPTConfig
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
from memgpt.log import get_logger
from memgpt.memory import ChatMemory
from memgpt.metadata import MetadataStore
from memgpt.migrate import migrate_all_agents, migrate_all_sources
from memgpt.server.constants import WS_DEFAULT_PORT
from memgpt.server.server import logger as server_logger
# from memgpt.interface import CLIInterface as interface # for printing to terminal
from memgpt.streaming_interface import (
@@ -391,7 +394,6 @@ def run(
persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None,
agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None,
human: Annotated[Optional[str], typer.Option(help="Specify human")] = None,
preset: Annotated[Optional[str], typer.Option(help="Specify preset")] = None,
# model flags
model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None,
model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None,
@@ -427,8 +429,10 @@ def run(
if debug:
logger.setLevel(logging.DEBUG)
server_logger.setLevel(logging.DEBUG)
else:
logger.setLevel(logging.CRITICAL)
server_logger.setLevel(logging.CRITICAL)
from memgpt.migrate import (
VERSION_CUTOFF,
@@ -511,8 +515,6 @@ def run(
# read user id from config
ms = MetadataStore(config)
user = create_default_user_or_exit(config, ms)
human = human if human else config.human
persona = persona if persona else config.persona
# determine agent to use, if not provided
if not yes and not agent:
@@ -529,6 +531,8 @@ def run(
# create agent config
agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None
human = human if human else config.human
persona = persona if persona else config.persona
if agent and agent_state: # use existing agent
typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN)
# agent_config = AgentConfig.load(agent)
@@ -540,14 +544,6 @@ def run(
# printd("Index path:", agent_config.save_agent_index_dir())
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
# TODO: load prior agent state
if persona and persona != agent_state.persona:
typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing persona {agent_state.persona} with {persona}", fg=typer.colors.YELLOW)
agent_state.persona = persona
# raise ValueError(f"Cannot override {agent_state.name} existing persona {agent_state.persona} with {persona}")
if human and human != agent_state.human:
typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing human {agent_state.human} with {human}", fg=typer.colors.YELLOW)
agent_state.human = human
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
if model and model != agent_state.llm_config.model:
@@ -582,7 +578,12 @@ def run(
# Update the agent with any overrides
ms.update_agent(agent_state)
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
tools = []
for tool_name in agent_state.tools:
tool = ms.get_tool(tool_name, agent_state.user_id)
if tool is None:
typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED)
tools.append(tool)
# create agent
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
@@ -626,45 +627,30 @@ def run(
# create agent
try:
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
client = create_client()
human_obj = ms.get_human(human, user.id)
persona_obj = ms.get_persona(persona, user.id)
if preset_obj is None:
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(user.id, ms)
# try again
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
if preset_obj is None:
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
sys.exit(1)
if human_obj is None:
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
if persona_obj is None:
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
# Overwrite fields in the preset if they were specified
preset_obj.human = ms.get_human(human, user.id).text
preset_obj.persona = ms.get_persona(persona, user.id).text
memory = ChatMemory(human=human_obj.text, persona=persona_obj.text)
metadata = {"human": human_obj.name, "persona": persona_obj.name}
typer.secho(f"-> 🤖 Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
agent_state = AgentState(
# add tools
agent_state = client.create_agent(
name=agent_name,
user_id=user.id,
tools=list([schema["name"] for schema in preset_obj.functions_schema]),
system=preset_obj.system,
llm_config=llm_config,
embedding_config=embedding_config,
human=preset_obj.human,
persona=preset_obj.persona,
preset=preset_obj.name,
state={"messages": None, "persona": preset_obj.persona, "human": preset_obj.human},
llm_config=llm_config,
memory=memory,
metadata=metadata,
)
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
memgpt_agent = Agent(
interface=interface(),

View File

@@ -38,8 +38,7 @@ from memgpt.local_llm.constants import (
)
from memgpt.local_llm.utils import get_available_wrappers
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.presets.presets import create_preset_from_file
from memgpt.models.pydantic_models import PersonaModel
from memgpt.server.utils import shorten_key_middle
app = typer.Typer()
@@ -1085,18 +1084,12 @@ def configure():
else:
ms.create_user(user)
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(user_id, ms)
class ListChoice(str, Enum):
agents = "agents"
humans = "humans"
personas = "personas"
sources = "sources"
presets = "presets"
@app.command()
@@ -1133,7 +1126,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
elif arg == ListChoice.humans:
"""List all humans"""
table.field_names = ["Name", "Text"]
for human in client.list_humans(user_id=user_id):
for human in client.list_humans():
table.add_row([human.name, human.text.replace("\n", "")[:100]])
print(table)
elif arg == ListChoice.personas:
@@ -1170,21 +1163,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
)
print(table)
elif arg == ListChoice.presets:
"""List all available presets"""
table.field_names = ["Name", "Description", "Sources", "Functions"]
for preset in ms.list_presets(user_id=user_id):
sources = ms.get_preset_sources(preset_id=preset.id)
table.add_row(
[
preset.name,
preset.description,
",".join([source.name for source in sources]),
# json.dumps(preset.functions_schema, indent=4)
",\n".join([f["name"] for f in preset.functions_schema]),
]
)
print(table)
else:
raise ValueError(f"Unknown argument {arg}")
@@ -1208,7 +1186,7 @@ def add(
with open(filename, "r", encoding="utf-8") as f:
text = f.read()
if option == "persona":
persona = ms.get_persona(name=name, user_id=user_id)
persona = ms.get_persona(name=name)
if persona:
# config if user wants to overwrite
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
@@ -1220,7 +1198,7 @@ def add(
ms.add_persona(persona)
elif option == "human":
human = client.get_human(name=name, user_id=user_id)
human = client.get_human(name=name)
if human:
# config if user wants to overwrite
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
@@ -1228,11 +1206,7 @@ def add(
human.text = text
client.update_human(human)
else:
human = HumanModel(name=name, text=text, user_id=user_id)
client.add_human(HumanModel(name=name, text=text, user_id=user_id))
elif option == "preset":
assert filename, "Must specify filename for preset"
create_preset_from_file(filename, name, user_id, ms)
human = client.create_human(name=name, human=text)
else:
raise ValueError(f"Unknown kind {option}")
@@ -1281,18 +1255,14 @@ def delete(option: str, name: str):
ms.delete_agent(agent_id=agent.id)
elif option == "human":
human = client.get_human(name=name, user_id=user_id)
human = client.get_human(name=name)
assert human is not None, f"Human {name} does not exist"
client.delete_human(name=name, user_id=user_id)
client.delete_human(name=name)
elif option == "persona":
persona = ms.get_persona(name=name, user_id=user_id)
persona = ms.get_persona(name=name)
assert persona is not None, f"Persona {name} does not exist"
ms.delete_persona(name=name, user_id=user_id)
assert ms.get_persona(name=name, user_id=user_id) is None, f"Persona {name} still exists"
elif option == "preset":
preset = ms.get_preset(name=name, user_id=user_id)
assert preset is not None, f"Preset {name} does not exist"
ms.delete_preset(name=name, user_id=user_id)
ms.delete_persona(name=name)
assert ms.get_persona(name=name) is None, f"Persona {name} still exists"
else:
raise ValueError(f"Option {option} not implemented")

View File

@@ -43,7 +43,6 @@ class Admin:
def create_key(self, user_id: uuid.UUID, key_name: str):
payload = {"user_id": str(user_id), "key_name": key_name}
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
print(response.json())
if response.status_code != 200:
raise HTTPError(response.json())
return CreateAPIKeyResponse(**response.json())
@@ -53,7 +52,6 @@ class Admin:
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
if response.status_code != 200:
raise HTTPError(response.json())
print(response.text, response.status_code)
return GetAPIKeysResponse(**response.json()).api_key_list
def delete_key(self, api_key: str):
@@ -114,6 +112,11 @@ class Admin:
source_type = "python"
json_schema["name"]
if "memory" in tags:
# special modifications to memory functions
# self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
source_code = source_code.replace("self.memory", "self.memory.memory")
# create data
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
CreateToolRequest(**data) # validate

View File

@@ -6,23 +6,17 @@ from typing import Dict, List, Optional, Tuple, Union
import requests
from memgpt.config import MemGPTConfig
from memgpt.constants import BASE_TOOLS, DEFAULT_PRESET
from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_sources.connectors import DataConnector
from memgpt.data_types import (
AgentState,
EmbeddingConfig,
LLMConfig,
Preset,
Source,
User,
)
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, Preset, Source
from memgpt.functions.functions import parse_source_code
from memgpt.functions.schema_generator import generate_schema
from memgpt.metadata import MetadataStore
from memgpt.memory import BaseMemory, ChatMemory, get_memory_functions
from memgpt.models.pydantic_models import (
HumanModel,
JobModel,
JobStatus,
LLMConfigModel,
PersonaModel,
PresetModel,
SourceModel,
@@ -32,6 +26,7 @@ from memgpt.server.rest_api.agents.command import CommandResponse
from memgpt.server.rest_api.agents.config import GetAgentResponse
from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse
from memgpt.server.rest_api.agents.memory import (
ArchivalMemoryObject,
GetAgentArchivalMemoryResponse,
GetAgentMemoryResponse,
InsertAgentArchivalMemoryResponse,
@@ -56,6 +51,7 @@ from memgpt.server.rest_api.sources.index import ListSourcesResponse
# import pydantic response objects from memgpt.server.rest_api
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
from memgpt.server.server import SyncServer
from memgpt.utils import get_human_text
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
@@ -242,8 +238,6 @@ class RESTClient(AbstractClient):
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers)
print(response.text, response.status_code)
print(response)
if response.status_code == 404:
# not found error
return False
@@ -252,17 +246,24 @@ class RESTClient(AbstractClient):
else:
raise ValueError(f"Failed to check if agent exists: {response.text}")
def get_tool(self, tool_name: str):
response = requests.get(f"{self.base_url}/api/tools/{tool_name}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get tool: {response.text}")
return ToolModel(**response.json())
def create_agent(
self,
name: Optional[str] = None,
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
persona: Optional[str] = None,
human: Optional[str] = None,
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
# tools
tools: Optional[List[str]] = None,
include_base_tools: Optional[bool] = True,
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
) -> AgentState:
"""
Create an agent
@@ -274,7 +275,6 @@ class RESTClient(AbstractClient):
Returns:
agent_state (AgentState): State of the the created agent.
"""
if embedding_config or llm_config:
raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API")
@@ -286,14 +286,22 @@ class RESTClient(AbstractClient):
if include_base_tools:
tool_names += BASE_TOOLS
# add memory tools
memory_functions = get_memory_functions(memory)
for func_name, func in memory_functions.items():
tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"], update=True)
tool_names.append(tool.name)
# TODO: distinguish between name and objects
# TODO: add metadata
payload = {
"config": {
"name": name,
"preset": preset,
"persona": persona,
"human": human,
"persona": memory.memory["persona"].value,
"human": memory.memory["human"].value,
"function_names": tool_names,
"metadata": metadata,
}
}
response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers)
@@ -322,14 +330,12 @@ class RESTClient(AbstractClient):
id=response.agent_state.id,
name=response.agent_state.name,
user_id=response.agent_state.user_id,
preset=response.agent_state.preset,
persona=response.agent_state.persona,
human=response.agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=response.agent_state.state,
system=response.agent_state.system,
tools=response.agent_state.tools,
_metadata=response.agent_state.metadata,
# load datetime from timestampe
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
)
@@ -352,25 +358,8 @@ class RESTClient(AbstractClient):
response_obj = GetAgentResponse(**response.json())
return self.get_agent_response_to_state(response_obj)
## presets
# def create_preset(self, preset: Preset) -> CreatePresetResponse:
# # TODO should the arg type here be PresetModel, not Preset?
# payload = CreatePresetsRequest(
# id=str(preset.id),
# name=preset.name,
# description=preset.description,
# system=preset.system,
# persona=preset.persona,
# human=preset.human,
# persona_name=preset.persona_name,
# human_name=preset.human_name,
# functions_schema=preset.functions_schema,
# )
# response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
# assert response.status_code == 200, f"Failed to create preset: {response.text}"
# return CreatePresetResponse(**response.json())
def get_preset(self, name: str) -> PresetModel:
# TODO: remove
response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers)
assert response.status_code == 200, f"Failed to get preset: {response.text}"
return PresetModel(**response.json())
@@ -385,6 +374,7 @@ class RESTClient(AbstractClient):
tools: Optional[List[ToolModel]] = None,
default_tools: bool = True,
) -> PresetModel:
# TODO: remove
"""Create an agent preset
:param name: Name of the preset
@@ -406,7 +396,6 @@ class RESTClient(AbstractClient):
schema = []
if tools:
for tool in tools:
print("CUSOTM TOOL", tool.json_schema)
schema.append(tool.json_schema)
# include default tools
@@ -427,9 +416,6 @@ class RESTClient(AbstractClient):
human_name=human_name,
functions_schema=schema,
)
print(schema)
print(human_name, persona_name, system_name, name)
print(payload.model_dump())
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
assert response.status_code == 200, f"Failed to create preset: {response.text}"
return CreatePresetResponse(**response.json()).preset
@@ -482,7 +468,6 @@ class RESTClient(AbstractClient):
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", json={"content": memory}, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to insert archival memory: {response.text}")
print(response.json())
return InsertAgentArchivalMemoryResponse(**response.json())
def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID):
@@ -518,8 +503,6 @@ class RESTClient(AbstractClient):
response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create human: {response.text}")
print(response.json())
return HumanModel(**response.json())
def list_personas(self) -> ListPersonasResponse:
@@ -531,9 +514,24 @@ class RESTClient(AbstractClient):
response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to create persona: {response.text}")
print(response.json())
return PersonaModel(**response.json())
def get_persona(self, name: str) -> PersonaModel:
response = requests.get(f"{self.base_url}/api/personas/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get persona: {response.text}")
return PersonaModel(**response.json())
def get_human(self, name: str) -> HumanModel:
response = requests.get(f"{self.base_url}/api/humans/{name}", headers=self.headers)
if response.status_code == 404:
return None
elif response.status_code != 200:
raise ValueError(f"Failed to get human: {response.text}")
return HumanModel(**response.json())
# sources
def list_sources(self):
@@ -638,7 +636,7 @@ class RESTClient(AbstractClient):
json_schema["name"]
# create data
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema, "update": update}
try:
CreateToolRequest(**data) # validate data
except Exception as e:
@@ -694,23 +692,12 @@ class LocalClient(AbstractClient):
else:
self.user_id = uuid.UUID(config.anon_clientid)
# create user if does not exist
ms = MetadataStore(config)
self.user = User(id=self.user_id)
if ms.get_user(self.user_id):
# update user
ms.update_user(self.user)
else:
ms.create_user(self.user)
# create preset records in metadata store
from memgpt.presets.presets import add_default_presets
add_default_presets(self.user_id, ms)
self.interface = QueuingInterface(debug=debug)
self.server = SyncServer(default_interface_factory=lambda: self.interface)
# create user if does not exist
self.server.create_user({"id": self.user_id}, exists_ok=True)
# messages
def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse:
self.interface.clear()
@@ -740,14 +727,16 @@ class LocalClient(AbstractClient):
def create_agent(
self,
name: Optional[str] = None,
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
persona: Optional[str] = None,
human: Optional[str] = None,
# model configs
embedding_config: Optional[EmbeddingConfig] = None,
llm_config: Optional[LLMConfig] = None,
# memory
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
# tools
tools: Optional[List[str]] = None,
include_base_tools: Optional[bool] = True,
# metadata
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
) -> AgentState:
if name and self.agent_exists(agent_name=name):
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
@@ -759,23 +748,35 @@ class LocalClient(AbstractClient):
if include_base_tools:
tool_names += BASE_TOOLS
# add memory tools
memory_functions = get_memory_functions(memory)
for func_name, func in memory_functions.items():
tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"])
tool_names.append(tool.name)
self.interface.clear()
# create agent
agent_state = self.server.create_agent(
user_id=self.user_id,
name=name,
preset=preset,
persona=persona,
human=human,
memory=memory,
llm_config=llm_config,
embedding_config=embedding_config,
tools=tool_names,
metadata=metadata,
)
return agent_state
def rename_agent(self, agent_id: uuid.UUID, new_name: str):
# TODO: check valid name
agent_state = self.server.rename_agent(user_id=self.user_id, agent_id=agent_id, new_agent_name=new_name)
return agent_state
def delete_agent(self, agent_id: uuid.UUID):
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
def get_agent_config(self, agent_id: str) -> AgentState:
def get_agent(self, agent_id: uuid.UUID) -> AgentState:
self.interface.clear()
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
@@ -803,6 +804,19 @@ class LocalClient(AbstractClient):
# agent interactions
def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse:
self.interface.clear()
if role == "system":
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message)
elif role == "user":
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
else:
raise ValueError(f"Role {role} not supported")
if self.auto_save:
self.save()
else:
return UserMessageResponse(messages=self.interface.to_list(), usage=usage)
def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]:
self.interface.clear()
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
@@ -820,36 +834,37 @@ class LocalClient(AbstractClient):
# archival memory
def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
_, archival_json_records = self.server.get_agent_archival_cursor(
user_id=self.user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
return archival_json_records
# messages
# humans / personas
def list_humans(self, user_id: uuid.UUID):
return self.server.list_humans(user_id=user_id if user_id else self.user_id)
def create_human(self, name: str, human: str):
return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id))
def get_human(self, name: str, user_id: uuid.UUID):
return self.server.get_human(name=name, user_id=user_id)
def create_persona(self, name: str, persona: str):
return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id))
def add_human(self, human: HumanModel):
return self.server.add_human(human=human)
def list_humans(self):
return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id)
def get_human(self, name: str):
return self.server.get_human(name=name, user_id=self.user_id)
def update_human(self, human: HumanModel):
return self.server.update_human(human=human)
def delete_human(self, name: str, user_id: uuid.UUID):
return self.server.delete_human(name, user_id)
def delete_human(self, name: str):
return self.server.delete_human(name, self.user_id)
def list_personas(self):
return self.server.list_personas(user_id=self.user_id)
def get_persona(self, name: str):
return self.server.get_persona(name=name, user_id=self.user_id)
def update_persona(self, persona: PersonaModel):
return self.server.update_persona(persona=persona)
def delete_persona(self, name: str):
return self.server.delete_persona(name, self.user_id)
# tools
def create_tool(
@@ -879,6 +894,11 @@ class LocalClient(AbstractClient):
source_type = "python"
tool_name = json_schema["name"]
if "memory" in tags:
# special modifications to memory functions
# self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
source_code = source_code.replace("self.memory", "self.memory.memory")
# check if already exists:
existing_tool = self.server.ms.get_tool(tool_name, self.user_id)
if existing_tool:
@@ -924,3 +944,48 @@ class LocalClient(AbstractClient):
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
def get_agent_archival_memory(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
):
self.interface.clear()
# TODO need to add support for non-postgres here
# chroma will throw:
# raise ValueError("Cannot run get_all_cursor with chroma")
_, archival_json_records = self.server.get_agent_archival_cursor(
user_id=self.user_id,
agent_id=agent_id,
after=after,
before=before,
limit=limit,
)
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
def insert_archival_memory(self, agent_id: uuid.UUID, memory: str) -> GetAgentArchivalMemoryResponse:
memory_ids = self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory)
return InsertAgentArchivalMemoryResponse(ids=memory_ids)
def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID):
self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id)
def get_messages(
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
) -> GetAgentMessagesResponse:
self.interface.clear()
[_, messages] = self.server.get_agent_recall_cursor(
user_id=self.user_id, agent_id=agent_id, before=before, limit=limit, reverse=True
)
return GetAgentMessagesResponse(messages=messages)
def list_models(self) -> ListModelsResponse:
llm_config = LLMConfigModel(
model=self.server.server_llm_config.model,
model_endpoint=self.server.server_llm_config.model_endpoint,
model_endpoint_type=self.server.server_llm_config.model_endpoint_type,
model_wrapper=self.server.server_llm_config.model_wrapper,
context_window=self.server.server_llm_config.context_window,
)
return ListModelsResponse(models=[llm_config])

View File

@@ -38,7 +38,7 @@ class MemGPTConfig:
anon_clientid: str = str(uuid.UUID(int=0))
# preset
preset: str = DEFAULT_PRESET
preset: str = DEFAULT_PRESET # TODO: rename to system prompt
# persona parameters
persona: str = DEFAULT_PERSONA

View File

@@ -24,8 +24,6 @@ DEFAULT_PRESET = "memgpt_chat"
# Tools
BASE_TOOLS = [
"send_message",
"core_memory_replace",
"core_memory_append",
"pause_heartbeats",
"conversation_search",
"conversation_search_date",

View File

@@ -763,22 +763,14 @@ class AgentState:
# system prompt
system: str,
# config
persona: str, # the filename where the persona was originally sourced from # TODO: remove
human: str, # the filename where the human was originally sourced from # TODO: remove
llm_config: LLMConfig,
embedding_config: EmbeddingConfig,
preset: str, # TODO: remove
# (in-context) state contains:
# persona: str # the current persona text
# human: str # the current human text
# system: str, # system prompt (not required if initializing with a preset)
# functions: dict, # schema definitions ONLY (function code linked at runtime)
# messages: List[dict], # in-context messages
id: Optional[uuid.UUID] = None,
state: Optional[dict] = None,
created_at: Optional[datetime] = None,
# messages (TODO: implement this)
# _metadata: Optional[dict] = None,
_metadata: Optional[dict] = None,
):
if id is None:
self.id = uuid.uuid4()
@@ -792,11 +784,8 @@ class AgentState:
self.name = name
assert self.name, f"AgentState name must be a non-empty string"
self.user_id = user_id
self.preset = preset
# The INITIAL values of the persona and human
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
self.persona = persona
self.human = human
self.llm_config = llm_config
self.embedding_config = embedding_config
@@ -811,9 +800,10 @@ class AgentState:
# system
self.system = system
assert self.system is not None, f"Must provide system prompt, cannot be None"
# metadata
# self._metadata = _metadata
self._metadata = _metadata
class Source:
@@ -866,6 +856,7 @@ class Token:
class Preset(BaseModel):
# TODO: remove Preset
name: str = Field(..., description="The name of the preset.")
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")

View File

@@ -56,39 +56,6 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
pause_heartbeats.__doc__ = pause_heartbeats_docstring
def core_memory_append(self: Agent, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory.edit_append(name, content)
self.rebuild_memory()
return None
def core_memory_replace(self: Agent, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory.edit_replace(name, old_content, new_content)
self.rebuild_memory()
return None
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
"""
Search prior conversation history using case-insensitive string matching.

View File

@@ -2,7 +2,6 @@ import importlib
import inspect
import os
import sys
import warnings
from textwrap import dedent # remove indentation
from types import ModuleType
@@ -104,97 +103,3 @@ def load_function_file(filepath: str) -> dict:
# load all functions in the module
function_dict = load_function_set(module)
return function_dict
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:
from memgpt.utils import printd
# functions/examples/*.py
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
# List all .py files in the directory (excluding __init__.py)
example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"]
# ~/.memgpt/functions/*.py
# create if missing
if not os.path.exists(USER_FUNCTIONS_DIR):
os.makedirs(USER_FUNCTIONS_DIR)
user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"]
# combine them both (pull from both examples and user-provided)
# all_module_files = example_module_files + user_module_files
# Add user_scripts_dir to sys.path
if USER_FUNCTIONS_DIR not in sys.path:
sys.path.append(USER_FUNCTIONS_DIR)
schemas_and_functions = {}
for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]:
for file in module_files:
tags = []
module_name = file[:-3] # Remove '.py' from filename
if dir_path == USER_FUNCTIONS_DIR:
# For user scripts, adjust the module name appropriately
module_full_path = os.path.join(dir_path, file)
printd(f"Loading user function set from '{module_full_path}'")
try:
spec = importlib.util.spec_from_file_location(module_name, module_full_path)
module = importlib.util.module_from_spec(spec)
spec.loader.exec_module(module)
except ModuleNotFoundError as e:
# Handle missing module imports
missing_package = str(e).split("'")[1] # Extract the name of the missing package
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
printd(
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
)
continue
except SyntaxError as e:
# Handle syntax errors in the module
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
continue
except Exception as e:
# Handle other general exceptions
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
continue
else:
# For built-in scripts, use the existing method
full_module_name = f"memgpt.functions.function_sets.{module_name}"
tags.append(f"memgpt-{module_name}")
try:
module = importlib.import_module(full_module_name)
except Exception as e:
# Handle other general exceptions
printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
continue
try:
# Load the function set
function_set = load_function_set(module)
# Add the metadata tags
for k, v in function_set.items():
# print(function_set)
v["tags"] = tags
schemas_and_functions[module_name] = function_set
except ValueError as e:
err = f"Error loading function set '{module_name}': {e}"
printd(err)
warnings.warn(err)
if merge:
# Put all functions from all sets into the same level dict
merged_functions = {}
for set_name, function_set in schemas_and_functions.items():
for function_name, function_info in function_set.items():
if function_name in merged_functions:
err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'"
if ignore_duplicates:
warnings.warn(err_msg, category=UserWarning, stacklevel=2)
else:
raise ValueError(err_msg)
else:
merged_functions[function_name] = function_info
return merged_functions
else:
# Nested dict where the top level is organized by the function set name
return schemas_and_functions

View File

@@ -3,6 +3,8 @@ import uuid
from abc import ABC, abstractmethod
from typing import List, Optional, Tuple, Union
from pydantic import BaseModel, validator
from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
from memgpt.data_types import AgentState, Message, Passage
from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding
@@ -16,96 +18,214 @@ from memgpt.utils import (
validate_date_format,
)
# from llama_index import Document
# from llama_index.node_parser import SimpleNodeParser
class MemoryModule(BaseModel):
"""Base class for memory modules"""
description: Optional[str] = None
limit: int = 2000
value: Optional[Union[List[str], str]] = None
def __setattr__(self, name, value):
"""Run validation if self.value is updated"""
super().__setattr__(name, value)
if name == "value":
# run validation
self.__class__.validate(self.dict(exclude_unset=True))
@validator("value", always=True)
def check_value_length(cls, v, values):
if v is not None:
# Fetching the limit from the values dictionary
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
# Check if the value exceeds the limit
if isinstance(v, str):
length = len(v)
elif isinstance(v, list):
length = sum(len(item) for item in v)
else:
raise ValueError("Value must be either a string or a list of strings.")
if length > limit:
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
# TODO: add archival memory error?
raise ValueError(error_msg)
return v
def __len__(self):
return len(str(self))
def __str__(self) -> str:
if isinstance(self.value, list):
return ",".join(self.value)
elif isinstance(self.value, str):
return self.value
else:
return ""
class CoreMemory(object):
"""Held in-context inside the system message
class BaseMemory:
Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
This includes the persona information, essential user details,
and any other baseline data you deem necessary for the AI's basic functioning.
"""
def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
self.persona = persona
self.human = human
self.persona_char_limit = persona_char_limit
self.human_char_limit = human_char_limit
# affects the error message the AI will see on overflow inserts
self.archival_memory_exists = archival_memory_exists
def __repr__(self) -> str:
return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
def to_dict(self):
return {
"persona": self.persona,
"human": self.human,
}
def __init__(self):
self.memory = {}
@classmethod
def load(cls, state):
return cls(state["persona"], state["human"])
def load(cls, state: dict):
"""Load memory from dictionary object"""
obj = cls()
for key, value in state.items():
obj.memory[key] = MemoryModule(**value)
return obj
def edit_persona(self, new_persona):
if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
if self.archival_memory_exists:
error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
raise ValueError(error_msg)
def __str__(self) -> str:
"""Representation of the memory in-context"""
section_strs = []
for section, module in self.memory.items():
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
return "\n".join(section_strs)
self.persona = new_persona
return len(self.persona)
def to_dict(self):
"""Convert to dictionary representation"""
return {key: value.dict() for key, value in self.memory.items()}
def edit_human(self, new_human):
if self.human_char_limit and len(new_human) > self.human_char_limit:
error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
if self.archival_memory_exists:
error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
raise ValueError(error_msg)
self.human = new_human
return len(self.human)
class ChatMemory(BaseMemory):
def edit(self, field, content):
if field == "persona":
return self.edit_persona(content)
elif field == "human":
return self.edit_human(content)
else:
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
def __init__(self, persona: str, human: str, limit: int = 2000):
self.memory = {
"persona": MemoryModule(name="persona", value=persona, limit=limit),
"human": MemoryModule(name="human", value=human, limit=limit),
}
def edit_append(self, field, content, sep="\n"):
if field == "persona":
new_content = self.persona + sep + content
return self.edit_persona(new_content)
elif field == "human":
new_content = self.human + sep + content
return self.edit_human(new_content)
else:
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
def core_memory_append(self, name: str, content: str) -> Optional[str]:
"""
Append to the contents of core memory.
def edit_replace(self, field, old_content, new_content):
if len(old_content) == 0:
raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
Args:
name (str): Section of the memory to be edited (persona or human).
content (str): Content to write to the memory. All unicode (including emojis) are supported.
if field == "persona":
if old_content in self.persona:
new_persona = self.persona.replace(old_content, new_content)
return self.edit_persona(new_persona)
else:
raise ValueError("Content not found in persona (make sure to use exact string)")
elif field == "human":
if old_content in self.human:
new_human = self.human.replace(old_content, new_content)
return self.edit_human(new_human)
else:
raise ValueError("Content not found in human (make sure to use exact string)")
else:
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value += "\n" + content
return None
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
"""
Replace the contents of core memory. To delete memories, use an empty string for new_content.
Args:
name (str): Section of the memory to be edited (persona or human).
old_content (str): String to replace. Must be an exact match.
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
Returns:
Optional[str]: None is always returned as this function does not produce a response.
"""
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
return None
def get_memory_functions(cls: BaseMemory) -> List[callable]:
"""Get memory functions for a memory class"""
functions = {}
for func_name in dir(cls):
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
continue
func = getattr(cls, func_name)
if callable(func):
functions[func_name] = func
return functions
# class CoreMemory(object):
# """Held in-context inside the system message
#
# Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
# This includes the persona information, essential user details,
# and any other baseline data you deem necessary for the AI's basic functioning.
# """
#
# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
# self.persona = persona
# self.human = human
# self.persona_char_limit = persona_char_limit
# self.human_char_limit = human_char_limit
#
# # affects the error message the AI will see on overflow inserts
# self.archival_memory_exists = archival_memory_exists
#
# def __repr__(self) -> str:
# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
#
# def to_dict(self):
# return {
# "persona": self.persona,
# "human": self.human,
# }
#
# @classmethod
# def load(cls, state):
# return cls(state["persona"], state["human"])
#
# def edit_persona(self, new_persona):
# if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
# if self.archival_memory_exists:
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
# raise ValueError(error_msg)
#
# self.persona = new_persona
# return len(self.persona)
#
# def edit_human(self, new_human):
# if self.human_char_limit and len(new_human) > self.human_char_limit:
# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
# if self.archival_memory_exists:
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
# raise ValueError(error_msg)
#
# self.human = new_human
# return len(self.human)
#
# def edit(self, field, content):
# if field == "persona":
# return self.edit_persona(content)
# elif field == "human":
# return self.edit_human(content)
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
#
# def edit_append(self, field, content, sep="\n"):
# if field == "persona":
# new_content = self.persona + sep + content
# return self.edit_persona(new_content)
# elif field == "human":
# new_content = self.human + sep + content
# return self.edit_human(new_content)
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
#
# def edit_replace(self, field, old_content, new_content):
# if len(old_content) == 0:
# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
#
# if field == "persona":
# if old_content in self.persona:
# new_persona = self.persona.replace(old_content, new_content)
# return self.edit_persona(new_persona)
# else:
# raise ValueError("Content not found in persona (make sure to use exact string)")
# elif field == "human":
# if old_content in self.human:
# new_human = self.human.replace(old_content, new_content)
# return self.edit_human(new_human)
# else:
# raise ValueError("Content not found in human (make sure to use exact string)")
# else:
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
def _format_summary_history(message_history: List[Message]):

View File

@@ -175,10 +175,7 @@ class AgentModel(Base):
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
user_id = Column(CommonUUID, nullable=False)
name = Column(String, nullable=False)
persona = Column(String)
human = Column(String)
system = Column(String)
preset = Column(String)
created_at = Column(DateTime(timezone=True), server_default=func.now())
# configs
@@ -187,6 +184,7 @@ class AgentModel(Base):
# state
state = Column(JSON)
_metadata = Column(JSON)
# tools
tools = Column(JSON)
@@ -199,15 +197,13 @@ class AgentModel(Base):
id=self.id,
user_id=self.user_id,
name=self.name,
persona=self.persona,
human=self.human,
preset=self.preset,
created_at=self.created_at,
llm_config=self.llm_config,
embedding_config=self.embedding_config,
state=self.state,
tools=self.tools,
system=self.system,
_metadata=self._metadata,
)
@@ -739,17 +735,21 @@ class MetadataStore:
@enforce_types
def add_human(self, human: HumanModel):
with self.session_maker() as session:
if self.get_human(human.name, human.user_id):
raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}")
session.add(human)
session.commit()
@enforce_types
def add_persona(self, persona: PersonaModel):
with self.session_maker() as session:
if self.get_persona(persona.name, persona.user_id):
raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}")
session.add(persona)
session.commit()
@enforce_types
def add_preset(self, preset: PresetModel):
def add_preset(self, preset: PresetModel): # TODO: remove
with self.session_maker() as session:
session.add(preset)
session.commit()

View File

@@ -98,9 +98,6 @@ class AgentStateModel(BaseModel):
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
# preset information
preset: str = Field(..., description="The preset used by the agent.")
persona: str = Field(..., description="The persona used by the agent.")
human: str = Field(..., description="The human used by the agent.")
tools: List[str] = Field(..., description="The tools used by the agent.")
system: str = Field(..., description="The system prompt used by the agent.")
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
@@ -111,6 +108,7 @@ class AgentStateModel(BaseModel):
# agent state
state: Optional[Dict] = Field(None, description="The state of the agent.")
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
class CoreMemory(BaseModel):

View File

@@ -2,23 +2,14 @@ import importlib
import inspect
import os
import uuid
from typing import List
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
from memgpt.data_types import AgentState, Preset
from memgpt.functions.functions import load_all_function_sets, load_function_set
from memgpt.functions.functions import load_function_set
from memgpt.interface import AgentInterface
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
from memgpt.presets.utils import load_all_presets, load_yaml_file
from memgpt.prompts import gpt_system
from memgpt.utils import (
get_human_text,
get_persona_text,
list_human_files,
list_persona_files,
printd,
)
from memgpt.presets.utils import load_all_presets
from memgpt.utils import list_human_files, list_persona_files, printd
available_presets = load_all_presets()
preset_options = list(available_presets.keys())
@@ -88,145 +79,13 @@ def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
printd(f"Human '{name}' already exists for user '{user_id}'")
continue
human = HumanModel(name=name, text=text, user_id=user_id)
print(human, user_id)
ms.add_human(human)
def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: MetadataStore) -> Preset:
preset_config = load_yaml_file(filename)
preset_system_prompt = preset_config["system_prompt"]
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
if ms.get_preset(user_id=user_id, name=name) is not None:
printd(f"Preset '{name}' already exists for user '{user_id}'")
return ms.get_preset(user_id=user_id, name=name)
preset = Preset(
user_id=user_id,
name=name,
system=gpt_system.get_system_text(preset_system_prompt),
persona=get_persona_text(DEFAULT_PERSONA),
human=get_human_text(DEFAULT_HUMAN),
persona_name=DEFAULT_PERSONA,
human_name=DEFAULT_HUMAN,
functions_schema=functions_schema,
)
ms.create_preset(preset)
return preset
def load_preset(preset_name: str, user_id: uuid.UUID):
preset_config = available_presets[preset_name]
preset_system_prompt = preset_config["system_prompt"]
preset_function_set_names = preset_config["functions"]
functions_schema = generate_functions_json(preset_function_set_names)
preset = Preset(
user_id=user_id,
name=preset_name,
system=gpt_system.get_system_text(preset_system_prompt),
persona=get_persona_text(DEFAULT_PERSONA),
persona_name=DEFAULT_PERSONA,
human=get_human_text(DEFAULT_HUMAN),
human_name=DEFAULT_HUMAN,
functions_schema=functions_schema,
)
return preset
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
"""Add the default presets to the metadata store"""
# make sure humans/personas added
add_default_humans_and_personas(user_id=user_id, ms=ms)
# make sure base functions added
# TODO: pull from functions instead
add_default_tools(user_id=user_id, ms=ms)
# add default presets
for preset_name in preset_options:
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
continue
preset = load_preset(preset_name, user_id)
ms.create_preset(preset)
def generate_functions_json(preset_functions: List[str]):
"""
Generate JSON schema for the functions based on what is locally available.
TODO: store function definitions in the DB, instead of locally
"""
# Available functions is a mapping from:
# function_name -> {
# json_schema: schema
# python_function: function
# }
available_functions = load_all_function_sets()
# Filter down the function set based on what the preset requested
preset_function_set = {}
for f_name in preset_functions:
if f_name not in available_functions:
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
preset_function_set[f_name] = available_functions[f_name]
assert len(preset_functions) == len(preset_function_set)
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
printd(f"Available functions:\n", list(preset_function_set.keys()))
return preset_function_set_schemas
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
def create_agent_from_preset(
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
):
"""Initialize a new agent from a preset (combination of system + function)"""
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")
# Input validation
if agent_state.persona is None:
raise ValueError(f"'persona' not specified in AgentState (required)")
if agent_state.human is None:
raise ValueError(f"'human' not specified in AgentState (required)")
if agent_state.preset is None:
raise ValueError(f"'preset' not specified in AgentState (required)")
if not (agent_state.state == {} or agent_state.state is None):
raise ValueError(f"'state' must be uninitialized (empty)")
assert preset is not None, "preset cannot be none"
preset_name = agent_state.preset
assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'"
persona = agent_state.persona
human = agent_state.human
model = agent_state.llm_config.model
from memgpt.agent import Agent
# available_presets = load_all_presets()
# if preset_name not in available_presets:
# raise ValueError(f"Preset '{preset_name}.yaml' not found")
# preset = available_presets[preset_name]
# preset_system_prompt = preset["system_prompt"]
# preset_function_set_names = preset["functions"]
# preset_function_set_schemas = generate_functions_json(preset_function_set_names)
# Override the following in the AgentState:
# persona: str # the current persona text
# human: str # the current human text
# system: str, # system prompt (not required if initializing with a preset)
# functions: dict, # schema definitions ONLY (function code linked at runtime)
# messages: List[dict], # in-context messages
agent_state.state = {
"persona": get_persona_text(persona) if persona_is_file else persona,
"human": get_human_text(human) if human_is_file else human,
"system": preset.system,
"functions": preset.functions_schema,
"messages": None,
}
return Agent(
agent_state=agent_state,
interface=interface,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
)

View File

@@ -89,9 +89,6 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
try:
server.ms.create_user(new_user)
# initialize default presets automatically for user
server.initialize_default_presets(new_user.id)
# make sure we can retrieve the user from the DB too
new_user_ret = server.ms.get_user(new_user.id)
if new_user_ret is None:

View File

@@ -79,15 +79,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=agent_state.tools,
system=agent_state.system,
metadata=agent_state._metadata,
),
last_run_at=None, # TODO
sources=attached_sources,
@@ -125,9 +123,6 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,

View File

@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.constants import BASE_TOOLS
from memgpt.memory import ChatMemory
from memgpt.models.pydantic_models import (
AgentStateModel,
EmbeddingConfigModel,
@@ -69,8 +70,11 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
human = request.config["human"] if "human" in request.config else None
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
persona = request.config["persona"] if "persona" in request.config else None
preset = request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
tool_names = request.config["function_names"]
metadata = request.config["metadata"] if "metadata" in request.config else {}
metadata["human"] = human_name
metadata["persona"] = persona_name
# TODO: remove this -- should be added based on create agent fields
if isinstance(tool_names, str): # TODO: fix this on clinet side?
@@ -82,56 +86,54 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
tool_names.append(name)
assert isinstance(tool_names, list), "Tool names must be a list of strings."
# TODO: eventually remove this - should support general memory at the REST endpoint
memory = ChatMemory(persona=persona, human=human)
try:
agent_state = server.create_agent(
user_id=user_id,
# **request.config
# TODO turn into a pydantic model
name=request.config["name"],
preset=preset,
persona_name=persona_name,
human_name=human_name,
persona=persona,
human=human,
memory=memory,
# persona_name=persona_name,
# human_name=human_name,
# persona=persona,
# human=human,
# llm_config=LLMConfigModel(
# model=request.config['model'],
# )
# tools
tools=tool_names,
metadata=metadata,
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
)
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
# TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line
# TODO: remove
preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id)
return CreateAgentResponse(
agent_state=AgentStateModel(
id=agent_state.id,
name=agent_state.name,
user_id=agent_state.user_id,
preset=agent_state.preset,
persona=agent_state.persona,
human=agent_state.human,
llm_config=llm_config,
embedding_config=embedding_config,
state=agent_state.state,
created_at=int(agent_state.created_at.timestamp()),
tools=tool_names,
system=agent_state.system,
metadata=agent_state._metadata,
),
preset=PresetModel(
name=preset.name,
id=preset.id,
user_id=preset.user_id,
description=preset.description,
created_at=preset.created_at,
system=preset.system,
persona=preset.persona,
human=preset.human,
functions_schema=preset.functions_schema,
preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend)
name="dummy_preset",
id=agent_state.id,
user_id=agent_state.user_id,
description="",
created_at=agent_state.created_at,
system=agent_state.system,
persona="",
human="",
functions_schema=[],
),
)
except Exception as e:

View File

@@ -2,7 +2,7 @@ import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import HumanModel
@@ -46,4 +46,24 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
server.ms.add_human(new_human)
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def delete_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.delete_human(human_name, user_id=user_id)
return human
@router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
async def get_human(
human_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
human = server.ms.get_human(human_name, user_id=user_id)
if human is None:
raise HTTPException(status_code=404, detail="Human not found")
return human
return router

View File

@@ -2,7 +2,7 @@ import uuid
from functools import partial
from typing import List
from fastapi import APIRouter, Body, Depends
from fastapi import APIRouter, Body, Depends, HTTPException
from pydantic import BaseModel, Field
from memgpt.models.pydantic_models import PersonaModel
@@ -47,4 +47,24 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
server.ms.add_persona(new_persona)
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
async def delete_persona(
persona_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
persona = server.ms.delete_persona(persona_name, user_id=user_id)
return persona
@router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
async def get_persona(
persona_name: str,
user_id: uuid.UUID = Depends(get_current_user_with_server),
):
interface.clear()
persona = server.ms.get_persona(persona_name, user_id=user_id)
if persona is None:
raise HTTPException(status_code=404, detail="Persona not found")
return persona
return router

View File

@@ -22,6 +22,7 @@ class CreateToolRequest(BaseModel):
source_code: str = Field(..., description="The source code of the function.")
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
update: Optional[bool] = Field(False, description="Update the tool if it already exists.")
class CreateToolResponse(BaseModel):
@@ -87,8 +88,10 @@ def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterfac
source_type=request.source_type,
tags=request.tags,
user_id=user_id,
exists_ok=request.update,
)
except Exception as e:
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
print(e)
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}")
return router

View File

@@ -15,8 +15,6 @@ import memgpt.server.utils as server_utils
import memgpt.system as system
from memgpt.agent import Agent, save_agent
from memgpt.agent_store.storage import StorageConnector, TableType
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.cli.cli_config import get_model_options
from memgpt.config import MemGPTConfig
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
@@ -37,6 +35,7 @@ from memgpt.data_types import (
from memgpt.interface import AgentInterface # abstract
from memgpt.interface import CLIInterface # for printing to terminal
from memgpt.log import get_logger
from memgpt.memory import BaseMemory
from memgpt.metadata import MetadataStore
from memgpt.models.chat_completion_response import UsageStatistics
from memgpt.models.pydantic_models import (
@@ -44,10 +43,14 @@ from memgpt.models.pydantic_models import (
HumanModel,
MemGPTUsageStatistics,
PassageModel,
PersonaModel,
PresetModel,
SourceModel,
ToolModel,
)
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
from memgpt.prompts import gpt_system
from memgpt.utils import create_random_username
logger = get_logger(__name__)
@@ -212,7 +215,7 @@ class SyncServer(LockingServer):
# Initialize the connection to the DB
self.config = MemGPTConfig.load()
print(f"server :: loading configuration from '{self.config.config_path}'")
logger.info(f"loading configuration from '{self.config.config_path}'")
assert self.config.persona is not None, "Persona must be set in the config"
assert self.config.human is not None, "Human must be set in the config"
@@ -274,14 +277,9 @@ class SyncServer(LockingServer):
self.ms.update_user(user)
else:
self.ms.create_user(user)
presets.add_default_presets(user_id, self.ms)
# NOTE: removed, since server should be multi-user
## Create the default user
# base_user_id = uuid.UUID(self.config.anon_clientid)
# if not self.ms.get_user(user_id=base_user_id):
# base_user = User(id=base_user_id)
# self.ms.create_user(base_user)
# add global default tools
presets.add_default_tools(None, self.ms)
def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
@@ -336,7 +334,14 @@ class SyncServer(LockingServer):
# Instantiate an agent object using the state retrieved
logger.info(f"Creating an agent object")
tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects
tool_objs = []
for name in agent_state.tools:
tool_obj = self.ms.get_tool(name, user_id)
if not tool_obj:
logger.exception(f"Tool {name} does not exist for user {user_id}")
raise ValueError(f"Tool {name} does not exist for user {user_id}")
tool_objs.append(tool_obj)
memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
# Add the agent to the in-memory store and return its reference
@@ -350,11 +355,11 @@ class SyncServer(LockingServer):
def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent:
"""Check if the agent is in-memory, then load"""
logger.info(f"Checking for agent user_id={user_id} agent_id={agent_id}")
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
# TODO: consider disabling loading cached agents due to potential concurrency issues
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
if not memgpt_agent:
logger.info(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
return memgpt_agent
@@ -426,7 +431,6 @@ class SyncServer(LockingServer):
# Get the agent object (loaded in memory)
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
# print("AGENT", memgpt_agent.agent_state.id, memgpt_agent.agent_state.user_id)
if command.lower() == "exit":
# exit not supported on server.py
@@ -657,49 +661,58 @@ class SyncServer(LockingServer):
def create_user(
self,
user_config: Optional[Union[dict, User]] = {},
exists_ok: bool = False,
):
"""Create a new user using a config"""
if not isinstance(user_config, dict):
raise ValueError(f"user_config must be provided as a dictionary")
if "id" in user_config:
existing_user = self.ms.get_user(user_id=user_config["id"])
if existing_user:
if exists_ok:
presets.add_default_humans_and_personas(existing_user.id, self.ms)
return existing_user
else:
raise ValueError(f"User with ID {existing_user.id} already exists")
user = User(
id=user_config["id"] if "id" in user_config else None,
)
self.ms.create_user(user)
logger.info(f"Created new user from config: {user}")
# add default for the user
presets.add_default_humans_and_personas(user.id, self.ms)
return user
def create_agent(
self,
user_id: uuid.UUID,
tools: List[str], # list of tool names (handles) to include
# system: str, # system prompt
memory: BaseMemory,
system: Optional[str] = None,
metadata: Optional[dict] = {}, # includes human/persona names
name: Optional[str] = None,
preset: Optional[str] = None, # TODO: remove eventually
# model config
llm_config: Optional[LLMConfig] = None,
embedding_config: Optional[EmbeddingConfig] = None,
# interface
interface: Union[AgentInterface, None] = None,
# TODO: refactor this to be a more general memory configuration
system: Optional[str] = None, # prompt value
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
persona_name: Optional[str] = None, # TODO: remove
human_name: Optional[str] = None, # TODO: remove
) -> AgentState:
"""Create a new agent using a config"""
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
if interface is None:
# interface = self.default_interface
interface = self.default_interface_factory()
# if persistence_manager is None:
# persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
# system prompt (get default if None)
if system is None:
system = gpt_system.get_system_text(self.config.preset)
# create agent name
if name is None:
name = create_random_username()
@@ -708,90 +721,34 @@ class SyncServer(LockingServer):
if not user:
raise ValueError(f"cannot find user with associated client id: {user_id}")
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
# TODO: fix this db dependency and remove
# self.ms.#create_agent(agent_state)
# TODO modify to do creation via preset
try:
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
preset_override = False
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
# system prompt
if system is None:
system = preset_obj.system
else:
preset_obj.system = system
preset_override = True
# Overwrite fields in the preset if they were specified
if human is not None and human != preset_obj.human:
preset_override = True
preset_obj.human = human
# This is a check for a common bug where users were providing filenames instead of values
# try:
# get_human_text(human)
# raise ValueError(human)
# raise UserWarning(
# f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?"
# )
# except:
# pass
if persona is not None:
preset_override = True
preset_obj.persona = persona
# try:
# get_persona_text(persona)
# raise ValueError(persona)
# raise UserWarning(
# f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?"
# )
# except:
# pass
if human_name is not None and human_name != preset_obj.human_name:
preset_override = True
preset_obj.human_name = human_name
if persona_name is not None and persona_name != preset_obj.persona_name:
preset_override = True
preset_obj.persona_name = persona_name
# model configuration
llm_config = llm_config if llm_config else self.server_llm_config
embedding_config = embedding_config if embedding_config else self.server_embedding_config
# get tools
# get tools + make sure they exist
tool_objs = []
for tool_name in tools:
tool_obj = self.ms.get_tool(tool_name, user_id=user_id)
assert tool_obj is not None, f"Tool {tool_name} does not exist"
assert tool_obj, f"Tool {tool_name} does not exist"
tool_objs.append(tool_obj)
# If the user overrode any parts of the preset, we need to create a new preset to refer back to
if preset_override:
# Change the name and uuid
preset_obj = Preset.clone(preset_obj=preset_obj)
# Then write out to the database for storage
self.ms.create_preset(preset=preset_obj)
# TODO: add metadata
agent_state = AgentState(
name=name,
user_id=user_id,
persona=preset_obj.persona_name, # TODO: remove
human=preset_obj.human_name, # TODO: remove
tools=tools, # name=id for tools
llm_config=llm_config,
embedding_config=embedding_config,
system=system,
preset=preset, # TODO: remove
state={"persona": preset_obj.persona, "human": preset_obj.human, "system": system, "messages": None},
state={"system": system, "messages": None, "memory": memory.to_dict()},
_metadata=metadata,
)
agent = Agent(
interface=interface,
agent_state=agent_state,
tools=tool_objs,
# embedding_config=embedding_config,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
)
@@ -805,10 +762,11 @@ class SyncServer(LockingServer):
logger.exception(f"Failed to delete_agent:\n{delete_e}")
raise e
# save agent
save_agent(agent, self.ms)
logger.info(f"Created new agent from config: {agent}")
# return AgentState
return agent.agent_state
def delete_agent(
@@ -872,8 +830,8 @@ class SyncServer(LockingServer):
agent_config = {
"id": agent_state.id,
"name": agent_state.name,
"human": agent_state.human,
"persona": agent_state.persona,
"human": agent_state._metadata.get("human", None),
"persona": agent_state._metadata.get("persona", None),
"created_at": agent_state.created_at.isoformat(),
}
return agent_config
@@ -905,8 +863,8 @@ class SyncServer(LockingServer):
# TODO hack for frontend, remove
# (top level .persona is persona_name, and nested memory.persona is the state)
# TODO: eventually modify this to be contained in the metadata
return_dict["persona"] = agent_state.human
return_dict["human"] = agent_state.persona
return_dict["persona"] = agent_state._metadata.get("persona", None)
return_dict["human"] = agent_state._metadata.get("human", None)
# Add information about tools
# TODO memgpt_agent should really have a field of List[ToolModel]
@@ -918,10 +876,7 @@ class SyncServer(LockingServer):
recall_memory = memgpt_agent.persistence_manager.recall_memory
archival_memory = memgpt_agent.persistence_manager.archival_memory
memory_obj = {
"core_memory": {
"persona": core_memory.persona,
"human": core_memory.human,
},
"core_memory": {section: module.value for (section, module) in core_memory.memory.items()},
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
@@ -941,12 +896,31 @@ class SyncServer(LockingServer):
# Sort agents by "last_run" in descending order, most recent first
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}")
logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}")
return {
"num_agents": len(agents_states),
"agents": agents_states_dicts,
}
def list_personas(self, user_id: uuid.UUID):
return self.ms.list_personas(user_id=user_id)
def get_persona(self, name: str, user_id: uuid.UUID):
return self.ms.get_persona(name=name, user_id=user_id)
def add_persona(self, persona: PersonaModel):
name = persona.name
user_id = persona.user_id
self.ms.add_persona(persona=persona)
persona = self.ms.get_persona(name=name, user_id=user_id)
return persona
def update_persona(self, persona: PersonaModel):
return self.ms.update_persona(persona=persona)
def delete_persona(self, name: str, user_id: uuid.UUID):
return self.ms.delete_persona(name=name, user_id=user_id)
def list_humans(self, user_id: uuid.UUID):
return self.ms.list_humans(user_id=user_id)
@@ -954,7 +928,11 @@ class SyncServer(LockingServer):
return self.ms.get_human(name=name, user_id=user_id)
def add_human(self, human: HumanModel):
return self.ms.add_human(human=human)
name = human.name
user_id = human.user_id
self.ms.add_human(human=human)
human = self.ms.get_human(name=name, user_id=user_id)
return human
def update_human(self, human: HumanModel):
return self.ms.update_human(human=human)
@@ -983,11 +961,9 @@ class SyncServer(LockingServer):
recall_memory = memgpt_agent.persistence_manager.recall_memory
archival_memory = memgpt_agent.persistence_manager.archival_memory
# NOTE
memory_obj = {
"core_memory": {
"persona": core_memory.persona,
"human": core_memory.human,
},
"core_memory": {key: value.value for key, value in core_memory.memory.items()},
"recall_memory": len(recall_memory) if recall_memory is not None else None,
"archival_memory": len(archival_memory) if archival_memory is not None else None,
}
@@ -1245,23 +1221,18 @@ class SyncServer(LockingServer):
new_core_memory = old_core_memory.copy()
modified = False
if "persona" in new_memory_contents and new_memory_contents["persona"] is not None:
new_persona = new_memory_contents["persona"]
if old_core_memory["persona"] != new_persona:
new_core_memory["persona"] = new_persona
memgpt_agent.memory.edit_persona(new_persona)
modified = True
if "human" in new_memory_contents and new_memory_contents["human"] is not None:
new_human = new_memory_contents["human"]
if old_core_memory["human"] != new_human:
new_core_memory["human"] = new_human
memgpt_agent.memory.edit_human(new_human)
for key, value in new_memory_contents.items():
if value is None:
continue
if key in old_core_memory and old_core_memory[key] != value:
memgpt_agent.memory.memory[key].value = value # update agent memory
modified = True
# If we modified the memory contents, we need to rebuild the memory block inside the system message
if modified:
memgpt_agent.rebuild_memory()
# save agent
save_agent(memgpt_agent, self.ms)
return {
"old_core_memory": old_core_memory,
@@ -1290,6 +1261,7 @@ class SyncServer(LockingServer):
logger.exception(f"Failed to update agent name with:\n{str(e)}")
raise ValueError(f"Failed to update agent name in database")
assert isinstance(memgpt_agent.agent_state.id, uuid.UUID)
return memgpt_agent.agent_state
def delete_user(self, user_id: uuid.UUID):
@@ -1504,7 +1476,7 @@ class SyncServer(LockingServer):
tool (ToolModel): Tool object
"""
name = json_schema["name"]
tool = self.ms.get_tool(name)
tool = self.ms.get_tool(name, user_id=user_id)
if tool: # check if function already exists
if exists_ok:
# update existing tool
@@ -1514,7 +1486,7 @@ class SyncServer(LockingServer):
tool.source_type = source_type
self.ms.update_tool(tool)
else:
raise ValueError(f"Tool with name {name} already exists.")
raise ValueError(f"[server] Tool with name {name} already exists.")
else:
# create new tool
tool = ToolModel(

View File

@@ -1,121 +0,0 @@
import inspect
import json
import os
import uuid
import pytest
from memgpt import constants, create_client
from memgpt.functions.functions import USER_FUNCTIONS_DIR
from memgpt.models import chat_completion_response
from memgpt.utils import assistant_function_to_tool
from tests import TEST_MEMGPT_CONFIG
from tests.utils import create_config, wipe_config
def hello_world(self) -> str:
"""Test function for agent to gain access to
Returns:
str: A message for the world
"""
return "hello, world!"
@pytest.fixture(scope="module")
def agent():
"""Create a test agent that we can call functions on"""
wipe_config()
global client
if os.getenv("OPENAI_API_KEY"):
create_config("openai")
else:
create_config("memgpt_hosted")
# create memgpt client
client = create_client()
# ensure user exists
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})
agent_state = client.create_agent()
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
@pytest.fixture(scope="module")
def hello_world_function():
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f:
f.write(inspect.getsource(hello_world))
@pytest.fixture(scope="module")
def ai_function_call():
return chat_completion_response.Message(
**assistant_function_to_tool(
{
"role": "assistant",
"content": "I will now call hello world",
"function_call": {
"name": "hello_world",
"arguments": json.dumps({}),
},
}
)
)
return
def test_add_function_happy(agent, hello_world_function, ai_function_call):
agent.add_function("hello_world")
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
content = json.loads(msgs[-1].to_openai_dict()["content"], strict=constants.JSON_LOADS_STRICT)
assert content["message"] == "hello, world!"
assert content["status"] == "OK"
assert not function_failed
def test_add_function_already_loaded(agent, hello_world_function):
agent.add_function("hello_world")
# no exception for duplicate loading
agent.add_function("hello_world")
def test_add_function_not_exist(agent):
# pytest assert exception
with pytest.raises(ValueError):
agent.add_function("non_existent")
def test_remove_function_happy(agent, hello_world_function):
agent.add_function("hello_world")
# ensure function is loaded
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" in agent.functions_python.keys()
agent.remove_function("hello_world")
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
assert "hello_world" not in agent.functions_python.keys()
def test_remove_function_not_exist(agent):
# do not raise error
agent.remove_function("non_existent")
def test_remove_base_function_fails(agent):
with pytest.raises(ValueError):
agent.remove_function("send_message")
if __name__ == "__main__":
pytest.main(["-vv", os.path.abspath(__file__)])

View File

@@ -1,11 +1,9 @@
import os
import uuid
import pytest
import memgpt.functions.function_sets.base as base_functions
from memgpt import create_client
from tests import TEST_MEMGPT_CONFIG
from .utils import create_config, wipe_config
@@ -28,8 +26,7 @@ def agent_obj():
agent_state = client.create_agent()
global agent_obj
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
yield agent_obj
client.delete_agent(agent_obj.agent_state.id)

View File

@@ -68,10 +68,7 @@ def run_server():
# Fixture to create clients with different configurations
@pytest.fixture(
params=[ # whether to use REST API server
{"server": True},
# {"server": False} # TODO: add when implemented
],
params=[{"server": True}, {"server": False}], # whether to use REST API server
scope="module",
)
def client(request):
@@ -90,8 +87,8 @@ def client(request):
# create user via admin client
admin = Admin(server_url, test_server_token)
response = admin.create_user(test_user_id) # Adjust as per your client's method
response.user_id
token = response.api_key
else:
# use local client (no server)
token = None
@@ -109,7 +106,7 @@ def client(request):
# Fixture for test agent
@pytest.fixture(scope="module")
def agent(client):
agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name)
agent_state = client.create_agent(name=test_agent_name)
print("AGENT ID", agent_state.id)
yield agent_state
@@ -123,11 +120,11 @@ def test_agent(client, agent):
# test client.rename_agent
new_name = "RenamedTestAgent"
client.rename_agent(agent_id=agent.id, new_name=new_name)
renamed_agent = client.get_agent(agent_id=str(agent.id))
renamed_agent = client.get_agent(agent_id=agent.id)
assert renamed_agent.name == new_name, "Agent renaming failed"
# test client.delete_agent and client.agent_exists
delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name)
delete_agent = client.create_agent(name="DeleteTestAgent")
assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed"
client.delete_agent(agent_id=delete_agent.id)
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"
@@ -140,7 +137,7 @@ def test_memory(client, agent):
print("MEMORY", memory_response)
updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory)
client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory)
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
assert (
updated_memory_response.core_memory.human == updated_memory["human"]
@@ -152,10 +149,10 @@ def test_agent_interactions(client, agent):
_reset_config()
message = "Hello, agent!"
message_response = client.user_message(agent_id=str(agent.id), message=message)
message_response = client.user_message(agent_id=agent.id, message=message)
command = "/memory"
command_response = client.run_command(agent_id=str(agent.id), command=command)
command_response = client.run_command(agent_id=agent.id, command=command)
print("command", command_response)
@@ -197,11 +194,15 @@ def test_humans_personas(client, agent):
print("PERSONAS", personas_response)
persona_name = "TestPersona"
if client.get_persona(persona_name):
client.delete_persona(persona_name)
persona = client.create_persona(name=persona_name, persona="Persona text")
assert persona.name == persona_name
assert persona.text == "Persona text", "Creating persona failed"
human_name = "TestHuman"
if client.get_human(human_name):
client.delete_human(human_name)
human = client.create_human(name=human_name, human="Human text")
assert human.name == human_name
assert human.text == "Human text", "Creating human failed"
@@ -285,55 +286,55 @@ def test_sources(client, agent):
client.delete_source(source.id)
def test_presets(client, agent):
_reset_config()
# new_preset = Preset(
# # user_id=client.user_id,
# name="pytest_test_preset",
# description="DUMMY_DESCRIPTION",
# system="DUMMY_SYSTEM",
# persona="DUMMY_PERSONA",
# persona_name="DUMMY_PERSONA_NAME",
# human="DUMMY_HUMAN",
# human_name="DUMMY_HUMAN_NAME",
# functions_schema=[
# {
# "name": "send_message",
# "json_schema": {
# "name": "send_message",
# "description": "Sends a message to the human user.",
# "parameters": {
# "type": "object",
# "properties": {
# "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
# },
# "required": ["message"],
# },
# },
# "tags": ["memgpt-base"],
# "source_type": "python",
# "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
# }
# ],
# )
## List all presets and make sure the preset is NOT in the list
# all_presets = client.list_presets()
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# Create a preset
new_preset = client.create_preset(name="pytest_test_preset")
# List all presets and make sure the preset is in the list
all_presets = client.list_presets()
assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
# Delete the preset
client.delete_preset(preset_id=new_preset.id)
# List all presets and make sure the preset is NOT in the list
all_presets = client.list_presets()
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# def test_presets(client, agent):
# _reset_config()
#
# # new_preset = Preset(
# # # user_id=client.user_id,
# # name="pytest_test_preset",
# # description="DUMMY_DESCRIPTION",
# # system="DUMMY_SYSTEM",
# # persona="DUMMY_PERSONA",
# # persona_name="DUMMY_PERSONA_NAME",
# # human="DUMMY_HUMAN",
# # human_name="DUMMY_HUMAN_NAME",
# # functions_schema=[
# # {
# # "name": "send_message",
# # "json_schema": {
# # "name": "send_message",
# # "description": "Sends a message to the human user.",
# # "parameters": {
# # "type": "object",
# # "properties": {
# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
# # },
# # "required": ["message"],
# # },
# # },
# # "tags": ["memgpt-base"],
# # "source_type": "python",
# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
# # }
# # ],
# # )
#
# ## List all presets and make sure the preset is NOT in the list
# # all_presets = client.list_presets()
# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# # Create a preset
# new_preset = client.create_preset(name="pytest_test_preset")
#
# # List all presets and make sure the preset is in the list
# all_presets = client.list_presets()
# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
#
# # Delete the preset
# client.delete_preset(preset_id=new_preset.id)
#
# # List all presets and make sure the preset is NOT in the list
# all_presets = client.list_presets()
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
# def test_tools(client, agent):

View File

@@ -29,14 +29,11 @@ def run_llm_endpoint(filename):
agent_state = AgentState(
name="test_agent",
tools=[tool.name for tool in load_module_tools()],
system="",
persona="",
human="",
preset="memgpt_chat",
embedding_config=embedding_config,
llm_config=llm_config,
user_id=uuid.UUID(int=1),
state={"persona": "", "human": "", "messages": None},
state={"persona": "", "human": "", "messages": None, "memory": {}},
system="",
)
agent = Agent(
interface=None,

65
tests/test_memory.py Normal file
View File

@@ -0,0 +1,65 @@
import pytest
# Import the classes here, assuming the above definitions are in a module named memory_module
from memgpt.memory import BaseMemory, ChatMemory
@pytest.fixture
def sample_memory():
return ChatMemory(persona="Chat Agent", human="User")
def test_create_chat_memory():
"""Test creating an instance of ChatMemory"""
chat_memory = ChatMemory(persona="Chat Agent", human="User")
assert chat_memory.memory["persona"].value == "Chat Agent"
assert chat_memory.memory["human"].value == "User"
def test_dump_memory_as_json(sample_memory):
"""Test dumping ChatMemory as JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
assert isinstance(memory_dict, dict)
assert "persona" in memory_dict
assert memory_dict["persona"]["value"] == "Chat Agent"
def test_load_memory_from_json(sample_memory):
"""Test loading ChatMemory from a JSON compatible dictionary"""
memory_dict = sample_memory.to_dict()
print(memory_dict)
new_memory = BaseMemory.load(memory_dict)
assert new_memory.memory["persona"].value == "Chat Agent"
assert new_memory.memory["human"].value == "User"
# def test_memory_functionality(sample_memory):
# """Test memory modification functions"""
# # Get memory functions
# functions = get_memory_functions(ChatMemory)
# # Test core_memory_append function
# append_func = functions['core_memory_append']
# print("FUNCTIONS", functions)
# env = {}
# env.update(globals())
# for tool in functions:
# # WARNING: name may not be consistent?
# exec(tool.source_code, env)
#
# print(exec)
#
# append_func(sample_memory, 'persona', " is a test.")
# assert sample_memory.memory['persona'].value == "Chat Agent\n is a test."
# # Test core_memory_replace function
# replace_func = functions['core_memory_replace']
# replace_func(sample_memory, 'persona', " is a test.", " was a test.")
# assert sample_memory.memory['persona'].value == "Chat Agent\n was a test."
def test_memory_limit_validation(sample_memory):
"""Test exceeding memory limit"""
with pytest.raises(ValueError):
ChatMemory(persona="x" * 3000, human="y" * 3000)
with pytest.raises(ValueError):
sample_memory.memory["persona"].value = "x" * 3000

View File

@@ -1,139 +0,0 @@
import pytest
from memgpt.agent import Agent, save_agent
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
from memgpt.data_types import AgentState, LLMConfig, Source, User
from memgpt.metadata import MetadataStore
from memgpt.models.pydantic_models import HumanModel, PersonaModel
from memgpt.presets.presets import add_default_presets
from memgpt.settings import settings
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
@pytest.mark.parametrize("storage_connector", ["sqlite"])
def test_storage(storage_connector):
if storage_connector == "postgres":
TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri
TEST_MEMGPT_CONFIG.recall_storage_uri = settings.pg_uri
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
if storage_connector == "sqlite":
TEST_MEMGPT_CONFIG.recall_storage_type = "local"
ms = MetadataStore(TEST_MEMGPT_CONFIG)
# users
user_1 = User()
user_2 = User()
ms.create_user(user_1)
ms.create_user(user_2)
# test adding default humans/personas/presets
# add_default_humans_and_personas(user_id=user_1.id, ms=ms)
# add_default_humans_and_personas(user_id=user_2.id, ms=ms)
ms.add_human(human=HumanModel(name="test_human", text="This is a test human"))
ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona"))
add_default_presets(user_id=user_1.id, ms=ms)
add_default_presets(user_id=user_2.id, ms=ms)
assert len(ms.list_humans(user_id=user_1.id)) > 0, ms.list_humans(user_id=user_1.id)
assert len(ms.list_personas(user_id=user_1.id)) > 0, ms.list_personas(user_id=user_1.id)
# generate data
agent_1 = AgentState(
user_id=user_1.id,
name="agent_1",
preset=DEFAULT_PRESET,
persona=DEFAULT_PERSONA,
human=DEFAULT_HUMAN,
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
)
source_1 = Source(user_id=user_1.id, name="source_1")
# test creation
ms.create_agent(agent_1)
ms.create_source(source_1)
# test listing
len(ms.list_agents(user_id=user_1.id)) == 1
len(ms.list_agents(user_id=user_2.id)) == 0
len(ms.list_sources(user_id=user_1.id)) == 1
len(ms.list_sources(user_id=user_2.id)) == 0
# test agent_state saving
agent_state = ms.get_agent(agent_1.id).state
assert agent_state == {}, agent_state # when created via create_agent, it should be empty
from memgpt.presets.presets import add_default_presets
add_default_presets(user_1.id, ms)
preset_obj = ms.get_preset(name=DEFAULT_PRESET, user_id=user_1.id)
from memgpt.interface import CLIInterface as interface # for printing to terminal
# Overwrite fields in the preset if they were specified
preset_obj.human = get_human_text(DEFAULT_HUMAN)
preset_obj.persona = get_persona_text(DEFAULT_PERSONA)
# Create the agent
agent = Agent(
interface=interface(),
created_by=user_1.id,
name="agent_test_agent_state",
preset=preset_obj,
llm_config=config.default_llm_config,
embedding_config=config.default_embedding_config,
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
first_message_verify_mono=(
True if (config.default_llm_config.model is not None and "gpt-4" in config.default_llm_config.model) else False
),
)
agent_with_agent_state = agent.agent_state
save_agent(agent=agent, ms=ms)
agent_state = ms.get_agent(agent_with_agent_state.id).state
assert agent_state is not None, agent_state # when created via create_agent_from_preset, it should be non-empty
# test: updating
# test: update JSON-stored LLMConfig class
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
llm_config = ms.get_agent(agent_1.id).llm_config
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
llm_config.model = "gpt3.5-turbo"
agent_1.llm_config = llm_config
ms.update_agent(agent_1)
assert ms.get_agent(agent_1.id).llm_config.model == "gpt3.5-turbo", f"Updated LLMConfig to {ms.get_agent(agent_1.id).llm_config.model}"
# test attaching sources
len(ms.list_attached_sources(agent_id=agent_1.id)) == 0
ms.attach_source(user_1.id, agent_1.id, source_1.id)
len(ms.list_attached_sources(agent_id=agent_1.id)) == 1
# test: detaching sources
ms.detach_source(agent_1.id, source_1.id)
len(ms.list_attached_sources(agent_id=agent_1.id)) == 0
# test getting
ms.get_user(user_1.id)
ms.get_agent(agent_1.id)
ms.get_source(source_1.id)
# test api keys
api_key = ms.create_api_key(user_id=user_1.id)
print("api_key=", api_key.token, api_key.user_id)
api_key_result = ms.get_api_key(api_key=api_key.token)
assert api_key.token == api_key_result.token, (api_key, api_key_result)
user_result = ms.get_user_from_api_key(api_key=api_key.token)
assert user_1.id == user_result.id, (user_1, user_result)
all_keys_for_user = ms.get_all_api_keys_for_user(user_id=user_1.id)
assert len(all_keys_for_user) > 0, all_keys_for_user
ms.delete_api_key(api_key=api_key.token)
# test deletion
ms.delete_user(user_1.id)
ms.delete_user(user_2.id)
ms.delete_agent(agent_1.id)
ms.delete_source(source_1.id)

View File

@@ -5,11 +5,12 @@ import pytest
from dotenv import load_dotenv
import memgpt.utils as utils
from memgpt.constants import BASE_TOOLS
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.credentials import MemGPTCredentials
from memgpt.presets.presets import load_module_tools
from memgpt.memory import ChatMemory
from memgpt.server.server import SyncServer
from memgpt.settings import settings
@@ -59,8 +60,6 @@ def user_id(server):
user = server.create_user()
print(f"Created user\n{user.id}")
# initialize with default presets
server.initialize_default_presets(user.id)
yield user.id
# cleanup
@@ -71,10 +70,7 @@ def user_id(server):
def agent_id(server, user_id):
# create agent
agent_state = server.create_agent(
user_id=user_id,
name="test_agent",
preset="memgpt_chat",
tools=[tool.name for tool in load_module_tools()],
user_id=user_id, name="test_agent", tools=BASE_TOOLS, memory=ChatMemory(human="I am Chad", persona="I love testing")
)
print(f"Created agent\n{agent_state}")
yield agent_state.id
@@ -170,8 +166,21 @@ def test_get_recall_memory(server, user_id, agent_id):
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
print("MESSAGES")
for m in messages_3:
print(m["id"], m["role"])
if m["role"] == "assistant":
print(m["text"])
print("------------")
# test in-context message ids
all_messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1000)
print("num messages", len(all_messages))
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
print(in_context_ids)
for m in messages_3:
if str(m["id"]) not in [str(i) for i in in_context_ids]:
print("missing", m["id"], m["role"])
assert len(in_context_ids) == len(messages_3)
assert isinstance(in_context_ids[0], uuid.UUID)
message_ids = [m["id"] for m in messages_3]

View File

@@ -7,13 +7,12 @@ from sqlalchemy.ext.declarative import declarative_base
from memgpt.agent_store.storage import StorageConnector, TableType
from memgpt.config import MemGPTConfig
from memgpt.constants import MAX_EMBEDDING_DIM
from memgpt.constants import BASE_TOOLS, MAX_EMBEDDING_DIM
from memgpt.credentials import MemGPTCredentials
from memgpt.data_types import AgentState, Message, Passage, User
from memgpt.embeddings import embedding_model, query_embedding
from memgpt.metadata import MetadataStore
from memgpt.settings import settings
from memgpt.utils import get_human_text, get_persona_text
from tests import TEST_MEMGPT_CONFIG
from tests.utils import create_config, wipe_config
@@ -185,15 +184,10 @@ def test_storage(
user_id=user_id,
name="agent_1",
id=agent_1_id,
preset=TEST_MEMGPT_CONFIG.preset,
# persona_name=TEST_MEMGPT_CONFIG.persona,
# human_name=TEST_MEMGPT_CONFIG.human,
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
human=get_human_text(TEST_MEMGPT_CONFIG.human),
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
system="",
tools=[],
tools=BASE_TOOLS,
state={
"persona": "",
"human": "",

View File

@@ -11,6 +11,7 @@ from memgpt.agent import Agent
from memgpt.config import MemGPTConfig
from memgpt.constants import DEFAULT_PRESET
from memgpt.credentials import MemGPTCredentials
from memgpt.memory import ChatMemory
from memgpt.settings import settings
from tests.utils import create_config
@@ -69,7 +70,7 @@ def run_server():
# Fixture to create clients with different configurations
@pytest.fixture(
params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
# params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
# params=[{"server": False}], # whether to use REST API server # TODO: add when implemented
scope="module",
)
def admin_client(request):
@@ -179,10 +180,9 @@ def test_create_agent_tool(client):
str: The agent that was deleted.
"""
self.memory.human = ""
self.memory.persona = ""
self.rebuild_memory()
print("UPDATED MEMORY", self.memory.human, self.memory.persona)
self.memory.memory["human"].value = ""
self.memory.memory["persona"].value = ""
print("UPDATED MEMORY", self.memory.memory)
return None
# TODO: test attaching and using function on agent
@@ -190,7 +190,8 @@ def test_create_agent_tool(client):
print(f"Created tool", tool.name)
# create agent with tool
agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you")
memory = ChatMemory(human="I am a human", persona="You must clear your memory if the human instructs you")
agent = client.create_agent(name=test_agent_name, tools=[tool.name], memory=memory)
assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}"
# initial memory
@@ -207,6 +208,7 @@ def test_create_agent_tool(client):
print(response)
# updated memory
print("Query agent memory")
updated_memory = client.get_agent_memory(agent.id)
human = updated_memory.core_memory.human
persona = updated_memory.core_memory.persona