fix: Remove usage of config.anon_clientid (#837)
This commit is contained in:
@@ -337,8 +337,8 @@ class SQLStorageConnector(StorageConnector):
|
||||
db_records = self.session.query(self.db_model).filter(*filters).all()
|
||||
return [record.to_record() for record in db_records]
|
||||
|
||||
def get(self, id: str) -> Optional[Record]:
|
||||
db_record = self.session.query(self.db_model).get(id)
|
||||
def get(self, id: uuid.UUID) -> Optional[Record]:
|
||||
db_record = self.session.get(self.db_model, id)
|
||||
if db_record is None:
|
||||
return None
|
||||
return db_record.to_record()
|
||||
|
||||
@@ -386,7 +386,7 @@ class AgentState:
|
||||
self.created_at = created_at if created_at is not None else datetime.now()
|
||||
|
||||
# state
|
||||
self.state = state
|
||||
self.state = {} if not state else state
|
||||
|
||||
# def __eq__(self, other):
|
||||
# if not isinstance(other, AgentState):
|
||||
|
||||
@@ -20,7 +20,7 @@ import memgpt.presets.presets as presets
|
||||
import memgpt.utils as utils
|
||||
import memgpt.server.utils as server_utils
|
||||
from memgpt.persistence_manager import PersistenceManager, LocalStateManager
|
||||
from memgpt.data_types import Source, Passage, Document, User, AgentState
|
||||
from memgpt.data_types import Source, Passage, Document, User, AgentState, LLMConfig, EmbeddingConfig, Message, ToolCall
|
||||
|
||||
# TODO use custom interface
|
||||
from memgpt.interface import CLIInterface # for printing to terminal
|
||||
@@ -167,13 +167,47 @@ class SyncServer(LockingServer):
|
||||
|
||||
# Initialize the connection to the DB
|
||||
self.config = MemGPTConfig.load()
|
||||
|
||||
# Ensure valid database configuration
|
||||
# TODO: add back once tests are matched
|
||||
# assert (
|
||||
# self.config.metadata_storage_type == "postgres"
|
||||
# ), f"Invalid metadata_storage_type for server: {self.config.metadata_storage_type}"
|
||||
# assert (
|
||||
# self.config.archival_storage_type == "postgres"
|
||||
# ), f"Invalid archival_storage_type for server: {self.config.archival_storage_type}"
|
||||
# assert self.config.recall_storage_type == "postgres", f"Invalid recall_storage_type for server: {self.config.recall_storage_type}"
|
||||
|
||||
# Generate default LLM/Embedding configs for the server
|
||||
# TODO: we may also want to do the same thing with default persona/human/etc.
|
||||
self.server_llm_config = LLMConfig(
|
||||
model=self.config.model,
|
||||
model_endpoint_type=self.config.model_endpoint_type,
|
||||
model_endpoint=self.config.model_endpoint,
|
||||
model_wrapper=self.config.model_wrapper,
|
||||
context_window=self.config.context_window,
|
||||
openai_key=self.config.openai_key,
|
||||
azure_key=self.config.azure_key,
|
||||
azure_endpoint=self.config.azure_endpoint,
|
||||
azure_version=self.config.azure_version,
|
||||
azure_deployment=self.config.azure_deployment,
|
||||
)
|
||||
self.server_embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type=self.config.embedding_endpoint_type,
|
||||
embedding_endpoint=self.config.embedding_endpoint,
|
||||
embedding_dim=self.config.embedding_dim,
|
||||
openai_key=self.config.openai_key,
|
||||
)
|
||||
|
||||
# Initialize the metadata store
|
||||
self.ms = MetadataStore(self.config)
|
||||
|
||||
# Create the default user
|
||||
base_user_id = uuid.UUID(self.config.anon_clientid)
|
||||
if not self.ms.get_user(user_id=base_user_id):
|
||||
base_user = User(id=base_user_id)
|
||||
self.ms.create_user(base_user)
|
||||
# NOTE: removed, since server should be multi-user
|
||||
## Create the default user
|
||||
# base_user_id = uuid.UUID(self.config.anon_clientid)
|
||||
# if not self.ms.get_user(user_id=base_user_id):
|
||||
# base_user = User(id=base_user_id)
|
||||
# self.ms.create_user(base_user)
|
||||
|
||||
def save_agents(self):
|
||||
"""Saves all the agents that are in the in-memory object store"""
|
||||
@@ -421,7 +455,6 @@ class SyncServer(LockingServer):
|
||||
@LockingServer.agent_lock_decorator
|
||||
def user_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
|
||||
"""Process an incoming user message and feed it through the MemGPT agent"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -443,7 +476,6 @@ class SyncServer(LockingServer):
|
||||
@LockingServer.agent_lock_decorator
|
||||
def system_message(self, user_id: uuid.UUID, agent_id: uuid.UUID, message: str) -> None:
|
||||
"""Process an incoming system message and feed it through the MemGPT agent"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -465,7 +497,6 @@ class SyncServer(LockingServer):
|
||||
@LockingServer.agent_lock_decorator
|
||||
def run_command(self, user_id: uuid.UUID, agent_id: uuid.UUID, command: str) -> Union[str, None]:
|
||||
"""Run a command on the agent"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -475,6 +506,26 @@ class SyncServer(LockingServer):
|
||||
command = command[1:] # strip the prefix
|
||||
return self._command(user_id=user_id, agent_id=agent_id, command=command)
|
||||
|
||||
def create_user(
|
||||
self,
|
||||
user_config: Optional[Union[dict, User]] = {},
|
||||
):
|
||||
"""Create a new user using a config"""
|
||||
if not isinstance(user_config, dict):
|
||||
raise ValueError(f"user_config must be provided as a dictionary")
|
||||
|
||||
user = User(
|
||||
id=user_config["id"] if "id" in user_config else None,
|
||||
default_preset=user_config["default_preset"] if "default_preset" in user_config else "memgpt_chat",
|
||||
default_persona=user_config["default_persona"] if "default_persona" in user_config else constants.DEFAULT_PERSONA,
|
||||
default_human=user_config["default_human"] if "default_human" in user_config else constants.DEFAULT_HUMAN,
|
||||
default_llm_config=self.server_llm_config,
|
||||
default_embedding_config=self.server_embedding_config,
|
||||
)
|
||||
self.ms.create_user(user)
|
||||
logger.info(f"Created new user from config: {user}")
|
||||
return user
|
||||
|
||||
def create_agent(
|
||||
self,
|
||||
user_id: uuid.UUID,
|
||||
@@ -483,7 +534,6 @@ class SyncServer(LockingServer):
|
||||
# persistence_manager: Union[PersistenceManager, None] = None,
|
||||
) -> AgentState:
|
||||
"""Create a new agent using a config"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -519,6 +569,9 @@ class SyncServer(LockingServer):
|
||||
logger.debug(f"Attempting to create agent from agent_state:\n{agent_state}")
|
||||
try:
|
||||
agent = presets.create_agent_from_preset(agent_state=agent_state, interface=interface)
|
||||
|
||||
# TODO: this is a hacky way to get the system prompts injected into agent into the DB
|
||||
self.ms.update_agent(agent.agent_state)
|
||||
except Exception as e:
|
||||
logger.exception(e)
|
||||
self.ms.delete_agent(agent_id=agent_state.id)
|
||||
@@ -533,25 +586,19 @@ class SyncServer(LockingServer):
|
||||
user_id: uuid.UUID,
|
||||
agent_id: uuid.UUID,
|
||||
):
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
# Make sure the user owns the agent
|
||||
# TODO use real user_id
|
||||
USER_ID = self.config.anon_clientid
|
||||
agent = self.ms.get_agent(agent_id=agent_id, user_id=USER_ID)
|
||||
# TODO: Make sure the user owns the agent
|
||||
agent = self.ms.get_agent(agent_id=agent_id, user_id=user_id)
|
||||
if agent is not None:
|
||||
self.ms.delete_agent(agent_id=agent_id)
|
||||
|
||||
def list_agents(self, user_id: uuid.UUID) -> dict:
|
||||
"""List all available agents to a user"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
# TODO actually use the user_id that was passed into the server
|
||||
user_id = uuid.UUID(self.config.anon_clientid)
|
||||
agents_states = self.ms.list_agents(user_id=user_id)
|
||||
logger.info(f"Retrieved {len(agents_states)} agents for user {user_id}:\n{[vars(s) for s in agents_states]}")
|
||||
return {
|
||||
@@ -568,9 +615,16 @@ class SyncServer(LockingServer):
|
||||
],
|
||||
}
|
||||
|
||||
def get_agent(self, agent_id: uuid.UUID):
|
||||
"""Get the agent state"""
|
||||
return self.ms.get_agent(agent_id=agent_id)
|
||||
|
||||
def get_user(self, user_id: uuid.UUID) -> User:
|
||||
"""Get the user"""
|
||||
return self.ms.get_user(user_id=user_id)
|
||||
|
||||
def get_agent_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
||||
"""Return the memory of an agent (core memory + non-core statistics)"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -594,7 +648,6 @@ class SyncServer(LockingServer):
|
||||
|
||||
def get_agent_messages(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -636,7 +689,6 @@ class SyncServer(LockingServer):
|
||||
|
||||
def get_agent_archival(self, user_id: uuid.UUID, agent_id: uuid.UUID, start: int, count: int) -> list:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -661,7 +713,6 @@ class SyncServer(LockingServer):
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -685,7 +736,6 @@ class SyncServer(LockingServer):
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
):
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -703,7 +753,6 @@ class SyncServer(LockingServer):
|
||||
|
||||
def get_agent_config(self, user_id: uuid.UUID, agent_id: uuid.UUID) -> dict:
|
||||
"""Return the config of an agent"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
@@ -730,7 +779,6 @@ class SyncServer(LockingServer):
|
||||
|
||||
def update_agent_core_memory(self, user_id: uuid.UUID, agent_id: uuid.UUID, new_memory_contents: dict) -> dict:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
user_id = uuid.UUID(self.config.anon_clientid) # TODO use real
|
||||
if self.ms.get_user(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
|
||||
|
||||
@@ -34,6 +34,13 @@ def agent():
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted")
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# ensure user exists
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
agent_state = client.create_agent(
|
||||
agent_config={
|
||||
# "name": test_agent_id,
|
||||
@@ -42,9 +49,6 @@ def agent():
|
||||
}
|
||||
)
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
|
||||
return client.server._get_or_load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
|
||||
|
||||
|
||||
@@ -23,6 +23,13 @@ def test_create_agent():
|
||||
else:
|
||||
client = MemGPT(quickstart="memgpt_hosted")
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# ensure user exists
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
if not client.server.get_user(user_id=user_id):
|
||||
client.server.create_user({"id": user_id})
|
||||
|
||||
global test_agent_state
|
||||
test_agent_state = client.create_agent(
|
||||
agent_config={
|
||||
|
||||
@@ -5,7 +5,7 @@ import memgpt.utils as utils
|
||||
utils.DEBUG = True
|
||||
from memgpt.config import MemGPTConfig
|
||||
from memgpt.server.server import SyncServer
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage
|
||||
from memgpt.data_types import EmbeddingConfig, AgentState, LLMConfig, Message, Passage, User
|
||||
from memgpt.embeddings import embedding_model
|
||||
from memgpt.metadata import MetadataStore
|
||||
from .utils import wipe_config, wipe_memgpt_home
|
||||
@@ -14,22 +14,34 @@ from .utils import wipe_config, wipe_memgpt_home
|
||||
def test_server():
|
||||
wipe_memgpt_home()
|
||||
|
||||
config = MemGPTConfig.load()
|
||||
|
||||
# setup config for postgres storage
|
||||
config.archival_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.recall_storage_uri = os.getenv("PGVECTOR_TEST_DB_URL")
|
||||
config.archival_storage_type = "postgres"
|
||||
config.recall_storage_type = "postgres"
|
||||
config = MemGPTConfig(
|
||||
archival_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
recall_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
metadata_storage_uri=os.getenv("PGVECTOR_TEST_DB_URL"),
|
||||
archival_storage_type="postgres",
|
||||
recall_storage_type="postgres",
|
||||
metadata_storage_type="postgres",
|
||||
# embeddings
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
# llms
|
||||
model_endpoint_type="openai",
|
||||
model_endpoint="https://api.openai.com/v1",
|
||||
model="gpt-4",
|
||||
)
|
||||
config.save()
|
||||
|
||||
user_id = uuid.UUID(config.anon_clientid)
|
||||
ms = MetadataStore(config)
|
||||
server = SyncServer()
|
||||
|
||||
# create user
|
||||
user = server.create_user()
|
||||
print(f"Created user\n{user.id}")
|
||||
|
||||
try:
|
||||
fake_agent_id = uuid.uuid4()
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
# Error is expected
|
||||
@@ -37,29 +49,14 @@ def test_server():
|
||||
except:
|
||||
raise
|
||||
|
||||
# embedding config
|
||||
if os.getenv("OPENAI_API_KEY"):
|
||||
embedding_config = EmbeddingConfig(
|
||||
embedding_endpoint_type="openai",
|
||||
embedding_endpoint="https://api.openai.com/v1",
|
||||
embedding_dim=1536,
|
||||
openai_key=os.getenv("OPENAI_API_KEY"),
|
||||
)
|
||||
print("Using OpenAI embeddings")
|
||||
else:
|
||||
embedding_config = EmbeddingConfig(embedding_endpoint_type="local", embedding_endpoint=None, embedding_dim=384)
|
||||
print("Using local embeddings")
|
||||
|
||||
agent_state = server.create_agent(
|
||||
user_id=user_id,
|
||||
agent_config=dict(
|
||||
name="test_agent", user_id=user_id, preset="memgpt_chat", human="cs_phd", persona="sam_pov", embedding_config=embedding_config
|
||||
),
|
||||
user_id=user.id,
|
||||
agent_config=dict(name="test_agent", user_id=user.id, preset="memgpt_chat", human="cs_phd", persona="sam_pov"),
|
||||
)
|
||||
print(f"Created agent\n{agent_state}")
|
||||
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="/memory")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="/memory")
|
||||
raise Exception("user_message call should have failed")
|
||||
except ValueError as e:
|
||||
# Error is expected
|
||||
@@ -67,47 +64,47 @@ def test_server():
|
||||
except:
|
||||
raise
|
||||
|
||||
print(server.run_command(user_id=user_id, agent_id=agent_state.id, command="/memory"))
|
||||
print(server.run_command(user_id=user.id, agent_id=agent_state.id, command="/memory"))
|
||||
|
||||
# add data into archival memory
|
||||
agent = server._load_agent(user_id=user_id, agent_id=agent_state.id)
|
||||
agent = server._load_agent(user_id=user.id, agent_id=agent_state.id)
|
||||
archival_memories = ["alpha", "Cinderella wore a blue dress", "Dog eat dog", "ZZZ", "Shishir loves indian food"]
|
||||
embed_model = embedding_model(embedding_config)
|
||||
embed_model = embedding_model(agent.agent_state.embedding_config)
|
||||
for text in archival_memories:
|
||||
embedding = embed_model.get_text_embedding(text)
|
||||
agent.persistence_manager.archival_memory.storage.insert(
|
||||
Passage(user_id=user_id, agent_id=agent_state.id, text=text, embedding=embedding)
|
||||
Passage(user_id=user.id, agent_id=agent_state.id, text=text, embedding=embedding)
|
||||
)
|
||||
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Hello?")
|
||||
|
||||
# test recall memory cursor pagination
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, limit=1000)
|
||||
cursor1, messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=2)
|
||||
cursor2, messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, after=cursor1, limit=1000)
|
||||
cursor3, messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, limit=1000)
|
||||
ids3 = [m["id"] for m in messages_3]
|
||||
ids2 = [m["id"] for m in messages_2]
|
||||
timestamps = [m["created_at"] for m in messages_3]
|
||||
print("timestamps", timestamps)
|
||||
assert messages_3[-1]["created_at"] < messages_3[0]["created_at"]
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
||||
cursor4, messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_state.id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test archival memory cursor pagination
|
||||
cursor1, passages_1 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, limit=2, order_by="text"
|
||||
)
|
||||
cursor2, passages_2 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, after=cursor1, order_by="text"
|
||||
)
|
||||
cursor3, passages_3 = server.get_agent_archival_cursor(
|
||||
user_id=user_id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
user_id=user.id, agent_id=agent_state.id, reverse=False, before=cursor2, limit=1000, order_by="text"
|
||||
)
|
||||
print("p1", [p["text"] for p in passages_1])
|
||||
print("p2", [p["text"] for p in passages_2])
|
||||
@@ -117,23 +114,23 @@ def test_server():
|
||||
assert len(passages_3) == 4
|
||||
|
||||
# test recall memory
|
||||
messages_1 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||
messages_1 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(messages_1) == 1
|
||||
messages_2 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1, count=5)
|
||||
messages_2 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
||||
messages_3 = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1, count=5)
|
||||
# not sure exactly how many messages there should be
|
||||
assert len(messages_2) > len(messages_3)
|
||||
# test safe empty return
|
||||
messages_none = server.get_agent_messages(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
messages_none = server.get_agent_messages(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(messages_none) == 0
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=0, count=1)
|
||||
passage_1 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=0, count=1)
|
||||
assert len(passage_1) == 1
|
||||
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1, count=1000)
|
||||
passage_2 = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1, count=1000)
|
||||
assert len(passage_2) == 4
|
||||
# test safe empty return
|
||||
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
passage_none = server.get_agent_archival(user_id=user.id, agent_id=agent_state.id, start=1000, count=1000)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
|
||||
Reference in New Issue
Block a user