fix: Remove usage of config.anon_clientid (#837)

This commit is contained in:
Sarah Wooders
2024-01-16 20:37:58 -08:00
committed by GitHub
parent 0c787a11e0
commit 9da53d1e1a
6 changed files with 139 additions and 83 deletions

View File

@@ -337,8 +337,8 @@ class SQLStorageConnector(StorageConnector):
db_records = self.session.query(self.db_model).filter(*filters).all()
return [record.to_record() for record in db_records]
def get(self, id: str) -> Optional[Record]:
db_record = self.session.query(self.db_model).get(id)
def get(self, id: uuid.UUID) -> Optional[Record]:
db_record = self.session.get(self.db_model, id)
if db_record is None:
return None
return db_record.to_record()

View File

@@ -386,7 +386,7 @@ class AgentState:
self.created_at = created_at if created_at is not None else datetime.now()
# state
self.state = state
self.state = {} if not state else state
# def __eq__(self, other):
# if not isinstance(other, AgentState):

View File

@@ -20,7 +20,7 @@ import memgpt.presets.presets as presets
import memgpt.utils as utils
import memgpt.server.utils as server_utils
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
from memgpt.data_types import Source, Passage, Document, User, AgentState
from memgpt.data_types import Source, Passage, Document, User, AgentState, LLMConfig, EmbeddingConfig, Message, ToolCall
# TODO use custom interface
from memgpt.interface import CLIInterface # for printing to terminal
@@ -167,13 +167,47 @@ class SyncServer(LockingServer):
# Initialize the connection to the DB
self.config = MemGPTConfig.load()
# Ensure valid database configuration
# TODO: add back once tests are matched
# assert (
# self.config.metadata_storage_type == "postgres"
# ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}"
# assert (
# self.config.archival_storage_type == "postgres"
# ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}"
# assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}"
# Generate default LLM/Embedding configs for the server
# TODO: we may also want to do the same thing with default persona/human/etc.
self.server_llm_config = LLMConfig(
model=self.config.model,
model_endpoint_type=self.config.model_endpoint_type,
model_endpoint=self.config.model_endpoint,
model_wrapper=self.config.model_wrapper,
context_window=self.config.context_window,
openai_key=self.config.openai_key,
azure_key=self.config.azure_key,
azure_endpoint=self.config.azure_endpoint,
azure_version=self.config.azure_version,
azure_deployment=self.config.azure_deployment,
)
self.server_embedding_config = EmbeddingConfig(
embedding_endpoint_type=self.config.embedding_endpoint_type,
embedding_endpoint=self.config.embedding_endpoint,
embedding_dim=self.config.embedding_dim,
openai_key=self.config.openai_key,
)
# Initialize the metadata store
self.ms = MetadataStore(self.config)
# Create the default user
base_user_id = uuid.UUID(self.config.anon_clientid)
if not self.ms.get_user(user_id=base_user_id):
base_user = User(id=base_user_id)
self.ms.create_user(base_user)
# NOTE: removed, since server should be multi-user
## Create the default user
# base_user_id = uuid.UUID(self.config.anon_clientid)
# if not self.ms.get_user(user_id=base_user_id):
# base_user = User(id=base_user_id)
# self.ms.create_user(base_user)
def save_agents(self):
"""Saves all the agents that are in the in-memory object store"""
@@ -421,7 +455,6 @@ class SyncServer(LockingServer):
@LockingServer.agent_lock_decorator
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process an incoming user message and feed it through the MemGPT agent"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -443,7 +476,6 @@ class SyncServer(LockingServer):
@LockingServer.agent_lock_decorator
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
"""Process an incoming system message and feed it through the MemGPT agent"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -465,7 +497,6 @@ class SyncServer(LockingServer):
@LockingServer.agent_lock_decorator
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
"""Run a command on the agent"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -475,6 +506,26 @@ class SyncServer(LockingServer):
command = command[1:] # strip the prefix
return self._command(user_id=user_id, agent_id=agent_id, command=command)
def create_user(
self,
user_config: Optional[Union[dict, User]] = {},
):
"""Create a new user using a config"""
if not isinstance(user_config, dict):
raise ValueError(f"user_config must be provided as a dictionary")
user = User(
id=user_config["id"] if "id" in user_config else None,
default_preset=user_config["default_preset"] if "default_preset" in user_config else "memgpt_chat",
default_persona=user_config["default_persona"] if "default_persona" in user_config else constants.DEFAULT_PERSONA,
default_human=user_config["default_human"] if "default_human" in user_config else constants.DEFAULT_HUMAN,
default_llm_config=self.server_llm_config,
default_embedding_config=self.server_embedding_config,
)
self.ms.create_user(user)
logger.info(f"Created new user from config: {user}")
return user
def create_agent(
self,
user_id: uuid.UUID,
@@ -483,7 +534,6 @@ class SyncServer(LockingServer):
# persistence_manager: Union[PersistenceManager, None] = None,
) -> AgentState:
"""Create a new agent using a config"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -519,6 +569,9 @@ class SyncServer(LockingServer):
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
try:
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
# TODO: this is a hacky way to get the system prompts injected into agent into the DB
self.ms.update_agent(agent.agent_state)
except Exception as e:
logger.exception(e)
self.ms.delete_agent(agent_id=agent_state.id)
@@ -533,25 +586,19 @@ class SyncServer(LockingServer):
user_id: uuid.UUID,
agent_id: uuid.UUID,
):
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# Make sure the user owns the agent
# TODO use real user_id
USER_ID = self.config.anon_clientid
agent = self.ms.get_agent(agent_id=agent_id, user_id=USER_ID)
# TODO: Make sure the user owns the agent
agent = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
if agent is not None:
self.ms.delete_agent(agent_id=agent_id)
def list_agents(self, user_id: uuid.UUID) -> dict:
"""List all available agents to a user"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
# TODO actually use the user_id that was passed into the server
user_id = uuid.UUID(self.config.anon_clientid)
agents_states = self.ms.list_agents(user_id=user_id)
logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}")
return {
@@ -568,9 +615,16 @@ class SyncServer(LockingServer):
],
}
def get_agent(self, agent_id: uuid.UUID):
"""Get the agent state"""
return self.ms.get_agent(agent_id=agent_id)
def get_user(self, user_id: uuid.UUID) -> User:
"""Get the user"""
return self.ms.get_user(user_id=user_id)
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the memory of an agent (core memory + non-core statistics)"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -594,7 +648,6 @@ class SyncServer(LockingServer):
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent message queue"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -636,7 +689,6 @@ class SyncServer(LockingServer):
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
"""Paginated query of all messages in agent archival memory"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -661,7 +713,6 @@ class SyncServer(LockingServer):
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
):
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -685,7 +736,6 @@ class SyncServer(LockingServer):
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
):
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -703,7 +753,6 @@ class SyncServer(LockingServer):
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
"""Return the config of an agent"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -730,7 +779,6 @@ class SyncServer(LockingServer):
def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict:
"""Update the agents core memory block, return the new state"""
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
if self.ms.get_user(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")

View File

@@ -34,6 +34,13 @@ def agent():
else:
client = MemGPT(quickstart="memgpt_hosted")
config = MemGPTConfig.load()
# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})
agent_state = client.create_agent(
agent_config={
# "name": test_agent_id,
@@ -42,9 +49,6 @@ def agent():
}
)
config = MemGPTConfig.load()
user_id = uuid.UUID(config.anon_clientid)
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)

