feat(asyncify): migrate actors(users) endpoints (#2211)
This commit is contained in:
@@ -1,16 +1,14 @@
|
||||
from datetime import datetime, timezone
|
||||
from typing import Any, AsyncGenerator, Dict, List, Optional
|
||||
from typing import AsyncGenerator, List, Optional
|
||||
|
||||
from openai import AsyncStream
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk, Choice, ChoiceDelta
|
||||
from openai.types.chat.chat_completion_chunk import ChatCompletionChunk
|
||||
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG, PRE_EXECUTION_MESSAGE_ARG
|
||||
from letta.interfaces.utils import _format_sse_chunk
|
||||
from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG
|
||||
from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage
|
||||
from letta.schemas.letta_message_content import ReasoningContent, RedactedReasoningContent, TextContent
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.server.rest_api.json_parser import OptimisticJSONParser
|
||||
from letta.streaming_utils import JSONInnerThoughtsExtractor
|
||||
|
||||
|
||||
@@ -745,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
self.is_deleted = True
|
||||
return self.update(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase":
|
||||
"""Soft delete a record asynchronously (mark as deleted)."""
|
||||
logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||
|
||||
if actor:
|
||||
self._set_created_and_updated_by_fields(actor.id)
|
||||
|
||||
self.is_deleted = True
|
||||
return await self.update_async(db_session)
|
||||
|
||||
@handle_db_timeout
|
||||
def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database."""
|
||||
@@ -761,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
else:
|
||||
logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted")
|
||||
|
||||
@handle_db_timeout
|
||||
async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None:
|
||||
"""Permanently removes the record from the database asynchronously."""
|
||||
logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)")
|
||||
|
||||
async with db_session as session:
|
||||
try:
|
||||
await session.delete(self)
|
||||
await session.commit()
|
||||
except Exception as e:
|
||||
await session.rollback()
|
||||
logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}")
|
||||
raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}")
|
||||
|
||||
@handle_db_timeout
|
||||
def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase":
|
||||
logger.debug(...)
|
||||
|
||||
@@ -83,7 +83,7 @@ async def list_agents(
|
||||
"""
|
||||
|
||||
# Retrieve the actor (user) details
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Call list_agents directly without unnecessary dict handling
|
||||
return await server.agent_manager.list_agents_async(
|
||||
@@ -163,7 +163,7 @@ async def import_agent_serialized(
|
||||
"""
|
||||
Import a serialized agent file and recreate the agent in the system.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
serialized_data = await file.read()
|
||||
@@ -233,7 +233,7 @@ async def create_agent(
|
||||
Create a new agent with the specified configuration.
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.create_agent_async(agent, actor=actor)
|
||||
except Exception as e:
|
||||
traceback.print_exc()
|
||||
@@ -248,7 +248,7 @@ async def modify_agent(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Update an existing agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor)
|
||||
|
||||
|
||||
@@ -628,7 +628,7 @@ async def send_message(
|
||||
Process a user message and return the agent's response.
|
||||
This endpoint accepts a message from a user and processes it through the agent.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
@@ -688,7 +688,7 @@ async def send_message_streaming(
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
request_start_timestamp_ns = get_utc_timestamp_ns()
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
# TODO: This is redundant, remove soon
|
||||
agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor)
|
||||
agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent
|
||||
@@ -788,7 +788,7 @@ async def send_message_async(
|
||||
Asynchronously process a user message and return a run object.
|
||||
The actual processing happens in the background, and the status can be checked using the run ID.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# Create a new job
|
||||
run = Run(
|
||||
@@ -842,6 +842,6 @@ async def list_agent_groups(
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
"""Lists the groups for an agent"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
print("in list agents with manager_type", manager_type)
|
||||
return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor)
|
||||
|
||||
@@ -26,7 +26,7 @@ async def list_blocks(
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
return await server.block_manager.get_blocks_async(
|
||||
actor=actor,
|
||||
label=label,
|
||||
|
||||
@@ -135,7 +135,7 @@ async def send_group_message(
|
||||
Process a user message and return the group's response.
|
||||
This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
@@ -174,7 +174,7 @@ async def send_group_message_streaming(
|
||||
This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern.
|
||||
It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
result = await server.send_group_message_to_agent(
|
||||
group_id=group_id,
|
||||
actor=actor,
|
||||
|
||||
@@ -52,7 +52,7 @@ async def create_messages_batch(
|
||||
detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.",
|
||||
)
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
batch_job = BatchJob(
|
||||
user_id=actor.id,
|
||||
status=JobStatus.running,
|
||||
@@ -100,7 +100,7 @@ async def retrieve_batch_run(
|
||||
"""
|
||||
Get the status of a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
@@ -118,7 +118,7 @@ async def list_batch_runs(
|
||||
List all batch runs.
|
||||
"""
|
||||
# TODO: filter
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH)
|
||||
return [BatchJob.from_job(job) for job in jobs]
|
||||
@@ -150,7 +150,7 @@ async def list_batch_messages(
|
||||
- For subsequent pages, use the ID of the last message from the previous response as the cursor
|
||||
- Results will include messages before/after the cursor based on sort_descending
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
# First, verify the batch job exists and the user has access to it
|
||||
try:
|
||||
@@ -177,7 +177,7 @@ async def cancel_batch_run(
|
||||
"""
|
||||
Cancel a batch run.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
|
||||
@@ -115,7 +115,7 @@ async def list_run_messages(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
messages = server.job_manager.get_run_messages(
|
||||
@@ -182,7 +182,7 @@ async def list_run_steps(
|
||||
if order not in ["asc", "desc"]:
|
||||
raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'")
|
||||
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
|
||||
try:
|
||||
steps = server.job_manager.get_job_steps(
|
||||
|
||||
@@ -87,7 +87,7 @@ async def list_tools(
|
||||
Get a list of all tools available to agents belonging to the org of the user
|
||||
"""
|
||||
try:
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
|
||||
if name is not None:
|
||||
tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor)
|
||||
return [tool] if tool else []
|
||||
|
||||
@@ -14,7 +14,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"])
|
||||
|
||||
|
||||
@router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users")
|
||||
def list_users(
|
||||
async def list_users(
|
||||
after: Optional[str] = Query(None),
|
||||
limit: Optional[int] = Query(50),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
@@ -23,7 +23,7 @@ def list_users(
|
||||
Get a list of all users in the database
|
||||
"""
|
||||
try:
|
||||
users = server.user_manager.list_users(after=after, limit=limit)
|
||||
users = await server.user_manager.list_actors_async(after=after, limit=limit)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
@@ -32,7 +32,7 @@ def list_users(
|
||||
|
||||
|
||||
@router.post("/", tags=["admin"], response_model=User, operation_id="create_user")
|
||||
def create_user(
|
||||
async def create_user(
|
||||
request: UserCreate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
@@ -40,33 +40,33 @@ def create_user(
|
||||
Create a new user in the database
|
||||
"""
|
||||
user = User(**request.model_dump())
|
||||
user = server.user_manager.create_user(user)
|
||||
user = await server.user_manager.create_actor_async(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.put("/", tags=["admin"], response_model=User, operation_id="update_user")
|
||||
def update_user(
|
||||
async def update_user(
|
||||
user: UserUpdate = Body(...),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
"""
|
||||
Update a user in the database
|
||||
"""
|
||||
user = server.user_manager.update_user(user)
|
||||
user = await server.user_manager.update_actor_async(user)
|
||||
return user
|
||||
|
||||
|
||||
@router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user")
|
||||
def delete_user(
|
||||
async def delete_user(
|
||||
user_id: str = Query(..., description="The user_id key to be deleted."),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
):
|
||||
# TODO make a soft deletion, instead of a hard deletion
|
||||
try:
|
||||
user = server.user_manager.get_user_by_id(user_id=user_id)
|
||||
user = await server.user_manager.get_actor_by_id_async(actor_id=user_id)
|
||||
if user is None:
|
||||
raise HTTPException(status_code=404, detail=f"User does not exist")
|
||||
server.user_manager.delete_user_by_id(user_id=user_id)
|
||||
await server.user_manager.delete_actor_by_id_async(user_id=user_id)
|
||||
except HTTPException:
|
||||
raise
|
||||
except Exception as e:
|
||||
|
||||
@@ -38,7 +38,7 @@ async def create_voice_chat_completions(
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id)
|
||||
|
||||
# Create OpenAI async client
|
||||
client = openai.AsyncClient(
|
||||
|
||||
@@ -44,6 +44,14 @@ class UserManager:
|
||||
new_user.create(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser:
|
||||
"""Create a new user if it doesn't already exist (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
new_user = UserModel(**pydantic_user.model_dump(to_orm=True))
|
||||
await new_user.create_async(session)
|
||||
return new_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def update_user(self, user_update: UserUpdate) -> PydanticUser:
|
||||
"""Update user details."""
|
||||
@@ -60,6 +68,22 @@ class UserManager:
|
||||
existing_user.update(session)
|
||||
return existing_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser:
|
||||
"""Update user details (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve the existing user by ID
|
||||
existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id)
|
||||
|
||||
# Update only the fields that are provided in UserUpdate
|
||||
update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
for key, value in update_data.items():
|
||||
setattr(existing_user, key, value)
|
||||
|
||||
# Commit the updated user
|
||||
await existing_user.update_async(session)
|
||||
return existing_user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def delete_user_by_id(self, user_id: str):
|
||||
"""Delete a user and their associated records (agents, sources, mappings)."""
|
||||
@@ -70,6 +94,14 @@ class UserManager:
|
||||
|
||||
session.commit()
|
||||
|
||||
@enforce_types
|
||||
async def delete_actor_by_id_async(self, user_id: str):
|
||||
"""Delete a user and their associated records (agents, sources, mappings) asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Delete from user table
|
||||
user = await UserModel.read_async(db_session=session, identifier=user_id)
|
||||
await user.hard_delete_async(session)
|
||||
|
||||
@enforce_types
|
||||
def get_user_by_id(self, user_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID."""
|
||||
@@ -77,6 +109,13 @@ class UserManager:
|
||||
user = UserModel.read(db_session=session, identifier=user_id)
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser:
|
||||
"""Fetch a user by ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
user = await UserModel.read_async(db_session=session, identifier=actor_id)
|
||||
return user.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_default_user(self) -> PydanticUser:
|
||||
"""Fetch the default user. If it doesn't exist, create it."""
|
||||
@@ -96,6 +135,26 @@ class UserManager:
|
||||
except NoResultFound:
|
||||
return self.get_default_user()
|
||||
|
||||
@enforce_types
|
||||
async def get_default_actor_async(self) -> PydanticUser:
|
||||
"""Fetch the default user asynchronously. If it doesn't exist, create it."""
|
||||
try:
|
||||
return await self.get_actor_by_id_async(self.DEFAULT_USER_ID)
|
||||
except NoResultFound:
|
||||
# Fall back to synchronous version since create_default_user isn't async yet
|
||||
return self.create_default_user(org_id=self.DEFAULT_ORG_ID)
|
||||
|
||||
@enforce_types
|
||||
async def get_actor_or_default_async(self, actor_id: Optional[str] = None):
|
||||
"""Fetch the user or default user asynchronously."""
|
||||
if not actor_id:
|
||||
return await self.get_default_actor_async()
|
||||
|
||||
try:
|
||||
return await self.get_actor_by_id_async(actor_id=actor_id)
|
||||
except NoResultFound:
|
||||
return await self.get_default_actor_async()
|
||||
|
||||
@enforce_types
|
||||
def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
|
||||
"""List all users with optional pagination."""
|
||||
@@ -106,3 +165,14 @@ class UserManager:
|
||||
limit=limit,
|
||||
)
|
||||
return [user.to_pydantic() for user in users]
|
||||
|
||||
@enforce_types
|
||||
async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]:
|
||||
"""List all users with optional pagination (async version)."""
|
||||
async with db_registry.async_session() as session:
|
||||
users = await UserModel.list_async(
|
||||
db_session=session,
|
||||
after=after,
|
||||
limit=limit,
|
||||
)
|
||||
return [user.to_pydantic() for user in users]
|
||||
|
||||
@@ -124,16 +124,16 @@ def default_user(server: SyncServer, default_organization):
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user(server: SyncServer, default_organization):
|
||||
async def other_user(server: SyncServer, default_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=default_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def other_user_different_org(server: SyncServer, other_organization):
|
||||
async def other_user_different_org(server: SyncServer, other_organization):
|
||||
"""Fixture to create and return the default user within the default organization."""
|
||||
user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=other_organization.id))
|
||||
yield user
|
||||
|
||||
|
||||
@@ -2120,20 +2120,21 @@ def test_passage_cascade_deletion(
|
||||
# ======================================================================================================================
|
||||
# User Manager Tests
|
||||
# ======================================================================================================================
|
||||
def test_list_users(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_list_users(server: SyncServer, event_loop):
|
||||
# Create default organization
|
||||
org = server.organization_manager.create_default_organization()
|
||||
|
||||
user_name = "user"
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id))
|
||||
|
||||
users = server.user_manager.list_users()
|
||||
users = await server.user_manager.list_actors_async()
|
||||
assert len(users) == 1
|
||||
assert users[0].name == user_name
|
||||
|
||||
# Delete it after
|
||||
server.user_manager.delete_user_by_id(user.id)
|
||||
assert len(server.user_manager.list_users()) == 0
|
||||
await server.user_manager.delete_actor_by_id_async(user.id)
|
||||
assert len(await server.user_manager.list_actors_async()) == 0
|
||||
|
||||
|
||||
def test_create_default_user(server: SyncServer):
|
||||
@@ -2143,7 +2144,8 @@ def test_create_default_user(server: SyncServer):
|
||||
assert retrieved.name == server.user_manager.DEFAULT_USER_NAME
|
||||
|
||||
|
||||
def test_update_user(server: SyncServer):
|
||||
@pytest.mark.asyncio
|
||||
async def test_update_user(server: SyncServer, event_loop):
|
||||
# Create default organization
|
||||
default_org = server.organization_manager.create_default_organization()
|
||||
test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org"))
|
||||
@@ -2152,16 +2154,16 @@ def test_update_user(server: SyncServer):
|
||||
user_name_b = "b"
|
||||
|
||||
# Assert it's been created
|
||||
user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id))
|
||||
user = await server.user_manager.create_actor_async(PydanticUser(name=user_name_a, organization_id=default_org.id))
|
||||
assert user.name == user_name_a
|
||||
|
||||
# Adjust name
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b))
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID
|
||||
|
||||
# Adjust org id
|
||||
user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id))
|
||||
assert user.name == user_name_b
|
||||
assert user.organization_id == test_org.id
|
||||
|
||||
|
||||
Reference in New Issue
Block a user