feat: separate Passages tables (#2245)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-16 15:24:20 -08:00
committed by GitHub
parent 10e610bb95
commit e2d916148e
19 changed files with 1026 additions and 546 deletions

View File

@@ -390,12 +390,16 @@ def test_user_message_memory(server, user_id, agent_id):
@pytest.mark.order(3)
def test_load_data(server, user_id, agent_id):
user = server.user_manager.get_user_or_default(user_id=user_id)
# create source
passages_before = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_before = server.agent_manager.list_passages(
actor=user, agent_id=agent_id, cursor=None, limit=10000
)
assert len(passages_before) == 0
source = server.source_manager.create_source(
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=server.default_user
PydanticSource(name="test_source", embedding_config=EmbeddingConfig.default_config(provider="openai")), actor=user
)
# load data
@@ -409,15 +413,11 @@ def test_load_data(server, user_id, agent_id):
connector = DummyDataConnector(archival_memories)
server.load_data(user_id, connector, source.name)
# @pytest.mark.order(3)
# def test_attach_source_to_agent(server, user_id, agent_id):
# check archival memory size
# attach source
server.attach_source_to_agent(user_id=user_id, agent_id=agent_id, source_name="test_source")
# check archival memory size
passages_after = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=None, limit=10000)
passages_after = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=None, limit=10000)
assert len(passages_after) == 5
@@ -465,7 +465,7 @@ def test_get_archival_memory(server, user_id, agent_id):
user = server.user_manager.get_user_by_id(user_id=user_id)
# List latest 2 passages
passages_1 = server.passage_manager.list_passages(
passages_1 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -475,7 +475,7 @@ def test_get_archival_memory(server, user_id, agent_id):
# List next 3 passages (earliest 3)
cursor1 = passages_1[-1].id
passages_2 = server.passage_manager.list_passages(
passages_2 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
@@ -484,24 +484,28 @@ def test_get_archival_memory(server, user_id, agent_id):
# List all 5
cursor2 = passages_1[0].created_at
passages_3 = server.passage_manager.list_passages(
passages_3 = server.agent_manager.list_passages(
actor=user,
agent_id=agent_id,
ascending=False,
end_date=cursor2,
limit=1000,
)
# assert passages_1[0].text == "Cinderella wore a blue dress"
assert len(passages_2) in [3, 4] # NOTE: exact size seems non-deterministic, so loosen test
assert len(passages_3) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
latest = passages_1[0]
earliest = passages_2[-1]
# test archival memory
passage_1 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, limit=1)
passage_1 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, limit=1, ascending=True)
assert len(passage_1) == 1
passage_2 = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passage_1[-1].id, limit=1000)
assert passage_1[0].text == "alpha"
passage_2 = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=earliest.id, limit=1000, ascending=True)
assert len(passage_2) in [4, 5] # NOTE: exact size seems non-deterministic, so loosen test
assert all("alpha" not in passage.text for passage in passage_2)
# test safe empty return
passage_none = server.get_agent_archival(user_id=user_id, agent_id=agent_id, cursor=passages_1[0].id, limit=1000)
passage_none = server.agent_manager.list_passages(actor=user, agent_id=agent_id, cursor=latest.id, limit=1000, ascending=True)
assert len(passage_none) == 0
@@ -955,6 +959,14 @@ def test_memory_rebuild_count(server, user_id, mock_e2b_api_key_none, base_tools
def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, other_agent_id: str, tmp_path):
actor = server.user_manager.get_user_or_default(user_id)
existing_sources = server.source_manager.list_sources(actor=actor)
if len(existing_sources) > 0:
for source in existing_sources:
server.agent_manager.detach_source(agent_id=agent_id, source_id=source.id, actor=actor)
initial_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert initial_passage_count == 0
# Create a source
source = server.source_manager.create_source(
PydanticSource(
@@ -973,10 +985,6 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
# Attach source to agent first
server.agent_manager.attach_source(agent_id=agent_id, source_id=source.id, actor=actor)
# Get initial passage count
initial_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
assert initial_passage_count == 0
# Create a job for loading the first file
job = server.job_manager.create_job(
PydanticJob(
@@ -1001,7 +1009,7 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job.metadata_["num_documents"] == 1
# Verify passages were added
first_file_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
first_file_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert first_file_passage_count > initial_passage_count
# Create a second test file with different content
@@ -1032,14 +1040,13 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert job2.metadata_["num_documents"] == 1
# Verify passages were appended (not replaced)
final_passage_count = server.passage_manager.size(actor=actor, agent_id=agent_id, source_id=source.id)
final_passage_count = server.agent_manager.passage_size(agent_id=agent_id, actor=actor)
assert final_passage_count > first_file_passage_count
# Verify both old and new content is searchable
passages = server.passage_manager.list_passages(
actor=actor,
passages = server.agent_manager.list_passages(
agent_id=agent_id,
source_id=source.id,
actor=actor,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
@@ -1048,35 +1055,27 @@ def test_load_file_to_source(server: SyncServer, user_id: str, agent_id: str, ot
assert any("chicken" in passage.text.lower() for passage in passages)
assert any("Anna".lower() in passage.text.lower() for passage in passages)
# TODO: Add this test back in after separation of `Passage tables` (LET-449)
# # Load second agent
# agent2 = server.load_agent(agent_id=other_agent_id)
# Initially should have no passages
initial_agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
assert initial_agent2_passages == 0
# # Initially should have no passages
# initial_agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# assert initial_agent2_passages == 0
# Attach source to second agent
server.agent_manager.attach_source(agent_id=other_agent_id, source_id=source.id, actor=actor)
# # Attach source to second agent
# agent2.attach_source(user=user, source_id=source.id, source_manager=server.source_manager, ms=server.ms)
# Verify second agent has same number of passages as first agent
agent2_passages = server.agent_manager.passage_size(agent_id=other_agent_id, actor=actor, source_id=source.id)
agent1_passages = server.agent_manager.passage_size(agent_id=agent_id, actor=actor, source_id=source.id)
assert agent2_passages == agent1_passages
# # Verify second agent has same number of passages as first agent
# agent2_passages = server.passage_manager.size(actor=user, agent_id=other_agent_id, source_id=source.id)
# agent1_passages = server.passage_manager.size(actor=user, agent_id=agent_id, source_id=source.id)
# assert agent2_passages == agent1_passages
# # Verify second agent can query the same content
# passages2 = server.passage_manager.list_passages(
# actor=user,
# agent_id=other_agent_id,
# source_id=source.id,
# query_text="what does Timber like to eat",
# embedding_config=EmbeddingConfig.default_config(provider="openai"),
# embed_query=True,
# limit=10,
# )
# assert len(passages2) == len(passages)
# assert any("chicken" in passage.text.lower() for passage in passages2)
# assert any("sleep" in passage.text.lower() for passage in passages2)
# # Cleanup
# server.delete_agent(user_id=user_id, agent_id=agent2_state.id)
# Verify second agent can query the same content
passages2 = server.agent_manager.list_passages(
actor=actor,
agent_id=other_agent_id,
source_id=source.id,
query_text="what does Timber like to eat",
embedding_config=EmbeddingConfig.default_config(provider="openai"),
embed_query=True,
)
assert len(passages2) == len(passages)
assert any("chicken" in passage.text.lower() for passage in passages2)
assert any("Anna".lower() in passage.text.lower() for passage in passages2)