diff --git a/letta/__init__.py b/letta/__init__.py index 876aac38..06054e86 100644 --- a/letta/__init__.py +++ b/letta/__init__.py @@ -1,4 +1,4 @@ -__version__ = "0.7.16" +__version__ = "0.7.17" # import clients from letta.client.client import LocalClient, RESTClient, create_client diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index bc754de5..78bc5c62 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -8,10 +8,11 @@ from openai.types import CompletionUsage from openai.types.chat import ChatCompletion, ChatCompletionChunk from letta.agents.base_agent import BaseAgent -from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages +from letta.agents.helpers import _create_letta_response, _prepare_in_context_messages_async from letta.helpers import ToolRulesSolver from letta.helpers.tool_execution_helper import enable_strict_mode from letta.interfaces.anthropic_streaming_interface import AnthropicStreamingInterface +from letta.interfaces.openai_streaming_interface import OpenAIStreamingInterface from letta.llm_api.llm_client import LLMClient from letta.llm_api.llm_client_base import LLMClientBase from letta.local_llm.constants import INNER_THOUGHTS_KWARG @@ -61,12 +62,8 @@ class LettaAgent(BaseAgent): self.last_function_response = None # Cached archival memory/message size - self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) - self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id) - - # Cached archival memory/message size - self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) - self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_id) + self.num_messages = 0 + self.num_archival_memories = 0 @trace_method async def step(self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True) -> LettaResponse: @@ -81,7 +78,7 @@ class LettaAgent(BaseAgent): async def _step( self, agent_state: AgentState, input_messages: List[MessageCreate], max_steps: int = 10 ) -> Tuple[List[Message], List[Message], CompletionUsage]: - current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) @@ -129,14 +126,14 @@ class LettaAgent(BaseAgent): @trace_method async def step_stream( - self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True + self, input_messages: List[MessageCreate], max_steps: int = 10, use_assistant_message: bool = True, stream_tokens: bool = False ) -> AsyncGenerator[str, None]: """ Main streaming loop that yields partial tokens. Whenever we detect a tool call, we yield from _handle_ai_response as well. """ agent_state = await self.agent_manager.get_agent_by_id_async(self.agent_id, actor=self.actor) - current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + current_in_context_messages, new_in_context_messages = await _prepare_in_context_messages_async( input_messages, agent_state, self.message_manager, self.actor ) tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) @@ -157,9 +154,16 @@ class LettaAgent(BaseAgent): ) # TODO: THIS IS INCREDIBLY UGLY # TODO: THERE ARE MULTIPLE COPIES OF THE LLM_CONFIG EVERYWHERE THAT ARE GETTING MANIPULATED - interface = AnthropicStreamingInterface( - use_assistant_message=use_assistant_message, put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs - ) + if agent_state.llm_config.model_endpoint_type == "anthropic": + interface = AnthropicStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) + elif agent_state.llm_config.model_endpoint_type == "openai": + interface = OpenAIStreamingInterface( + use_assistant_message=use_assistant_message, + put_inner_thoughts_in_kwarg=agent_state.llm_config.put_inner_thoughts_in_kwargs, + ) async for chunk in interface.process(stream): yield f"data: {chunk.model_dump_json()}\n\n" @@ -197,8 +201,8 @@ class LettaAgent(BaseAgent): # TODO: This may be out of sync, if in between steps users add files # NOTE (cliandy): temporary for now for particlar use cases. - self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) - self.num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + self.num_messages = await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id) + self.num_archival_memories = await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) # TODO: Also yield out a letta usage stats SSE yield f"data: {usage.model_dump_json()}\n\n" @@ -215,6 +219,10 @@ class LettaAgent(BaseAgent): stream: bool, ) -> ChatCompletion | AsyncStream[ChatCompletionChunk]: if settings.experimental_enable_async_db_engine: + self.num_messages = self.num_messages or (await self.message_manager.size_async(actor=self.actor, agent_id=agent_state.id)) + self.num_archival_memories = self.num_archival_memories or ( + await self.passage_manager.size_async(actor=self.actor, agent_id=agent_state.id) + ) in_context_messages = await self._rebuild_memory_async( in_context_messages, agent_state, num_messages=self.num_messages, num_archival_memories=self.num_archival_memories ) diff --git a/letta/interfaces/openai_streaming_interface.py b/letta/interfaces/openai_streaming_interface.py new file mode 100644 index 00000000..168d0521 --- /dev/null +++ b/letta/interfaces/openai_streaming_interface.py @@ -0,0 +1,303 @@ +from datetime import datetime, timezone +from typing import AsyncGenerator, List, Optional + +from openai import AsyncStream +from openai.types.chat.chat_completion_chunk import ChatCompletionChunk + +from letta.constants import DEFAULT_MESSAGE_TOOL, DEFAULT_MESSAGE_TOOL_KWARG +from letta.schemas.letta_message import AssistantMessage, LettaMessage, ReasoningMessage, ToolCallDelta, ToolCallMessage +from letta.schemas.letta_message_content import TextContent +from letta.schemas.message import Message +from letta.schemas.openai.chat_completion_response import FunctionCall, ToolCall +from letta.server.rest_api.json_parser import OptimisticJSONParser +from letta.streaming_utils import JSONInnerThoughtsExtractor + + +class OpenAIStreamingInterface: + """ + Encapsulates the logic for streaming responses from OpenAI. + This class handles parsing of partial tokens, pre-execution messages, + and detection of tool call events. + """ + + def __init__(self, use_assistant_message: bool = False, put_inner_thoughts_in_kwarg: bool = False): + self.use_assistant_message = use_assistant_message + self.assistant_message_tool_name = DEFAULT_MESSAGE_TOOL + self.assistant_message_tool_kwarg = DEFAULT_MESSAGE_TOOL_KWARG + + self.optimistic_json_parser: OptimisticJSONParser = OptimisticJSONParser() + self.function_args_reader = JSONInnerThoughtsExtractor(wait_for_first_key=True) # TODO: pass in kward + self.function_name_buffer = None + self.function_args_buffer = None + self.function_id_buffer = None + self.last_flushed_function_name = None + + # Buffer to hold function arguments until inner thoughts are complete + self.current_function_arguments = "" + self.current_json_parse_result = {} + + # Premake IDs for database writes + self.letta_assistant_message_id = Message.generate_id() + self.letta_tool_message_id = Message.generate_id() + + # token counters + self.input_tokens = 0 + self.output_tokens = 0 + + self.content_buffer: List[str] = [] + self.tool_call_name: Optional[str] = None + self.tool_call_id: Optional[str] = None + self.reasoning_messages = [] + + def get_reasoning_content(self) -> List[TextContent]: + content = "".join(self.reasoning_messages) + return [TextContent(text=content)] + + def get_tool_call_object(self) -> ToolCall: + """Useful for agent loop""" + return ToolCall( + id=self.letta_tool_message_id, + function=FunctionCall(arguments=self.current_function_arguments, name=self.last_flushed_function_name), + ) + + async def process(self, stream: AsyncStream[ChatCompletionChunk]) -> AsyncGenerator[LettaMessage, None]: + """ + Iterates over the OpenAI stream, yielding SSE events. + It also collects tokens and detects if a tool call is triggered. + """ + async with stream: + prev_message_type = None + message_index = 0 + async for chunk in stream: + # track usage + if chunk.usage: + self.input_tokens += len(chunk.usage.prompt_tokens) + self.output_tokens += len(chunk.usage.completion_tokens) + + if chunk.choices: + choice = chunk.choices[0] + message_delta = choice.delta + + if message_delta.tool_calls is not None and len(message_delta.tool_calls) > 0: + tool_call = message_delta.tool_calls[0] + + if tool_call.function.name: + # If we're waiting for the first key, then we should hold back the name + # ie add it to a buffer instead of returning it as a chunk + if self.function_name_buffer is None: + self.function_name_buffer = tool_call.function.name + else: + self.function_name_buffer += tool_call.function.name + + if tool_call.id: + # Buffer until next time + if self.function_id_buffer is None: + self.function_id_buffer = tool_call.id + else: + self.function_id_buffer += tool_call.id + + if tool_call.function.arguments: + # updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment(tool_call.function.arguments) + self.current_function_arguments += tool_call.function.arguments + updates_main_json, updates_inner_thoughts = self.function_args_reader.process_fragment( + tool_call.function.arguments + ) + + # If we have inner thoughts, we should output them as a chunk + if updates_inner_thoughts: + if prev_message_type and prev_message_type != "reasoning_message": + message_index += 1 + self.reasoning_messages.append(updates_inner_thoughts) + reasoning_message = ReasoningMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + reasoning=updates_inner_thoughts, + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = reasoning_message.message_type + yield reasoning_message + + # Additionally inner thoughts may stream back with a chunk of main JSON + # In that case, since we can only return a chunk at a time, we should buffer it + if updates_main_json: + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If we have main_json, we should output a ToolCallMessage + elif updates_main_json: + + # If there's something in the function_name buffer, we should release it first + # NOTE: we could output it as part of a chunk that has both name and args, + # however the frontend may expect name first, then args, so to be + # safe we'll output name first in a separate chunk + if self.function_name_buffer: + + # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." + if self.use_assistant_message and self.function_name_buffer == self.assistant_message_tool_name: + + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + + else: + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + self.tool_call_name = str(self.function_name_buffer) + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=self.function_name_buffer, + arguments=None, + tool_call_id=self.function_id_buffer, + ), + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + + # Record what the last function name we flushed was + self.last_flushed_function_name = self.function_name_buffer + # Clear the buffer + self.function_name_buffer = None + self.function_id_buffer = None + # Since we're clearing the name buffer, we should store + # any updates to the arguments inside a separate buffer + + # Add any main_json updates to the arguments buffer + if self.function_args_buffer is None: + self.function_args_buffer = updates_main_json + else: + self.function_args_buffer += updates_main_json + + # If there was nothing in the name buffer, we can proceed to + # output the arguments chunk as a ToolCallMessage + else: + + # use_assisitant_message means that we should also not release main_json raw, and instead should only release the contents of "message": "..." + if self.use_assistant_message and ( + self.last_flushed_function_name is not None + and self.last_flushed_function_name == self.assistant_message_tool_name + ): + # do an additional parse on the updates_main_json + if self.function_args_buffer: + updates_main_json = self.function_args_buffer + updates_main_json + self.function_args_buffer = None + + # Pretty gross hardcoding that assumes that if we're toggling into the keywords, we have the full prefix + match_str = '{"' + self.assistant_message_tool_kwarg + '":"' + if updates_main_json == match_str: + updates_main_json = None + + else: + # Some hardcoding to strip off the trailing "}" + if updates_main_json in ["}", '"}']: + updates_main_json = None + if updates_main_json and len(updates_main_json) > 0 and updates_main_json[-1:] == '"': + updates_main_json = updates_main_json[:-1] + + if not updates_main_json: + # early exit to turn into content mode + continue + + # There may be a buffer from a previous chunk, for example + # if the previous chunk had arguments but we needed to flush name + if self.function_args_buffer: + # In this case, we should release the buffer + new data at once + combined_chunk = self.function_args_buffer + updates_main_json + + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 + assistant_message = AssistantMessage( + id=self.letta_assistant_message_id, + date=datetime.now(timezone.utc), + content=combined_chunk, + otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + ) + prev_message_type = assistant_message.message_type + yield assistant_message + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + # clear buffer + self.function_args_buffer = None + self.function_id_buffer = None + + else: + # If there's no buffer to clear, just output a new chunk with new data + # TODO: THIS IS HORRIBLE + # TODO: WE USE THE OLD JSON PARSER EARLIER (WHICH DOES NOTHING) AND NOW THE NEW JSON PARSER + # TODO: THIS IS TOTALLY WRONG AND BAD, BUT SAVING FOR A LARGER REWRITE IN THE NEAR FUTURE + parsed_args = self.optimistic_json_parser.parse(self.current_function_arguments) + + if parsed_args.get(self.assistant_message_tool_kwarg) and parsed_args.get( + self.assistant_message_tool_kwarg + ) != self.current_json_parse_result.get(self.assistant_message_tool_kwarg): + new_content = parsed_args.get(self.assistant_message_tool_kwarg) + prev_content = self.current_json_parse_result.get(self.assistant_message_tool_kwarg, "") + # TODO: Assumes consistent state and that prev_content is subset of new_content + diff = new_content.replace(prev_content, "", 1) + self.current_json_parse_result = parsed_args + if prev_message_type and prev_message_type != "assistant_message": + message_index += 1 + assistant_message = AssistantMessage( + id=self.letta_assistant_message_id, + date=datetime.now(timezone.utc), + content=diff, + # name=name, + otid=Message.generate_otid_from_id(self.letta_assistant_message_id, message_index), + ) + prev_message_type = assistant_message.message_type + yield assistant_message + + # Store the ID of the tool call so allow skipping the corresponding response + if self.function_id_buffer: + self.prev_assistant_message_id = self.function_id_buffer + # clear buffers + self.function_id_buffer = None + else: + + # There may be a buffer from a previous chunk, for example + # if the previous chunk had arguments but we needed to flush name + if self.function_args_buffer: + # In this case, we should release the buffer + new data at once + combined_chunk = self.function_args_buffer + updates_main_json + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=combined_chunk, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + # clear buffer + self.function_args_buffer = None + self.function_id_buffer = None + else: + # If there's no buffer to clear, just output a new chunk with new data + if prev_message_type and prev_message_type != "tool_call_message": + message_index += 1 + tool_call_msg = ToolCallMessage( + id=self.letta_tool_message_id, + date=datetime.now(timezone.utc), + tool_call=ToolCallDelta( + name=None, + arguments=updates_main_json, + tool_call_id=self.function_id_buffer, + ), + # name=name, + otid=Message.generate_otid_from_id(self.letta_tool_message_id, message_index), + ) + prev_message_type = tool_call_msg.message_type + yield tool_call_msg + self.function_id_buffer = None diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index dda47c6c..d167e5e9 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -745,6 +745,17 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): self.is_deleted = True return self.update(db_session) + @handle_db_timeout + async def delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> "SqlalchemyBase": + """Soft delete a record asynchronously (mark as deleted).""" + logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)") + + if actor: + self._set_created_and_updated_by_fields(actor.id) + + self.is_deleted = True + return await self.update_async(db_session) + @handle_db_timeout def hard_delete(self, db_session: "Session", actor: Optional["User"] = None) -> None: """Permanently removes the record from the database.""" @@ -761,6 +772,20 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): else: logger.debug(f"{self.__class__.__name__} with ID {self.id} successfully hard deleted") + @handle_db_timeout + async def hard_delete_async(self, db_session: "AsyncSession", actor: Optional["User"] = None) -> None: + """Permanently removes the record from the database asynchronously.""" + logger.debug(f"Hard deleting {self.__class__.__name__} with ID: {self.id} with actor={actor} (async)") + + async with db_session as session: + try: + await session.delete(self) + await session.commit() + except Exception as e: + await session.rollback() + logger.exception(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}") + raise ValueError(f"Failed to hard delete {self.__class__.__name__} with ID {self.id}: {e}") + @handle_db_timeout def update(self, db_session: Session, actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": logger.debug(...) @@ -793,6 +818,39 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): await db_session.refresh(self) return self + @classmethod + def _size_preprocess( + cls, + *, + db_session: "Session", + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + **kwargs, + ): + logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}") + query = select(func.count()).select_from(cls) + + if actor: + query = cls.apply_access_predicate(query, actor, access, access_type) + + # Apply filtering logic based on kwargs + for key, value in kwargs.items(): + if value: + column = getattr(cls, key, None) + if not column: + raise AttributeError(f"{cls.__name__} has no attribute '{key}'") + if isinstance(value, (list, tuple, set)): # Check for iterables + query = query.where(column.in_(value)) + else: # Single value for equality filtering + query = query.where(column == value) + + # Handle soft deletes if the class has the 'is_deleted' attribute + if hasattr(cls, "is_deleted"): + query = query.where(cls.is_deleted == False) + + return query + @classmethod @handle_db_timeout def size( @@ -817,28 +875,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): Raises: DBAPIError: If a database error occurs """ - logger.debug(f"Calculating size for {cls.__name__} with filters {kwargs}") - with db_session as session: - query = select(func.count()).select_from(cls) - - if actor: - query = cls.apply_access_predicate(query, actor, access, access_type) - - # Apply filtering logic based on kwargs - for key, value in kwargs.items(): - if value: - column = getattr(cls, key, None) - if not column: - raise AttributeError(f"{cls.__name__} has no attribute '{key}'") - if isinstance(value, (list, tuple, set)): # Check for iterables - query = query.where(column.in_(value)) - else: # Single value for equality filtering - query = query.where(column == value) - - # Handle soft deletes if the class has the 'is_deleted' attribute - if hasattr(cls, "is_deleted"): - query = query.where(cls.is_deleted == False) + query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs) try: count = session.execute(query).scalar() @@ -847,6 +885,37 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): logger.exception(f"Failed to calculate size for {cls.__name__}") raise e + @classmethod + @handle_db_timeout + async def size_async( + cls, + *, + db_session: "AsyncSession", + actor: Optional["User"] = None, + access: Optional[List[Literal["read", "write", "admin"]]] = ["read"], + access_type: AccessType = AccessType.ORGANIZATION, + **kwargs, + ) -> int: + """ + Get the count of rows that match the provided filters. + Args: + db_session: SQLAlchemy session + **kwargs: Filters to apply to the query (e.g., column_name=value) + Returns: + int: The count of rows that match the filters + Raises: + DBAPIError: If a database error occurs + """ + async with db_session as session: + query = cls._size_preprocess(db_session=session, actor=actor, access=access, access_type=access_type, **kwargs) + + try: + count = await session.execute(query).scalar() + return count if count else 0 + except DBAPIError as e: + logger.exception(f"Failed to calculate size for {cls.__name__}") + raise e + @classmethod def apply_access_predicate( cls, diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 4f56be88..96f153f3 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -83,7 +83,7 @@ async def list_agents( """ # Retrieve the actor (user) details - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Call list_agents directly without unnecessary dict handling return await server.agent_manager.list_agents_async( @@ -163,7 +163,7 @@ async def import_agent_serialized( """ Import a serialized agent file and recreate the agent in the system. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: serialized_data = await file.read() @@ -233,7 +233,7 @@ async def create_agent( Create a new agent with the specified configuration. """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.create_agent_async(agent, actor=actor) except Exception as e: traceback.print_exc() @@ -248,7 +248,7 @@ async def modify_agent( actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Update an existing agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.update_agent_async(agent_id=agent_id, request=update_agent, actor=actor) @@ -333,7 +333,7 @@ def detach_source( @router.get("/{agent_id}", response_model=AgentState, operation_id="retrieve_agent") -def retrieve_agent( +async def retrieve_agent( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -344,7 +344,7 @@ def retrieve_agent( actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - return server.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor) + return await server.agent_manager.get_agent_by_id_async(agent_id=agent_id, actor=actor) except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -414,7 +414,7 @@ def retrieve_block( @router.get("/{agent_id}/core-memory/blocks", response_model=List[Block], operation_id="list_core_memory_blocks") -def list_blocks( +async def list_blocks( agent_id: str, server: "SyncServer" = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present @@ -424,7 +424,7 @@ def list_blocks( """ actor = server.user_manager.get_user_or_default(user_id=actor_id) try: - agent = server.agent_manager.get_agent_by_id(agent_id, actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) return agent.memory.blocks except NoResultFound as e: raise HTTPException(status_code=404, detail=str(e)) @@ -628,9 +628,9 @@ async def send_message( Process a user message and return the agent's response. This endpoint accepts a message from a user and processes it through the agent. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon - agent = server.agent_manager.get_agent_by_id(agent_id, actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" @@ -686,13 +686,13 @@ async def send_message_streaming( It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ request_start_timestamp_ns = get_utc_timestamp_ns() - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # TODO: This is redundant, remove soon - agent = server.agent_manager.get_agent_by_id(agent_id, actor) + agent = await server.agent_manager.get_agent_by_id_async(agent_id, actor) agent_eligible = not agent.enable_sleeptime and not agent.multi_agent_group and agent.agent_type != AgentType.sleeptime_agent experimental_header = request_obj.headers.get("X-EXPERIMENTAL") or "false" feature_enabled = settings.use_experimental or experimental_header.lower() == "true" - model_compatible = agent.llm_config.model_endpoint_type == "anthropic" + model_compatible = agent.llm_config.model_endpoint_type in ["anthropic", "openai"] if agent_eligible and feature_enabled and model_compatible and request.stream_tokens: experimental_agent = LettaAgent( @@ -705,7 +705,9 @@ async def send_message_streaming( ) result = StreamingResponse( - experimental_agent.step_stream(request.messages, max_steps=10, use_assistant_message=request.use_assistant_message), + experimental_agent.step_stream( + request.messages, max_steps=10, use_assistant_message=request.use_assistant_message, stream_tokens=request.stream_tokens + ), media_type="text/event-stream", ) else: @@ -784,7 +786,7 @@ async def send_message_async( Asynchronously process a user message and return a run object. The actual processing happens in the background, and the status can be checked using the run ID. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # Create a new job run = Run( @@ -838,6 +840,6 @@ async def list_agent_groups( actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): """Lists the groups for an agent""" - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) print("in list agents with manager_type", manager_type) return server.agent_manager.list_groups(agent_id=agent_id, manager_type=manager_type, actor=actor) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 4a9ea8da..c9506906 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -26,7 +26,7 @@ async def list_blocks( server: SyncServer = Depends(get_letta_server), actor_id: Optional[str] = Header(None, alias="user_id"), # Extract user_id from header, default to None if not present ): - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) return await server.block_manager.get_blocks_async( actor=actor, label=label, diff --git a/letta/server/rest_api/routers/v1/groups.py b/letta/server/rest_api/routers/v1/groups.py index 3ed71153..c6c6fb12 100644 --- a/letta/server/rest_api/routers/v1/groups.py +++ b/letta/server/rest_api/routers/v1/groups.py @@ -135,7 +135,7 @@ async def send_group_message( Process a user message and return the group's response. This endpoint accepts a message from a user and processes it through through agents in the group based on the specified pattern """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, @@ -174,7 +174,7 @@ async def send_group_message_streaming( This endpoint accepts a message from a user and processes it through agents in the group based on the specified pattern. It will stream the steps of the response always, and stream the tokens if 'stream_tokens' is set to True. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) result = await server.send_group_message_to_agent( group_id=group_id, actor=actor, diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index fe5e0f91..4d7d3588 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -52,7 +52,7 @@ async def create_messages_batch( detail=f"Server misconfiguration: LETTA_ENABLE_BATCH_JOB_POLLING is set to False.", ) - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) batch_job = BatchJob( user_id=actor.id, status=JobStatus.running, @@ -100,7 +100,7 @@ async def retrieve_batch_run( """ Get the status of a batch run. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor) @@ -118,7 +118,7 @@ async def list_batch_runs( List all batch runs. """ # TODO: filter - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) jobs = server.job_manager.list_jobs(actor=actor, statuses=[JobStatus.created, JobStatus.running], job_type=JobType.BATCH) return [BatchJob.from_job(job) for job in jobs] @@ -150,7 +150,7 @@ async def list_batch_messages( - For subsequent pages, use the ID of the last message from the previous response as the cursor - Results will include messages before/after the cursor based on sort_descending """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) # First, verify the batch job exists and the user has access to it try: @@ -177,7 +177,7 @@ async def cancel_batch_run( """ Cancel a batch run. """ - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor) diff --git a/letta/server/rest_api/routers/v1/runs.py b/letta/server/rest_api/routers/v1/runs.py index fd7e5131..8a8793a3 100644 --- a/letta/server/rest_api/routers/v1/runs.py +++ b/letta/server/rest_api/routers/v1/runs.py @@ -115,7 +115,7 @@ async def list_run_messages( if order not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: messages = server.job_manager.get_run_messages( @@ -182,7 +182,7 @@ async def list_run_steps( if order not in ["asc", "desc"]: raise HTTPException(status_code=400, detail="Order must be 'asc' or 'desc'") - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) try: steps = server.job_manager.get_job_steps( diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index bd5dd80e..ce8acc46 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -87,7 +87,7 @@ async def list_tools( Get a list of all tools available to agents belonging to the org of the user """ try: - actor = server.user_manager.get_user_or_default(user_id=actor_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id) if name is not None: tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor) return [tool] if tool else [] diff --git a/letta/server/rest_api/routers/v1/users.py b/letta/server/rest_api/routers/v1/users.py index bf2de7ef..4b4bfd91 100644 --- a/letta/server/rest_api/routers/v1/users.py +++ b/letta/server/rest_api/routers/v1/users.py @@ -14,7 +14,7 @@ router = APIRouter(prefix="/users", tags=["users", "admin"]) @router.get("/", tags=["admin"], response_model=List[User], operation_id="list_users") -def list_users( +async def list_users( after: Optional[str] = Query(None), limit: Optional[int] = Query(50), server: "SyncServer" = Depends(get_letta_server), @@ -23,7 +23,7 @@ def list_users( Get a list of all users in the database """ try: - users = server.user_manager.list_users(after=after, limit=limit) + users = await server.user_manager.list_actors_async(after=after, limit=limit) except HTTPException: raise except Exception as e: @@ -32,7 +32,7 @@ def list_users( @router.post("/", tags=["admin"], response_model=User, operation_id="create_user") -def create_user( +async def create_user( request: UserCreate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): @@ -40,33 +40,33 @@ def create_user( Create a new user in the database """ user = User(**request.model_dump()) - user = server.user_manager.create_user(user) + user = await server.user_manager.create_actor_async(user) return user @router.put("/", tags=["admin"], response_model=User, operation_id="update_user") -def update_user( +async def update_user( user: UserUpdate = Body(...), server: "SyncServer" = Depends(get_letta_server), ): """ Update a user in the database """ - user = server.user_manager.update_user(user) + user = await server.user_manager.update_actor_async(user) return user @router.delete("/", tags=["admin"], response_model=User, operation_id="delete_user") -def delete_user( +async def delete_user( user_id: str = Query(..., description="The user_id key to be deleted."), server: "SyncServer" = Depends(get_letta_server), ): # TODO make a soft deletion, instead of a hard deletion try: - user = server.user_manager.get_user_by_id(user_id=user_id) + user = await server.user_manager.get_actor_by_id_async(actor_id=user_id) if user is None: raise HTTPException(status_code=404, detail=f"User does not exist") - server.user_manager.delete_user_by_id(user_id=user_id) + await server.user_manager.delete_actor_by_id_async(user_id=user_id) except HTTPException: raise except Exception as e: diff --git a/letta/server/rest_api/routers/v1/voice.py b/letta/server/rest_api/routers/v1/voice.py index 4517a1a0..694f8946 100644 --- a/letta/server/rest_api/routers/v1/voice.py +++ b/letta/server/rest_api/routers/v1/voice.py @@ -36,7 +36,7 @@ async def create_voice_chat_completions( server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), ): - actor = server.user_manager.get_user_or_default(user_id=user_id) + actor = await server.user_manager.get_actor_or_default_async(actor_id=user_id) # Create OpenAI async client client = openai.AsyncClient( diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index b861cd49..91cdffce 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1,3 +1,4 @@ +import asyncio from datetime import datetime, timezone from typing import Dict, List, Optional, Set, Tuple @@ -905,12 +906,7 @@ class AgentManager: result = await session.execute(query) agents = result.scalars().all() - pydantic_agents = [] - for agent in agents: - pydantic_agent = await agent.to_pydantic_async(include_relationships=include_relationships) - pydantic_agents.append(pydantic_agent) - - return pydantic_agents + return await asyncio.gather(*[agent.to_pydantic_async(include_relationships=include_relationships) for agent in agents]) @enforce_types def list_agents_matching_tags( @@ -1195,8 +1191,8 @@ class AgentManager: @enforce_types async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: - message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids - return await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor) + agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) + return await self.message_manager.get_messages_by_ids_async(message_ids=agent.message_ids, actor=actor) @enforce_types def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 426743bf..2cc13f3f 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -286,6 +286,21 @@ class MessageManager: with db_registry.session() as session: return MessageModel.size(db_session=session, actor=actor, role=role, agent_id=agent_id) + @enforce_types + async def size_async( + self, + actor: PydanticUser, + role: Optional[MessageRole] = None, + agent_id: Optional[str] = None, + ) -> int: + """Get the total count of messages with optional filters. + Args: + actor: The user requesting the count + role: The role of the message + """ + async with db_registry.async_session() as session: + return await MessageModel.size_async(db_session=session, actor=actor, role=role, agent_id=agent_id) + @enforce_types def list_user_messages_for_agent( self, diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 8d735d9b..3cd581b3 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -216,6 +216,20 @@ class PassageManager: with db_registry.session() as session: return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id) + @enforce_types + async def size_async( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + ) -> int: + """Get the total count of messages with optional filters. + Args: + actor: The user requesting the count + agent_id: The agent ID of the messages + """ + async with db_registry.async_session() as session: + return await AgentPassage.size_async(db_session=session, actor=actor, agent_id=agent_id) + def estimate_embeddings_size( self, actor: PydanticUser, diff --git a/letta/services/user_manager.py b/letta/services/user_manager.py index 9f6a72a5..b1c64100 100644 --- a/letta/services/user_manager.py +++ b/letta/services/user_manager.py @@ -44,6 +44,14 @@ class UserManager: new_user.create(session) return new_user.to_pydantic() + @enforce_types + async def create_actor_async(self, pydantic_user: PydanticUser) -> PydanticUser: + """Create a new user if it doesn't already exist (async version).""" + async with db_registry.async_session() as session: + new_user = UserModel(**pydantic_user.model_dump(to_orm=True)) + await new_user.create_async(session) + return new_user.to_pydantic() + @enforce_types def update_user(self, user_update: UserUpdate) -> PydanticUser: """Update user details.""" @@ -60,6 +68,22 @@ class UserManager: existing_user.update(session) return existing_user.to_pydantic() + @enforce_types + async def update_actor_async(self, user_update: UserUpdate) -> PydanticUser: + """Update user details (async version).""" + async with db_registry.async_session() as session: + # Retrieve the existing user by ID + existing_user = await UserModel.read_async(db_session=session, identifier=user_update.id) + + # Update only the fields that are provided in UserUpdate + update_data = user_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True) + for key, value in update_data.items(): + setattr(existing_user, key, value) + + # Commit the updated user + await existing_user.update_async(session) + return existing_user.to_pydantic() + @enforce_types def delete_user_by_id(self, user_id: str): """Delete a user and their associated records (agents, sources, mappings).""" @@ -70,6 +94,14 @@ class UserManager: session.commit() + @enforce_types + async def delete_actor_by_id_async(self, user_id: str): + """Delete a user and their associated records (agents, sources, mappings) asynchronously.""" + async with db_registry.async_session() as session: + # Delete from user table + user = await UserModel.read_async(db_session=session, identifier=user_id) + await user.hard_delete_async(session) + @enforce_types def get_user_by_id(self, user_id: str) -> PydanticUser: """Fetch a user by ID.""" @@ -77,6 +109,13 @@ class UserManager: user = UserModel.read(db_session=session, identifier=user_id) return user.to_pydantic() + @enforce_types + async def get_actor_by_id_async(self, actor_id: str) -> PydanticUser: + """Fetch a user by ID asynchronously.""" + async with db_registry.async_session() as session: + user = await UserModel.read_async(db_session=session, identifier=actor_id) + return user.to_pydantic() + @enforce_types def get_default_user(self) -> PydanticUser: """Fetch the default user. If it doesn't exist, create it.""" @@ -96,6 +135,26 @@ class UserManager: except NoResultFound: return self.get_default_user() + @enforce_types + async def get_default_actor_async(self) -> PydanticUser: + """Fetch the default user asynchronously. If it doesn't exist, create it.""" + try: + return await self.get_actor_by_id_async(self.DEFAULT_USER_ID) + except NoResultFound: + # Fall back to synchronous version since create_default_user isn't async yet + return self.create_default_user(org_id=self.DEFAULT_ORG_ID) + + @enforce_types + async def get_actor_or_default_async(self, actor_id: Optional[str] = None): + """Fetch the user or default user asynchronously.""" + if not actor_id: + return await self.get_default_actor_async() + + try: + return await self.get_actor_by_id_async(actor_id=actor_id) + except NoResultFound: + return await self.get_default_actor_async() + @enforce_types def list_users(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: """List all users with optional pagination.""" @@ -106,3 +165,14 @@ class UserManager: limit=limit, ) return [user.to_pydantic() for user in users] + + @enforce_types + async def list_actors_async(self, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticUser]: + """List all users with optional pagination (async version).""" + async with db_registry.async_session() as session: + users = await UserModel.list_async( + db_session=session, + after=after, + limit=limit, + ) + return [user.to_pydantic() for user in users] diff --git a/pyproject.toml b/pyproject.toml index 046a2bd0..31a6aa2f 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -1,6 +1,6 @@ [tool.poetry] name = "letta" -version = "0.7.16" +version = "0.7.17" packages = [ {include = "letta"}, ] diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index f928baf5..246611dd 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -4,7 +4,7 @@ from unittest.mock import MagicMock import pytest from dotenv import load_dotenv -from letta_client import Letta +from letta_client import AsyncLetta from openai import AsyncOpenAI from openai.types.chat import ChatCompletionChunk @@ -130,12 +130,12 @@ def server_url(): @pytest.fixture(scope="session") def client(server_url): """Creates a REST client for testing.""" - client = Letta(base_url=server_url) + client = AsyncLetta(base_url=server_url) yield client @pytest.fixture(scope="function") -def roll_dice_tool(client): +async def roll_dice_tool(client): def roll_dice(): """ Rolls a 6 sided die. @@ -145,13 +145,13 @@ def roll_dice_tool(client): """ return "Rolled a 10!" - tool = client.tools.upsert_from_function(func=roll_dice) + tool = await client.tools.upsert_from_function(func=roll_dice) # Yield the created tool yield tool @pytest.fixture(scope="function") -def weather_tool(client): +async def weather_tool(client): def get_weather(location: str) -> str: """ Fetches the current weather for a given location. @@ -176,7 +176,7 @@ def weather_tool(client): else: raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") - tool = client.tools.upsert_from_function(func=get_weather) + tool = await client.tools.upsert_from_function(func=get_weather) # Yield the created tool yield tool @@ -270,7 +270,7 @@ def _assert_valid_chunk(chunk, idx, chunks): @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"]) -async def test_model_compatibility(disable_e2b_api_key, client, model, server, group_id, actor): +async def test_model_compatibility(disable_e2b_api_key, voice_agent, model, server, group_id, actor): request = _get_chat_request("How are you?") server.tool_manager.upsert_base_tools(actor=actor) @@ -306,7 +306,7 @@ async def test_model_compatibility(disable_e2b_api_key, client, model, server, g @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."]) @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, message, endpoint): +async def test_voice_recall_memory(disable_e2b_api_key, voice_agent, message, endpoint): """Tests chat completion streaming using the Async OpenAI client.""" request = _get_chat_request(message) @@ -320,7 +320,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, mes @pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) -async def test_trigger_summarization(disable_e2b_api_key, client, server, voice_agent, group_id, endpoint, actor): +async def test_trigger_summarization(disable_e2b_api_key, server, voice_agent, group_id, endpoint, actor): server.group_manager.modify_group( group_id=group_id, group_update=GroupUpdate( @@ -423,7 +423,7 @@ async def test_summarization(disable_e2b_api_key, voice_agent): @pytest.mark.asyncio(loop_scope="session") -async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent): +async def test_voice_sleeptime_agent(disable_e2b_api_key, voice_agent): """Tests chat completion streaming using the Async OpenAI client.""" agent_manager = AgentManager() tool_manager = ToolManager() diff --git a/tests/test_managers.py b/tests/test_managers.py index 0be3d7a6..719f867d 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -124,16 +124,16 @@ def default_user(server: SyncServer, default_organization): @pytest.fixture -def other_user(server: SyncServer, default_organization): +async def other_user(server: SyncServer, default_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_user(PydanticUser(name="other", organization_id=default_organization.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=default_organization.id)) yield user @pytest.fixture -def other_user_different_org(server: SyncServer, other_organization): +async def other_user_different_org(server: SyncServer, other_organization): """Fixture to create and return the default user within the default organization.""" - user = server.user_manager.create_user(PydanticUser(name="other", organization_id=other_organization.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name="other", organization_id=other_organization.id)) yield user @@ -2160,20 +2160,21 @@ def test_passage_cascade_deletion( # ====================================================================================================================== # User Manager Tests # ====================================================================================================================== -def test_list_users(server: SyncServer): +@pytest.mark.asyncio +async def test_list_users(server: SyncServer, event_loop): # Create default organization org = server.organization_manager.create_default_organization() user_name = "user" - user = server.user_manager.create_user(PydanticUser(name=user_name, organization_id=org.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name=user_name, organization_id=org.id)) - users = server.user_manager.list_users() + users = await server.user_manager.list_actors_async() assert len(users) == 1 assert users[0].name == user_name # Delete it after - server.user_manager.delete_user_by_id(user.id) - assert len(server.user_manager.list_users()) == 0 + await server.user_manager.delete_actor_by_id_async(user.id) + assert len(await server.user_manager.list_actors_async()) == 0 def test_create_default_user(server: SyncServer): @@ -2183,7 +2184,8 @@ def test_create_default_user(server: SyncServer): assert retrieved.name == server.user_manager.DEFAULT_USER_NAME -def test_update_user(server: SyncServer): +@pytest.mark.asyncio +async def test_update_user(server: SyncServer, event_loop): # Create default organization default_org = server.organization_manager.create_default_organization() test_org = server.organization_manager.create_organization(PydanticOrganization(name="test_org")) @@ -2192,16 +2194,16 @@ def test_update_user(server: SyncServer): user_name_b = "b" # Assert it's been created - user = server.user_manager.create_user(PydanticUser(name=user_name_a, organization_id=default_org.id)) + user = await server.user_manager.create_actor_async(PydanticUser(name=user_name_a, organization_id=default_org.id)) assert user.name == user_name_a # Adjust name - user = server.user_manager.update_user(UserUpdate(id=user.id, name=user_name_b)) + user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, name=user_name_b)) assert user.name == user_name_b assert user.organization_id == OrganizationManager.DEFAULT_ORG_ID # Adjust org id - user = server.user_manager.update_user(UserUpdate(id=user.id, organization_id=test_org.id)) + user = await server.user_manager.update_actor_async(UserUpdate(id=user.id, organization_id=test_org.id)) assert user.name == user_name_b assert user.organization_id == test_org.id