View File

@@ -23,6 +23,13 @@ def test_create_agent():
else:
client = MemGPT(quickstart="memgpt_hosted")
config = MemGPTConfig.load()
# ensure user exists
user_id = uuid.UUID(config.anon_clientid)
if not client.server.get_user(user_id=user_id):
client.server.create_user({"id": user_id})
global test_agent_state
test_agent_state = client.create_agent(
agent_config={

View File

@@ -5,7 +5,7 @@ import memgpt.utils as utils
utils.DEBUG = True
from memgpt.config import MemGPTConfig
from memgpt.server.server import SyncServer
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
from memgpt.embeddings import embedding_model
from memgpt.metadata import MetadataStore
from .utils import wipe_config, wipe_memgpt_home
@@ -14,22 +14,34 @@ from .utils import wipe_config, wipe_memgpt_home
def test_server():
wipe_memgpt_home()
config = MemGPTConfig.load()
# setup config for postgres storage
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
config.archival_storage_type = "postgres"
config.recall_storage_type = "postgres"
config = MemGPTConfig(
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
archival_storage_type="postgres",
recall_storage_type="postgres",
metadata_storage_type="postgres",
# embeddings
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
# llms
model_endpoint_type="openai",
model_endpoint="https://api.openai.com/v1",
model="gpt-4",
)
config.save()
user_id = uuid.UUID(config.anon_clientid)
ms = MetadataStore(config)
server = SyncServer()
# create user
user = server.create_user()
print(f"Created user\n{user.id}")
try:
fake_agent_id = uuid.uuid4()
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected
@@ -37,29 +49,14 @@ def test_server():
except:
raise
# embedding config
if os.getenv("OPENAI_API_KEY"):
embedding_config = EmbeddingConfig(
embedding_endpoint_type="openai",
embedding_endpoint="https://api.openai.com/v1",
embedding_dim=1536,
openai_key=os.getenv("OPENAI_API_KEY"),
)
print("Using OpenAI embeddings")
else:
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
print("Using local embeddings")
agent_state = server.create_agent(
user_id=user_id,
agent_config=dict(
name="test_agent", user_id=user_id, preset="memgpt_chat", human="cs_phd", persona="sam_pov", embedding_config=embedding_config
),
user_id=user.id,
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
)
print(f"Created agent\n{agent_state}")
try:
server.user_message(user_id=user_id, agent_id=agent_state.id, message="/memory")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
raise Exception("user_message call should have failed")
except ValueError as e:
# Error is expected
@@ -67,47 +64,47 @@ def test_server():
except:
raise
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
# add data into archival memory
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
embed_model = embedding_model(embedding_config)
embed_model = embedding_model(agent.agent_state.embedding_config)
for text in archival_memories:
embedding = embed_model.get_text_embedding(text)
agent.persistence_manager.archival_memory.storage.insert(
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding)
)
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
# test recall memory cursor pagination
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000)
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
ids3 = [m["id"] for m in messages_3]
ids2 = [m["id"] for m in messages_2]
timestamps = [m["created_at"] for m in messages_3]
print("timestamps", timestamps)
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
assert len(messages_3) == len(messages_1) + len(messages_2)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1)
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test archival memory cursor pagination
cursor1, passages_1 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
)
cursor2, passages_2 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
)
cursor3, passages_3 = server.get_agent_archival_cursor(
user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
)
print("p1", [p["text"] for p in passages_1])
print("p2", [p["text"] for p in passages_2])
@@ -117,23 +114,23 @@ def test_server():
assert len(passages_3) == 4
# test recall memory
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
assert len(messages_1) == 1
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
# not sure exactly how many messages there should be
assert len(messages_2) > len(messages_3)
# test safe empty return
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
assert len(messages_none) == 0
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
assert len(passage_2) == 4
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
assert len(passage_none) == 0