fix: Modify the list ORM function (#2208)
This commit is contained in:
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import time
|
||||
from datetime import datetime, timedelta
|
||||
|
||||
import pytest
|
||||
@@ -73,8 +74,8 @@ DEFAULT_EMBEDDING_CONFIG = EmbeddingConfig(
|
||||
azure_version=None,
|
||||
azure_deployment=None,
|
||||
)
|
||||
|
||||
using_sqlite = not bool(os.getenv("LETTA_PG_URI"))
|
||||
CREATE_DELAY_SQLITE = 1
|
||||
USING_SQLITE = not bool(os.getenv("LETTA_PG_URI"))
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
@@ -911,6 +912,8 @@ def test_list_sources(server: SyncServer, default_user):
|
||||
"""Test listing sources with pagination."""
|
||||
# Create multiple sources
|
||||
server.source_manager.create_source(PydanticSource(name="Source 1", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.source_manager.create_source(PydanticSource(name="Source 2", embedding_config=DEFAULT_EMBEDDING_CONFIG), actor=default_user)
|
||||
|
||||
# List sources without pagination
|
||||
@@ -1004,6 +1007,8 @@ def test_list_files(server: SyncServer, default_user, default_source):
|
||||
PydanticFileMetadata(file_name="File 1", file_path="/path/to/file1.txt", file_type="text/plain", source_id=default_source.id),
|
||||
actor=default_user,
|
||||
)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.source_manager.create_file(
|
||||
PydanticFileMetadata(file_name="File 2", file_path="/path/to/file2.txt", file_type="text/plain", source_id=default_source.id),
|
||||
actor=default_user,
|
||||
@@ -1184,6 +1189,8 @@ def test_list_sandbox_configs(server: SyncServer, default_user):
|
||||
config=LocalSandboxConfig(sandbox_dir=""),
|
||||
)
|
||||
server.sandbox_config_manager.create_or_update_sandbox_config(config_a, actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.sandbox_config_manager.create_or_update_sandbox_config(config_b, actor=default_user)
|
||||
|
||||
# List configs without pagination
|
||||
@@ -1239,6 +1246,8 @@ def test_list_sandbox_env_vars(server: SyncServer, sandbox_config_fixture, defau
|
||||
env_var_create_a = SandboxEnvironmentVariableCreate(key="VAR1", value="value1")
|
||||
env_var_create_b = SandboxEnvironmentVariableCreate(key="VAR2", value="value2")
|
||||
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_a, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
|
||||
if USING_SQLITE:
|
||||
time.sleep(CREATE_DELAY_SQLITE)
|
||||
server.sandbox_config_manager.create_sandbox_env_var(env_var_create_b, sandbox_config_id=sandbox_config_fixture.id, actor=default_user)
|
||||
|
||||
# List env vars without pagination
|
||||
@@ -1299,7 +1308,7 @@ def test_change_label_on_block_reflects_in_block_agents_table(server, sarah_agen
|
||||
assert default_block.label not in labels
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_block_to_agent_nonexistent_block(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.blocks_agents_manager.add_block_to_agent(
|
||||
@@ -1361,7 +1370,7 @@ def test_list_agent_ids_with_block(server, sarah_agent, charles_agent, default_u
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_block_to_agent_with_deleted_block(server, sarah_agent, default_user, default_block):
|
||||
block_manager = BlockManager()
|
||||
block_manager.delete_block(block_id=default_block.id, actor=default_user)
|
||||
@@ -1401,7 +1410,7 @@ def test_change_name_on_tool_reflects_in_tool_agents_table(server, sarah_agent,
|
||||
assert print_tool.name not in names
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_nonexistent_tool(server, sarah_agent, default_user):
|
||||
with pytest.raises(ForeignKeyConstraintViolationError):
|
||||
server.tools_agents_manager.add_tool_to_agent(agent_id=sarah_agent.id, tool_id="nonexistent_tool", tool_name="nonexistent_name")
|
||||
@@ -1447,7 +1456,7 @@ def test_list_agent_ids_with_tool(server, sarah_agent, charles_agent, default_us
|
||||
assert len(agent_ids) == 2
|
||||
|
||||
|
||||
@pytest.mark.skipif(using_sqlite, reason="Skipped because using SQLite")
|
||||
@pytest.mark.skipif(USING_SQLITE, reason="Skipped because using SQLite")
|
||||
def test_add_tool_to_agent_with_deleted_tool(server, sarah_agent, default_user, print_tool):
|
||||
tool_manager = ToolManager()
|
||||
tool_manager.delete_tool_by_id(tool_id=print_tool.id, actor=default_user)
|
||||
|
||||
@@ -161,37 +161,25 @@ def test_user_message(server, user_id, agent_id):
|
||||
# server.user_message(user_id=user_id, agent_id=agent_id, message="Hello?")
|
||||
|
||||
|
||||
# TODO: Add this back, this is broken on main
|
||||
# @pytest.mark.order(5)
|
||||
# def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
# # test recall memory cursor pagination
|
||||
# messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
# cursor1 = messages_1[-1].id
|
||||
# messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
# messages_2[-1].id
|
||||
# messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
# messages_3[-1].id
|
||||
# assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
# assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
# messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
# assert len(messages_4) == 1
|
||||
#
|
||||
# # test in-context message ids
|
||||
# in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
# message_ids = [m.id for m in messages_3]
|
||||
# for message_id in in_context_ids:
|
||||
# assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
#
|
||||
# # test recall memory
|
||||
# messages_1 = server.get_agent_messages(agent_id=agent_id, start=0, count=1)
|
||||
# assert len(messages_1) == 1
|
||||
# messages_2 = server.get_agent_messages(agent_id=agent_id, start=1, count=1000)
|
||||
# messages_3 = server.get_agent_messages(agent_id=agent_id, start=1, count=2)
|
||||
# # not sure exactly how many messages there should be
|
||||
# assert len(messages_2) > len(messages_3)
|
||||
# # test safe empty return
|
||||
# messages_none = server.get_agent_messages(agent_id=agent_id, start=1000, count=1000)
|
||||
# assert len(messages_none) == 0
|
||||
@pytest.mark.order(5)
|
||||
def test_get_recall_memory(server, org_id, user_id, agent_id):
|
||||
# test recall memory cursor pagination
|
||||
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
|
||||
cursor1 = messages_1[-1].id
|
||||
messages_2 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, after=cursor1, limit=1000)
|
||||
messages_2[-1].id
|
||||
messages_3 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=1000)
|
||||
messages_3[-1].id
|
||||
assert messages_3[-1].created_at >= messages_3[0].created_at
|
||||
assert len(messages_3) == len(messages_1) + len(messages_2)
|
||||
messages_4 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, reverse=True, before=cursor1)
|
||||
assert len(messages_4) == 1
|
||||
|
||||
# test in-context message ids
|
||||
in_context_ids = server.get_in_context_message_ids(agent_id=agent_id)
|
||||
message_ids = [m.id for m in messages_3]
|
||||
for message_id in in_context_ids:
|
||||
assert message_id in message_ids, f"{message_id} not in {message_ids}"
|
||||
|
||||
|
||||
@pytest.mark.order(6)
|
||||
|
||||
Reference in New Issue
Block a user