test: Add test for archival tag compilation into system prompt [LET-4114] (#4312)
Add tests and thread created_at
This commit is contained in:
@@ -1134,10 +1134,9 @@ class SyncServer(Server):
|
||||
agent_state = await self.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Use passage manager which handles dual-write to Turbopuffer if enabled
|
||||
passages = await self.passage_manager.insert_passage(agent_state=agent_state, text=memory_contents, actor=actor)
|
||||
|
||||
# TODO: Add support for tags and created_at parameters
|
||||
# Currently PassageManager.insert_passage doesn't support these parameters
|
||||
passages = await self.passage_manager.insert_passage(
|
||||
agent_state=agent_state, text=memory_contents, tags=tags, actor=actor, created_at=created_at
|
||||
)
|
||||
|
||||
return passages
|
||||
|
||||
|
||||
@@ -551,6 +551,7 @@ class PassageManager:
|
||||
text: str,
|
||||
actor: PydanticUser,
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passage(s) into archival memory
|
||||
|
||||
@@ -588,15 +589,20 @@ class PassageManager:
|
||||
|
||||
# Always write to SQL database first
|
||||
for chunk_text, embedding in zip(text_chunks, embeddings):
|
||||
passage_data = {
|
||||
"organization_id": actor.organization_id,
|
||||
"archive_id": archive.id,
|
||||
"text": chunk_text,
|
||||
"embedding": embedding,
|
||||
"embedding_config": agent_state.embedding_config,
|
||||
"tags": tags,
|
||||
}
|
||||
# only include created_at if provided
|
||||
if created_at is not None:
|
||||
passage_data["created_at"] = created_at
|
||||
|
||||
passage = await self.create_agent_passage_async(
|
||||
PydanticPassage(
|
||||
organization_id=actor.organization_id,
|
||||
archive_id=archive.id,
|
||||
text=chunk_text,
|
||||
embedding=embedding,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
tags=tags,
|
||||
),
|
||||
PydanticPassage(**passage_data),
|
||||
actor=actor,
|
||||
)
|
||||
passages.append(passage)
|
||||
|
||||
@@ -1224,89 +1224,80 @@ def test_preview_payload(client: LettaSDKClient):
|
||||
client.agents.delete(agent_id=temp_agent.id)
|
||||
|
||||
|
||||
# TODO: Re-enable
|
||||
# def test_archive_tags_in_system_prompt(client: LettaSDKClient):
|
||||
# """Test that archive tags are correctly compiled into the system prompt."""
|
||||
# # Create a test agent
|
||||
# temp_agent = client.agents.create(
|
||||
# memory_blocks=[
|
||||
# CreateBlock(
|
||||
# label="human",
|
||||
# value="username: test_user",
|
||||
# ),
|
||||
# ],
|
||||
# model="openai/gpt-4o-mini",
|
||||
# embedding="openai/text-embedding-3-small",
|
||||
# )
|
||||
#
|
||||
# try:
|
||||
# # Add passages with different tags to the agent's archive
|
||||
# test_tags = ["project_alpha", "meeting_notes", "research", "ideas", "todo_items"]
|
||||
#
|
||||
# # Create passages with tags
|
||||
# for i, tag in enumerate(test_tags):
|
||||
# client.agents.passages.create(
|
||||
# agent_id=temp_agent.id,
|
||||
# text=f"Test passage {i} with tag {tag}",
|
||||
# tags=[tag]
|
||||
# )
|
||||
#
|
||||
# # Also create a passage with multiple tags
|
||||
# client.agents.passages.create(
|
||||
# agent_id=temp_agent.id,
|
||||
# text="Passage with multiple tags",
|
||||
# tags=["multi_tag_1", "multi_tag_2"]
|
||||
# )
|
||||
#
|
||||
# # Get the raw payload to check the system prompt
|
||||
# payload = client.agents.messages.preview_raw_payload(
|
||||
# agent_id=temp_agent.id,
|
||||
# request=LettaRequest(
|
||||
# messages=[
|
||||
# MessageCreate(
|
||||
# role="user",
|
||||
# content=[
|
||||
# TextContent(
|
||||
# text="Hello",
|
||||
# )
|
||||
# ],
|
||||
# )
|
||||
# ],
|
||||
# ),
|
||||
# )
|
||||
#
|
||||
# # Extract the system message
|
||||
# assert isinstance(payload, dict)
|
||||
# assert "messages" in payload
|
||||
# assert len(payload["messages"]) > 0
|
||||
#
|
||||
# system_message = payload["messages"][0]
|
||||
# assert system_message["role"] == "system"
|
||||
# system_content = system_message["content"]
|
||||
#
|
||||
# # Check that the archive tags are included in the metadata
|
||||
# assert "Available archival memory tags:" in system_content
|
||||
#
|
||||
# # Check that all unique tags are present
|
||||
# all_unique_tags = set(test_tags + ["multi_tag_1", "multi_tag_2"])
|
||||
# for tag in all_unique_tags:
|
||||
# assert tag in system_content, f"Tag '{tag}' not found in system prompt"
|
||||
#
|
||||
# # Verify the tags are in the memory_metadata section
|
||||
# assert "<memory_metadata>" in system_content
|
||||
# assert "</memory_metadata>" in system_content
|
||||
#
|
||||
# # Extract the metadata section to verify format
|
||||
# metadata_start = system_content.index("<memory_metadata>")
|
||||
# metadata_end = system_content.index("</memory_metadata>")
|
||||
# metadata_section = system_content[metadata_start:metadata_end]
|
||||
#
|
||||
# # Verify the tags line is properly formatted
|
||||
# assert "- Available archival memory tags:" in metadata_section
|
||||
#
|
||||
# finally:
|
||||
# # Clean up the agent
|
||||
# client.agents.delete(agent_id=temp_agent.id)
|
||||
def test_archive_tags_in_system_prompt(client: LettaSDKClient):
|
||||
"""Test that archive tags are correctly compiled into the system prompt."""
|
||||
# Create a test agent
|
||||
temp_agent = client.agents.create(
|
||||
memory_blocks=[
|
||||
CreateBlock(
|
||||
label="human",
|
||||
value="username: test_user",
|
||||
),
|
||||
],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Add passages with different tags to the agent's archive
|
||||
test_tags = ["project_alpha", "meeting_notes", "research", "ideas", "todo_items"]
|
||||
|
||||
# Create passages with tags
|
||||
for i, tag in enumerate(test_tags):
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text=f"Test passage {i} with tag {tag}", tags=[tag])
|
||||
|
||||
# Also create a passage with multiple tags
|
||||
client.agents.passages.create(agent_id=temp_agent.id, text="Passage with multiple tags", tags=["multi_tag_1", "multi_tag_2"])
|
||||
|
||||
# Get the raw payload to check the system prompt
|
||||
payload = client.agents.messages.preview_raw_payload(
|
||||
agent_id=temp_agent.id,
|
||||
request=LettaRequest(
|
||||
messages=[
|
||||
MessageCreate(
|
||||
role="user",
|
||||
content=[
|
||||
TextContent(
|
||||
text="Hello",
|
||||
)
|
||||
],
|
||||
)
|
||||
],
|
||||
),
|
||||
)
|
||||
|
||||
# Extract the system message
|
||||
assert isinstance(payload, dict)
|
||||
assert "messages" in payload
|
||||
assert len(payload["messages"]) > 0
|
||||
|
||||
system_message = payload["messages"][0]
|
||||
assert system_message["role"] == "system"
|
||||
system_content = system_message["content"]
|
||||
|
||||
# Check that the archive tags are included in the metadata
|
||||
assert "Available archival memory tags:" in system_content
|
||||
|
||||
# Check that all unique tags are present
|
||||
all_unique_tags = set(test_tags + ["multi_tag_1", "multi_tag_2"])
|
||||
for tag in all_unique_tags:
|
||||
assert tag in system_content, f"Tag '{tag}' not found in system prompt"
|
||||
|
||||
# Verify the tags are in the memory_metadata section
|
||||
assert "<memory_metadata>" in system_content
|
||||
assert "</memory_metadata>" in system_content
|
||||
|
||||
# Extract the metadata section to verify format
|
||||
metadata_start = system_content.index("<memory_metadata>")
|
||||
metadata_end = system_content.index("</memory_metadata>")
|
||||
metadata_section = system_content[metadata_start:metadata_end]
|
||||
|
||||
# Verify the tags line is properly formatted
|
||||
assert "- Available archival memory tags:" in metadata_section
|
||||
|
||||
finally:
|
||||
# Clean up the agent
|
||||
client.agents.delete(agent_id=temp_agent.id)
|
||||
|
||||
|
||||
def test_agent_tools_list(client: LettaSDKClient):
|
||||
|
||||
Reference in New Issue
Block a user