feat: Add ORM for organization model (#1914)
This commit is contained in:
@@ -3,9 +3,17 @@ import uuid
|
||||
import warnings
|
||||
|
||||
import pytest
|
||||
from sqlalchemy import delete
|
||||
|
||||
import letta.utils as utils
|
||||
from letta.constants import BASE_TOOLS, DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.constants import (
|
||||
BASE_TOOLS,
|
||||
DEFAULT_MESSAGE_TOOL,
|
||||
DEFAULT_MESSAGE_TOOL_KWARG,
|
||||
DEFAULT_ORG_ID,
|
||||
DEFAULT_ORG_NAME,
|
||||
)
|
||||
from letta.orm.organization import Organization
|
||||
from letta.schemas.enums import MessageRole
|
||||
|
||||
utils.DEBUG = True
|
||||
@@ -31,6 +39,14 @@ from letta.server.server import SyncServer
|
||||
from .utils import DummyDataConnector
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_organization_table(server: SyncServer):
|
||||
"""Fixture to clear the organization table before each test."""
|
||||
with server.organization_manager.session_maker() as session:
|
||||
session.execute(delete(Organization)) # Clear all records from the organization table
|
||||
session.commit() # Commit the deletion
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
# if os.getenv("OPENAI_API_KEY"):
|
||||
@@ -547,3 +563,47 @@ def test_get_context_window_overview(server: SyncServer, user_id: str, agent_id:
|
||||
+ overview.num_tokens_functions_definitions
|
||||
+ overview.num_tokens_external_memory_summary
|
||||
)
|
||||
|
||||
|
||||
def test_list_organizations(server: SyncServer):
|
||||
# Create a new org and confirm that it is created correctly
|
||||
org_name = "test"
|
||||
org = server.organization_manager.create_organization(name=org_name)
|
||||
|
||||
orgs = server.organization_manager.list_organizations()
|
||||
assert len(orgs) == 1
|
||||
assert orgs[0].name == org_name
|
||||
|
||||
# Delete it after
|
||||
server.organization_manager.delete_organization(org.id)
|
||||
assert len(server.organization_manager.list_organizations()) == 0
|
||||
|
||||
|
||||
def test_create_default_organization(server: SyncServer):
|
||||
server.organization_manager.create_default_organization()
|
||||
retrieved = server.organization_manager.get_organization_by_id(DEFAULT_ORG_ID)
|
||||
assert retrieved.name == DEFAULT_ORG_NAME
|
||||
|
||||
|
||||
def test_update_organization_name(server: SyncServer):
|
||||
org_name_a = "a"
|
||||
org_name_b = "b"
|
||||
org = server.organization_manager.create_organization(name=org_name_a)
|
||||
assert org.name == org_name_a
|
||||
org = server.organization_manager.update_organization_name_using_id(org_id=org.id, name=org_name_b)
|
||||
assert org.name == org_name_b
|
||||
|
||||
|
||||
def test_list_organizations_pagination(server: SyncServer):
|
||||
server.organization_manager.create_organization(name="a")
|
||||
server.organization_manager.create_organization(name="b")
|
||||
|
||||
orgs_x = server.organization_manager.list_organizations(limit=1)
|
||||
assert len(orgs_x) == 1
|
||||
|
||||
orgs_y = server.organization_manager.list_organizations(cursor=orgs_x[0].id, limit=1)
|
||||
assert len(orgs_y) == 1
|
||||
assert orgs_y[0].name != orgs_x[0].name
|
||||
|
||||
orgs = server.organization_manager.list_organizations(cursor=orgs_y[0].id, limit=1)
|
||||
assert len(orgs) == 0
|
||||
|
||||
Reference in New Issue
Block a user