chore: Clean up .load_agent usage (#2298)
This commit is contained in:
@@ -362,10 +362,10 @@ def other_agent_id(server, user_id, base_tools):
|
||||
server.agent_manager.delete_agent(agent_state.id, actor=actor)
|
||||
|
||||
|
||||
def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
def test_error_on_nonexistent_agent(server, user, agent_id):
|
||||
try:
|
||||
fake_agent_id = str(uuid.uuid4())
|
||||
server.user_message(user_id=user_id, agent_id=fake_agent_id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=fake_agent_id, message="Hello?")
|
||||
raise Exception("user_message call should have failed")
|
||||
except (KeyError, ValueError) as e:
|
||||
# Error is expected
|
||||
@@ -375,9 +375,9 @@ def test_error_on_nonexistent_agent(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(1)
|
||||
def test_user_message_memory(server, user_id, agent_id):
|
||||
def test_user_message_memory(server, user, agent_id):
|
||||
try:
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="/memory")
|
||||
server.user_message(user_id=user.id, agent_id=agent_id, message="/memory")
|
||||
raise Exception("user_message call should have failed")
|
||||
except ValueError as e:
|
||||
# Error is expected
|
||||
@@ -385,13 +385,11 @@ def test_user_message_memory(server, user_id, agent_id):
|
||||
except:
|
||||
raise
|
||||
|
||||
server.run_command(user_id=user_id, agent_id=agent_id, command="/memory")
|
||||
server.run_command(user_id=user.id, agent_id=agent_id, command="/memory")
|
||||
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_load_data(server, user_id, agent_id):
|
||||
user = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
def test_load_data(server, user, agent_id):
|
||||
# create source
|
||||
passages_before = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
assert len(passages_before) == 0
|
||||
@@ -409,10 +407,10 @@ def test_load_data(server, user_id, agent_id):
|
||||
"Shishir loves indian food",
|
||||
]
|
||||
connector = DummyDataConnector(archival_memories)
|
||||
server.load_data(user_id, connector, source.name)
|
||||
server.load_data(user.id, connector, source.name)
|
||||
|
||||
# attach source
|
||||
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
|
||||
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=user)
|
||||
|
||||
# check archival memory size
|
||||
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
|
||||
@@ -425,9 +423,9 @@ def test_save_archival_memory(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(4)
|
||||
def test_user_message(server, user_id, agent_id):
|
||||
def test_user_message(server, user, agent_id):
|
||||
# add data into recall memory
|
||||
server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
server.user_message(user_id=user.id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
@@ -435,21 +433,20 @@ def test_user_message(server, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
def test_get_recall_memory(server, org_id, user, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
actor = user
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user.id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
in_context_ids = server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
|
||||
message_ids = [m.id for m in messages_3]
|
||||
@@ -458,13 +455,13 @@ def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
def test_get_archival_memory(server, user_id, agent_id):
|
||||
def test_get_archival_memory(server, user, agent_id):
|
||||
# test archival memory cursor pagination
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
actor = user
|
||||
|
||||
# List latest 2 passages
|
||||
passages_1 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
limit=2,
|
||||
@@ -474,7 +471,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# List next 3 passages (earliest 3)
|
||||
cursor1 = passages_1[-1].id
|
||||
passages_2 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
cursor=cursor1,
|
||||
@@ -483,7 +480,7 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
# List all 5
|
||||
cursor2 = passages_1[0].created_at
|
||||
passages_3 = server.agent_manager.list_passages(
|
||||
actor=user,
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
ascending=False,
|
||||
end_date=cursor2,
|
||||
@@ -496,20 +493,20 @@ def test_get_archival_memory(server, user_id, agent_id):
|
||||
earliest = passages_2[-1]
|
||||
|
||||
# test archival memory
|
||||
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
|
||||
passage_1 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, limit=1, ascending=True)
|
||||
assert len(passage_1) == 1
|
||||
assert passage_1[0].text == "alpha"
|
||||
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
passage_2 = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
|
||||
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
|
||||
assert all("alpha" not in passage.text for passage in passage_2)
|
||||
# test safe empty return
|
||||
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
passage_none = server.agent_manager.list_passages(actor=actor, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
|
||||
assert len(passage_none) == 0
|
||||
|
||||
|
||||
def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id: str):
|
||||
def test_get_context_window_overview(server: SyncServer, user, agent_id):
|
||||
"""Test that the context window overview fetch works"""
|
||||
overview = server.get_agent_context_window(agent_id=agent_id, actor=server.user_manager.get_user_or_default(user_id))
|
||||
overview = server.get_agent_context_window(agent_id=agent_id, actor=user)
|
||||
assert overview is not None
|
||||
|
||||
# Run some basic checks
|
||||
@@ -546,7 +543,7 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
|
||||
)
|
||||
|
||||
|
||||
def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
def test_delete_agent_same_org(server: SyncServer, org_id: str, user: User):
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
name="nonexistent_tools_agent",
|
||||
@@ -554,7 +551,7 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
llm="openai/gpt-4",
|
||||
embedding="openai/text-embedding-ada-002",
|
||||
),
|
||||
actor=server.user_manager.get_user_or_default(user_id),
|
||||
actor=user,
|
||||
)
|
||||
|
||||
# create another user in the same org
|
||||
@@ -566,14 +563,14 @@ def test_delete_agent_same_org(server: SyncServer, org_id: str, user_id: str):
|
||||
|
||||
def _test_get_messages_letta_format(
|
||||
server,
|
||||
user_id,
|
||||
user,
|
||||
agent_id,
|
||||
reverse=False,
|
||||
):
|
||||
"""Test mapping between messages and letta_messages with reverse=False."""
|
||||
|
||||
messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
@@ -582,7 +579,7 @@ def _test_get_messages_letta_format(
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_id,
|
||||
limit=1000,
|
||||
reverse=reverse,
|
||||
@@ -675,10 +672,10 @@ def _test_get_messages_letta_format(
|
||||
warnings.warn(f"Extra letta_messages found: {len(letta_messages) - letta_message_index}")
|
||||
|
||||
|
||||
def test_get_messages_letta_format(server, user_id, agent_id):
|
||||
def test_get_messages_letta_format(server, user, agent_id):
|
||||
# for reverse in [False, True]:
|
||||
for reverse in [False]:
|
||||
_test_get_messages_letta_format(server, user_id, agent_id, reverse=reverse)
|
||||
_test_get_messages_letta_format(server, user, agent_id, reverse=reverse)
|
||||
|
||||
|
||||
EXAMPLE_TOOL_SOURCE = '''
|
||||
@@ -825,9 +822,9 @@ def test_composio_client_simple(server):
|
||||
assert len(actions) > 0
|
||||
|
||||
|
||||
def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools, base_memory_tools):
|
||||
def test_memory_rebuild_count(server, user, mock_e2b_api_key_none, base_tools, base_memory_tools):
|
||||
"""Test that the memory rebuild is generating the correct number of role=system messages"""
|
||||
actor = server.user_manager.get_user_or_default(user_id)
|
||||
actor = user
|
||||
# create agent
|
||||
agent_state = server.create_agent(
|
||||
request=CreateAgent(
|
||||
@@ -848,7 +845,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
|
||||
# At this stage, there should only be 1 system message inside of recall storage
|
||||
letta_messages = server.get_agent_recall_cursor(
|
||||
user_id=user_id,
|
||||
user_id=user.id,
|
||||
agent_id=agent_state.id,
|
||||
limit=1000,
|
||||
# reverse=reverse,
|
||||
@@ -870,7 +867,7 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
|
||||
assert num_system_messages == 1, (num_system_messages, all_messages)
|
||||
|
||||
# Assuming core memory append actually ran correctly, at this point there should be 2 messages
|
||||
server.user_message(user_id=user_id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
|
||||
server.user_message(user_id=user.id, agent_id=agent_state.id, message="Append 'banana' to your core memory")
|
||||
|
||||
# At this stage, there should be 2 system message inside of recall storage
|
||||
num_system_messages, all_messages = count_system_messages_in_recall()
|
||||
|
||||
Reference in New Issue
Block a user