chore: Clean up .load_agent usage (#2298)

This commit is contained in:
Matthew Zhou
2024-12-20 16:56:53 -08:00
committed by GitHub
parent a5b1aac1fd
commit 9ad5fd64cf
10 changed files with 134 additions and 164 deletions

View File

@@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools):
server.agent_manager.delete_agent(agent_state.id, actor=actor)
def test_error_on_nonexistent_agent(server, user_id, agent_id):
def test_error_on_nonexistent_agent(server, user, agent_id):
try:
fake_agent_id = str(uuid.uuid4())
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
raise Exception("user_message call should have failed")
except (KeyError, ValueError) as e:
# Error is expected
@@ -375,9 +375,9 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
@pytest.mark.order(1)
def test_user_message_memory(server, user_id, agent_id):
def test_user_message_memory(server, user, agent_id):
try:
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
server.user_message(user_id=user.id, agent_id=agent_id, message="/memory")
raise Exception("user_message call should have failed")
except ValueError as e:
# Error is expected
@@ -385,13 +385,11 @@ def test_user_message_memory(server, user_id, agent_id):
except:
raise
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
def test_load_data(server, user, agent_id):
# create source
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_before) == 0
@@ -409,10 +407,10 @@ def test_load_data(server, user_id, agent_id):
"Shishir loves indian food",
]
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
server.load_data(user.id, connector, source.name)
# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
# check archival memory size
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
@@ -425,9 +423,9 @@ def test_save_archival_memory(server, user_id, agent_id):
@pytest.mark.order(4)
def test_user_message(server, user_id, agent_id):
def test_user_message(server, user, agent_id):
# add data into recall memory
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
@@ -435,21 +433,20 @@ def test_user_message(server, user_id, agent_id):
@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
def test_get_recall_memory(server, org_id, user, agent_id):
# test recall memory cursor pagination
actor = server.user_manager.get_user_or_default(user_id=user_id)
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
actor = user
messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000)
messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000)
messages_3[-1].id
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
message_ids = [m.id for m in messages_3]
@@ -458,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
@pytest.mark.order(6)
def test_get_archival_memory(server, user_id, agent_id):
def test_get_archival_memory(server, user, agent_id):
# test archival memory cursor pagination
user = server.user_manager.get_user_by_id(user_id=user_id)
actor = user
# List latest 2 passages
passages_1 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
limit=2,
@@ -474,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
cursor=cursor1,
@@ -483,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.agent_manager.list_passages(
actor=user,
actor=actor,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
@@ -496,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id):
earliest = passages_2[-1]
# test archival memory
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
def test_get_context_window_overview(server: SyncServer, user, agent_id):
"""Test that the context window overview fetch works"""
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
assert overview is not None
# Run some basic checks
@@ -546,7 +543,7 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
)
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
agent_state = server.create_agent(
request=CreateAgent(
name="nonexistent_tools_agent",
@@ -554,7 +551,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
llm="openai/gpt-4",
embedding="openai/text-embedding-ada-002",
),
actor=server.user_manager.get_user_or_default(user_id),
actor=user,
)
# create another user in the same org
@@ -566,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
def _test_get_messages_letta_format(
server,
user_id,
user,
agent_id,
reverse=False,
):
"""Test mapping between messages and letta_messages with reverse=False."""
messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
@@ -582,7 +579,7 @@ def _test_get_messages_letta_format(
assert all(isinstance(m, Message) for m in messages)
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_id,
limit=1000,
reverse=reverse,
@@ -675,10 +672,10 @@ def _test_get_messages_letta_format(
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
def test_get_messages_letta_format(server, user_id, agent_id):
def test_get_messages_letta_format(server, user, agent_id):
# for reverse in [False, True]:
for reverse in [False]:
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)
EXAMPLE_TOOL_SOURCE = '''
@@ -825,9 +822,9 @@ def test_composio_client_simple(server):
assert len(actions) > 0
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools):
"""Test that the memory rebuild is generating the correct number of role=system messages"""
actor = server.user_manager.get_user_or_default(user_id)
actor = user
# create agent
agent_state = server.create_agent(
request=CreateAgent(
@@ -848,7 +845,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
# At this stage, there should only be 1 system message inside of recall storage
letta_messages = server.get_agent_recall_cursor(
user_id=user_id,
user_id=user.id,
agent_id=agent_state.id,
limit=1000,
# reverse=reverse,
@@ -870,7 +867,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
assert num_system_messages == 1, (num_system_messages, all_messages)
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
# At this stage, there should be 2 system message inside of recall storage
num_system_messages, all_messages = count_system_messages_in_recall()