feat: refactor CoreMemory to support generalized memory fields and memory editing functions (#1479)
Co-authored-by: cpacker <packercharles@gmail.com> Co-authored-by: Maximilian-Winter <maximilian.winter.91@gmail.com>
This commit is contained in:
@@ -56,7 +56,7 @@ jobs:
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
pipx install poetry==1.8.2
|
||||
poetry install -E dev
|
||||
poetry install -E dev -E postgres
|
||||
poetry run pytest -s tests/test_client.py
|
||||
poetry run pytest -s tests/test_concurrent_connections.py
|
||||
|
||||
|
||||
270
memgpt/agent.py
270
memgpt/agent.py
@@ -3,7 +3,6 @@ import inspect
|
||||
import json
|
||||
import traceback
|
||||
import uuid
|
||||
from pathlib import Path
|
||||
from typing import List, Optional, Tuple, Union, cast
|
||||
|
||||
from tqdm import tqdm
|
||||
@@ -11,8 +10,6 @@ from tqdm import tqdm
|
||||
from memgpt.agent_store.storage import StorageConnector
|
||||
from memgpt.constants import (
|
||||
CLI_WARNING_PREFIX,
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT,
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
FIRST_MESSAGE_ATTEMPTS,
|
||||
JSON_ENSURE_ASCII,
|
||||
JSON_LOADS_STRICT,
|
||||
@@ -24,9 +21,7 @@ from memgpt.constants import (
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, Message, Passage
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.llm_api.llm_api_tools import create, is_context_overflow_error
|
||||
from memgpt.memory import ArchivalMemory
|
||||
from memgpt.memory import CoreMemory as InContextMemory
|
||||
from memgpt.memory import RecallMemory, summarize_messages
|
||||
from memgpt.memory import ArchivalMemory, BaseMemory, RecallMemory, summarize_messages
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.models.pydantic_models import ToolModel
|
||||
@@ -41,7 +36,6 @@ from memgpt.utils import (
|
||||
count_tokens,
|
||||
create_uuid_from_string,
|
||||
get_local_time,
|
||||
get_schema_diff,
|
||||
get_tool_call_id,
|
||||
get_utc_time,
|
||||
is_utc_datetime,
|
||||
@@ -53,81 +47,17 @@ from memgpt.utils import (
|
||||
)
|
||||
|
||||
from .errors import LLMError
|
||||
from .functions.functions import USER_FUNCTIONS_DIR, load_all_function_sets
|
||||
|
||||
|
||||
def link_functions(function_schemas: list):
|
||||
"""Link function definitions to list of function schemas"""
|
||||
|
||||
# need to dynamically link the functions
|
||||
# the saved agent.functions will just have the schemas, but we need to
|
||||
# go through the functions library and pull the respective python functions
|
||||
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
# agent.functions is a list of schemas (OpenAI kwarg functions style, see: https://platform.openai.com/docs/api-reference/chat/create)
|
||||
# [{'name': ..., 'description': ...}, {...}]
|
||||
available_functions = load_all_function_sets()
|
||||
linked_function_set = {}
|
||||
for f_schema in function_schemas:
|
||||
# Attempt to find the function in the existing function library
|
||||
f_name = f_schema.get("name")
|
||||
if f_name is None:
|
||||
raise ValueError(f"While loading agent.state.functions encountered a bad function schema object with no name:\n{f_schema}")
|
||||
linked_function = available_functions.get(f_name)
|
||||
if linked_function is None:
|
||||
# raise ValueError(
|
||||
# f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
|
||||
# )
|
||||
print(
|
||||
f"Function '{f_name}' was specified in agent.state.functions, but is not in function library:\n{available_functions.keys()}"
|
||||
)
|
||||
continue
|
||||
|
||||
# Once we find a matching function, make sure the schema is identical
|
||||
if json.dumps(f_schema, ensure_ascii=JSON_ENSURE_ASCII) != json.dumps(
|
||||
linked_function["json_schema"], ensure_ascii=JSON_ENSURE_ASCII
|
||||
):
|
||||
# error_message = (
|
||||
# f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different."
|
||||
# + f"\n>>>agent.state.functions\n{json.dumps(f_schema, indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
# + f"\n>>>function library\n{json.dumps(linked_function['json_schema'], indent=2, ensure_ascii=JSON_ENSURE_ASCII)}"
|
||||
# )
|
||||
schema_diff = get_schema_diff(f_schema, linked_function["json_schema"])
|
||||
error_message = (
|
||||
f"Found matching function '{f_name}' from agent.state.functions inside function library, but schemas are different.\n"
|
||||
+ "".join(schema_diff)
|
||||
)
|
||||
|
||||
# NOTE to handle old configs, instead of erroring here let's just warn
|
||||
# raise ValueError(error_message)
|
||||
printd(error_message)
|
||||
linked_function_set[f_name] = linked_function
|
||||
return linked_function_set
|
||||
|
||||
|
||||
def initialize_memory(ai_notes: Union[str, None], human_notes: Union[str, None]):
|
||||
if ai_notes is None:
|
||||
raise ValueError(ai_notes)
|
||||
if human_notes is None:
|
||||
raise ValueError(human_notes)
|
||||
memory = InContextMemory(human_char_limit=CORE_MEMORY_HUMAN_CHAR_LIMIT, persona_char_limit=CORE_MEMORY_PERSONA_CHAR_LIMIT)
|
||||
memory.edit_persona(ai_notes)
|
||||
memory.edit_human(human_notes)
|
||||
return memory
|
||||
|
||||
|
||||
def construct_system_with_memory(
|
||||
system: str,
|
||||
memory: InContextMemory,
|
||||
memory: BaseMemory,
|
||||
memory_edit_timestamp: str,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
include_char_count: bool = True,
|
||||
):
|
||||
# TODO: modify this to be generalized
|
||||
full_system_message = "\n".join(
|
||||
[
|
||||
system,
|
||||
@@ -136,12 +66,13 @@ def construct_system_with_memory(
|
||||
f"{len(recall_memory) if recall_memory else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{len(archival_memory) if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
f'<persona characters="{len(memory.persona)}/{memory.persona_char_limit}">' if include_char_count else "<persona>",
|
||||
memory.persona,
|
||||
"</persona>",
|
||||
f'<human characters="{len(memory.human)}/{memory.human_char_limit}">' if include_char_count else "<human>",
|
||||
memory.human,
|
||||
"</human>",
|
||||
str(memory),
|
||||
# f'<persona characters="{len(memory.persona)}/{memory.persona_char_limit}">' if include_char_count else "<persona>",
|
||||
# memory.persona,
|
||||
# "</persona>",
|
||||
# f'<human characters="{len(memory.human)}/{memory.human_char_limit}">' if include_char_count else "<human>",
|
||||
# memory.human,
|
||||
# "</human>",
|
||||
]
|
||||
)
|
||||
return full_system_message
|
||||
@@ -150,7 +81,7 @@ def construct_system_with_memory(
|
||||
def initialize_message_sequence(
|
||||
model: str,
|
||||
system: str,
|
||||
memory: InContextMemory,
|
||||
memory: BaseMemory,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
recall_memory: Optional[RecallMemory] = None,
|
||||
memory_edit_timestamp: Optional[str] = None,
|
||||
@@ -195,13 +126,14 @@ class Agent(object):
|
||||
# agents can be created from providing agent_state
|
||||
agent_state: AgentState,
|
||||
tools: List[ToolModel],
|
||||
# memory: BaseMemory,
|
||||
# extras
|
||||
messages_total: Optional[int] = None, # TODO remove?
|
||||
first_message_verify_mono: bool = True, # TODO move to config?
|
||||
):
|
||||
|
||||
# tools
|
||||
for tool in tools:
|
||||
assert tool, f"Tool is None - must be error in querying tool from DB"
|
||||
assert tool.name in agent_state.tools, f"Tool {tool} not found in agent_state.tools"
|
||||
for tool_name in agent_state.tools:
|
||||
assert tool_name in [tool.name for tool in tools], f"Tool name {tool_name} not included in agent tool list"
|
||||
@@ -230,13 +162,8 @@ class Agent(object):
|
||||
self.system = self.agent_state.system
|
||||
|
||||
# Initialize the memory object
|
||||
# TODO: support more general memory types
|
||||
if "persona" not in self.agent_state.state: # TODO: remove
|
||||
raise ValueError(f"'persona' not found in provided AgentState")
|
||||
if "human" not in self.agent_state.state: # TODO: remove
|
||||
raise ValueError(f"'human' not found in provided AgentState")
|
||||
self.memory = initialize_memory(ai_notes=self.agent_state.state["persona"], human_notes=self.agent_state.state["human"])
|
||||
printd("INITIALIZED MEMORY", self.memory.persona, self.memory.human)
|
||||
self.memory = BaseMemory.load(self.agent_state.state["memory"])
|
||||
printd("Initialized memory object", self.memory)
|
||||
|
||||
# Interface must implement:
|
||||
# - internal_monologue
|
||||
@@ -285,7 +212,7 @@ class Agent(object):
|
||||
m.created_at = m.created_at.replace(tzinfo=datetime.timezone.utc)
|
||||
|
||||
else:
|
||||
# print(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
|
||||
printd(f"Agent.__init__ :: creating, state={agent_state.state['messages']}")
|
||||
init_messages = initialize_message_sequence(
|
||||
self.model,
|
||||
self.system,
|
||||
@@ -311,12 +238,10 @@ class Agent(object):
|
||||
|
||||
# Keep track of the total number of messages throughout all time
|
||||
self.messages_total = messages_total if messages_total is not None else (len(self._messages) - 1) # (-system)
|
||||
# self.messages_total_init = self.messages_total
|
||||
self.messages_total_init = len(self._messages) - 1
|
||||
printd(f"Agent initialized, self.messages_total={self.messages_total}")
|
||||
|
||||
# Create the agent in the DB
|
||||
# self.save()
|
||||
self.update_state()
|
||||
|
||||
@property
|
||||
@@ -609,6 +534,10 @@ class Agent(object):
|
||||
heartbeat_request = False
|
||||
function_failed = False
|
||||
|
||||
# rebuild memory
|
||||
# TODO: @charles please check this
|
||||
self.rebuild_memory()
|
||||
|
||||
return messages, heartbeat_request, function_failed
|
||||
|
||||
def step(
|
||||
@@ -770,6 +699,10 @@ class Agent(object):
|
||||
|
||||
self._append_to_messages(all_new_messages)
|
||||
messages_to_return = [msg.to_openai_dict() for msg in all_new_messages] if return_dicts else all_new_messages
|
||||
|
||||
# update state after each step
|
||||
self.update_state()
|
||||
|
||||
return messages_to_return, heartbeat_request, function_failed, active_memory_warning, response.usage
|
||||
|
||||
except Exception as e:
|
||||
@@ -913,6 +846,14 @@ class Agent(object):
|
||||
def rebuild_memory(self):
|
||||
"""Rebuilds the system message with the latest memory object"""
|
||||
curr_system_message = self.messages[0] # this is the system + memory bank, not just the system prompt
|
||||
|
||||
# NOTE: This is a hacky way to check if the memory has changed
|
||||
memory_repr = str(self.memory)
|
||||
if memory_repr == curr_system_message["content"][-(len(memory_repr)) :]:
|
||||
printd(f"Memory has not changed, not rebuilding system")
|
||||
return
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats seperately)
|
||||
new_system_message = initialize_message_sequence(
|
||||
self.model,
|
||||
self.system,
|
||||
@@ -922,136 +863,85 @@ class Agent(object):
|
||||
)[0]
|
||||
|
||||
diff = united_diff(curr_system_message["content"], new_system_message["content"])
|
||||
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
if len(diff) > 0: # there was a diff
|
||||
printd(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out
|
||||
self._swap_system_message(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
|
||||
# Swap the system message out (only if there is a diff)
|
||||
self._swap_system_message(
|
||||
Message.dict_to_message(
|
||||
agent_id=self.agent_state.id, user_id=self.agent_state.user_id, model=self.model, openai_message_dict=new_system_message
|
||||
)
|
||||
)
|
||||
assert self.messages[0]["content"] == new_system_message["content"], (
|
||||
self.messages[0]["content"],
|
||||
new_system_message["content"],
|
||||
)
|
||||
)
|
||||
|
||||
# def to_agent_state(self) -> AgentState:
|
||||
# # The state may have change since the last time we wrote it
|
||||
# updated_state = {
|
||||
# "persona": self.memory.persona,
|
||||
# "human": self.memory.human,
|
||||
# "system": self.system,
|
||||
# "functions": self.functions,
|
||||
# "messages": [str(msg.id) for msg in self._messages],
|
||||
# }
|
||||
|
||||
# agent_state = AgentState(
|
||||
# name=self.agent_state.name,
|
||||
# user_id=self.agent_state.user_id,
|
||||
# persona=self.agent_state.persona,
|
||||
# human=self.agent_state.human,
|
||||
# llm_config=self.agent_state.llm_config,
|
||||
# embedding_config=self.agent_state.embedding_config,
|
||||
# preset=self.agent_state.preset,
|
||||
# id=self.agent_state.id,
|
||||
# created_at=self.agent_state.created_at,
|
||||
# state=updated_state,
|
||||
# )
|
||||
|
||||
# return agent_state
|
||||
|
||||
def add_function(self, function_name: str) -> str:
|
||||
if function_name in self.functions_python.keys():
|
||||
msg = f"Function {function_name} already loaded"
|
||||
printd(msg)
|
||||
return msg
|
||||
# TODO: refactor
|
||||
raise NotImplementedError
|
||||
# if function_name in self.functions_python.keys():
|
||||
# msg = f"Function {function_name} already loaded"
|
||||
# printd(msg)
|
||||
# return msg
|
||||
|
||||
available_functions = load_all_function_sets()
|
||||
if function_name not in available_functions.keys():
|
||||
raise ValueError(f"Function {function_name} not found in function library")
|
||||
# available_functions = load_all_function_sets()
|
||||
# if function_name not in available_functions.keys():
|
||||
# raise ValueError(f"Function {function_name} not found in function library")
|
||||
|
||||
self.functions.append(available_functions[function_name]["json_schema"])
|
||||
self.functions_python[function_name] = available_functions[function_name]["python_function"]
|
||||
# self.functions.append(available_functions[function_name]["json_schema"])
|
||||
# self.functions_python[function_name] = available_functions[function_name]["python_function"]
|
||||
|
||||
msg = f"Added function {function_name}"
|
||||
# self.save()
|
||||
self.update_state()
|
||||
printd(msg)
|
||||
return msg
|
||||
# msg = f"Added function {function_name}"
|
||||
## self.save()
|
||||
# self.update_state()
|
||||
# printd(msg)
|
||||
# return msg
|
||||
|
||||
def remove_function(self, function_name: str) -> str:
|
||||
if function_name not in self.functions_python.keys():
|
||||
msg = f"Function {function_name} not loaded, ignoring"
|
||||
printd(msg)
|
||||
return msg
|
||||
# TODO: refactor
|
||||
raise NotImplementedError
|
||||
# if function_name not in self.functions_python.keys():
|
||||
# msg = f"Function {function_name} not loaded, ignoring"
|
||||
# printd(msg)
|
||||
# return msg
|
||||
|
||||
# only allow removal of user defined functions
|
||||
user_func_path = Path(USER_FUNCTIONS_DIR)
|
||||
func_path = Path(inspect.getfile(self.functions_python[function_name]))
|
||||
is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
|
||||
## only allow removal of user defined functions
|
||||
# user_func_path = Path(USER_FUNCTIONS_DIR)
|
||||
# func_path = Path(inspect.getfile(self.functions_python[function_name]))
|
||||
# is_subpath = func_path.resolve().parts[: len(user_func_path.resolve().parts)] == user_func_path.resolve().parts
|
||||
|
||||
if not is_subpath:
|
||||
raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
|
||||
# if not is_subpath:
|
||||
# raise ValueError(f"Function {function_name} is not user defined and cannot be removed")
|
||||
|
||||
self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
|
||||
self.functions_python.pop(function_name)
|
||||
# self.functions = [f_schema for f_schema in self.functions if f_schema["name"] != function_name]
|
||||
# self.functions_python.pop(function_name)
|
||||
|
||||
msg = f"Removed function {function_name}"
|
||||
# self.save()
|
||||
self.update_state()
|
||||
printd(msg)
|
||||
return msg
|
||||
|
||||
# def save(self):
|
||||
# """Save agent state locally"""
|
||||
|
||||
# new_agent_state = self.to_agent_state()
|
||||
|
||||
# # without this, even after Agent.__init__, agent.config.state["messages"] will be None
|
||||
# self.agent_state = new_agent_state
|
||||
|
||||
# # Check if we need to create the agent
|
||||
# if not self.ms.get_agent(agent_id=new_agent_state.id, user_id=new_agent_state.user_id, agent_name=new_agent_state.name):
|
||||
# # print(f"Agent.save {new_agent_state.id} :: agent does not exist, creating...")
|
||||
# self.ms.create_agent(agent=new_agent_state)
|
||||
# # Otherwise, we should update the agent
|
||||
# else:
|
||||
# # print(f"Agent.save {new_agent_state.id} :: agent already exists, updating...")
|
||||
# print(f"Agent.save {new_agent_state.id} :: preupdate:\n\tmessages={new_agent_state.state['messages']}")
|
||||
# self.ms.update_agent(agent=new_agent_state)
|
||||
# msg = f"Removed function {function_name}"
|
||||
## self.save()
|
||||
# self.update_state()
|
||||
# printd(msg)
|
||||
# return msg
|
||||
|
||||
def update_state(self) -> AgentState:
|
||||
# updated_state = {
|
||||
# "persona": self.memory.persona,
|
||||
# "human": self.memory.human,
|
||||
# "system": self.system,
|
||||
# "functions": self.functions,
|
||||
# "messages": [str(msg.id) for msg in self._messages],
|
||||
# }
|
||||
memory = {
|
||||
"system": self.system,
|
||||
"persona": self.memory.persona,
|
||||
"human": self.memory.human,
|
||||
"memory": self.memory.to_dict(),
|
||||
"messages": [str(msg.id) for msg in self._messages], # TODO: move out into AgentState.message_ids
|
||||
}
|
||||
|
||||
# TODO: add this field
|
||||
metadata = { # TODO
|
||||
"human_name": self.agent_state.persona,
|
||||
"persona_name": self.agent_state.human,
|
||||
}
|
||||
|
||||
self.agent_state = AgentState(
|
||||
name=self.agent_state.name,
|
||||
user_id=self.agent_state.user_id,
|
||||
tools=self.agent_state.tools,
|
||||
system=self.system,
|
||||
persona=self.agent_state.persona, # TODO: remove (stores persona_name)
|
||||
human=self.agent_state.human, # TODO: remove (stores human_name)
|
||||
## "model_state"
|
||||
llm_config=self.agent_state.llm_config,
|
||||
embedding_config=self.agent_state.embedding_config,
|
||||
preset=self.agent_state.preset,
|
||||
id=self.agent_state.id,
|
||||
created_at=self.agent_state.created_at,
|
||||
## "agent_state"
|
||||
state=memory,
|
||||
_metadata=self.agent_state._metadata,
|
||||
)
|
||||
return self.agent_state
|
||||
|
||||
|
||||
@@ -493,9 +493,7 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
if isinstance(records[0], Passage):
|
||||
with self.engine.connect() as conn:
|
||||
db_records = [vars(record) for record in records]
|
||||
# print("records", db_records)
|
||||
stmt = insert(self.db_model.__table__).values(db_records)
|
||||
# print(stmt)
|
||||
if exists_ok:
|
||||
upsert_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
|
||||
@@ -594,9 +592,7 @@ class SQLLiteStorageConnector(SQLStorageConnector):
|
||||
if isinstance(records[0], Passage):
|
||||
with self.engine.connect() as conn:
|
||||
db_records = [vars(record) for record in records]
|
||||
# print("records", db_records)
|
||||
stmt = insert(self.db_model.__table__).values(db_records)
|
||||
# print(stmt)
|
||||
if exists_ok:
|
||||
upsert_stmt = stmt.on_conflict_do_update(
|
||||
index_elements=["id"], set_={c.name: c for c in stmt.excluded} # Replace with your primary key column
|
||||
|
||||
@@ -13,16 +13,19 @@ import requests
|
||||
import typer
|
||||
|
||||
import memgpt.utils as utils
|
||||
from memgpt import create_client
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.cli.cli_config import configure
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import CLI_WARNING_PREFIX, MEMGPT_DIR
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.data_types import EmbeddingConfig, LLMConfig, User
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.migrate import migrate_all_agents, migrate_all_sources
|
||||
from memgpt.server.constants import WS_DEFAULT_PORT
|
||||
from memgpt.server.server import logger as server_logger
|
||||
|
||||
# from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
from memgpt.streaming_interface import (
|
||||
@@ -391,7 +394,6 @@ def run(
|
||||
persona: Annotated[Optional[str], typer.Option(help="Specify persona")] = None,
|
||||
agent: Annotated[Optional[str], typer.Option(help="Specify agent name")] = None,
|
||||
human: Annotated[Optional[str], typer.Option(help="Specify human")] = None,
|
||||
preset: Annotated[Optional[str], typer.Option(help="Specify preset")] = None,
|
||||
# model flags
|
||||
model: Annotated[Optional[str], typer.Option(help="Specify the LLM model")] = None,
|
||||
model_wrapper: Annotated[Optional[str], typer.Option(help="Specify the LLM model wrapper")] = None,
|
||||
@@ -427,8 +429,10 @@ def run(
|
||||
|
||||
if debug:
|
||||
logger.setLevel(logging.DEBUG)
|
||||
server_logger.setLevel(logging.DEBUG)
|
||||
else:
|
||||
logger.setLevel(logging.CRITICAL)
|
||||
server_logger.setLevel(logging.CRITICAL)
|
||||
|
||||
from memgpt.migrate import (
|
||||
VERSION_CUTOFF,
|
||||
@@ -511,8 +515,6 @@ def run(
|
||||
# read user id from config
|
||||
ms = MetadataStore(config)
|
||||
user = create_default_user_or_exit(config, ms)
|
||||
human = human if human else config.human
|
||||
persona = persona if persona else config.persona
|
||||
|
||||
# determine agent to use, if not provided
|
||||
if not yes and not agent:
|
||||
@@ -529,6 +531,8 @@ def run(
|
||||
|
||||
# create agent config
|
||||
agent_state = ms.get_agent(agent_name=agent, user_id=user.id) if agent else None
|
||||
human = human if human else config.human
|
||||
persona = persona if persona else config.persona
|
||||
if agent and agent_state: # use existing agent
|
||||
typer.secho(f"\n🔁 Using existing agent {agent}", fg=typer.colors.GREEN)
|
||||
# agent_config = AgentConfig.load(agent)
|
||||
@@ -540,14 +544,6 @@ def run(
|
||||
# printd("Index path:", agent_config.save_agent_index_dir())
|
||||
# persistence_manager = LocalStateManager(agent_config).load() # TODO: implement load
|
||||
# TODO: load prior agent state
|
||||
if persona and persona != agent_state.persona:
|
||||
typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing persona {agent_state.persona} with {persona}", fg=typer.colors.YELLOW)
|
||||
agent_state.persona = persona
|
||||
# raise ValueError(f"Cannot override {agent_state.name} existing persona {agent_state.persona} with {persona}")
|
||||
if human and human != agent_state.human:
|
||||
typer.secho(f"{CLI_WARNING_PREFIX}Overriding existing human {agent_state.human} with {human}", fg=typer.colors.YELLOW)
|
||||
agent_state.human = human
|
||||
# raise ValueError(f"Cannot override {agent_config.name} existing human {agent_config.human} with {human}")
|
||||
|
||||
# Allow overriding model specifics (model, model wrapper, model endpoint IP + type, context_window)
|
||||
if model and model != agent_state.llm_config.model:
|
||||
@@ -582,7 +578,12 @@ def run(
|
||||
|
||||
# Update the agent with any overrides
|
||||
ms.update_agent(agent_state)
|
||||
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
|
||||
tools = []
|
||||
for tool_name in agent_state.tools:
|
||||
tool = ms.get_tool(tool_name, agent_state.user_id)
|
||||
if tool is None:
|
||||
typer.secho(f"Couldn't find tool {tool_name} in database, please run `memgpt add tool`", fg=typer.colors.RED)
|
||||
tools.append(tool)
|
||||
|
||||
# create agent
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface(), tools=tools)
|
||||
@@ -626,45 +627,30 @@ def run(
|
||||
|
||||
# create agent
|
||||
try:
|
||||
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
|
||||
client = create_client()
|
||||
human_obj = ms.get_human(human, user.id)
|
||||
persona_obj = ms.get_persona(persona, user.id)
|
||||
if preset_obj is None:
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user.id, ms)
|
||||
# try again
|
||||
preset_obj = ms.get_preset(name=preset if preset else config.preset, user_id=user.id)
|
||||
if preset_obj is None:
|
||||
typer.secho("Couldn't find presets in database, please run `memgpt configure`", fg=typer.colors.RED)
|
||||
sys.exit(1)
|
||||
if human_obj is None:
|
||||
typer.secho("Couldn't find human {human} in database, please run `memgpt add human`", fg=typer.colors.RED)
|
||||
if persona_obj is None:
|
||||
typer.secho("Couldn't find persona {persona} in database, please run `memgpt add persona`", fg=typer.colors.RED)
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
preset_obj.human = ms.get_human(human, user.id).text
|
||||
preset_obj.persona = ms.get_persona(persona, user.id).text
|
||||
memory = ChatMemory(human=human_obj.text, persona=persona_obj.text)
|
||||
metadata = {"human": human_obj.name, "persona": persona_obj.name}
|
||||
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{preset_obj.persona_name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{preset_obj.human_name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🤖 Using persona profile: '{persona_obj.name}'", fg=typer.colors.WHITE)
|
||||
typer.secho(f"-> 🧑 Using human profile: '{human_obj.name}'", fg=typer.colors.WHITE)
|
||||
|
||||
agent_state = AgentState(
|
||||
# add tools
|
||||
agent_state = client.create_agent(
|
||||
name=agent_name,
|
||||
user_id=user.id,
|
||||
tools=list([schema["name"] for schema in preset_obj.functions_schema]),
|
||||
system=preset_obj.system,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
human=preset_obj.human,
|
||||
persona=preset_obj.persona,
|
||||
preset=preset_obj.name,
|
||||
state={"messages": None, "persona": preset_obj.persona, "human": preset_obj.human},
|
||||
llm_config=llm_config,
|
||||
memory=memory,
|
||||
metadata=metadata,
|
||||
)
|
||||
typer.secho(f"-> 🛠️ {len(agent_state.tools)} tools: {', '.join([t for t in agent_state.tools])}", fg=typer.colors.WHITE)
|
||||
tools = [ms.get_tool(tool_name) for tool_name in agent_state.tools]
|
||||
tools = [ms.get_tool(tool_name, user_id=client.user_id) for tool_name in agent_state.tools]
|
||||
|
||||
memgpt_agent = Agent(
|
||||
interface=interface(),
|
||||
|
||||
@@ -38,8 +38,7 @@ from memgpt.local_llm.constants import (
|
||||
)
|
||||
from memgpt.local_llm.utils import get_available_wrappers
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
from memgpt.presets.presets import create_preset_from_file
|
||||
from memgpt.models.pydantic_models import PersonaModel
|
||||
from memgpt.server.utils import shorten_key_middle
|
||||
|
||||
app = typer.Typer()
|
||||
@@ -1085,18 +1084,12 @@ def configure():
|
||||
else:
|
||||
ms.create_user(user)
|
||||
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user_id, ms)
|
||||
|
||||
|
||||
class ListChoice(str, Enum):
|
||||
agents = "agents"
|
||||
humans = "humans"
|
||||
personas = "personas"
|
||||
sources = "sources"
|
||||
presets = "presets"
|
||||
|
||||
|
||||
@app.command()
|
||||
@@ -1133,7 +1126,7 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
elif arg == ListChoice.humans:
|
||||
"""List all humans"""
|
||||
table.field_names = ["Name", "Text"]
|
||||
for human in client.list_humans(user_id=user_id):
|
||||
for human in client.list_humans():
|
||||
table.add_row([human.name, human.text.replace("\n", "")[:100]])
|
||||
print(table)
|
||||
elif arg == ListChoice.personas:
|
||||
@@ -1170,21 +1163,6 @@ def list(arg: Annotated[ListChoice, typer.Argument]):
|
||||
)
|
||||
|
||||
print(table)
|
||||
elif arg == ListChoice.presets:
|
||||
"""List all available presets"""
|
||||
table.field_names = ["Name", "Description", "Sources", "Functions"]
|
||||
for preset in ms.list_presets(user_id=user_id):
|
||||
sources = ms.get_preset_sources(preset_id=preset.id)
|
||||
table.add_row(
|
||||
[
|
||||
preset.name,
|
||||
preset.description,
|
||||
",".join([source.name for source in sources]),
|
||||
# json.dumps(preset.functions_schema, indent=4)
|
||||
",\n".join([f["name"] for f in preset.functions_schema]),
|
||||
]
|
||||
)
|
||||
print(table)
|
||||
else:
|
||||
raise ValueError(f"Unknown argument {arg}")
|
||||
|
||||
@@ -1208,7 +1186,7 @@ def add(
|
||||
with open(filename, "r", encoding="utf-8") as f:
|
||||
text = f.read()
|
||||
if option == "persona":
|
||||
persona = ms.get_persona(name=name, user_id=user_id)
|
||||
persona = ms.get_persona(name=name)
|
||||
if persona:
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Persona {name} already exists. Overwrite?").ask():
|
||||
@@ -1220,7 +1198,7 @@ def add(
|
||||
ms.add_persona(persona)
|
||||
|
||||
elif option == "human":
|
||||
human = client.get_human(name=name, user_id=user_id)
|
||||
human = client.get_human(name=name)
|
||||
if human:
|
||||
# config if user wants to overwrite
|
||||
if not questionary.confirm(f"Human {name} already exists. Overwrite?").ask():
|
||||
@@ -1228,11 +1206,7 @@ def add(
|
||||
human.text = text
|
||||
client.update_human(human)
|
||||
else:
|
||||
human = HumanModel(name=name, text=text, user_id=user_id)
|
||||
client.add_human(HumanModel(name=name, text=text, user_id=user_id))
|
||||
elif option == "preset":
|
||||
assert filename, "Must specify filename for preset"
|
||||
create_preset_from_file(filename, name, user_id, ms)
|
||||
human = client.create_human(name=name, human=text)
|
||||
else:
|
||||
raise ValueError(f"Unknown kind {option}")
|
||||
|
||||
@@ -1281,18 +1255,14 @@ def delete(option: str, name: str):
|
||||
ms.delete_agent(agent_id=agent.id)
|
||||
|
||||
elif option == "human":
|
||||
human = client.get_human(name=name, user_id=user_id)
|
||||
human = client.get_human(name=name)
|
||||
assert human is not None, f"Human {name} does not exist"
|
||||
client.delete_human(name=name, user_id=user_id)
|
||||
client.delete_human(name=name)
|
||||
elif option == "persona":
|
||||
persona = ms.get_persona(name=name, user_id=user_id)
|
||||
persona = ms.get_persona(name=name)
|
||||
assert persona is not None, f"Persona {name} does not exist"
|
||||
ms.delete_persona(name=name, user_id=user_id)
|
||||
assert ms.get_persona(name=name, user_id=user_id) is None, f"Persona {name} still exists"
|
||||
elif option == "preset":
|
||||
preset = ms.get_preset(name=name, user_id=user_id)
|
||||
assert preset is not None, f"Preset {name} does not exist"
|
||||
ms.delete_preset(name=name, user_id=user_id)
|
||||
ms.delete_persona(name=name)
|
||||
assert ms.get_persona(name=name) is None, f"Persona {name} still exists"
|
||||
else:
|
||||
raise ValueError(f"Option {option} not implemented")
|
||||
|
||||
|
||||
@@ -43,7 +43,6 @@ class Admin:
|
||||
def create_key(self, user_id: uuid.UUID, key_name: str):
|
||||
payload = {"user_id": str(user_id), "key_name": key_name}
|
||||
response = requests.post(f"{self.base_url}/admin/users/keys", headers=self.headers, json=payload)
|
||||
print(response.json())
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
return CreateAPIKeyResponse(**response.json())
|
||||
@@ -53,7 +52,6 @@ class Admin:
|
||||
response = requests.get(f"{self.base_url}/admin/users/keys", params=params, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise HTTPError(response.json())
|
||||
print(response.text, response.status_code)
|
||||
return GetAPIKeysResponse(**response.json()).api_key_list
|
||||
|
||||
def delete_key(self, api_key: str):
|
||||
@@ -114,6 +112,11 @@ class Admin:
|
||||
source_type = "python"
|
||||
json_schema["name"]
|
||||
|
||||
if "memory" in tags:
|
||||
# special modifications to memory functions
|
||||
# self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
|
||||
source_code = source_code.replace("self.memory", "self.memory.memory")
|
||||
|
||||
# create data
|
||||
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
CreateToolRequest(**data) # validate
|
||||
|
||||
@@ -6,23 +6,17 @@ from typing import Dict, List, Optional, Tuple, Union
|
||||
import requests
|
||||
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import BASE_TOOLS, DEFAULT_PRESET
|
||||
from memgpt.constants import BASE_TOOLS, DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
from memgpt.data_sources.connectors import DataConnector
|
||||
from memgpt.data_types import (
|
||||
AgentState,
|
||||
EmbeddingConfig,
|
||||
LLMConfig,
|
||||
Preset,
|
||||
Source,
|
||||
User,
|
||||
)
|
||||
from memgpt.data_types import AgentState, EmbeddingConfig, LLMConfig, Preset, Source
|
||||
from memgpt.functions.functions import parse_source_code
|
||||
from memgpt.functions.schema_generator import generate_schema
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.memory import BaseMemory, ChatMemory, get_memory_functions
|
||||
from memgpt.models.pydantic_models import (
|
||||
HumanModel,
|
||||
JobModel,
|
||||
JobStatus,
|
||||
LLMConfigModel,
|
||||
PersonaModel,
|
||||
PresetModel,
|
||||
SourceModel,
|
||||
@@ -32,6 +26,7 @@ from memgpt.server.rest_api.agents.command import CommandResponse
|
||||
from memgpt.server.rest_api.agents.config import GetAgentResponse
|
||||
from memgpt.server.rest_api.agents.index import CreateAgentResponse, ListAgentsResponse
|
||||
from memgpt.server.rest_api.agents.memory import (
|
||||
ArchivalMemoryObject,
|
||||
GetAgentArchivalMemoryResponse,
|
||||
GetAgentMemoryResponse,
|
||||
InsertAgentArchivalMemoryResponse,
|
||||
@@ -56,6 +51,7 @@ from memgpt.server.rest_api.sources.index import ListSourcesResponse
|
||||
# import pydantic response objects from memgpt.server.rest_api
|
||||
from memgpt.server.rest_api.tools.index import CreateToolRequest, ListToolsResponse
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.utils import get_human_text
|
||||
|
||||
|
||||
def create_client(base_url: Optional[str] = None, token: Optional[str] = None):
|
||||
@@ -242,8 +238,6 @@ class RESTClient(AbstractClient):
|
||||
|
||||
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
|
||||
response = requests.get(f"{self.base_url}/api/agents/{str(agent_id)}/config", headers=self.headers)
|
||||
print(response.text, response.status_code)
|
||||
print(response)
|
||||
if response.status_code == 404:
|
||||
# not found error
|
||||
return False
|
||||
@@ -252,17 +246,24 @@ class RESTClient(AbstractClient):
|
||||
else:
|
||||
raise ValueError(f"Failed to check if agent exists: {response.text}")
|
||||
|
||||
def get_tool(self, tool_name: str):
|
||||
response = requests.get(f"{self.base_url}/api/tools/{tool_name}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get tool: {response.text}")
|
||||
return ToolModel(**response.json())
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
|
||||
persona: Optional[str] = None,
|
||||
human: Optional[str] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
# memory
|
||||
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
|
||||
# tools
|
||||
tools: Optional[List[str]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
) -> AgentState:
|
||||
"""
|
||||
Create an agent
|
||||
@@ -274,7 +275,6 @@ class RESTClient(AbstractClient):
|
||||
|
||||
Returns:
|
||||
agent_state (AgentState): State of the the created agent.
|
||||
|
||||
"""
|
||||
if embedding_config or llm_config:
|
||||
raise ValueError("Cannot override embedding_config or llm_config when creating agent via REST API")
|
||||
@@ -286,14 +286,22 @@ class RESTClient(AbstractClient):
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
|
||||
# add memory tools
|
||||
memory_functions = get_memory_functions(memory)
|
||||
for func_name, func in memory_functions.items():
|
||||
tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"], update=True)
|
||||
tool_names.append(tool.name)
|
||||
|
||||
# TODO: distinguish between name and objects
|
||||
# TODO: add metadata
|
||||
payload = {
|
||||
"config": {
|
||||
"name": name,
|
||||
"preset": preset,
|
||||
"persona": persona,
|
||||
"human": human,
|
||||
"persona": memory.memory["persona"].value,
|
||||
"human": memory.memory["human"].value,
|
||||
"function_names": tool_names,
|
||||
"metadata": metadata,
|
||||
}
|
||||
}
|
||||
response = requests.post(f"{self.base_url}/api/agents", json=payload, headers=self.headers)
|
||||
@@ -322,14 +330,12 @@ class RESTClient(AbstractClient):
|
||||
id=response.agent_state.id,
|
||||
name=response.agent_state.name,
|
||||
user_id=response.agent_state.user_id,
|
||||
preset=response.agent_state.preset,
|
||||
persona=response.agent_state.persona,
|
||||
human=response.agent_state.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=response.agent_state.state,
|
||||
system=response.agent_state.system,
|
||||
tools=response.agent_state.tools,
|
||||
_metadata=response.agent_state.metadata,
|
||||
# load datetime from timestampe
|
||||
created_at=datetime.datetime.fromtimestamp(response.agent_state.created_at, tz=datetime.timezone.utc),
|
||||
)
|
||||
@@ -352,25 +358,8 @@ class RESTClient(AbstractClient):
|
||||
response_obj = GetAgentResponse(**response.json())
|
||||
return self.get_agent_response_to_state(response_obj)
|
||||
|
||||
## presets
|
||||
# def create_preset(self, preset: Preset) -> CreatePresetResponse:
|
||||
# # TODO should the arg type here be PresetModel, not Preset?
|
||||
# payload = CreatePresetsRequest(
|
||||
# id=str(preset.id),
|
||||
# name=preset.name,
|
||||
# description=preset.description,
|
||||
# system=preset.system,
|
||||
# persona=preset.persona,
|
||||
# human=preset.human,
|
||||
# persona_name=preset.persona_name,
|
||||
# human_name=preset.human_name,
|
||||
# functions_schema=preset.functions_schema,
|
||||
# )
|
||||
# response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
|
||||
# assert response.status_code == 200, f"Failed to create preset: {response.text}"
|
||||
# return CreatePresetResponse(**response.json())
|
||||
|
||||
def get_preset(self, name: str) -> PresetModel:
|
||||
# TODO: remove
|
||||
response = requests.get(f"{self.base_url}/api/presets/{name}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to get preset: {response.text}"
|
||||
return PresetModel(**response.json())
|
||||
@@ -385,6 +374,7 @@ class RESTClient(AbstractClient):
|
||||
tools: Optional[List[ToolModel]] = None,
|
||||
default_tools: bool = True,
|
||||
) -> PresetModel:
|
||||
# TODO: remove
|
||||
"""Create an agent preset
|
||||
|
||||
:param name: Name of the preset
|
||||
@@ -406,7 +396,6 @@ class RESTClient(AbstractClient):
|
||||
schema = []
|
||||
if tools:
|
||||
for tool in tools:
|
||||
print("CUSOTM TOOL", tool.json_schema)
|
||||
schema.append(tool.json_schema)
|
||||
|
||||
# include default tools
|
||||
@@ -427,9 +416,6 @@ class RESTClient(AbstractClient):
|
||||
human_name=human_name,
|
||||
functions_schema=schema,
|
||||
)
|
||||
print(schema)
|
||||
print(human_name, persona_name, system_name, name)
|
||||
print(payload.model_dump())
|
||||
response = requests.post(f"{self.base_url}/api/presets", json=payload.model_dump(), headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to create preset: {response.text}"
|
||||
return CreatePresetResponse(**response.json()).preset
|
||||
@@ -482,7 +468,6 @@ class RESTClient(AbstractClient):
|
||||
response = requests.post(f"{self.base_url}/api/agents/{agent_id}/archival", json={"content": memory}, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to insert archival memory: {response.text}")
|
||||
print(response.json())
|
||||
return InsertAgentArchivalMemoryResponse(**response.json())
|
||||
|
||||
def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID):
|
||||
@@ -518,8 +503,6 @@ class RESTClient(AbstractClient):
|
||||
response = requests.post(f"{self.base_url}/api/humans", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create human: {response.text}")
|
||||
|
||||
print(response.json())
|
||||
return HumanModel(**response.json())
|
||||
|
||||
def list_personas(self) -> ListPersonasResponse:
|
||||
@@ -531,9 +514,24 @@ class RESTClient(AbstractClient):
|
||||
response = requests.post(f"{self.base_url}/api/personas", json=data, headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to create persona: {response.text}")
|
||||
print(response.json())
|
||||
return PersonaModel(**response.json())
|
||||
|
||||
def get_persona(self, name: str) -> PersonaModel:
|
||||
response = requests.get(f"{self.base_url}/api/personas/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get persona: {response.text}")
|
||||
return PersonaModel(**response.json())
|
||||
|
||||
def get_human(self, name: str) -> HumanModel:
|
||||
response = requests.get(f"{self.base_url}/api/humans/{name}", headers=self.headers)
|
||||
if response.status_code == 404:
|
||||
return None
|
||||
elif response.status_code != 200:
|
||||
raise ValueError(f"Failed to get human: {response.text}")
|
||||
return HumanModel(**response.json())
|
||||
|
||||
# sources
|
||||
|
||||
def list_sources(self):
|
||||
@@ -638,7 +636,7 @@ class RESTClient(AbstractClient):
|
||||
json_schema["name"]
|
||||
|
||||
# create data
|
||||
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema}
|
||||
data = {"source_code": source_code, "source_type": source_type, "tags": tags, "json_schema": json_schema, "update": update}
|
||||
try:
|
||||
CreateToolRequest(**data) # validate data
|
||||
except Exception as e:
|
||||
@@ -694,23 +692,12 @@ class LocalClient(AbstractClient):
|
||||
else:
|
||||
self.user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
# create user if does not exist
|
||||
ms = MetadataStore(config)
|
||||
self.user = User(id=self.user_id)
|
||||
if ms.get_user(self.user_id):
|
||||
# update user
|
||||
ms.update_user(self.user)
|
||||
else:
|
||||
ms.create_user(self.user)
|
||||
|
||||
# create preset records in metadata store
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(self.user_id, ms)
|
||||
|
||||
self.interface = QueuingInterface(debug=debug)
|
||||
self.server = SyncServer(default_interface_factory=lambda: self.interface)
|
||||
|
||||
# create user if does not exist
|
||||
self.server.create_user({"id": self.user_id}, exists_ok=True)
|
||||
|
||||
# messages
|
||||
def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse:
|
||||
self.interface.clear()
|
||||
@@ -740,14 +727,16 @@ class LocalClient(AbstractClient):
|
||||
def create_agent(
|
||||
self,
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None, # TODO: this should actually be re-named preset_name
|
||||
persona: Optional[str] = None,
|
||||
human: Optional[str] = None,
|
||||
# model configs
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
# memory
|
||||
memory: BaseMemory = ChatMemory(human=get_human_text(DEFAULT_HUMAN), persona=get_human_text(DEFAULT_PERSONA)),
|
||||
# tools
|
||||
tools: Optional[List[str]] = None,
|
||||
include_base_tools: Optional[bool] = True,
|
||||
# metadata
|
||||
metadata: Optional[Dict] = {"human:": DEFAULT_HUMAN, "persona": DEFAULT_PERSONA},
|
||||
) -> AgentState:
|
||||
if name and self.agent_exists(agent_name=name):
|
||||
raise ValueError(f"Agent with name {name} already exists (user_id={self.user_id})")
|
||||
@@ -759,23 +748,35 @@ class LocalClient(AbstractClient):
|
||||
if include_base_tools:
|
||||
tool_names += BASE_TOOLS
|
||||
|
||||
# add memory tools
|
||||
memory_functions = get_memory_functions(memory)
|
||||
for func_name, func in memory_functions.items():
|
||||
tool = self.create_tool(func, name=func_name, tags=["memory", "memgpt-base"])
|
||||
tool_names.append(tool.name)
|
||||
|
||||
self.interface.clear()
|
||||
|
||||
# create agent
|
||||
agent_state = self.server.create_agent(
|
||||
user_id=self.user_id,
|
||||
name=name,
|
||||
preset=preset,
|
||||
persona=persona,
|
||||
human=human,
|
||||
memory=memory,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
tools=tool_names,
|
||||
metadata=metadata,
|
||||
)
|
||||
return agent_state
|
||||
|
||||
def rename_agent(self, agent_id: uuid.UUID, new_name: str):
|
||||
# TODO: check valid name
|
||||
agent_state = self.server.rename_agent(user_id=self.user_id, agent_id=agent_id, new_agent_name=new_name)
|
||||
return agent_state
|
||||
|
||||
def delete_agent(self, agent_id: uuid.UUID):
|
||||
self.server.delete_agent(user_id=self.user_id, agent_id=agent_id)
|
||||
|
||||
def get_agent_config(self, agent_id: str) -> AgentState:
|
||||
def get_agent(self, agent_id: uuid.UUID) -> AgentState:
|
||||
self.interface.clear()
|
||||
return self.server.get_agent_config(user_id=self.user_id, agent_id=agent_id)
|
||||
|
||||
@@ -803,6 +804,19 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# agent interactions
|
||||
|
||||
def send_message(self, agent_id: uuid.UUID, message: str, role: str, stream: Optional[bool] = False) -> UserMessageResponse:
|
||||
self.interface.clear()
|
||||
if role == "system":
|
||||
usage = self.server.system_message(user_id=self.user_id, agent_id=agent_id, message=message)
|
||||
elif role == "user":
|
||||
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
|
||||
else:
|
||||
raise ValueError(f"Role {role} not supported")
|
||||
if self.auto_save:
|
||||
self.save()
|
||||
else:
|
||||
return UserMessageResponse(messages=self.interface.to_list(), usage=usage)
|
||||
|
||||
def user_message(self, agent_id: str, message: str) -> Union[List[Dict], Tuple[List[Dict], int]]:
|
||||
self.interface.clear()
|
||||
usage = self.server.user_message(user_id=self.user_id, agent_id=agent_id, message=message)
|
||||
@@ -820,36 +834,37 @@ class LocalClient(AbstractClient):
|
||||
|
||||
# archival memory
|
||||
|
||||
def get_agent_archival_memory(
|
||||
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
|
||||
):
|
||||
_, archival_json_records = self.server.get_agent_archival_cursor(
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
)
|
||||
return archival_json_records
|
||||
|
||||
# messages
|
||||
|
||||
# humans / personas
|
||||
|
||||
def list_humans(self, user_id: uuid.UUID):
|
||||
return self.server.list_humans(user_id=user_id if user_id else self.user_id)
|
||||
def create_human(self, name: str, human: str):
|
||||
return self.server.add_human(HumanModel(name=name, text=human, user_id=self.user_id))
|
||||
|
||||
def get_human(self, name: str, user_id: uuid.UUID):
|
||||
return self.server.get_human(name=name, user_id=user_id)
|
||||
def create_persona(self, name: str, persona: str):
|
||||
return self.server.add_persona(PersonaModel(name=name, text=persona, user_id=self.user_id))
|
||||
|
||||
def add_human(self, human: HumanModel):
|
||||
return self.server.add_human(human=human)
|
||||
def list_humans(self):
|
||||
return self.server.list_humans(user_id=self.user_id if self.user_id else self.user_id)
|
||||
|
||||
def get_human(self, name: str):
|
||||
return self.server.get_human(name=name, user_id=self.user_id)
|
||||
|
||||
def update_human(self, human: HumanModel):
|
||||
return self.server.update_human(human=human)
|
||||
|
||||
def delete_human(self, name: str, user_id: uuid.UUID):
|
||||
return self.server.delete_human(name, user_id)
|
||||
def delete_human(self, name: str):
|
||||
return self.server.delete_human(name, self.user_id)
|
||||
|
||||
def list_personas(self):
|
||||
return self.server.list_personas(user_id=self.user_id)
|
||||
|
||||
def get_persona(self, name: str):
|
||||
return self.server.get_persona(name=name, user_id=self.user_id)
|
||||
|
||||
def update_persona(self, persona: PersonaModel):
|
||||
return self.server.update_persona(persona=persona)
|
||||
|
||||
def delete_persona(self, name: str):
|
||||
return self.server.delete_persona(name, self.user_id)
|
||||
|
||||
# tools
|
||||
def create_tool(
|
||||
@@ -879,6 +894,11 @@ class LocalClient(AbstractClient):
|
||||
source_type = "python"
|
||||
tool_name = json_schema["name"]
|
||||
|
||||
if "memory" in tags:
|
||||
# special modifications to memory functions
|
||||
# self.memory -> self.memory.memory, since Agent.memory.memory needs to be modified (not BaseMemory.memory)
|
||||
source_code = source_code.replace("self.memory", "self.memory.memory")
|
||||
|
||||
# check if already exists:
|
||||
existing_tool = self.server.ms.get_tool(tool_name, self.user_id)
|
||||
if existing_tool:
|
||||
@@ -924,3 +944,48 @@ class LocalClient(AbstractClient):
|
||||
|
||||
def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID):
|
||||
self.server.attach_source_to_agent(user_id=self.user_id, source_id=source_id, agent_id=agent_id)
|
||||
|
||||
def get_agent_archival_memory(
|
||||
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
|
||||
):
|
||||
self.interface.clear()
|
||||
# TODO need to add support for non-postgres here
|
||||
# chroma will throw:
|
||||
# raise ValueError("Cannot run get_all_cursor with chroma")
|
||||
_, archival_json_records = self.server.get_agent_archival_cursor(
|
||||
user_id=self.user_id,
|
||||
agent_id=agent_id,
|
||||
after=after,
|
||||
before=before,
|
||||
limit=limit,
|
||||
)
|
||||
archival_memory_objects = [ArchivalMemoryObject(id=passage["id"], contents=passage["text"]) for passage in archival_json_records]
|
||||
return GetAgentArchivalMemoryResponse(archival_memory=archival_memory_objects)
|
||||
|
||||
def insert_archival_memory(self, agent_id: uuid.UUID, memory: str) -> GetAgentArchivalMemoryResponse:
|
||||
memory_ids = self.server.insert_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_contents=memory)
|
||||
return InsertAgentArchivalMemoryResponse(ids=memory_ids)
|
||||
|
||||
def delete_archival_memory(self, agent_id: uuid.UUID, memory_id: uuid.UUID):
|
||||
self.server.delete_archival_memory(user_id=self.user_id, agent_id=agent_id, memory_id=memory_id)
|
||||
|
||||
def get_messages(
|
||||
self, agent_id: uuid.UUID, before: Optional[uuid.UUID] = None, after: Optional[uuid.UUID] = None, limit: Optional[int] = 1000
|
||||
) -> GetAgentMessagesResponse:
|
||||
self.interface.clear()
|
||||
[_, messages] = self.server.get_agent_recall_cursor(
|
||||
user_id=self.user_id, agent_id=agent_id, before=before, limit=limit, reverse=True
|
||||
)
|
||||
return GetAgentMessagesResponse(messages=messages)
|
||||
|
||||
def list_models(self) -> ListModelsResponse:
|
||||
|
||||
llm_config = LLMConfigModel(
|
||||
model=self.server.server_llm_config.model,
|
||||
model_endpoint=self.server.server_llm_config.model_endpoint,
|
||||
model_endpoint_type=self.server.server_llm_config.model_endpoint_type,
|
||||
model_wrapper=self.server.server_llm_config.model_wrapper,
|
||||
context_window=self.server.server_llm_config.context_window,
|
||||
)
|
||||
|
||||
return ListModelsResponse(models=[llm_config])
|
||||
|
||||
@@ -38,7 +38,7 @@ class MemGPTConfig:
|
||||
anon_clientid: str = str(uuid.UUID(int=0))
|
||||
|
||||
# preset
|
||||
preset: str = DEFAULT_PRESET
|
||||
preset: str = DEFAULT_PRESET # TODO: rename to system prompt
|
||||
|
||||
# persona parameters
|
||||
persona: str = DEFAULT_PERSONA
|
||||
|
||||
@@ -24,8 +24,6 @@ DEFAULT_PRESET = "memgpt_chat"
|
||||
# Tools
|
||||
BASE_TOOLS = [
|
||||
"send_message",
|
||||
"core_memory_replace",
|
||||
"core_memory_append",
|
||||
"pause_heartbeats",
|
||||
"conversation_search",
|
||||
"conversation_search_date",
|
||||
|
||||
@@ -763,22 +763,14 @@ class AgentState:
|
||||
# system prompt
|
||||
system: str,
|
||||
# config
|
||||
persona: str, # the filename where the persona was originally sourced from # TODO: remove
|
||||
human: str, # the filename where the human was originally sourced from # TODO: remove
|
||||
llm_config: LLMConfig,
|
||||
embedding_config: EmbeddingConfig,
|
||||
preset: str, # TODO: remove
|
||||
# (in-context) state contains:
|
||||
# persona: str # the current persona text
|
||||
# human: str # the current human text
|
||||
# system: str, # system prompt (not required if initializing with a preset)
|
||||
# functions: dict, # schema definitions ONLY (function code linked at runtime)
|
||||
# messages: List[dict], # in-context messages
|
||||
id: Optional[uuid.UUID] = None,
|
||||
state: Optional[dict] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
# messages (TODO: implement this)
|
||||
# _metadata: Optional[dict] = None,
|
||||
_metadata: Optional[dict] = None,
|
||||
):
|
||||
if id is None:
|
||||
self.id = uuid.uuid4()
|
||||
@@ -792,11 +784,8 @@ class AgentState:
|
||||
self.name = name
|
||||
assert self.name, f"AgentState name must be a non-empty string"
|
||||
self.user_id = user_id
|
||||
self.preset = preset
|
||||
# The INITIAL values of the persona and human
|
||||
# The values inside self.state['persona'], self.state['human'] are the CURRENT values
|
||||
self.persona = persona
|
||||
self.human = human
|
||||
|
||||
self.llm_config = llm_config
|
||||
self.embedding_config = embedding_config
|
||||
@@ -811,9 +800,10 @@ class AgentState:
|
||||
|
||||
# system
|
||||
self.system = system
|
||||
assert self.system is not None, f"Must provide system prompt, cannot be None"
|
||||
|
||||
# metadata
|
||||
# self._metadata = _metadata
|
||||
self._metadata = _metadata
|
||||
|
||||
|
||||
class Source:
|
||||
@@ -866,6 +856,7 @@ class Token:
|
||||
|
||||
|
||||
class Preset(BaseModel):
|
||||
# TODO: remove Preset
|
||||
name: str = Field(..., description="The name of the preset.")
|
||||
id: uuid.UUID = Field(default_factory=uuid.uuid4, description="The unique identifier of the preset.")
|
||||
user_id: Optional[uuid.UUID] = Field(None, description="The unique identifier of the user who created the preset.")
|
||||
|
||||
@@ -56,39 +56,6 @@ def pause_heartbeats(self: Agent, minutes: int) -> Optional[str]:
|
||||
pause_heartbeats.__doc__ = pause_heartbeats_docstring
|
||||
|
||||
|
||||
def core_memory_append(self: Agent, name: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory.edit_append(name, content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
|
||||
def core_memory_replace(self: Agent, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory.edit_replace(name, old_content, new_content)
|
||||
self.rebuild_memory()
|
||||
return None
|
||||
|
||||
|
||||
def conversation_search(self: Agent, query: str, page: Optional[int] = 0) -> Optional[str]:
|
||||
"""
|
||||
Search prior conversation history using case-insensitive string matching.
|
||||
|
||||
@@ -2,7 +2,6 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import sys
|
||||
import warnings
|
||||
from textwrap import dedent # remove indentation
|
||||
from types import ModuleType
|
||||
|
||||
@@ -104,97 +103,3 @@ def load_function_file(filepath: str) -> dict:
|
||||
# load all functions in the module
|
||||
function_dict = load_function_set(module)
|
||||
return function_dict
|
||||
|
||||
|
||||
def load_all_function_sets(merge: bool = True, ignore_duplicates: bool = True) -> dict:
|
||||
from memgpt.utils import printd
|
||||
|
||||
# functions/examples/*.py
|
||||
scripts_dir = os.path.dirname(os.path.abspath(__file__)) # Get the directory of the current script
|
||||
function_sets_dir = os.path.join(scripts_dir, "function_sets") # Path to the function_sets directory
|
||||
# List all .py files in the directory (excluding __init__.py)
|
||||
example_module_files = [f for f in os.listdir(function_sets_dir) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# ~/.memgpt/functions/*.py
|
||||
# create if missing
|
||||
if not os.path.exists(USER_FUNCTIONS_DIR):
|
||||
os.makedirs(USER_FUNCTIONS_DIR)
|
||||
user_module_files = [f for f in os.listdir(USER_FUNCTIONS_DIR) if f.endswith(".py") and f != "__init__.py"]
|
||||
|
||||
# combine them both (pull from both examples and user-provided)
|
||||
# all_module_files = example_module_files + user_module_files
|
||||
|
||||
# Add user_scripts_dir to sys.path
|
||||
if USER_FUNCTIONS_DIR not in sys.path:
|
||||
sys.path.append(USER_FUNCTIONS_DIR)
|
||||
|
||||
schemas_and_functions = {}
|
||||
for dir_path, module_files in [(function_sets_dir, example_module_files), (USER_FUNCTIONS_DIR, user_module_files)]:
|
||||
for file in module_files:
|
||||
tags = []
|
||||
module_name = file[:-3] # Remove '.py' from filename
|
||||
if dir_path == USER_FUNCTIONS_DIR:
|
||||
# For user scripts, adjust the module name appropriately
|
||||
module_full_path = os.path.join(dir_path, file)
|
||||
printd(f"Loading user function set from '{module_full_path}'")
|
||||
try:
|
||||
spec = importlib.util.spec_from_file_location(module_name, module_full_path)
|
||||
module = importlib.util.module_from_spec(spec)
|
||||
spec.loader.exec_module(module)
|
||||
except ModuleNotFoundError as e:
|
||||
# Handle missing module imports
|
||||
missing_package = str(e).split("'")[1] # Extract the name of the missing package
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{module_full_path}'!")
|
||||
printd(
|
||||
f"'{file}' imports '{missing_package}', but '{missing_package}' is not installed locally - install python package '{missing_package}' to link functions from '{file}' to MemGPT."
|
||||
)
|
||||
continue
|
||||
except SyntaxError as e:
|
||||
# Handle syntax errors in the module
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}' due to a syntax error: {e}")
|
||||
continue
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python file '{file}': {e}")
|
||||
continue
|
||||
else:
|
||||
# For built-in scripts, use the existing method
|
||||
full_module_name = f"memgpt.functions.function_sets.{module_name}"
|
||||
tags.append(f"memgpt-{module_name}")
|
||||
try:
|
||||
module = importlib.import_module(full_module_name)
|
||||
except Exception as e:
|
||||
# Handle other general exceptions
|
||||
printd(f"{CLI_WARNING_PREFIX}skipped loading python module '{full_module_name}': {e}")
|
||||
continue
|
||||
|
||||
try:
|
||||
# Load the function set
|
||||
function_set = load_function_set(module)
|
||||
# Add the metadata tags
|
||||
for k, v in function_set.items():
|
||||
# print(function_set)
|
||||
v["tags"] = tags
|
||||
schemas_and_functions[module_name] = function_set
|
||||
except ValueError as e:
|
||||
err = f"Error loading function set '{module_name}': {e}"
|
||||
printd(err)
|
||||
warnings.warn(err)
|
||||
|
||||
if merge:
|
||||
# Put all functions from all sets into the same level dict
|
||||
merged_functions = {}
|
||||
for set_name, function_set in schemas_and_functions.items():
|
||||
for function_name, function_info in function_set.items():
|
||||
if function_name in merged_functions:
|
||||
err_msg = f"Duplicate function name '{function_name}' found in function set '{set_name}'"
|
||||
if ignore_duplicates:
|
||||
warnings.warn(err_msg, category=UserWarning, stacklevel=2)
|
||||
else:
|
||||
raise ValueError(err_msg)
|
||||
else:
|
||||
merged_functions[function_name] = function_info
|
||||
return merged_functions
|
||||
else:
|
||||
# Nested dict where the top level is organized by the function set name
|
||||
return schemas_and_functions
|
||||
|
||||
274
memgpt/memory.py
274
memgpt/memory.py
@@ -3,6 +3,8 @@ import uuid
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Optional, Tuple, Union
|
||||
|
||||
from pydantic import BaseModel, validator
|
||||
|
||||
from memgpt.constants import MESSAGE_SUMMARY_REQUEST_ACK, MESSAGE_SUMMARY_WARNING_FRAC
|
||||
from memgpt.data_types import AgentState, Message, Passage
|
||||
from memgpt.embeddings import embedding_model, parse_and_chunk_text, query_embedding
|
||||
@@ -16,96 +18,214 @@ from memgpt.utils import (
|
||||
validate_date_format,
|
||||
)
|
||||
|
||||
# from llama_index import Document
|
||||
# from llama_index.node_parser import SimpleNodeParser
|
||||
|
||||
class MemoryModule(BaseModel):
|
||||
"""Base class for memory modules"""
|
||||
|
||||
description: Optional[str] = None
|
||||
limit: int = 2000
|
||||
value: Optional[Union[List[str], str]] = None
|
||||
|
||||
def __setattr__(self, name, value):
|
||||
"""Run validation if self.value is updated"""
|
||||
super().__setattr__(name, value)
|
||||
if name == "value":
|
||||
# run validation
|
||||
self.__class__.validate(self.dict(exclude_unset=True))
|
||||
|
||||
@validator("value", always=True)
|
||||
def check_value_length(cls, v, values):
|
||||
if v is not None:
|
||||
# Fetching the limit from the values dictionary
|
||||
limit = values.get("limit", 2000) # Default to 2000 if limit is not yet set
|
||||
|
||||
# Check if the value exceeds the limit
|
||||
if isinstance(v, str):
|
||||
length = len(v)
|
||||
elif isinstance(v, list):
|
||||
length = sum(len(item) for item in v)
|
||||
else:
|
||||
raise ValueError("Value must be either a string or a list of strings.")
|
||||
|
||||
if length > limit:
|
||||
error_msg = f"Edit failed: Exceeds {limit} character limit (requested {length})."
|
||||
# TODO: add archival memory error?
|
||||
raise ValueError(error_msg)
|
||||
return v
|
||||
|
||||
def __len__(self):
|
||||
return len(str(self))
|
||||
|
||||
def __str__(self) -> str:
|
||||
if isinstance(self.value, list):
|
||||
return ",".join(self.value)
|
||||
elif isinstance(self.value, str):
|
||||
return self.value
|
||||
else:
|
||||
return ""
|
||||
|
||||
|
||||
class CoreMemory(object):
|
||||
"""Held in-context inside the system message
|
||||
class BaseMemory:
|
||||
|
||||
Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
|
||||
This includes the persona information, essential user details,
|
||||
and any other baseline data you deem necessary for the AI's basic functioning.
|
||||
"""
|
||||
|
||||
def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
|
||||
self.persona = persona
|
||||
self.human = human
|
||||
self.persona_char_limit = persona_char_limit
|
||||
self.human_char_limit = human_char_limit
|
||||
|
||||
# affects the error message the AI will see on overflow inserts
|
||||
self.archival_memory_exists = archival_memory_exists
|
||||
|
||||
def __repr__(self) -> str:
|
||||
return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
|
||||
|
||||
def to_dict(self):
|
||||
return {
|
||||
"persona": self.persona,
|
||||
"human": self.human,
|
||||
}
|
||||
def __init__(self):
|
||||
self.memory = {}
|
||||
|
||||
@classmethod
|
||||
def load(cls, state):
|
||||
return cls(state["persona"], state["human"])
|
||||
def load(cls, state: dict):
|
||||
"""Load memory from dictionary object"""
|
||||
obj = cls()
|
||||
for key, value in state.items():
|
||||
obj.memory[key] = MemoryModule(**value)
|
||||
return obj
|
||||
|
||||
def edit_persona(self, new_persona):
|
||||
if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
|
||||
error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
|
||||
if self.archival_memory_exists:
|
||||
error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
raise ValueError(error_msg)
|
||||
def __str__(self) -> str:
|
||||
"""Representation of the memory in-context"""
|
||||
section_strs = []
|
||||
for section, module in self.memory.items():
|
||||
section_strs.append(f'<{section} characters="{len(module)}/{module.limit}">\n{module.value}\n</{section}>')
|
||||
return "\n".join(section_strs)
|
||||
|
||||
self.persona = new_persona
|
||||
return len(self.persona)
|
||||
def to_dict(self):
|
||||
"""Convert to dictionary representation"""
|
||||
return {key: value.dict() for key, value in self.memory.items()}
|
||||
|
||||
def edit_human(self, new_human):
|
||||
if self.human_char_limit and len(new_human) > self.human_char_limit:
|
||||
error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
|
||||
if self.archival_memory_exists:
|
||||
error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
raise ValueError(error_msg)
|
||||
|
||||
self.human = new_human
|
||||
return len(self.human)
|
||||
class ChatMemory(BaseMemory):
|
||||
|
||||
def edit(self, field, content):
|
||||
if field == "persona":
|
||||
return self.edit_persona(content)
|
||||
elif field == "human":
|
||||
return self.edit_human(content)
|
||||
else:
|
||||
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
def __init__(self, persona: str, human: str, limit: int = 2000):
|
||||
self.memory = {
|
||||
"persona": MemoryModule(name="persona", value=persona, limit=limit),
|
||||
"human": MemoryModule(name="human", value=human, limit=limit),
|
||||
}
|
||||
|
||||
def edit_append(self, field, content, sep="\n"):
|
||||
if field == "persona":
|
||||
new_content = self.persona + sep + content
|
||||
return self.edit_persona(new_content)
|
||||
elif field == "human":
|
||||
new_content = self.human + sep + content
|
||||
return self.edit_human(new_content)
|
||||
else:
|
||||
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
def core_memory_append(self, name: str, content: str) -> Optional[str]:
|
||||
"""
|
||||
Append to the contents of core memory.
|
||||
|
||||
def edit_replace(self, field, old_content, new_content):
|
||||
if len(old_content) == 0:
|
||||
raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
if field == "persona":
|
||||
if old_content in self.persona:
|
||||
new_persona = self.persona.replace(old_content, new_content)
|
||||
return self.edit_persona(new_persona)
|
||||
else:
|
||||
raise ValueError("Content not found in persona (make sure to use exact string)")
|
||||
elif field == "human":
|
||||
if old_content in self.human:
|
||||
new_human = self.human.replace(old_content, new_content)
|
||||
return self.edit_human(new_human)
|
||||
else:
|
||||
raise ValueError("Content not found in human (make sure to use exact string)")
|
||||
else:
|
||||
raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory[name].value += "\n" + content
|
||||
return None
|
||||
|
||||
def core_memory_replace(self, name: str, old_content: str, new_content: str) -> Optional[str]:
|
||||
"""
|
||||
Replace the contents of core memory. To delete memories, use an empty string for new_content.
|
||||
|
||||
Args:
|
||||
name (str): Section of the memory to be edited (persona or human).
|
||||
old_content (str): String to replace. Must be an exact match.
|
||||
new_content (str): Content to write to the memory. All unicode (including emojis) are supported.
|
||||
|
||||
Returns:
|
||||
Optional[str]: None is always returned as this function does not produce a response.
|
||||
"""
|
||||
self.memory[name].value = self.memory[name].value.replace(old_content, new_content)
|
||||
return None
|
||||
|
||||
|
||||
def get_memory_functions(cls: BaseMemory) -> List[callable]:
|
||||
"""Get memory functions for a memory class"""
|
||||
functions = {}
|
||||
for func_name in dir(cls):
|
||||
if func_name.startswith("_") or func_name in ["load", "to_dict"]: # skip base functions
|
||||
continue
|
||||
func = getattr(cls, func_name)
|
||||
if callable(func):
|
||||
functions[func_name] = func
|
||||
return functions
|
||||
|
||||
|
||||
# class CoreMemory(object):
|
||||
# """Held in-context inside the system message
|
||||
#
|
||||
# Core Memory: Refers to the system block, which provides essential, foundational context to the AI.
|
||||
# This includes the persona information, essential user details,
|
||||
# and any other baseline data you deem necessary for the AI's basic functioning.
|
||||
# """
|
||||
#
|
||||
# def __init__(self, persona=None, human=None, persona_char_limit=None, human_char_limit=None, archival_memory_exists=True):
|
||||
# self.persona = persona
|
||||
# self.human = human
|
||||
# self.persona_char_limit = persona_char_limit
|
||||
# self.human_char_limit = human_char_limit
|
||||
#
|
||||
# # affects the error message the AI will see on overflow inserts
|
||||
# self.archival_memory_exists = archival_memory_exists
|
||||
#
|
||||
# def __repr__(self) -> str:
|
||||
# return f"\n### CORE MEMORY ###" + f"\n=== Persona ===\n{self.persona}" + f"\n\n=== Human ===\n{self.human}"
|
||||
#
|
||||
# def to_dict(self):
|
||||
# return {
|
||||
# "persona": self.persona,
|
||||
# "human": self.human,
|
||||
# }
|
||||
#
|
||||
# @classmethod
|
||||
# def load(cls, state):
|
||||
# return cls(state["persona"], state["human"])
|
||||
#
|
||||
# def edit_persona(self, new_persona):
|
||||
# if self.persona_char_limit and len(new_persona) > self.persona_char_limit:
|
||||
# error_msg = f"Edit failed: Exceeds {self.persona_char_limit} character limit (requested {len(new_persona)})."
|
||||
# if self.archival_memory_exists:
|
||||
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'persona' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
# raise ValueError(error_msg)
|
||||
#
|
||||
# self.persona = new_persona
|
||||
# return len(self.persona)
|
||||
#
|
||||
# def edit_human(self, new_human):
|
||||
# if self.human_char_limit and len(new_human) > self.human_char_limit:
|
||||
# error_msg = f"Edit failed: Exceeds {self.human_char_limit} character limit (requested {len(new_human)})."
|
||||
# if self.archival_memory_exists:
|
||||
# error_msg = f"{error_msg} Consider summarizing existing core memories in 'human' and/or moving lower priority content to archival memory to free up space in core memory, then trying again."
|
||||
# raise ValueError(error_msg)
|
||||
#
|
||||
# self.human = new_human
|
||||
# return len(self.human)
|
||||
#
|
||||
# def edit(self, field, content):
|
||||
# if field == "persona":
|
||||
# return self.edit_persona(content)
|
||||
# elif field == "human":
|
||||
# return self.edit_human(content)
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
#
|
||||
# def edit_append(self, field, content, sep="\n"):
|
||||
# if field == "persona":
|
||||
# new_content = self.persona + sep + content
|
||||
# return self.edit_persona(new_content)
|
||||
# elif field == "human":
|
||||
# new_content = self.human + sep + content
|
||||
# return self.edit_human(new_content)
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
#
|
||||
# def edit_replace(self, field, old_content, new_content):
|
||||
# if len(old_content) == 0:
|
||||
# raise ValueError("old_content cannot be an empty string (must specify old_content to replace)")
|
||||
#
|
||||
# if field == "persona":
|
||||
# if old_content in self.persona:
|
||||
# new_persona = self.persona.replace(old_content, new_content)
|
||||
# return self.edit_persona(new_persona)
|
||||
# else:
|
||||
# raise ValueError("Content not found in persona (make sure to use exact string)")
|
||||
# elif field == "human":
|
||||
# if old_content in self.human:
|
||||
# new_human = self.human.replace(old_content, new_content)
|
||||
# return self.edit_human(new_human)
|
||||
# else:
|
||||
# raise ValueError("Content not found in human (make sure to use exact string)")
|
||||
# else:
|
||||
# raise KeyError(f'No memory section named {field} (must be either "persona" or "human")')
|
||||
|
||||
|
||||
def _format_summary_history(message_history: List[Message]):
|
||||
|
||||
@@ -175,10 +175,7 @@ class AgentModel(Base):
|
||||
id = Column(CommonUUID, primary_key=True, default=uuid.uuid4)
|
||||
user_id = Column(CommonUUID, nullable=False)
|
||||
name = Column(String, nullable=False)
|
||||
persona = Column(String)
|
||||
human = Column(String)
|
||||
system = Column(String)
|
||||
preset = Column(String)
|
||||
created_at = Column(DateTime(timezone=True), server_default=func.now())
|
||||
|
||||
# configs
|
||||
@@ -187,6 +184,7 @@ class AgentModel(Base):
|
||||
|
||||
# state
|
||||
state = Column(JSON)
|
||||
_metadata = Column(JSON)
|
||||
|
||||
# tools
|
||||
tools = Column(JSON)
|
||||
@@ -199,15 +197,13 @@ class AgentModel(Base):
|
||||
id=self.id,
|
||||
user_id=self.user_id,
|
||||
name=self.name,
|
||||
persona=self.persona,
|
||||
human=self.human,
|
||||
preset=self.preset,
|
||||
created_at=self.created_at,
|
||||
llm_config=self.llm_config,
|
||||
embedding_config=self.embedding_config,
|
||||
state=self.state,
|
||||
tools=self.tools,
|
||||
system=self.system,
|
||||
_metadata=self._metadata,
|
||||
)
|
||||
|
||||
|
||||
@@ -739,17 +735,21 @@ class MetadataStore:
|
||||
@enforce_types
|
||||
def add_human(self, human: HumanModel):
|
||||
with self.session_maker() as session:
|
||||
if self.get_human(human.name, human.user_id):
|
||||
raise ValueError(f"Human with name {human.name} already exists for user_id {human.user_id}")
|
||||
session.add(human)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_persona(self, persona: PersonaModel):
|
||||
with self.session_maker() as session:
|
||||
if self.get_persona(persona.name, persona.user_id):
|
||||
raise ValueError(f"Persona with name {persona.name} already exists for user_id {persona.user_id}")
|
||||
session.add(persona)
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
def add_preset(self, preset: PresetModel):
|
||||
def add_preset(self, preset: PresetModel): # TODO: remove
|
||||
with self.session_maker() as session:
|
||||
session.add(preset)
|
||||
session.commit()
|
||||
|
||||
@@ -98,9 +98,6 @@ class AgentStateModel(BaseModel):
|
||||
created_at: int = Field(..., description="The unix timestamp of when the agent was created.")
|
||||
|
||||
# preset information
|
||||
preset: str = Field(..., description="The preset used by the agent.")
|
||||
persona: str = Field(..., description="The persona used by the agent.")
|
||||
human: str = Field(..., description="The human used by the agent.")
|
||||
tools: List[str] = Field(..., description="The tools used by the agent.")
|
||||
system: str = Field(..., description="The system prompt used by the agent.")
|
||||
# functions_schema: List[Dict] = Field(..., description="The functions schema used by the agent.")
|
||||
@@ -111,6 +108,7 @@ class AgentStateModel(BaseModel):
|
||||
|
||||
# agent state
|
||||
state: Optional[Dict] = Field(None, description="The state of the agent.")
|
||||
metadata: Optional[Dict] = Field(None, description="The metadata of the agent.")
|
||||
|
||||
|
||||
class CoreMemory(BaseModel):
|
||||
|
||||
@@ -2,23 +2,14 @@ import importlib
|
||||
import inspect
|
||||
import os
|
||||
import uuid
|
||||
from typing import List
|
||||
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA
|
||||
from memgpt.data_types import AgentState, Preset
|
||||
from memgpt.functions.functions import load_all_function_sets, load_function_set
|
||||
from memgpt.functions.functions import load_function_set
|
||||
from memgpt.interface import AgentInterface
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel, ToolModel
|
||||
from memgpt.presets.utils import load_all_presets, load_yaml_file
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.utils import (
|
||||
get_human_text,
|
||||
get_persona_text,
|
||||
list_human_files,
|
||||
list_persona_files,
|
||||
printd,
|
||||
)
|
||||
from memgpt.presets.utils import load_all_presets
|
||||
from memgpt.utils import list_human_files, list_persona_files, printd
|
||||
|
||||
available_presets = load_all_presets()
|
||||
preset_options = list(available_presets.keys())
|
||||
@@ -88,145 +79,13 @@ def add_default_humans_and_personas(user_id: uuid.UUID, ms: MetadataStore):
|
||||
printd(f"Human '{name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
human = HumanModel(name=name, text=text, user_id=user_id)
|
||||
print(human, user_id)
|
||||
ms.add_human(human)
|
||||
|
||||
|
||||
def create_preset_from_file(filename: str, name: str, user_id: uuid.UUID, ms: MetadataStore) -> Preset:
|
||||
preset_config = load_yaml_file(filename)
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
if ms.get_preset(user_id=user_id, name=name) is not None:
|
||||
printd(f"Preset '{name}' already exists for user '{user_id}'")
|
||||
return ms.get_preset(user_id=user_id, name=name)
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
name=name,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
persona_name=DEFAULT_PERSONA,
|
||||
human_name=DEFAULT_HUMAN,
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
ms.create_preset(preset)
|
||||
return preset
|
||||
|
||||
|
||||
def load_preset(preset_name: str, user_id: uuid.UUID):
|
||||
preset_config = available_presets[preset_name]
|
||||
preset_system_prompt = preset_config["system_prompt"]
|
||||
preset_function_set_names = preset_config["functions"]
|
||||
functions_schema = generate_functions_json(preset_function_set_names)
|
||||
|
||||
preset = Preset(
|
||||
user_id=user_id,
|
||||
name=preset_name,
|
||||
system=gpt_system.get_system_text(preset_system_prompt),
|
||||
persona=get_persona_text(DEFAULT_PERSONA),
|
||||
persona_name=DEFAULT_PERSONA,
|
||||
human=get_human_text(DEFAULT_HUMAN),
|
||||
human_name=DEFAULT_HUMAN,
|
||||
functions_schema=functions_schema,
|
||||
)
|
||||
return preset
|
||||
|
||||
|
||||
def add_default_presets(user_id: uuid.UUID, ms: MetadataStore):
|
||||
"""Add the default presets to the metadata store"""
|
||||
# make sure humans/personas added
|
||||
add_default_humans_and_personas(user_id=user_id, ms=ms)
|
||||
|
||||
# make sure base functions added
|
||||
# TODO: pull from functions instead
|
||||
add_default_tools(user_id=user_id, ms=ms)
|
||||
|
||||
# add default presets
|
||||
for preset_name in preset_options:
|
||||
if ms.get_preset(user_id=user_id, name=preset_name) is not None:
|
||||
printd(f"Preset '{preset_name}' already exists for user '{user_id}'")
|
||||
continue
|
||||
|
||||
preset = load_preset(preset_name, user_id)
|
||||
ms.create_preset(preset)
|
||||
|
||||
|
||||
def generate_functions_json(preset_functions: List[str]):
|
||||
"""
|
||||
Generate JSON schema for the functions based on what is locally available.
|
||||
|
||||
TODO: store function definitions in the DB, instead of locally
|
||||
"""
|
||||
# Available functions is a mapping from:
|
||||
# function_name -> {
|
||||
# json_schema: schema
|
||||
# python_function: function
|
||||
# }
|
||||
available_functions = load_all_function_sets()
|
||||
# Filter down the function set based on what the preset requested
|
||||
preset_function_set = {}
|
||||
for f_name in preset_functions:
|
||||
if f_name not in available_functions:
|
||||
raise ValueError(f"Function '{f_name}' was specified in preset, but is not in function library:\n{available_functions.keys()}")
|
||||
preset_function_set[f_name] = available_functions[f_name]
|
||||
assert len(preset_functions) == len(preset_function_set)
|
||||
preset_function_set_schemas = [f_dict["json_schema"] for f_name, f_dict in preset_function_set.items()]
|
||||
printd(f"Available functions:\n", list(preset_function_set.keys()))
|
||||
return preset_function_set_schemas
|
||||
|
||||
|
||||
# def create_agent_from_preset(preset_name, agent_config, model, persona, human, interface, persistence_manager):
|
||||
def create_agent_from_preset(
|
||||
agent_state: AgentState, preset: Preset, interface: AgentInterface, persona_is_file: bool = True, human_is_file: bool = True
|
||||
):
|
||||
"""Initialize a new agent from a preset (combination of system + function)"""
|
||||
raise DeprecationWarning("Function no longer supported - pass a Preset object to Agent.__init__ instead")
|
||||
|
||||
# Input validation
|
||||
if agent_state.persona is None:
|
||||
raise ValueError(f"'persona' not specified in AgentState (required)")
|
||||
if agent_state.human is None:
|
||||
raise ValueError(f"'human' not specified in AgentState (required)")
|
||||
if agent_state.preset is None:
|
||||
raise ValueError(f"'preset' not specified in AgentState (required)")
|
||||
if not (agent_state.state == {} or agent_state.state is None):
|
||||
raise ValueError(f"'state' must be uninitialized (empty)")
|
||||
|
||||
assert preset is not None, "preset cannot be none"
|
||||
preset_name = agent_state.preset
|
||||
assert preset_name == preset.name, f"AgentState preset '{preset_name}' does not match preset name '{preset.name}'"
|
||||
persona = agent_state.persona
|
||||
human = agent_state.human
|
||||
model = agent_state.llm_config.model
|
||||
|
||||
from memgpt.agent import Agent
|
||||
|
||||
# available_presets = load_all_presets()
|
||||
# if preset_name not in available_presets:
|
||||
# raise ValueError(f"Preset '{preset_name}.yaml' not found")
|
||||
# preset = available_presets[preset_name]
|
||||
# preset_system_prompt = preset["system_prompt"]
|
||||
# preset_function_set_names = preset["functions"]
|
||||
# preset_function_set_schemas = generate_functions_json(preset_function_set_names)
|
||||
# Override the following in the AgentState:
|
||||
# persona: str # the current persona text
|
||||
# human: str # the current human text
|
||||
# system: str, # system prompt (not required if initializing with a preset)
|
||||
# functions: dict, # schema definitions ONLY (function code linked at runtime)
|
||||
# messages: List[dict], # in-context messages
|
||||
agent_state.state = {
|
||||
"persona": get_persona_text(persona) if persona_is_file else persona,
|
||||
"human": get_human_text(human) if human_is_file else human,
|
||||
"system": preset.system,
|
||||
"functions": preset.functions_schema,
|
||||
"messages": None,
|
||||
}
|
||||
|
||||
return Agent(
|
||||
agent_state=agent_state,
|
||||
interface=interface,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (model is not None and "gpt-4" in model) else False,
|
||||
)
|
||||
|
||||
@@ -89,9 +89,6 @@ def setup_admin_router(server: SyncServer, interface: QueuingInterface):
|
||||
try:
|
||||
server.ms.create_user(new_user)
|
||||
|
||||
# initialize default presets automatically for user
|
||||
server.initialize_default_presets(new_user.id)
|
||||
|
||||
# make sure we can retrieve the user from the DB too
|
||||
new_user_ret = server.ms.get_user(new_user.id)
|
||||
if new_user_ret is None:
|
||||
|
||||
@@ -79,15 +79,13 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
preset=agent_state.preset,
|
||||
persona=agent_state.persona,
|
||||
human=agent_state.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
tools=agent_state.tools,
|
||||
system=agent_state.system,
|
||||
metadata=agent_state._metadata,
|
||||
),
|
||||
last_run_at=None, # TODO
|
||||
sources=attached_sources,
|
||||
@@ -125,9 +123,6 @@ def setup_agents_config_router(server: SyncServer, interface: QueuingInterface,
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
preset=agent_state.preset,
|
||||
persona=agent_state.persona,
|
||||
human=agent_state.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
|
||||
@@ -6,6 +6,7 @@ from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.constants import BASE_TOOLS
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.models.pydantic_models import (
|
||||
AgentStateModel,
|
||||
EmbeddingConfigModel,
|
||||
@@ -69,8 +70,11 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
human = request.config["human"] if "human" in request.config else None
|
||||
persona_name = request.config["persona_name"] if "persona_name" in request.config else None
|
||||
persona = request.config["persona"] if "persona" in request.config else None
|
||||
preset = request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
|
||||
request.config["preset"] if ("preset" in request.config and request.config["preset"]) else settings.default_preset
|
||||
tool_names = request.config["function_names"]
|
||||
metadata = request.config["metadata"] if "metadata" in request.config else {}
|
||||
metadata["human"] = human_name
|
||||
metadata["persona"] = persona_name
|
||||
|
||||
# TODO: remove this -- should be added based on create agent fields
|
||||
if isinstance(tool_names, str): # TODO: fix this on clinet side?
|
||||
@@ -82,56 +86,54 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
tool_names.append(name)
|
||||
assert isinstance(tool_names, list), "Tool names must be a list of strings."
|
||||
|
||||
# TODO: eventually remove this - should support general memory at the REST endpoint
|
||||
memory = ChatMemory(persona=persona, human=human)
|
||||
|
||||
try:
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
# **request.config
|
||||
# TODO turn into a pydantic model
|
||||
name=request.config["name"],
|
||||
preset=preset,
|
||||
persona_name=persona_name,
|
||||
human_name=human_name,
|
||||
persona=persona,
|
||||
human=human,
|
||||
memory=memory,
|
||||
# persona_name=persona_name,
|
||||
# human_name=human_name,
|
||||
# persona=persona,
|
||||
# human=human,
|
||||
# llm_config=LLMConfigModel(
|
||||
# model=request.config['model'],
|
||||
# )
|
||||
# tools
|
||||
tools=tool_names,
|
||||
metadata=metadata,
|
||||
# function_names=request.config["function_names"].split(",") if "function_names" in request.config else None,
|
||||
)
|
||||
llm_config = LLMConfigModel(**vars(agent_state.llm_config))
|
||||
embedding_config = EmbeddingConfigModel(**vars(agent_state.embedding_config))
|
||||
|
||||
# TODO when get_preset returns a PresetModel instead of Preset, we can remove this packing/unpacking line
|
||||
# TODO: remove
|
||||
preset = server.ms.get_preset(name=agent_state.preset, user_id=user_id)
|
||||
|
||||
return CreateAgentResponse(
|
||||
agent_state=AgentStateModel(
|
||||
id=agent_state.id,
|
||||
name=agent_state.name,
|
||||
user_id=agent_state.user_id,
|
||||
preset=agent_state.preset,
|
||||
persona=agent_state.persona,
|
||||
human=agent_state.human,
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
state=agent_state.state,
|
||||
created_at=int(agent_state.created_at.timestamp()),
|
||||
tools=tool_names,
|
||||
system=agent_state.system,
|
||||
metadata=agent_state._metadata,
|
||||
),
|
||||
preset=PresetModel(
|
||||
name=preset.name,
|
||||
id=preset.id,
|
||||
user_id=preset.user_id,
|
||||
description=preset.description,
|
||||
created_at=preset.created_at,
|
||||
system=preset.system,
|
||||
persona=preset.persona,
|
||||
human=preset.human,
|
||||
functions_schema=preset.functions_schema,
|
||||
preset=PresetModel( # TODO: remove (placeholder to avoid breaking frontend)
|
||||
name="dummy_preset",
|
||||
id=agent_state.id,
|
||||
user_id=agent_state.user_id,
|
||||
description="",
|
||||
created_at=agent_state.created_at,
|
||||
system=agent_state.system,
|
||||
persona="",
|
||||
human="",
|
||||
functions_schema=[],
|
||||
),
|
||||
)
|
||||
except Exception as e:
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import HumanModel
|
||||
@@ -46,4 +46,24 @@ def setup_humans_index_router(server: SyncServer, interface: QueuingInterface, p
|
||||
server.ms.add_human(new_human)
|
||||
return HumanModel(id=human_id, text=request.text, name=request.name, user_id=user_id)
|
||||
|
||||
@router.delete("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
|
||||
async def delete_human(
|
||||
human_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
human = server.ms.delete_human(human_name, user_id=user_id)
|
||||
return human
|
||||
|
||||
@router.get("/humans/{human_name}", tags=["humans"], response_model=HumanModel)
|
||||
async def get_human(
|
||||
human_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
human = server.ms.get_human(human_name, user_id=user_id)
|
||||
if human is None:
|
||||
raise HTTPException(status_code=404, detail="Human not found")
|
||||
return human
|
||||
|
||||
return router
|
||||
|
||||
@@ -2,7 +2,7 @@ import uuid
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Body, Depends
|
||||
from fastapi import APIRouter, Body, Depends, HTTPException
|
||||
from pydantic import BaseModel, Field
|
||||
|
||||
from memgpt.models.pydantic_models import PersonaModel
|
||||
@@ -47,4 +47,24 @@ def setup_personas_index_router(server: SyncServer, interface: QueuingInterface,
|
||||
server.ms.add_persona(new_persona)
|
||||
return PersonaModel(id=persona_id, text=request.text, name=request.name, user_id=user_id)
|
||||
|
||||
@router.delete("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
|
||||
async def delete_persona(
|
||||
persona_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
persona = server.ms.delete_persona(persona_name, user_id=user_id)
|
||||
return persona
|
||||
|
||||
@router.get("/personas/{persona_name}", tags=["personas"], response_model=PersonaModel)
|
||||
async def get_persona(
|
||||
persona_name: str,
|
||||
user_id: uuid.UUID = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
persona = server.ms.get_persona(persona_name, user_id=user_id)
|
||||
if persona is None:
|
||||
raise HTTPException(status_code=404, detail="Persona not found")
|
||||
return persona
|
||||
|
||||
return router
|
||||
|
||||
@@ -22,6 +22,7 @@ class CreateToolRequest(BaseModel):
|
||||
source_code: str = Field(..., description="The source code of the function.")
|
||||
source_type: Optional[Literal["python"]] = Field(None, description="The type of the source code.")
|
||||
tags: Optional[List[str]] = Field(None, description="Metadata tags.")
|
||||
update: Optional[bool] = Field(False, description="Update the tool if it already exists.")
|
||||
|
||||
|
||||
class CreateToolResponse(BaseModel):
|
||||
@@ -87,8 +88,10 @@ def setup_user_tools_index_router(server: SyncServer, interface: QueuingInterfac
|
||||
source_type=request.source_type,
|
||||
tags=request.tags,
|
||||
user_id=user_id,
|
||||
exists_ok=request.update,
|
||||
)
|
||||
except Exception as e:
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}")
|
||||
print(e)
|
||||
raise HTTPException(status_code=500, detail=f"Failed to create tool: {e}, exists_ok={request.update}")
|
||||
|
||||
return router
|
||||
|
||||
@@ -15,8 +15,6 @@ import memgpt.server.utils as server_utils
|
||||
import memgpt.system as system
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.cli.cli_config import get_model_options
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import JSON_ENSURE_ASCII, JSON_LOADS_STRICT
|
||||
@@ -37,6 +35,7 @@ from memgpt.data_types import (
|
||||
from memgpt.interface import AgentInterface # abstract
|
||||
from memgpt.interface import CLIInterface # for printing to terminal
|
||||
from memgpt.log import get_logger
|
||||
from memgpt.memory import BaseMemory
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.chat_completion_response import UsageStatistics
|
||||
from memgpt.models.pydantic_models import (
|
||||
@@ -44,10 +43,14 @@ from memgpt.models.pydantic_models import (
|
||||
HumanModel,
|
||||
MemGPTUsageStatistics,
|
||||
PassageModel,
|
||||
PersonaModel,
|
||||
PresetModel,
|
||||
SourceModel,
|
||||
ToolModel,
|
||||
)
|
||||
|
||||
# from memgpt.llm_api_tools import openai_get_model_list, azure_openai_get_model_list, smart_urljoin
|
||||
from memgpt.prompts import gpt_system
|
||||
from memgpt.utils import create_random_username
|
||||
|
||||
logger = get_logger(__name__)
|
||||
@@ -212,7 +215,7 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Initialize the connection to the DB
|
||||
self.config = MemGPTConfig.load()
|
||||
print(f"server :: loading configuration from '{self.config.config_path}'")
|
||||
logger.info(f"loading configuration from '{self.config.config_path}'")
|
||||
assert self.config.persona is not None, "Persona must be set in the config"
|
||||
assert self.config.human is not None, "Human must be set in the config"
|
||||
|
||||
@@ -274,14 +277,9 @@ class SyncServer(LockingServer):
|
||||
self.ms.update_user(user)
|
||||
else:
|
||||
self.ms.create_user(user)
|
||||
presets.add_default_presets(user_id, self.ms)
|
||||
|
||||
# NOTE: removed, since server should be multi-user
|
||||
## Create the default user
|
||||
# base_user_id = uuid.UUID(self.config.anon_clientid)
|
||||
# if not self.ms.get_user(user_id=base_user_id):
|
||||
# base_user = User(id=base_user_id)
|
||||
# self.ms.create_user(base_user)
|
||||
# add global default tools
|
||||
presets.add_default_tools(None, self.ms)
|
||||
|
||||
def save_agents(self):
|
||||
"""Saves all the agents that are in the in-memory object store"""
|
||||
@@ -336,7 +334,14 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Instantiate an agent object using the state retrieved
|
||||
logger.info(f"Creating an agent object")
|
||||
tool_objs = [self.ms.get_tool(name, user_id) for name in agent_state.tools] # get tool objects
|
||||
tool_objs = []
|
||||
for name in agent_state.tools:
|
||||
tool_obj = self.ms.get_tool(name, user_id)
|
||||
if not tool_obj:
|
||||
logger.exception(f"Tool {name} does not exist for user {user_id}")
|
||||
raise ValueError(f"Tool {name} does not exist for user {user_id}")
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
memgpt_agent = Agent(agent_state=agent_state, interface=interface, tools=tool_objs)
|
||||
|
||||
# Add the agent to the in-memory store and return its reference
|
||||
@@ -350,11 +355,11 @@ class SyncServer(LockingServer):
|
||||
|
||||
def _get_or_load_agent(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> Agent:
|
||||
"""Check if the agent is in-memory, then load"""
|
||||
logger.info(f"Checking for agent user_id={user_id} agent_id={agent_id}")
|
||||
logger.debug(f"Checking for agent user_id={user_id} agent_id={agent_id}")
|
||||
# TODO: consider disabling loading cached agents due to potential concurrency issues
|
||||
memgpt_agent = self._get_agent(user_id=user_id, agent_id=agent_id)
|
||||
if not memgpt_agent:
|
||||
logger.info(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
||||
logger.debug(f"Agent not loaded, loading agent user_id={user_id} agent_id={agent_id}")
|
||||
memgpt_agent = self._load_agent(user_id=user_id, agent_id=agent_id)
|
||||
return memgpt_agent
|
||||
|
||||
@@ -426,7 +431,6 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
memgpt_agent = self._get_or_load_agent(user_id=user_id, agent_id=agent_id)
|
||||
# print("AGENT", memgpt_agent.agent_state.id, memgpt_agent.agent_state.user_id)
|
||||
|
||||
if command.lower() == "exit":
|
||||
# exit not supported on server.py
|
||||
@@ -657,49 +661,58 @@ class SyncServer(LockingServer):
|
||||
def create_user(
|
||||
self,
|
||||
user_config: Optional[Union[dict, User]] = {},
|
||||
exists_ok: bool = False,
|
||||
):
|
||||
"""Create a new user using a config"""
|
||||
if not isinstance(user_config, dict):
|
||||
raise ValueError(f"user_config must be provided as a dictionary")
|
||||
|
||||
if "id" in user_config:
|
||||
existing_user = self.ms.get_user(user_id=user_config["id"])
|
||||
if existing_user:
|
||||
if exists_ok:
|
||||
presets.add_default_humans_and_personas(existing_user.id, self.ms)
|
||||
return existing_user
|
||||
else:
|
||||
raise ValueError(f"User with ID {existing_user.id} already exists")
|
||||
|
||||
user = User(
|
||||
id=user_config["id"] if "id" in user_config else None,
|
||||
)
|
||||
self.ms.create_user(user)
|
||||
logger.info(f"Created new user from config: {user}")
|
||||
|
||||
# add default for the user
|
||||
presets.add_default_humans_and_personas(user.id, self.ms)
|
||||
|
||||
return user
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
tools: List[str], # list of tool names (handles) to include
|
||||
# system: str, # system prompt
|
||||
memory: BaseMemory,
|
||||
system: Optional[str] = None,
|
||||
metadata: Optional[dict] = {}, # includes human/persona names
|
||||
name: Optional[str] = None,
|
||||
preset: Optional[str] = None, # TODO: remove eventually
|
||||
# model config
|
||||
llm_config: Optional[LLMConfig] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
# TODO: refactor this to be a more general memory configuration
|
||||
system: Optional[str] = None, # prompt value
|
||||
persona: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
human: Optional[str] = None, # NOTE: this is not the name, it's the memory init value
|
||||
persona_name: Optional[str] = None, # TODO: remove
|
||||
human_name: Optional[str] = None, # TODO: remove
|
||||
) -> AgentState:
|
||||
"""Create a new agent using a config"""
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
if interface is None:
|
||||
# interface = self.default_interface
|
||||
interface = self.default_interface_factory()
|
||||
|
||||
# if persistence_manager is None:
|
||||
# persistence_manager = self.default_persistence_manager_cls(agent_config=agent_config)
|
||||
# system prompt (get default if None)
|
||||
if system is None:
|
||||
system = gpt_system.get_system_text(self.config.preset)
|
||||
|
||||
# create agent name
|
||||
if name is None:
|
||||
name = create_random_username()
|
||||
|
||||
@@ -708,90 +721,34 @@ class SyncServer(LockingServer):
|
||||
if not user:
|
||||
raise ValueError(f"cannot find user with associated client id: {user_id}")
|
||||
|
||||
# NOTE: you MUST add to the metadata store before creating the agent, otherwise the storage connectors will error on creation
|
||||
# TODO: fix this db dependency and remove
|
||||
# self.ms.#create_agent(agent_state)
|
||||
|
||||
# TODO modify to do creation via preset
|
||||
try:
|
||||
preset_obj = self.ms.get_preset(name=preset if preset else self.config.preset, user_id=user_id)
|
||||
preset_override = False
|
||||
assert preset_obj is not None, f"preset {preset if preset else self.config.preset} does not exist"
|
||||
logger.debug(f"Attempting to create agent from preset:\n{preset_obj}")
|
||||
|
||||
# system prompt
|
||||
if system is None:
|
||||
system = preset_obj.system
|
||||
else:
|
||||
preset_obj.system = system
|
||||
preset_override = True
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
if human is not None and human != preset_obj.human:
|
||||
preset_override = True
|
||||
preset_obj.human = human
|
||||
# This is a check for a common bug where users were providing filenames instead of values
|
||||
# try:
|
||||
# get_human_text(human)
|
||||
# raise ValueError(human)
|
||||
# raise UserWarning(
|
||||
# f"It looks like there is a human file named {human} - did you mean to pass the file contents to the `human` arg?"
|
||||
# )
|
||||
# except:
|
||||
# pass
|
||||
if persona is not None:
|
||||
preset_override = True
|
||||
preset_obj.persona = persona
|
||||
# try:
|
||||
# get_persona_text(persona)
|
||||
# raise ValueError(persona)
|
||||
# raise UserWarning(
|
||||
# f"It looks like there is a persona file named {persona} - did you mean to pass the file contents to the `persona` arg?"
|
||||
# )
|
||||
# except:
|
||||
# pass
|
||||
if human_name is not None and human_name != preset_obj.human_name:
|
||||
preset_override = True
|
||||
preset_obj.human_name = human_name
|
||||
if persona_name is not None and persona_name != preset_obj.persona_name:
|
||||
preset_override = True
|
||||
preset_obj.persona_name = persona_name
|
||||
|
||||
# model configuration
|
||||
llm_config = llm_config if llm_config else self.server_llm_config
|
||||
embedding_config = embedding_config if embedding_config else self.server_embedding_config
|
||||
|
||||
# get tools
|
||||
# get tools + make sure they exist
|
||||
tool_objs = []
|
||||
for tool_name in tools:
|
||||
tool_obj = self.ms.get_tool(tool_name, user_id=user_id)
|
||||
assert tool_obj is not None, f"Tool {tool_name} does not exist"
|
||||
assert tool_obj, f"Tool {tool_name} does not exist"
|
||||
tool_objs.append(tool_obj)
|
||||
|
||||
# If the user overrode any parts of the preset, we need to create a new preset to refer back to
|
||||
if preset_override:
|
||||
# Change the name and uuid
|
||||
preset_obj = Preset.clone(preset_obj=preset_obj)
|
||||
# Then write out to the database for storage
|
||||
self.ms.create_preset(preset=preset_obj)
|
||||
|
||||
# TODO: add metadata
|
||||
agent_state = AgentState(
|
||||
name=name,
|
||||
user_id=user_id,
|
||||
persona=preset_obj.persona_name, # TODO: remove
|
||||
human=preset_obj.human_name, # TODO: remove
|
||||
tools=tools, # name=id for tools
|
||||
llm_config=llm_config,
|
||||
embedding_config=embedding_config,
|
||||
system=system,
|
||||
preset=preset, # TODO: remove
|
||||
state={"persona": preset_obj.persona, "human": preset_obj.human, "system": system, "messages": None},
|
||||
state={"system": system, "messages": None, "memory": memory.to_dict()},
|
||||
_metadata=metadata,
|
||||
)
|
||||
|
||||
agent = Agent(
|
||||
interface=interface,
|
||||
agent_state=agent_state,
|
||||
tools=tool_objs,
|
||||
# embedding_config=embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=True if (llm_config.model is not None and "gpt-4" in llm_config.model) else False,
|
||||
)
|
||||
@@ -805,10 +762,11 @@ class SyncServer(LockingServer):
|
||||
logger.exception(f"Failed to delete_agent:\n{delete_e}")
|
||||
raise e
|
||||
|
||||
# save agent
|
||||
save_agent(agent, self.ms)
|
||||
|
||||
logger.info(f"Created new agent from config: {agent}")
|
||||
|
||||
# return AgentState
|
||||
return agent.agent_state
|
||||
|
||||
def delete_agent(
|
||||
@@ -872,8 +830,8 @@ class SyncServer(LockingServer):
|
||||
agent_config = {
|
||||
"id": agent_state.id,
|
||||
"name": agent_state.name,
|
||||
"human": agent_state.human,
|
||||
"persona": agent_state.persona,
|
||||
"human": agent_state._metadata.get("human", None),
|
||||
"persona": agent_state._metadata.get("persona", None),
|
||||
"created_at": agent_state.created_at.isoformat(),
|
||||
}
|
||||
return agent_config
|
||||
@@ -905,8 +863,8 @@ class SyncServer(LockingServer):
|
||||
# TODO hack for frontend, remove
|
||||
# (top level .persona is persona_name, and nested memory.persona is the state)
|
||||
# TODO: eventually modify this to be contained in the metadata
|
||||
return_dict["persona"] = agent_state.human
|
||||
return_dict["human"] = agent_state.persona
|
||||
return_dict["persona"] = agent_state._metadata.get("persona", None)
|
||||
return_dict["human"] = agent_state._metadata.get("human", None)
|
||||
|
||||
# Add information about tools
|
||||
# TODO memgpt_agent should really have a field of List[ToolModel]
|
||||
@@ -918,10 +876,7 @@ class SyncServer(LockingServer):
|
||||
recall_memory = memgpt_agent.persistence_manager.recall_memory
|
||||
archival_memory = memgpt_agent.persistence_manager.archival_memory
|
||||
memory_obj = {
|
||||
"core_memory": {
|
||||
"persona": core_memory.persona,
|
||||
"human": core_memory.human,
|
||||
},
|
||||
"core_memory": {section: module.value for (section, module) in core_memory.memory.items()},
|
||||
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
||||
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
||||
}
|
||||
@@ -941,12 +896,31 @@ class SyncServer(LockingServer):
|
||||
# Sort agents by "last_run" in descending order, most recent first
|
||||
agents_states_dicts.sort(key=lambda x: x["last_run"], reverse=True)
|
||||
|
||||
logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}")
|
||||
logger.debug(f"Retrieved {len(agents_states)} agents for user {user_id}")
|
||||
return {
|
||||
"num_agents": len(agents_states),
|
||||
"agents": agents_states_dicts,
|
||||
}
|
||||
|
||||
def list_personas(self, user_id: uuid.UUID):
|
||||
return self.ms.list_personas(user_id=user_id)
|
||||
|
||||
def get_persona(self, name: str, user_id: uuid.UUID):
|
||||
return self.ms.get_persona(name=name, user_id=user_id)
|
||||
|
||||
def add_persona(self, persona: PersonaModel):
|
||||
name = persona.name
|
||||
user_id = persona.user_id
|
||||
self.ms.add_persona(persona=persona)
|
||||
persona = self.ms.get_persona(name=name, user_id=user_id)
|
||||
return persona
|
||||
|
||||
def update_persona(self, persona: PersonaModel):
|
||||
return self.ms.update_persona(persona=persona)
|
||||
|
||||
def delete_persona(self, name: str, user_id: uuid.UUID):
|
||||
return self.ms.delete_persona(name=name, user_id=user_id)
|
||||
|
||||
def list_humans(self, user_id: uuid.UUID):
|
||||
return self.ms.list_humans(user_id=user_id)
|
||||
|
||||
@@ -954,7 +928,11 @@ class SyncServer(LockingServer):
|
||||
return self.ms.get_human(name=name, user_id=user_id)
|
||||
|
||||
def add_human(self, human: HumanModel):
|
||||
return self.ms.add_human(human=human)
|
||||
name = human.name
|
||||
user_id = human.user_id
|
||||
self.ms.add_human(human=human)
|
||||
human = self.ms.get_human(name=name, user_id=user_id)
|
||||
return human
|
||||
|
||||
def update_human(self, human: HumanModel):
|
||||
return self.ms.update_human(human=human)
|
||||
@@ -983,11 +961,9 @@ class SyncServer(LockingServer):
|
||||
recall_memory = memgpt_agent.persistence_manager.recall_memory
|
||||
archival_memory = memgpt_agent.persistence_manager.archival_memory
|
||||
|
||||
# NOTE
|
||||
memory_obj = {
|
||||
"core_memory": {
|
||||
"persona": core_memory.persona,
|
||||
"human": core_memory.human,
|
||||
},
|
||||
"core_memory": {key: value.value for key, value in core_memory.memory.items()},
|
||||
"recall_memory": len(recall_memory) if recall_memory is not None else None,
|
||||
"archival_memory": len(archival_memory) if archival_memory is not None else None,
|
||||
}
|
||||
@@ -1245,23 +1221,18 @@ class SyncServer(LockingServer):
|
||||
new_core_memory = old_core_memory.copy()
|
||||
|
||||
modified = False
|
||||
if "persona" in new_memory_contents and new_memory_contents["persona"] is not None:
|
||||
new_persona = new_memory_contents["persona"]
|
||||
if old_core_memory["persona"] != new_persona:
|
||||
new_core_memory["persona"] = new_persona
|
||||
memgpt_agent.memory.edit_persona(new_persona)
|
||||
modified = True
|
||||
|
||||
if "human" in new_memory_contents and new_memory_contents["human"] is not None:
|
||||
new_human = new_memory_contents["human"]
|
||||
if old_core_memory["human"] != new_human:
|
||||
new_core_memory["human"] = new_human
|
||||
memgpt_agent.memory.edit_human(new_human)
|
||||
for key, value in new_memory_contents.items():
|
||||
if value is None:
|
||||
continue
|
||||
if key in old_core_memory and old_core_memory[key] != value:
|
||||
memgpt_agent.memory.memory[key].value = value # update agent memory
|
||||
modified = True
|
||||
|
||||
# If we modified the memory contents, we need to rebuild the memory block inside the system message
|
||||
if modified:
|
||||
memgpt_agent.rebuild_memory()
|
||||
# save agent
|
||||
save_agent(memgpt_agent, self.ms)
|
||||
|
||||
return {
|
||||
"old_core_memory": old_core_memory,
|
||||
@@ -1290,6 +1261,7 @@ class SyncServer(LockingServer):
|
||||
logger.exception(f"Failed to update agent name with:\n{str(e)}")
|
||||
raise ValueError(f"Failed to update agent name in database")
|
||||
|
||||
assert isinstance(memgpt_agent.agent_state.id, uuid.UUID)
|
||||
return memgpt_agent.agent_state
|
||||
|
||||
def delete_user(self, user_id: uuid.UUID):
|
||||
@@ -1504,7 +1476,7 @@ class SyncServer(LockingServer):
|
||||
tool (ToolModel): Tool object
|
||||
"""
|
||||
name = json_schema["name"]
|
||||
tool = self.ms.get_tool(name)
|
||||
tool = self.ms.get_tool(name, user_id=user_id)
|
||||
if tool: # check if function already exists
|
||||
if exists_ok:
|
||||
# update existing tool
|
||||
@@ -1514,7 +1486,7 @@ class SyncServer(LockingServer):
|
||||
tool.source_type = source_type
|
||||
self.ms.update_tool(tool)
|
||||
else:
|
||||
raise ValueError(f"Tool with name {name} already exists.")
|
||||
raise ValueError(f"[server] Tool with name {name} already exists.")
|
||||
else:
|
||||
# create new tool
|
||||
tool = ToolModel(
|
||||
|
||||
@@ -1,121 +0,0 @@
|
||||
import inspect
|
||||
import json
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
from memgpt import constants, create_client
|
||||
from memgpt.functions.functions import USER_FUNCTIONS_DIR
|
||||
from memgpt.models import chat_completion_response
|
||||
from memgpt.utils import assistant_function_to_tool
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from tests.utils import create_config, wipe_config
|
||||
|
||||
|
||||
def hello_world(self) -> str:
|
||||
"""Test function for agent to gain access to
|
||||
|
||||
Returns:
|
||||
str: A message for the world
|
||||
"""
|
||||
return "hello, world!"
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def agent():
|
||||
"""Create a test agent that we can call functions on"""
|
||||
wipe_config()
|
||||
global client
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
create_config("openai")
|
||||
else:
|
||||
create_config("memgpt_hosted")
|
||||
|
||||
# create memgpt client
|
||||
client = create_client()
|
||||
|
||||
# ensure user exists
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
agent_state = client.create_agent()
|
||||
|
||||
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def hello_world_function():
|
||||
with open(os.path.join(USER_FUNCTIONS_DIR, "hello_world.py"), "w", encoding="utf-8") as f:
|
||||
f.write(inspect.getsource(hello_world))
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def ai_function_call():
|
||||
return chat_completion_response.Message(
|
||||
**assistant_function_to_tool(
|
||||
{
|
||||
"role": "assistant",
|
||||
"content": "I will now call hello world",
|
||||
"function_call": {
|
||||
"name": "hello_world",
|
||||
"arguments": json.dumps({}),
|
||||
},
|
||||
}
|
||||
)
|
||||
)
|
||||
|
||||
return
|
||||
|
||||
|
||||
def test_add_function_happy(agent, hello_world_function, ai_function_call):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
msgs, heartbeat_req, function_failed = agent._handle_ai_response(ai_function_call)
|
||||
content = json.loads(msgs[-1].to_openai_dict()["content"], strict=constants.JSON_LOADS_STRICT)
|
||||
assert content["message"] == "hello, world!"
|
||||
assert content["status"] == "OK"
|
||||
assert not function_failed
|
||||
|
||||
|
||||
def test_add_function_already_loaded(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
# no exception for duplicate loading
|
||||
agent.add_function("hello_world")
|
||||
|
||||
|
||||
def test_add_function_not_exist(agent):
|
||||
# pytest assert exception
|
||||
with pytest.raises(ValueError):
|
||||
agent.add_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_function_happy(agent, hello_world_function):
|
||||
agent.add_function("hello_world")
|
||||
|
||||
# ensure function is loaded
|
||||
assert "hello_world" in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" in agent.functions_python.keys()
|
||||
|
||||
agent.remove_function("hello_world")
|
||||
|
||||
assert "hello_world" not in [f_schema["name"] for f_schema in agent.functions]
|
||||
assert "hello_world" not in agent.functions_python.keys()
|
||||
|
||||
|
||||
def test_remove_function_not_exist(agent):
|
||||
# do not raise error
|
||||
agent.remove_function("non_existent")
|
||||
|
||||
|
||||
def test_remove_base_function_fails(agent):
|
||||
with pytest.raises(ValueError):
|
||||
agent.remove_function("send_message")
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
pytest.main(["-vv", os.path.abspath(__file__)])
|
||||
@@ -1,11 +1,9 @@
|
||||
import os
|
||||
import uuid
|
||||
|
||||
import pytest
|
||||
|
||||
import memgpt.functions.function_sets.base as base_functions
|
||||
from memgpt import create_client
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
from .utils import create_config, wipe_config
|
||||
|
||||
@@ -28,8 +26,7 @@ def agent_obj():
|
||||
agent_state = client.create_agent()
|
||||
|
||||
global agent_obj
|
||||
user_id = uuid.UUID(TEST_MEMGPT_CONFIG.anon_clientid)
|
||||
agent_obj = client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
agent_obj = client.server._get_or_load_agent(user_id=client.user_id, agent_id=agent_state.id)
|
||||
yield agent_obj
|
||||
|
||||
client.delete_agent(agent_obj.agent_state.id)
|
||||
|
||||
@@ -68,10 +68,7 @@ def run_server():
|
||||
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
params=[ # whether to use REST API server
|
||||
{"server": True},
|
||||
# {"server": False} # TODO: add when implemented
|
||||
],
|
||||
params=[{"server": True}, {"server": False}], # whether to use REST API server
|
||||
scope="module",
|
||||
)
|
||||
def client(request):
|
||||
@@ -90,8 +87,8 @@ def client(request):
|
||||
# create user via admin client
|
||||
admin = Admin(server_url, test_server_token)
|
||||
response = admin.create_user(test_user_id) # Adjust as per your client's method
|
||||
response.user_id
|
||||
token = response.api_key
|
||||
|
||||
else:
|
||||
# use local client (no server)
|
||||
token = None
|
||||
@@ -109,7 +106,7 @@ def client(request):
|
||||
# Fixture for test agent
|
||||
@pytest.fixture(scope="module")
|
||||
def agent(client):
|
||||
agent_state = client.create_agent(name=test_agent_name, preset=test_preset_name)
|
||||
agent_state = client.create_agent(name=test_agent_name)
|
||||
print("AGENT ID", agent_state.id)
|
||||
yield agent_state
|
||||
|
||||
@@ -123,11 +120,11 @@ def test_agent(client, agent):
|
||||
# test client.rename_agent
|
||||
new_name = "RenamedTestAgent"
|
||||
client.rename_agent(agent_id=agent.id, new_name=new_name)
|
||||
renamed_agent = client.get_agent(agent_id=str(agent.id))
|
||||
renamed_agent = client.get_agent(agent_id=agent.id)
|
||||
assert renamed_agent.name == new_name, "Agent renaming failed"
|
||||
|
||||
# test client.delete_agent and client.agent_exists
|
||||
delete_agent = client.create_agent(name="DeleteTestAgent", preset=test_preset_name)
|
||||
delete_agent = client.create_agent(name="DeleteTestAgent")
|
||||
assert client.agent_exists(agent_id=delete_agent.id), "Agent creation failed"
|
||||
client.delete_agent(agent_id=delete_agent.id)
|
||||
assert client.agent_exists(agent_id=delete_agent.id) == False, "Agent deletion failed"
|
||||
@@ -140,7 +137,7 @@ def test_memory(client, agent):
|
||||
print("MEMORY", memory_response)
|
||||
|
||||
updated_memory = {"human": "Updated human memory", "persona": "Updated persona memory"}
|
||||
client.update_agent_core_memory(agent_id=str(agent.id), new_memory_contents=updated_memory)
|
||||
client.update_agent_core_memory(agent_id=agent.id, new_memory_contents=updated_memory)
|
||||
updated_memory_response = client.get_agent_memory(agent_id=agent.id)
|
||||
assert (
|
||||
updated_memory_response.core_memory.human == updated_memory["human"]
|
||||
@@ -152,10 +149,10 @@ def test_agent_interactions(client, agent):
|
||||
_reset_config()
|
||||
|
||||
message = "Hello, agent!"
|
||||
message_response = client.user_message(agent_id=str(agent.id), message=message)
|
||||
message_response = client.user_message(agent_id=agent.id, message=message)
|
||||
|
||||
command = "/memory"
|
||||
command_response = client.run_command(agent_id=str(agent.id), command=command)
|
||||
command_response = client.run_command(agent_id=agent.id, command=command)
|
||||
print("command", command_response)
|
||||
|
||||
|
||||
@@ -197,11 +194,15 @@ def test_humans_personas(client, agent):
|
||||
print("PERSONAS", personas_response)
|
||||
|
||||
persona_name = "TestPersona"
|
||||
if client.get_persona(persona_name):
|
||||
client.delete_persona(persona_name)
|
||||
persona = client.create_persona(name=persona_name, persona="Persona text")
|
||||
assert persona.name == persona_name
|
||||
assert persona.text == "Persona text", "Creating persona failed"
|
||||
|
||||
human_name = "TestHuman"
|
||||
if client.get_human(human_name):
|
||||
client.delete_human(human_name)
|
||||
human = client.create_human(name=human_name, human="Human text")
|
||||
assert human.name == human_name
|
||||
assert human.text == "Human text", "Creating human failed"
|
||||
@@ -285,55 +286,55 @@ def test_sources(client, agent):
|
||||
client.delete_source(source.id)
|
||||
|
||||
|
||||
def test_presets(client, agent):
|
||||
_reset_config()
|
||||
|
||||
# new_preset = Preset(
|
||||
# # user_id=client.user_id,
|
||||
# name="pytest_test_preset",
|
||||
# description="DUMMY_DESCRIPTION",
|
||||
# system="DUMMY_SYSTEM",
|
||||
# persona="DUMMY_PERSONA",
|
||||
# persona_name="DUMMY_PERSONA_NAME",
|
||||
# human="DUMMY_HUMAN",
|
||||
# human_name="DUMMY_HUMAN_NAME",
|
||||
# functions_schema=[
|
||||
# {
|
||||
# "name": "send_message",
|
||||
# "json_schema": {
|
||||
# "name": "send_message",
|
||||
# "description": "Sends a message to the human user.",
|
||||
# "parameters": {
|
||||
# "type": "object",
|
||||
# "properties": {
|
||||
# "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
|
||||
# },
|
||||
# "required": ["message"],
|
||||
# },
|
||||
# },
|
||||
# "tags": ["memgpt-base"],
|
||||
# "source_type": "python",
|
||||
# "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
|
||||
# }
|
||||
# ],
|
||||
# )
|
||||
|
||||
## List all presets and make sure the preset is NOT in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
# Create a preset
|
||||
new_preset = client.create_preset(name="pytest_test_preset")
|
||||
|
||||
# List all presets and make sure the preset is in the list
|
||||
all_presets = client.list_presets()
|
||||
assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
|
||||
# Delete the preset
|
||||
client.delete_preset(preset_id=new_preset.id)
|
||||
|
||||
# List all presets and make sure the preset is NOT in the list
|
||||
all_presets = client.list_presets()
|
||||
assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
# def test_presets(client, agent):
|
||||
# _reset_config()
|
||||
#
|
||||
# # new_preset = Preset(
|
||||
# # # user_id=client.user_id,
|
||||
# # name="pytest_test_preset",
|
||||
# # description="DUMMY_DESCRIPTION",
|
||||
# # system="DUMMY_SYSTEM",
|
||||
# # persona="DUMMY_PERSONA",
|
||||
# # persona_name="DUMMY_PERSONA_NAME",
|
||||
# # human="DUMMY_HUMAN",
|
||||
# # human_name="DUMMY_HUMAN_NAME",
|
||||
# # functions_schema=[
|
||||
# # {
|
||||
# # "name": "send_message",
|
||||
# # "json_schema": {
|
||||
# # "name": "send_message",
|
||||
# # "description": "Sends a message to the human user.",
|
||||
# # "parameters": {
|
||||
# # "type": "object",
|
||||
# # "properties": {
|
||||
# # "message": {"type": "string", "description": "Message contents. All unicode (including emojis) are supported."}
|
||||
# # },
|
||||
# # "required": ["message"],
|
||||
# # },
|
||||
# # },
|
||||
# # "tags": ["memgpt-base"],
|
||||
# # "source_type": "python",
|
||||
# # "source_code": 'def send_message(self, message: str) -> Optional[str]:\n """\n Sends a message to the human user.\n\n Args:\n message (str): Message contents. All unicode (including emojis) are supported.\n\n Returns:\n Optional[str]: None is always returned as this function does not produce a response.\n """\n self.interface.assistant_message(message)\n return None\n',
|
||||
# # }
|
||||
# # ],
|
||||
# # )
|
||||
#
|
||||
# ## List all presets and make sure the preset is NOT in the list
|
||||
# # all_presets = client.list_presets()
|
||||
# # assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
# # Create a preset
|
||||
# new_preset = client.create_preset(name="pytest_test_preset")
|
||||
#
|
||||
# # List all presets and make sure the preset is in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
#
|
||||
# # Delete the preset
|
||||
# client.delete_preset(preset_id=new_preset.id)
|
||||
#
|
||||
# # List all presets and make sure the preset is NOT in the list
|
||||
# all_presets = client.list_presets()
|
||||
# assert new_preset.id not in [p.id for p in all_presets], (new_preset, all_presets)
|
||||
|
||||
|
||||
# def test_tools(client, agent):
|
||||
|
||||
@@ -29,14 +29,11 @@ def run_llm_endpoint(filename):
|
||||
agent_state = AgentState(
|
||||
name="test_agent",
|
||||
tools=[tool.name for tool in load_module_tools()],
|
||||
system="",
|
||||
persona="",
|
||||
human="",
|
||||
preset="memgpt_chat",
|
||||
embedding_config=embedding_config,
|
||||
llm_config=llm_config,
|
||||
user_id=uuid.UUID(int=1),
|
||||
state={"persona": "", "human": "", "messages": None},
|
||||
state={"persona": "", "human": "", "messages": None, "memory": {}},
|
||||
system="",
|
||||
)
|
||||
agent = Agent(
|
||||
interface=None,
|
||||
|
||||
65
tests/test_memory.py
Normal file
65
tests/test_memory.py
Normal file
@@ -0,0 +1,65 @@
|
||||
import pytest
|
||||
|
||||
# Import the classes here, assuming the above definitions are in a module named memory_module
|
||||
from memgpt.memory import BaseMemory, ChatMemory
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def sample_memory():
|
||||
return ChatMemory(persona="Chat Agent", human="User")
|
||||
|
||||
|
||||
def test_create_chat_memory():
|
||||
"""Test creating an instance of ChatMemory"""
|
||||
chat_memory = ChatMemory(persona="Chat Agent", human="User")
|
||||
assert chat_memory.memory["persona"].value == "Chat Agent"
|
||||
assert chat_memory.memory["human"].value == "User"
|
||||
|
||||
|
||||
def test_dump_memory_as_json(sample_memory):
|
||||
"""Test dumping ChatMemory as JSON compatible dictionary"""
|
||||
memory_dict = sample_memory.to_dict()
|
||||
assert isinstance(memory_dict, dict)
|
||||
assert "persona" in memory_dict
|
||||
assert memory_dict["persona"]["value"] == "Chat Agent"
|
||||
|
||||
|
||||
def test_load_memory_from_json(sample_memory):
|
||||
"""Test loading ChatMemory from a JSON compatible dictionary"""
|
||||
memory_dict = sample_memory.to_dict()
|
||||
print(memory_dict)
|
||||
new_memory = BaseMemory.load(memory_dict)
|
||||
assert new_memory.memory["persona"].value == "Chat Agent"
|
||||
assert new_memory.memory["human"].value == "User"
|
||||
|
||||
|
||||
# def test_memory_functionality(sample_memory):
|
||||
# """Test memory modification functions"""
|
||||
# # Get memory functions
|
||||
# functions = get_memory_functions(ChatMemory)
|
||||
# # Test core_memory_append function
|
||||
# append_func = functions['core_memory_append']
|
||||
# print("FUNCTIONS", functions)
|
||||
# env = {}
|
||||
# env.update(globals())
|
||||
# for tool in functions:
|
||||
# # WARNING: name may not be consistent?
|
||||
# exec(tool.source_code, env)
|
||||
#
|
||||
# print(exec)
|
||||
#
|
||||
# append_func(sample_memory, 'persona', " is a test.")
|
||||
# assert sample_memory.memory['persona'].value == "Chat Agent\n is a test."
|
||||
# # Test core_memory_replace function
|
||||
# replace_func = functions['core_memory_replace']
|
||||
# replace_func(sample_memory, 'persona', " is a test.", " was a test.")
|
||||
# assert sample_memory.memory['persona'].value == "Chat Agent\n was a test."
|
||||
|
||||
|
||||
def test_memory_limit_validation(sample_memory):
|
||||
"""Test exceeding memory limit"""
|
||||
with pytest.raises(ValueError):
|
||||
ChatMemory(persona="x" * 3000, human="y" * 3000)
|
||||
|
||||
with pytest.raises(ValueError):
|
||||
sample_memory.memory["persona"].value = "x" * 3000
|
||||
@@ -1,139 +0,0 @@
|
||||
import pytest
|
||||
|
||||
from memgpt.agent import Agent, save_agent
|
||||
from memgpt.constants import DEFAULT_HUMAN, DEFAULT_PERSONA, DEFAULT_PRESET
|
||||
from memgpt.data_types import AgentState, LLMConfig, Source, User
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.models.pydantic_models import HumanModel, PersonaModel
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
|
||||
|
||||
# @pytest.mark.parametrize("storage_connector", ["postgres", "sqlite"])
|
||||
@pytest.mark.parametrize("storage_connector", ["sqlite"])
|
||||
def test_storage(storage_connector):
|
||||
if storage_connector == "postgres":
|
||||
TEST_MEMGPT_CONFIG.archival_storage_uri = settings.pg_uri
|
||||
TEST_MEMGPT_CONFIG.recall_storage_uri = settings.pg_uri
|
||||
TEST_MEMGPT_CONFIG.archival_storage_type = "postgres"
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "postgres"
|
||||
if storage_connector == "sqlite":
|
||||
TEST_MEMGPT_CONFIG.recall_storage_type = "local"
|
||||
|
||||
ms = MetadataStore(TEST_MEMGPT_CONFIG)
|
||||
|
||||
# users
|
||||
user_1 = User()
|
||||
user_2 = User()
|
||||
ms.create_user(user_1)
|
||||
ms.create_user(user_2)
|
||||
|
||||
# test adding default humans/personas/presets
|
||||
# add_default_humans_and_personas(user_id=user_1.id, ms=ms)
|
||||
# add_default_humans_and_personas(user_id=user_2.id, ms=ms)
|
||||
ms.add_human(human=HumanModel(name="test_human", text="This is a test human"))
|
||||
ms.add_persona(persona=PersonaModel(name="test_persona", text="This is a test persona"))
|
||||
add_default_presets(user_id=user_1.id, ms=ms)
|
||||
add_default_presets(user_id=user_2.id, ms=ms)
|
||||
assert len(ms.list_humans(user_id=user_1.id)) > 0, ms.list_humans(user_id=user_1.id)
|
||||
assert len(ms.list_personas(user_id=user_1.id)) > 0, ms.list_personas(user_id=user_1.id)
|
||||
|
||||
# generate data
|
||||
agent_1 = AgentState(
|
||||
user_id=user_1.id,
|
||||
name="agent_1",
|
||||
preset=DEFAULT_PRESET,
|
||||
persona=DEFAULT_PERSONA,
|
||||
human=DEFAULT_HUMAN,
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
)
|
||||
source_1 = Source(user_id=user_1.id, name="source_1")
|
||||
|
||||
# test creation
|
||||
ms.create_agent(agent_1)
|
||||
ms.create_source(source_1)
|
||||
|
||||
# test listing
|
||||
len(ms.list_agents(user_id=user_1.id)) == 1
|
||||
len(ms.list_agents(user_id=user_2.id)) == 0
|
||||
len(ms.list_sources(user_id=user_1.id)) == 1
|
||||
len(ms.list_sources(user_id=user_2.id)) == 0
|
||||
|
||||
# test agent_state saving
|
||||
agent_state = ms.get_agent(agent_1.id).state
|
||||
assert agent_state == {}, agent_state # when created via create_agent, it should be empty
|
||||
|
||||
from memgpt.presets.presets import add_default_presets
|
||||
|
||||
add_default_presets(user_1.id, ms)
|
||||
preset_obj = ms.get_preset(name=DEFAULT_PRESET, user_id=user_1.id)
|
||||
from memgpt.interface import CLIInterface as interface # for printing to terminal
|
||||
|
||||
# Overwrite fields in the preset if they were specified
|
||||
preset_obj.human = get_human_text(DEFAULT_HUMAN)
|
||||
preset_obj.persona = get_persona_text(DEFAULT_PERSONA)
|
||||
|
||||
# Create the agent
|
||||
agent = Agent(
|
||||
interface=interface(),
|
||||
created_by=user_1.id,
|
||||
name="agent_test_agent_state",
|
||||
preset=preset_obj,
|
||||
llm_config=config.default_llm_config,
|
||||
embedding_config=config.default_embedding_config,
|
||||
# gpt-3.5-turbo tends to omit inner monologue, relax this requirement for now
|
||||
first_message_verify_mono=(
|
||||
True if (config.default_llm_config.model is not None and "gpt-4" in config.default_llm_config.model) else False
|
||||
),
|
||||
)
|
||||
agent_with_agent_state = agent.agent_state
|
||||
save_agent(agent=agent, ms=ms)
|
||||
|
||||
agent_state = ms.get_agent(agent_with_agent_state.id).state
|
||||
assert agent_state is not None, agent_state # when created via create_agent_from_preset, it should be non-empty
|
||||
|
||||
# test: updating
|
||||
|
||||
# test: update JSON-stored LLMConfig class
|
||||
print(agent_1.llm_config, TEST_MEMGPT_CONFIG.default_llm_config)
|
||||
llm_config = ms.get_agent(agent_1.id).llm_config
|
||||
assert isinstance(llm_config, LLMConfig), f"LLMConfig is {type(llm_config)}"
|
||||
assert llm_config.model == "gpt-4", f"LLMConfig model is {llm_config.model}"
|
||||
llm_config.model = "gpt3.5-turbo"
|
||||
agent_1.llm_config = llm_config
|
||||
ms.update_agent(agent_1)
|
||||
assert ms.get_agent(agent_1.id).llm_config.model == "gpt3.5-turbo", f"Updated LLMConfig to {ms.get_agent(agent_1.id).llm_config.model}"
|
||||
|
||||
# test attaching sources
|
||||
len(ms.list_attached_sources(agent_id=agent_1.id)) == 0
|
||||
ms.attach_source(user_1.id, agent_1.id, source_1.id)
|
||||
len(ms.list_attached_sources(agent_id=agent_1.id)) == 1
|
||||
|
||||
# test: detaching sources
|
||||
ms.detach_source(agent_1.id, source_1.id)
|
||||
len(ms.list_attached_sources(agent_id=agent_1.id)) == 0
|
||||
|
||||
# test getting
|
||||
ms.get_user(user_1.id)
|
||||
ms.get_agent(agent_1.id)
|
||||
ms.get_source(source_1.id)
|
||||
|
||||
# test api keys
|
||||
api_key = ms.create_api_key(user_id=user_1.id)
|
||||
print("api_key=", api_key.token, api_key.user_id)
|
||||
api_key_result = ms.get_api_key(api_key=api_key.token)
|
||||
assert api_key.token == api_key_result.token, (api_key, api_key_result)
|
||||
user_result = ms.get_user_from_api_key(api_key=api_key.token)
|
||||
assert user_1.id == user_result.id, (user_1, user_result)
|
||||
all_keys_for_user = ms.get_all_api_keys_for_user(user_id=user_1.id)
|
||||
assert len(all_keys_for_user) > 0, all_keys_for_user
|
||||
ms.delete_api_key(api_key=api_key.token)
|
||||
|
||||
# test deletion
|
||||
ms.delete_user(user_1.id)
|
||||
ms.delete_user(user_2.id)
|
||||
ms.delete_agent(agent_1.id)
|
||||
ms.delete_source(source_1.id)
|
||||
@@ -5,11 +5,12 @@ import pytest
|
||||
from dotenv import load_dotenv
|
||||
|
||||
import memgpt.utils as utils
|
||||
from memgpt.constants import BASE_TOOLS
|
||||
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.presets.presets import load_module_tools
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.settings import settings
|
||||
|
||||
@@ -59,8 +60,6 @@ def user_id(server):
|
||||
user = server.create_user()
|
||||
print(f"Created user\n{user.id}")
|
||||
|
||||
# initialize with default presets
|
||||
server.initialize_default_presets(user.id)
|
||||
yield user.id
|
||||
|
||||
# cleanup
|
||||
@@ -71,10 +70,7 @@ def user_id(server):
|
||||
def agent_id(server, user_id):
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
name="test_agent",
|
||||
preset="memgpt_chat",
|
||||
tools=[tool.name for tool in load_module_tools()],
|
||||
user_id=user_id, name="test_agent", tools=BASE_TOOLS, memory=ChatMemory(human="I am Chad", persona="I love testing")
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
yield agent_state.id
|
||||
@@ -170,8 +166,21 @@ def test_get_recall_memory(server, user_id, agent_id):
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
print("MESSAGES")
|
||||
for m in messages_3:
|
||||
print(m["id"], m["role"])
|
||||
if m["role"] == "assistant":
|
||||
print(m["text"])
|
||||
print("------------")
|
||||
|
||||
# test in-context message ids
|
||||
all_messages = server.get_agent_messages(user_id=user_id, agent_id=agent_id, start=0, count=1000)
|
||||
print("num messages", len(all_messages))
|
||||
in_context_ids = server.get_in_context_message_ids(user_id=user_id, agent_id=agent_id)
|
||||
print(in_context_ids)
|
||||
for m in messages_3:
|
||||
if str(m["id"]) not in [str(i) for i in in_context_ids]:
|
||||
print("missing", m["id"], m["role"])
|
||||
assert len(in_context_ids) == len(messages_3)
|
||||
assert isinstance(in_context_ids[0], uuid.UUID)
|
||||
message_ids = [m["id"] for m in messages_3]
|
||||
|
||||
@@ -7,13 +7,12 @@ from sqlalchemy.ext.declarative import declarative_base
|
||||
|
||||
from memgpt.agent_store.storage import StorageConnector, TableType
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import MAX_EMBEDDING_DIM
|
||||
from memgpt.constants import BASE_TOOLS, MAX_EMBEDDING_DIM
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.data_types import AgentState, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model, query_embedding
|
||||
from memgpt.metadata import MetadataStore
|
||||
from memgpt.settings import settings
|
||||
from memgpt.utils import get_human_text, get_persona_text
|
||||
from tests import TEST_MEMGPT_CONFIG
|
||||
from tests.utils import create_config, wipe_config
|
||||
|
||||
@@ -185,15 +184,10 @@ def test_storage(
|
||||
user_id=user_id,
|
||||
name="agent_1",
|
||||
id=agent_1_id,
|
||||
preset=TEST_MEMGPT_CONFIG.preset,
|
||||
# persona_name=TEST_MEMGPT_CONFIG.persona,
|
||||
# human_name=TEST_MEMGPT_CONFIG.human,
|
||||
persona=get_persona_text(TEST_MEMGPT_CONFIG.persona),
|
||||
human=get_human_text(TEST_MEMGPT_CONFIG.human),
|
||||
llm_config=TEST_MEMGPT_CONFIG.default_llm_config,
|
||||
embedding_config=TEST_MEMGPT_CONFIG.default_embedding_config,
|
||||
system="",
|
||||
tools=[],
|
||||
tools=BASE_TOOLS,
|
||||
state={
|
||||
"persona": "",
|
||||
"human": "",
|
||||
|
||||
@@ -11,6 +11,7 @@ from memgpt.agent import Agent
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.credentials import MemGPTCredentials
|
||||
from memgpt.memory import ChatMemory
|
||||
from memgpt.settings import settings
|
||||
from tests.utils import create_config
|
||||
|
||||
@@ -69,7 +70,7 @@ def run_server():
|
||||
# Fixture to create clients with different configurations
|
||||
@pytest.fixture(
|
||||
params=[{"server": True}, {"server": False}], # whether to use REST API server # TODO: add when implemented
|
||||
# params=[{"server": True}], # whether to use REST API server # TODO: add when implemented
|
||||
# params=[{"server": False}], # whether to use REST API server # TODO: add when implemented
|
||||
scope="module",
|
||||
)
|
||||
def admin_client(request):
|
||||
@@ -179,10 +180,9 @@ def test_create_agent_tool(client):
|
||||
str: The agent that was deleted.
|
||||
|
||||
"""
|
||||
self.memory.human = ""
|
||||
self.memory.persona = ""
|
||||
self.rebuild_memory()
|
||||
print("UPDATED MEMORY", self.memory.human, self.memory.persona)
|
||||
self.memory.memory["human"].value = ""
|
||||
self.memory.memory["persona"].value = ""
|
||||
print("UPDATED MEMORY", self.memory.memory)
|
||||
return None
|
||||
|
||||
# TODO: test attaching and using function on agent
|
||||
@@ -190,7 +190,8 @@ def test_create_agent_tool(client):
|
||||
print(f"Created tool", tool.name)
|
||||
|
||||
# create agent with tool
|
||||
agent = client.create_agent(name=test_agent_name, tools=[tool.name], persona="You must clear your memory if the human instructs you")
|
||||
memory = ChatMemory(human="I am a human", persona="You must clear your memory if the human instructs you")
|
||||
agent = client.create_agent(name=test_agent_name, tools=[tool.name], memory=memory)
|
||||
assert str(tool.user_id) == str(agent.user_id), f"Expected {tool.user_id} to be {agent.user_id}"
|
||||
|
||||
# initial memory
|
||||
@@ -207,6 +208,7 @@ def test_create_agent_tool(client):
|
||||
print(response)
|
||||
|
||||
# updated memory
|
||||
print("Query agent memory")
|
||||
updated_memory = client.get_agent_memory(agent.id)
|
||||
human = updated_memory.core_memory.human
|
||||
persona = updated_memory.core_memory.persona
|
||||
|
||||
Reference in New Issue
Block a user