fix: Modify the list ORM function (#2208)

This commit is contained in:
Matthew Zhou
2024-12-09 19:35:58 -08:00
committed by GitHub
parent d61b2f9545
commit 2125421bd8
8 changed files with 83 additions and 106 deletions

View File

@@ -1,4 +1,5 @@
import os
import time
from datetime import datetime, timedelta
import pytest
@@ -73,8 +74,8 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
azure_version=None,
azure_deployment=None,
)
using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
CREATE_DELAY_SQLITE = 1
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
@pytest.fixture(autouse=True)
@@ -911,6 +912,8 @@ def test_list_sources(server: SyncServer, default_user):
"""Test listing sources with pagination."""
# Create multiple sources
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
# List sources without pagination
@@ -1004,6 +1007,8 @@ def test_list_files(server: SyncServer, default_user, default_source):
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.source_manager.create_file(
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
actor=default_user,
@@ -1184,6 +1189,8 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
config=LocalSandboxConfig(sandbox_dir=""),
)
server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user)
# List configs without pagination
@@ -1239,6 +1246,8 @@ def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, defau
env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1")
env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2")
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
if USING_SQLITE:
time.sleep(CREATE_DELAY_SQLITE)
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
# List env vars without pagination
@@ -1299,7 +1308,7 @@ def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agen
assert default_block.label not in labels
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.blocks_agents_manager.add_block_to_agent(
@@ -1361,7 +1370,7 @@ def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_u
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
block_manager = BlockManager()
block_manager.delete_block(block_id=default_block.id, actor=default_user)
@@ -1401,7 +1410,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
assert print_tool.name not in names
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
with pytest.raises(ForeignKeyConstraintViolationError):
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
@@ -1447,7 +1456,7 @@ def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_us
assert len(agent_ids) == 2
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
tool_manager = ToolManager()
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)

View File

@@ -161,37 +161,25 @@ def test_user_message(server, user_id, agent_id):
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
# TODO: Add this back, this is broken on main
# @pytest.mark.order(5)
# def test_get_recall_memory(server, org_id, user_id, agent_id):
# # test recall memory cursor pagination
# messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
# cursor1 = messages_1[-1].id
# messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
# messages_2[-1].id
# messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
# messages_3[-1].id
# assert messages_3[-1].created_at >= messages_3[0].created_at
# assert len(messages_3) == len(messages_1) + len(messages_2)
# messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
# assert len(messages_4) == 1
#
# # test in-context message ids
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
# message_ids = [m.id for m in messages_3]
# for message_id in in_context_ids:
# assert message_id in message_ids, f"{message_id} not in {message_ids}"
#
# # test recall memory
# messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1)
# assert len(messages_1) == 1
# messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000)
# messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2)
# # not sure exactly how many messages there should be
# assert len(messages_2) > len(messages_3)
# # test safe empty return
# messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000)
# assert len(messages_none) == 0
@pytest.mark.order(5)
def test_get_recall_memory(server, org_id, user_id, agent_id):
# test recall memory cursor pagination
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
messages_2[-1].id
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
messages_3[-1].id
assert messages_3[-1].created_at >= messages_3[0].created_at
assert len(messages_3) == len(messages_1) + len(messages_2)
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
assert len(messages_4) == 1
# test in-context message ids
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
message_ids = [m.id for m in messages_3]
for message_id in in_context_ids:
assert message_id in message_ids, f"{message_id} not in {message_ids}"
@pytest.mark.order(6)