feat: Add ORM for user model (#1924)

This commit is contained in:
Matthew Zhou
2024-10-23 10:28:00 -07:00
committed by GitHub
parent a70dea15a4
commit ff4be4576b
16 changed files with 422 additions and 282 deletions

View File

@@ -3,17 +3,9 @@ import uuid
import warnings
import pytest
from sqlalchemy import delete
import letta.utils as utils
from letta.constants import (
BASE_TOOLS,
DEFAULT_MESSAGE_TOOL,
DEFAULT_MESSAGE_TOOL_KWARG,
DEFAULT_ORG_ID,
DEFAULT_ORG_NAME,
)
from letta.orm.organization import Organization
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
from letta.schemas.enums import MessageRole
utils.DEBUG = True
@@ -39,14 +31,6 @@ from letta.server.server import SyncServer
from .utils import DummyDataConnector
@pytest.fixture(autouse=True)
def clear_organization_table(server: SyncServer):
"""Fixture to clear the organization table before each test."""
with server.organization_manager.session_maker() as session:
session.execute(delete(Organization)) # Clear all records from the organization table
session.commit() # Commit the deletion
@pytest.fixture(scope="module")
def server():
# if os.getenv("OPENAI_API_KEY"):
@@ -76,15 +60,27 @@ def server():
@pytest.fixture(scope="module")
def user_id(server):
def org_id(server):
# create org
org = server.organization_manager.create_organization(name="test_org")
print(f"Created org\n{org.id}")
yield org.id
# cleanup
server.organization_manager.delete_organization_by_id(org.id)
@pytest.fixture(scope="module")
def user_id(server, org_id):
# create user
user = server.create_user(UserCreate(name="test_user"))
user = server.create_user(UserCreate(name="test_user", organization_id=org_id))
print(f"Created user\n{user.id}")
yield user.id
# cleanup
server.delete_user(user.id)
server.user_manager.delete_user_by_id(user.id)
@pytest.fixture(scope="module")
@@ -183,7 +179,7 @@ def test_user_message(server, user_id, agent_id):
@pytest.mark.order(5)
def test_get_recall_memory(server, user_id, agent_id):
def test_get_recall_memory(server, org_id, user_id, agent_id):
# test recall memory cursor pagination
messages_1 = server.get_agent_recall_cursor(user_id=user_id, agent_id=agent_id, limit=2)
cursor1 = messages_1[-1].id
@@ -563,47 +559,3 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
+ overview.num_tokens_functions_definitions
+ overview.num_tokens_external_memory_summary
)
def test_list_organizations(server: SyncServer):
# Create a new org and confirm that it is created correctly
org_name = "test"
org = server.organization_manager.create_organization(name=org_name)
orgs = server.organization_manager.list_organizations()
assert len(orgs) == 1
assert orgs[0].name == org_name
# Delete it after
server.organization_manager.delete_organization(org.id)
assert len(server.organization_manager.list_organizations()) == 0
def test_create_default_organization(server: SyncServer):
server.organization_manager.create_default_organization()
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
assert retrieved.name == DEFAULT_ORG_NAME
def test_update_organization_name(server: SyncServer):
org_name_a = "a"
org_name_b = "b"
org = server.organization_manager.create_organization(name=org_name_a)
assert org.name == org_name_a
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
assert org.name == org_name_b
def test_list_organizations_pagination(server: SyncServer):
server.organization_manager.create_organization(name="a")
server.organization_manager.create_organization(name="b")
orgs_x = server.organization_manager.list_organizations(limit=1)
assert len(orgs_x) == 1
orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1)
assert len(orgs_y) == 1
assert orgs_y[0].name != orgs_x[0].name
orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1)
assert len(orgs) == 0