From 7f90746152d43176c5e96ecf95aa5b453c602555 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 15 May 2025 12:09:40 -0700 Subject: [PATCH] feat: convert many methods to async (#2193) --- letta/agents/helpers.py | 3 +- letta/agents/letta_agent.py | 10 +- letta/agents/letta_agent_batch.py | 14 +- letta/client/client.py | 17 ++- letta/jobs/llm_batch_job_polling.py | 8 +- letta/orm/sqlalchemy_base.py | 63 ++++++++ letta/server/rest_api/routers/v1/tools.py | 6 +- letta/services/agent_manager.py | 5 + letta/services/job_manager.py | 13 ++ letta/services/llm_batch_manager.py | 55 +++---- letta/services/message_manager.py | 41 ++++-- letta/services/tool_manager.py | 16 +- poetry.lock | 10 +- pyproject.toml | 2 +- tests/integration_test_batch_api_cron_jobs.py | 30 ++-- tests/integration_test_voice_agent.py | 19 +-- tests/test_letta_agent_batch.py | 55 +++---- tests/test_managers.py | 137 +++++++++++------- 18 files changed, 326 insertions(+), 178 deletions(-) diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index 17d9e676..5578d1fb 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -92,8 +92,7 @@ async def _prepare_in_context_messages_async( current_in_context_messages = await message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=actor) # Create a new user message from the input and store it - # TODO: make this async - new_in_context_messages = message_manager.create_many_messages( + new_in_context_messages = await message_manager.create_many_messages_async( create_input_messages(input_messages=input_messages, agent_id=agent_state.id, actor=actor), actor=actor ) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 62e6dfe0..1010540d 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -58,7 +58,7 @@ class LettaAgent(BaseAgent): self.passage_manager = passage_manager self.response_messages: List[Message] = [] - self.last_function_response = self._load_last_function_response() + self.last_function_response = None # Cached archival memory/message size self.num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_id) @@ -237,6 +237,8 @@ class LettaAgent(BaseAgent): ] # Mirror the sync agent loop: get allowed tools or allow all if none are allowed + if self.last_function_response is None: + self.last_function_response = await self._load_last_function_response_async() valid_tool_names = tool_rules_solver.get_allowed_tool_names( available_tools=set([t.name for t in tools]), last_function_response=self.last_function_response, @@ -330,7 +332,7 @@ class LettaAgent(BaseAgent): pre_computed_assistant_message_id=pre_computed_assistant_message_id, pre_computed_tool_message_id=pre_computed_tool_message_id, ) - persisted_messages = self.message_manager.create_many_messages(tool_call_messages, actor=self.actor) + persisted_messages = await self.message_manager.create_many_messages_async(tool_call_messages, actor=self.actor) self.last_function_response = function_response return persisted_messages, continue_stepping @@ -416,9 +418,9 @@ class LettaAgent(BaseAgent): results = await asyncio.gather(*tasks) return results - def _load_last_function_response(self): + async def _load_last_function_response_async(self): """Load the last function response from message history""" - in_context_messages = self.agent_manager.get_in_context_messages(agent_id=self.agent_id, actor=self.actor) + in_context_messages = await self.agent_manager.get_in_context_messages_async(agent_id=self.agent_id, actor=self.actor) for msg in reversed(in_context_messages): if msg.role == MessageRole.tool and msg.content and len(msg.content) == 1 and isinstance(msg.content[0], TextContent): text_content = msg.content[0].text diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index f844876e..201cd565 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -196,7 +196,7 @@ class LettaAgentBatch(BaseAgent): ) log_event(name="persist_llm_batch_job") - llm_batch_job = self.batch_manager.create_llm_batch_job( + llm_batch_job = await self.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, # TODO: Expand to more providers create_batch_response=batch_response, actor=self.actor, @@ -214,7 +214,7 @@ class LettaAgentBatch(BaseAgent): if batch_items: log_event(name="bulk_create_batch_items") - batch_items_persisted = self.batch_manager.create_llm_batch_items_bulk(batch_items, actor=self.actor) + batch_items_persisted = await self.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=self.actor) log_event(name="return_batch_response") return LettaBatchResponse( @@ -229,7 +229,7 @@ class LettaAgentBatch(BaseAgent): @trace_method async def resume_step_after_request(self, letta_batch_id: str, llm_batch_id: str) -> LettaBatchResponse: log_event(name="load_context") - llm_batch_job = self.batch_manager.get_llm_batch_job_by_id(llm_batch_id=llm_batch_id, actor=self.actor) + llm_batch_job = await self.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=llm_batch_id, actor=self.actor) ctx = await self._collect_resume_context(llm_batch_id) log_event(name="update_statuses") @@ -239,7 +239,7 @@ class LettaAgentBatch(BaseAgent): exec_results = await self._execute_tools(ctx) log_event(name="persist_messages") - msg_map = self._persist_tool_messages(exec_results, ctx) + msg_map = await self._persist_tool_messages(exec_results, ctx) log_event(name="mark_steps_done") self._mark_steps_complete(llm_batch_id, ctx.agent_ids) @@ -266,7 +266,7 @@ class LettaAgentBatch(BaseAgent): @trace_method async def _collect_resume_context(self, llm_batch_id: str) -> _ResumeContext: # NOTE: We only continue for items with successful results - batch_items = self.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_id, request_status=JobStatus.completed) + batch_items = await self.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_id, request_status=JobStatus.completed) agent_ids, agent_state_map = [], {} provider_results, name_map, args_map, cont_map = {}, {}, {}, {} @@ -386,7 +386,7 @@ class LettaAgentBatch(BaseAgent): return result - def _persist_tool_messages( + async def _persist_tool_messages( self, exec_results: Sequence[Tuple[str, Tuple[str, bool]]], ctx: _ResumeContext, @@ -408,7 +408,7 @@ class LettaAgentBatch(BaseAgent): ) msg_map[aid] = msgs # flatten & persist - self.message_manager.create_many_messages([m for msgs in msg_map.values() for m in msgs], actor=self.actor) + await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor) return msg_map def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None: diff --git a/letta/client/client.py b/letta/client/client.py index 14fdc009..802ca451 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -1,3 +1,4 @@ +import asyncio import logging import sys import time @@ -3055,7 +3056,21 @@ class LocalClient(AbstractClient): Returns: tools (List[Tool]): List of tools """ - return self.server.tool_manager.list_tools(after=after, limit=limit, actor=self.user) + # Get the current event loop or create a new one if there isn't one + try: + loop = asyncio.get_event_loop() + if loop.is_running(): + # We're in an async context but can't await - use a new loop via run_coroutine_threadsafe + concurrent_future = asyncio.run_coroutine_threadsafe( + self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit), loop + ) + return concurrent_future.result() + else: + # We have a loop but it's not running - we can just run the coroutine + return loop.run_until_complete(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) + except RuntimeError: + # No running event loop - create a new one with asyncio.run + return asyncio.run(self.server.tool_manager.list_tools_async(actor=self.user, after=after, limit=limit)) def get_tool(self, id: str) -> Optional[Tool]: """ diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index a1227475..e0f51dd5 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -180,7 +180,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo try: # 1. Retrieve running batch jobs - batches = server.batch_manager.list_running_llm_batches() + batches = await server.batch_manager.list_running_llm_batches_async() metrics.total_batches = len(batches) # TODO: Expand to more providers @@ -220,7 +220,11 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo ) # launch them all at once - tasks = [_resume(server.batch_manager.get_llm_batch_job_by_id(bid)) for bid, *_ in completed] + async def get_and_resume(batch_id): + batch = await server.batch_manager.get_llm_batch_job_by_id_async(batch_id) + return await _resume(batch) + + tasks = [get_and_resume(bid) for bid, *_ in completed] new_batch_responses = await asyncio.gather(*tasks, return_exceptions=True) return new_batch_responses diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 2e5abe99..dda47c6c 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -631,6 +631,24 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except (DBAPIError, IntegrityError) as e: self._handle_dbapi_error(e) + @handle_db_timeout + async def create_async(self, db_session: "AsyncSession", actor: Optional["User"] = None, no_commit: bool = False) -> "SqlalchemyBase": + """Async version of create function""" + logger.debug(f"Creating {self.__class__.__name__} with ID: {self.id} with actor={actor}") + + if actor: + self._set_created_and_updated_by_fields(actor.id) + try: + db_session.add(self) + if no_commit: + await db_session.flush() # no commit, just flush to get PK + else: + await db_session.commit() + await db_session.refresh(self) + return self + except (DBAPIError, IntegrityError) as e: + self._handle_dbapi_error(e) + @classmethod @handle_db_timeout def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]: @@ -672,6 +690,51 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): except (DBAPIError, IntegrityError) as e: cls._handle_dbapi_error(e) + @classmethod + @handle_db_timeout + async def batch_create_async( + cls, items: List["SqlalchemyBase"], db_session: "AsyncSession", actor: Optional["User"] = None + ) -> List["SqlalchemyBase"]: + """ + Async version of batch_create method. + Create multiple records in a single transaction for better performance. + Args: + items: List of model instances to create + db_session: AsyncSession session + actor: Optional user performing the action + Returns: + List of created model instances + """ + logger.debug(f"Async batch creating {len(items)} {cls.__name__} items with actor={actor}") + if not items: + return [] + + # Set created/updated by fields if actor is provided + if actor: + for item in items: + item._set_created_and_updated_by_fields(actor.id) + + try: + async with db_session as session: + session.add_all(items) + await session.flush() # Flush to generate IDs but don't commit yet + + # Collect IDs to fetch the complete objects after commit + item_ids = [item.id for item in items] + + await session.commit() + + # Re-query the objects to get them with relationships loaded + query = select(cls).where(cls.id.in_(item_ids)) + if hasattr(cls, "created_at"): + query = query.order_by(cls.created_at) + + result = await session.execute(query) + return list(result.scalars()) + + except (DBAPIError, IntegrityError) as e: + cls._handle_dbapi_error(e) + @handle_db_timeout def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": logger.debug(f"Soft deleting {self.__class__.__name__} with ID: {self.id} with actor={actor}") diff --git a/letta/server/rest_api/routers/v1/tools.py b/letta/server/rest_api/routers/v1/tools.py index a1c7591b..8c9aeac0 100644 --- a/letta/server/rest_api/routers/v1/tools.py +++ b/letta/server/rest_api/routers/v1/tools.py @@ -76,7 +76,7 @@ def retrieve_tool( @router.get("/", response_model=List[Tool], operation_id="list_tools") -def list_tools( +async def list_tools( after: Optional[str] = None, limit: Optional[int] = 50, name: Optional[str] = None, @@ -89,9 +89,9 @@ def list_tools( try: actor = server.user_manager.get_user_or_default(user_id=actor_id) if name is not None: - tool = server.tool_manager.get_tool_by_name(tool_name=name, actor=actor) + tool = await server.tool_manager.get_tool_by_name_async(tool_name=name, actor=actor) return [tool] if tool else [] - return server.tool_manager.list_tools(actor=actor, after=after, limit=limit) + return await server.tool_manager.list_tools_async(actor=actor, after=after, limit=limit) except Exception as e: # Log or print the full exception here for debugging print(f"Error occurred: {e}") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 9747ea60..54bb454c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -791,6 +791,11 @@ class AgentManager: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor) + @enforce_types + async def get_in_context_messages_async(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]: + message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids + return await self.message_manager.get_messages_by_ids_async(message_ids=message_ids, actor=actor) + @enforce_types def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids diff --git a/letta/services/job_manager.py b/letta/services/job_manager.py index 74576d4d..d92c817b 100644 --- a/letta/services/job_manager.py +++ b/letta/services/job_manager.py @@ -44,6 +44,19 @@ class JobManager: job.create(session, actor=actor) # Save job in the database return job.to_pydantic() + @enforce_types + async def create_job_async( + self, pydantic_job: Union[PydanticJob, PydanticRun, PydanticBatchJob], actor: PydanticUser + ) -> Union[PydanticJob, PydanticRun, PydanticBatchJob]: + """Create a new job based on the JobCreate schema.""" + async with db_registry.async_session() as session: + # Associate the job with the user + pydantic_job.user_id = actor.id + job_data = pydantic_job.model_dump(to_orm=True) + job = JobModel(**job_data) + await job.create_async(session, actor=actor) # Save job in the database + return job.to_pydantic() + @enforce_types def update_job_by_id(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob: """Update a job by its ID with the given JobUpdate object.""" diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 7d7b4b54..052e2bbe 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -2,7 +2,7 @@ import datetime from typing import Any, Dict, List, Optional, Tuple from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse -from sqlalchemy import desc, func, tuple_ +from sqlalchemy import desc, func, select, tuple_ from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.log import get_logger @@ -26,7 +26,7 @@ class LLMBatchManager: """Manager for handling both LLMBatchJob and LLMBatchItem operations.""" @enforce_types - def create_llm_batch_job( + async def create_llm_batch_job_async( self, llm_provider: ProviderType, create_batch_response: BetaMessageBatch, @@ -35,7 +35,7 @@ class LLMBatchManager: status: JobStatus = JobStatus.created, ) -> PydanticLLMBatchJob: """Create a new LLM batch job.""" - with db_registry.session() as session: + async with db_registry.async_session() as session: batch = LLMBatchJob( status=status, llm_provider=llm_provider, @@ -43,14 +43,14 @@ class LLMBatchManager: organization_id=actor.organization_id, letta_batch_job_id=letta_batch_job_id, ) - batch.create(session, actor=actor) + await batch.create_async(session, actor=actor) return batch.to_pydantic() @enforce_types - def get_llm_batch_job_by_id(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: + async def get_llm_batch_job_by_id_async(self, llm_batch_id: str, actor: Optional[PydanticUser] = None) -> PydanticLLMBatchJob: """Retrieve a single batch job by ID.""" - with db_registry.session() as session: - batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) + async with db_registry.async_session() as session: + batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor) return batch.to_pydantic() @enforce_types @@ -197,16 +197,16 @@ class LLMBatchManager: return [message.to_pydantic() for message in results] @enforce_types - def list_running_llm_batches(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: + async def list_running_llm_batches_async(self, actor: Optional[PydanticUser] = None) -> List[PydanticLLMBatchJob]: """Return all running LLM batch jobs, optionally filtered by actor's organization.""" - with db_registry.session() as session: - query = session.query(LLMBatchJob).filter(LLMBatchJob.status == JobStatus.running) + async with db_registry.async_session() as session: + query = select(LLMBatchJob).where(LLMBatchJob.status == JobStatus.running) if actor is not None: - query = query.filter(LLMBatchJob.organization_id == actor.organization_id) + query = query.where(LLMBatchJob.organization_id == actor.organization_id) - results = query.all() - return [batch.to_pydantic() for batch in results] + results = await session.execute(query) + return [batch.to_pydantic() for batch in results.scalars().all()] @enforce_types def create_llm_batch_item( @@ -234,7 +234,9 @@ class LLMBatchManager: return item.to_pydantic() @enforce_types - def create_llm_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]: + async def create_llm_batch_items_bulk_async( + self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser + ) -> List[PydanticLLMBatchItem]: """ Create multiple batch items in bulk for better performance. @@ -245,7 +247,7 @@ class LLMBatchManager: Returns: List of created batch items as Pydantic models """ - with db_registry.session() as session: + async with db_registry.async_session() as session: # Convert Pydantic models to ORM objects orm_items = [] for item in llm_batch_items: @@ -261,8 +263,7 @@ class LLMBatchManager: ) orm_items.append(orm_item) - # Use the batch_create method to create all items at once - created_items = LLMBatchItem.batch_create(orm_items, session, actor=actor) + created_items = await LLMBatchItem.batch_create_async(orm_items, session, actor=actor) # Convert back to Pydantic models return [item.to_pydantic() for item in created_items] @@ -300,7 +301,7 @@ class LLMBatchManager: return item.update(db_session=session, actor=actor).to_pydantic() @enforce_types - def list_llm_batch_items( + async def list_llm_batch_items_async( self, llm_batch_id: str, limit: Optional[int] = None, @@ -321,29 +322,29 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ - with db_registry.session() as session: - query = session.query(LLMBatchItem).filter(LLMBatchItem.llm_batch_id == llm_batch_id) + async with db_registry.async_session() as session: + query = select(LLMBatchItem).where(LLMBatchItem.llm_batch_id == llm_batch_id) if actor is not None: - query = query.filter(LLMBatchItem.organization_id == actor.organization_id) + query = query.where(LLMBatchItem.organization_id == actor.organization_id) # Additional optional filters if agent_id is not None: - query = query.filter(LLMBatchItem.agent_id == agent_id) + query = query.where(LLMBatchItem.agent_id == agent_id) if request_status is not None: - query = query.filter(LLMBatchItem.request_status == request_status) + query = query.where(LLMBatchItem.request_status == request_status) if step_status is not None: - query = query.filter(LLMBatchItem.step_status == step_status) + query = query.where(LLMBatchItem.step_status == step_status) if after is not None: - query = query.filter(LLMBatchItem.id > after) + query = query.where(LLMBatchItem.id > after) query = query.order_by(LLMBatchItem.id.asc()) if limit is not None: query = query.limit(limit) - results = query.all() - return [item.to_pydantic() for item in results] + results = await session.execute(query) + return [item.to_pydantic() for item in results.scalars()] def bulk_update_llm_batch_items( self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 0dda1cfe..95666fec 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -71,6 +71,16 @@ class MessageManager: msg.create(session, actor=actor) # Persist to database return msg.to_pydantic() + def _create_many_preprocess(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[MessageModel]: + # Create ORM model instances for all messages + orm_messages = [] + for pydantic_msg in pydantic_msgs: + # Set the organization id of the Pydantic message + pydantic_msg.organization_id = actor.organization_id + msg_data = pydantic_msg.model_dump(to_orm=True) + orm_messages.append(MessageModel(**msg_data)) + return orm_messages + @enforce_types def create_many_messages(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: """ @@ -83,23 +93,32 @@ class MessageManager: Returns: List of created Pydantic message models """ - if not pydantic_msgs: return [] - # Create ORM model instances for all messages - orm_messages = [] - for pydantic_msg in pydantic_msgs: - # Set the organization id of the Pydantic message - pydantic_msg.organization_id = actor.organization_id - msg_data = pydantic_msg.model_dump(to_orm=True) - orm_messages.append(MessageModel(**msg_data)) - - # Use the batch_create method for efficient creation + orm_messages = self._create_many_preprocess(pydantic_msgs, actor) with db_registry.session() as session: created_messages = MessageModel.batch_create(orm_messages, session, actor=actor) + return [msg.to_pydantic() for msg in created_messages] - # Convert back to Pydantic models + @enforce_types + async def create_many_messages_async(self, pydantic_msgs: List[PydanticMessage], actor: PydanticUser) -> List[PydanticMessage]: + """ + Create multiple messages in a single database transaction asynchronously. + + Args: + pydantic_msgs: List of Pydantic message models to create + actor: User performing the action + + Returns: + List of created Pydantic message models + """ + if not pydantic_msgs: + return [] + + orm_messages = self._create_many_preprocess(pydantic_msgs, actor) + async with db_registry.async_session() as session: + created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor) return [msg.to_pydantic() for msg in created_messages] @enforce_types diff --git a/letta/services/tool_manager.py b/letta/services/tool_manager.py index 5b0cff89..eebff5ea 100644 --- a/letta/services/tool_manager.py +++ b/letta/services/tool_manager.py @@ -115,6 +115,16 @@ class ToolManager: except NoResultFound: return None + @enforce_types + async def get_tool_by_name_async(self, tool_name: str, actor: PydanticUser) -> Optional[PydanticTool]: + """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" + try: + async with db_registry.async_session() as session: + tool = await ToolModel.read_async(db_session=session, name=tool_name, actor=actor) + return tool.to_pydantic() + except NoResultFound: + return None + @enforce_types def get_tool_id_by_name(self, tool_name: str, actor: PydanticUser) -> Optional[str]: """Retrieve a tool by its name and a user. We derive the organization from the user, and retrieve that tool.""" @@ -126,10 +136,10 @@ class ToolManager: return None @enforce_types - def list_tools(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: + async def list_tools_async(self, actor: PydanticUser, after: Optional[str] = None, limit: Optional[int] = 50) -> List[PydanticTool]: """List all tools with optional pagination.""" - with db_registry.session() as session: - tools = ToolModel.list( + async with db_registry.async_session() as session: + tools = await ToolModel.list_async( db_session=session, after=after, limit=limit, diff --git a/poetry.lock b/poetry.lock index c6402e43..6d001a9c 100644 --- a/poetry.lock +++ b/poetry.lock @@ -5533,19 +5533,19 @@ dev = ["argcomplete", "attrs (>=19.2)", "hypothesis (>=3.56)", "mock", "pygments [[package]] name = "pytest-asyncio" -version = "0.23.8" +version = "0.24.0" description = "Pytest support for asyncio" optional = true python-versions = ">=3.8" groups = ["main"] markers = "extra == \"dev\" or extra == \"all\"" files = [ - {file = "pytest_asyncio-0.23.8-py3-none-any.whl", hash = "sha256:50265d892689a5faefb84df80819d1ecef566eb3549cf915dfb33569359d1ce2"}, - {file = "pytest_asyncio-0.23.8.tar.gz", hash = "sha256:759b10b33a6dc61cce40a8bd5205e302978bbbcc00e279a8b61d9a6a3c82e4d3"}, + {file = "pytest_asyncio-0.24.0-py3-none-any.whl", hash = "sha256:a811296ed596b69bf0b6f3dc40f83bcaf341b155a269052d82efa2b25ac7037b"}, + {file = "pytest_asyncio-0.24.0.tar.gz", hash = "sha256:d081d828e576d85f875399194281e92bf8a68d60d72d1a2faf2feddb6c46b276"}, ] [package.dependencies] -pytest = ">=7.0.0,<9" +pytest = ">=8.2,<9" [package.extras] docs = ["sphinx (>=5.3)", "sphinx-rtd-theme (>=1.0)"] @@ -7570,4 +7570,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.1" python-versions = "<3.14,>=3.10" -content-hash = "7322eb70314d3f078fda8cf9b580d416f13663b399d4b0f6a9e0c4a8914808b5" +content-hash = "19eee9b3cd3d270cb748183bc332dd69706bb0bd3150c62e73e61ed437a40c78" diff --git a/pyproject.toml b/pyproject.toml index 77884854..0abde6b5 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -47,7 +47,7 @@ autoflake = {version = "^2.3.0", optional = true} python-multipart = "^0.0.19" sqlalchemy-utils = "^0.41.2" pytest-order = {version = "^1.2.0", optional = true} -pytest-asyncio = {version = "^0.23.2", optional = true} +pytest-asyncio = {version = "^0.24.0", optional = true} pydantic-settings = "^2.2.1" httpx-sse = "^0.4.0" isort = { version = "^5.13.2", optional = true } diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 39306568..4479b0dd 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -174,16 +174,16 @@ def create_test_agent(name, actor, test_id: Optional[str] = None, model="anthrop return agent_manager.create_agent(agent_create=agent_create, actor=actor, _test_only_force_id=test_id) -def create_test_letta_batch_job(server, default_user): +async def create_test_letta_batch_job_async(server, default_user): """Create a test batch job with the given batch response.""" - return server.job_manager.create_job(BatchJob(user_id=default_user.id), actor=default_user) + return await server.job_manager.create_job_async(BatchJob(user_id=default_user.id), actor=default_user) -def create_test_llm_batch_job(server, batch_response, default_user): +async def create_test_llm_batch_job_async(server, batch_response, default_user): """Create a test batch job with the given batch response.""" - letta_batch_job = create_test_letta_batch_job(server, default_user) + letta_batch_job = await create_test_letta_batch_job_async(server, default_user) - return server.batch_manager.create_llm_batch_job( + return await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=batch_response, actor=default_user, @@ -262,7 +262,7 @@ def mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b_id, agent_ # ----------------------------- # End-to-End Test # ----------------------------- -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_polling_simple_real_batch(client, default_user, server): # --- Step 1: Prepare test data --- # Create batch responses with different statuses @@ -276,7 +276,7 @@ async def test_polling_simple_real_batch(client, default_user, server): agent_c = create_test_agent("agent_c", default_user, test_id="agent-6156f470-a09d-4d51-aa62-7114e0971d56") # --- Step 2: Create batch jobs --- - job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) + job_a = await create_test_llm_batch_job_async(server, batch_a_resp, default_user) # --- Step 3: Create batch items --- item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) @@ -293,7 +293,7 @@ async def test_polling_simple_real_batch(client, default_user, server): await poll_running_llm_batches(server) # --- Step 5: Verify batch job status updates --- - updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) + updated_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) assert updated_job_a.status == JobStatus.completed @@ -403,7 +403,7 @@ async def test_polling_simple_real_batch(client, default_user, server): ) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_polling_mixed_batch_jobs(client, default_user, server): """ End-to-end test for polling batch jobs with mixed statuses and idempotency. @@ -433,8 +433,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): agent_c = create_test_agent("agent_c", default_user) # --- Step 2: Create batch jobs --- - job_a = create_test_llm_batch_job(server, batch_a_resp, default_user) - job_b = create_test_llm_batch_job(server, batch_b_resp, default_user) + job_a = await create_test_llm_batch_job_async(server, batch_a_resp, default_user) + job_b = await create_test_llm_batch_job_async(server, batch_b_resp, default_user) # --- Step 3: Create batch items --- item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) @@ -449,8 +449,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): await poll_running_llm_batches(server) # --- Step 6: Verify batch job status updates --- - updated_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) - updated_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user) + updated_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) + updated_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user) # Job A should remain running since its processing_status is "in_progress" assert updated_job_a.status == JobStatus.running @@ -498,8 +498,8 @@ async def test_polling_mixed_batch_jobs(client, default_user, server): # --- Step 9: Verify that nothing changed for completed jobs --- # Refresh all objects - final_job_a = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_a.id, actor=default_user) - final_job_b = server.batch_manager.get_llm_batch_job_by_id(llm_batch_id=job_b.id, actor=default_user) + final_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) + final_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user) final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) diff --git a/tests/integration_test_voice_agent.py b/tests/integration_test_voice_agent.py index 6ace4640..65bc5494 100644 --- a/tests/integration_test_voice_agent.py +++ b/tests/integration_test_voice_agent.py @@ -268,7 +268,7 @@ def _assert_valid_chunk(chunk, idx, chunks): # --- Tests --- # -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("model", ["openai/gpt-4o-mini", "anthropic/claude-3-5-sonnet-20241022"]) async def test_model_compatibility(disable_e2b_api_key, client, model, server, group_id, actor): request = _get_chat_request("How are you?") @@ -303,7 +303,7 @@ async def test_model_compatibility(disable_e2b_api_key, client, model, server, g print(chunk.choices[0].delta.content) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("message", ["Use search memory tool to recall what my name is."]) @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, message, endpoint): @@ -318,7 +318,7 @@ async def test_voice_recall_memory(disable_e2b_api_key, client, voice_agent, mes print(chunk.choices[0].delta.content) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") @pytest.mark.parametrize("endpoint", ["v1/voice-beta"]) async def test_trigger_summarization(disable_e2b_api_key, client, server, voice_agent, group_id, endpoint, actor): server.group_manager.modify_group( @@ -350,7 +350,7 @@ async def test_trigger_summarization(disable_e2b_api_key, client, server, voice_ print(chunk.choices[0].delta.content) -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_summarization(disable_e2b_api_key, voice_agent): agent_manager = AgentManager() user_manager = UserManager() @@ -422,16 +422,17 @@ async def test_summarization(disable_e2b_api_key, voice_agent): summarizer.fire_and_forget.assert_called_once() -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent): """Tests chat completion streaming using the Async OpenAI client.""" agent_manager = AgentManager() + tool_manager = ToolManager() user_manager = UserManager() actor = user_manager.get_default_user() - finish_rethinking_memory_tool = client.tools.list(name="finish_rethinking_memory")[0] - store_memories_tool = client.tools.list(name="store_memories")[0] - rethink_user_memory_tool = client.tools.list(name="rethink_user_memory")[0] + finish_rethinking_memory_tool = tool_manager.get_tool_by_name(tool_name="finish_rethinking_memory", actor=actor) + store_memories_tool = tool_manager.get_tool_by_name(tool_name="store_memories", actor=actor) + rethink_user_memory_tool = tool_manager.get_tool_by_name(tool_name="rethink_user_memory", actor=actor) request = CreateAgent( name=voice_agent.name + "-sleeptime", agent_type=AgentType.voice_sleeptime_agent, @@ -487,7 +488,7 @@ async def test_voice_sleeptime_agent(disable_e2b_api_key, client, voice_agent): assert not missing, f"Did not see calls to: {', '.join(missing)}" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_init_voice_convo_agent(voice_agent, server, actor): assert voice_agent.enable_sleeptime == True diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index b043b216..11da3a19 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -1,4 +1,3 @@ -import asyncio import os import threading from datetime import datetime, timezone @@ -165,14 +164,6 @@ def step_state_map(agents): return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents} -@pytest.fixture(scope="session") -def event_loop(request): - """Create an instance of the default event loop for each test case.""" - loop = asyncio.get_event_loop_policy().new_event_loop() - yield loop - loop.close() - - def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch: """Create a dummy BetaMessageBatch with the specified ID and status.""" now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc) @@ -377,7 +368,7 @@ class MockAsyncIterable: # --------------------------------------------------------------------------- # -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, server, default_user, batch_job, rethink_tool): target_block_label = "human" new_memory = "banana" @@ -459,9 +450,9 @@ async def test_rethink_tool_modify_agent_state(client, disable_e2b_api_key, serv assert block.value == new_memory -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_partial_error_from_anthropic_batch( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -518,13 +509,13 @@ async def test_partial_error_from_anthropic_batch( new_batch_responses = await poll_running_llm_batches(server) # Verify database records were updated correctly - llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user) + llm_batch_job = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_job.id, actor=default_user) # Verify job properties assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'" # Verify batch items - items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" # Verify only one new batch response @@ -542,7 +533,7 @@ async def test_partial_error_from_anthropic_batch( assert post_resume_response.agent_count == 2 # New batch‑items should exist, initialised in (created, paused) state - new_items = server.batch_manager.list_llm_batch_items( + new_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user ) assert len(new_items) == 2, f"Expected 2 new batch item, got {len(new_items)}" @@ -563,7 +554,7 @@ async def test_partial_error_from_anthropic_batch( # Old items must have been flipped to completed / finished earlier # (sanity – we already asserted this above, but we keep it close for clarity) - old_items = server.batch_manager.list_llm_batch_items( + old_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user ) for item in old_items: @@ -619,9 +610,9 @@ async def test_partial_error_from_anthropic_batch( assert agent_messages[0].role == MessageRole.user, "Expected initial user message" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_resume_step_some_stop( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -680,13 +671,13 @@ async def test_resume_step_some_stop( new_batch_responses = await poll_running_llm_batches(server) # Verify database records were updated correctly - llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user) + llm_batch_job = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_job.id, actor=default_user) # Verify job properties assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'" # Verify batch items - items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" assert all([item.request_status == JobStatus.completed for item in items]) @@ -705,7 +696,7 @@ async def test_resume_step_some_stop( assert post_resume_response.agent_count == 1 # New batch‑items should exist, initialised in (created, paused) state - new_items = server.batch_manager.list_llm_batch_items( + new_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user ) assert len(new_items) == 1, f"Expected 1 new batch item, got {len(new_items)}" @@ -726,7 +717,7 @@ async def test_resume_step_some_stop( # Old items must have been flipped to completed / finished earlier # (sanity – we already asserted this above, but we keep it close for clarity) - old_items = server.batch_manager.list_llm_batch_items( + old_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user ) assert {i.request_status for i in old_items} == {JobStatus.completed} @@ -782,9 +773,9 @@ def _assert_descending_order(messages): return True -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_resume_step_after_request_all_continue( - disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job, event_loop + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map, batch_job ): anthropic_batch_id = "msgbatch_test_12345" dummy_batch_response = create_batch_response( @@ -818,7 +809,7 @@ async def test_resume_step_after_request_all_continue( assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] - llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + llm_batch_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}" # 2. Invoke the polling job and mock responses from Anthropic @@ -840,13 +831,13 @@ async def test_resume_step_after_request_all_continue( new_batch_responses = await poll_running_llm_batches(server) # Verify database records were updated correctly - llm_batch_job = server.batch_manager.get_llm_batch_job_by_id(llm_batch_job.id, actor=default_user) + llm_batch_job = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_job.id, actor=default_user) # Verify job properties assert llm_batch_job.status == JobStatus.completed, "Job status should be 'completed'" # Verify batch items - items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" assert all([item.request_status == JobStatus.completed for item in items]) @@ -864,7 +855,7 @@ async def test_resume_step_after_request_all_continue( assert post_resume_response.agent_count == 3 # New batch‑items should exist, initialised in (created, paused) state - new_items = server.batch_manager.list_llm_batch_items( + new_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=post_resume_response.last_llm_batch_id, actor=default_user ) assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}" @@ -883,7 +874,7 @@ async def test_resume_step_after_request_all_continue( # Old items must have been flipped to completed / finished earlier # (sanity – we already asserted this above, but we keep it close for clarity) - old_items = server.batch_manager.list_llm_batch_items( + old_items = await server.batch_manager.list_llm_batch_items_async( llm_batch_id=pre_resume_response.last_llm_batch_id, actor=default_user ) assert {i.request_status for i in old_items} == {JobStatus.completed} @@ -920,9 +911,9 @@ async def test_resume_step_after_request_all_continue( assert agent_messages[-4].role == MessageRole.user, "Expected final system-level heartbeat user message" -@pytest.mark.asyncio +@pytest.mark.asyncio(loop_scope="session") async def test_step_until_request_prepares_and_submits_batch_correctly( - disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job, event_loop + disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response, batch_job ): """ Test that step_until_request correctly: @@ -1013,7 +1004,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] - llm_batch_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + llm_batch_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(llm_batch_items) == 3, f"Expected 3 llm_batch_items, got {len(llm_batch_items)}" # Verify job properties diff --git a/tests/test_managers.py b/tests/test_managers.py index afb5353c..2a05ba4c 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2218,9 +2218,10 @@ def test_get_tool_with_actor(server: SyncServer, print_tool, default_user): assert fetched_tool.tool_type == ToolType.CUSTOM -def test_list_tools(server: SyncServer, print_tool, default_user): +@pytest.mark.asyncio +async def test_list_tools(server: SyncServer, print_tool, default_user, event_loop): # List tools (should include the one created by the fixture) - tools = server.tool_manager.list_tools(actor=default_user) + tools = await server.tool_manager.list_tools_async(actor=default_user) # Assertions to check that the created tool is listed assert len(tools) == 1 @@ -2344,11 +2345,12 @@ def test_update_tool_multi_user(server: SyncServer, print_tool, default_user, ot assert updated_tool.created_by_id == default_user.id -def test_delete_tool_by_id(server: SyncServer, print_tool, default_user): +@pytest.mark.asyncio +async def test_delete_tool_by_id(server: SyncServer, print_tool, default_user, event_loop): # Delete the print_tool using the manager method server.tool_manager.delete_tool_by_id(print_tool.id, actor=default_user) - tools = server.tool_manager.list_tools(actor=default_user) + tools = await server.tool_manager.list_tools_async(actor=default_user) assert len(tools) == 0 @@ -4997,8 +4999,9 @@ def test_list_tags(server: SyncServer, default_user, default_organization): # ====================================================================================================================== -def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch, letta_batch_job): - batch = server.batch_manager.create_llm_batch_job( +@pytest.mark.asyncio +async def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop): + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5007,12 +5010,13 @@ def test_create_and_get_batch_request(server, default_user, dummy_beta_message_b ) assert batch.id.startswith("batch_req-") assert batch.create_batch_response == dummy_beta_message_batch - fetched = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) + fetched = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user) assert fetched.id == batch.id -def test_update_batch_status(server, default_user, dummy_beta_message_batch, letta_batch_job): - batch = server.batch_manager.create_llm_batch_job( +@pytest.mark.asyncio +async def test_update_batch_status(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop): + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5028,16 +5032,17 @@ def test_update_batch_status(server, default_user, dummy_beta_message_batch, let actor=default_user, ) - updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user) assert updated.status == JobStatus.completed assert updated.latest_polling_response == dummy_beta_message_batch assert updated.last_polled_at >= before -def test_create_and_get_batch_item( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_create_and_get_batch_item( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5061,7 +5066,8 @@ def test_create_and_get_batch_item( assert fetched.id == item.id -def test_update_batch_item( +@pytest.mark.asyncio +async def test_update_batch_item( server, default_user, sarah_agent, @@ -5070,8 +5076,9 @@ def test_update_batch_item( dummy_step_state, dummy_successful_response, letta_batch_job, + event_loop, ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5103,10 +5110,11 @@ def test_update_batch_item( assert updated.batch_request_result == dummy_successful_response -def test_delete_batch_item( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_delete_batch_item( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5128,8 +5136,9 @@ def test_delete_batch_item( server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) -def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job): - server.batch_manager.create_llm_batch_job( +@pytest.mark.asyncio +async def test_list_running_batches(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop): + await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.running, create_batch_response=dummy_beta_message_batch, @@ -5137,13 +5146,14 @@ def test_list_running_batches(server, default_user, dummy_beta_message_batch, le letta_batch_job_id=letta_batch_job.id, ) - running_batches = server.batch_manager.list_running_llm_batches(actor=default_user) + running_batches = await server.batch_manager.list_running_llm_batches_async(actor=default_user) assert len(running_batches) >= 1 assert all(batch.status == JobStatus.running for batch in running_batches) -def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job): - batch = server.batch_manager.create_llm_batch_job( +@pytest.mark.asyncio +async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch, letta_batch_job, event_loop): + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -5153,12 +5163,13 @@ def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_bat server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) - updated = server.batch_manager.get_llm_batch_job_by_id(batch.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user) assert updated.status == JobStatus.completed assert updated.latest_polling_response == dummy_beta_message_batch -def test_bulk_update_batch_items_results_by_agent( +@pytest.mark.asyncio +async def test_bulk_update_batch_items_results_by_agent( server, default_user, sarah_agent, @@ -5167,8 +5178,9 @@ def test_bulk_update_batch_items_results_by_agent( dummy_step_state, dummy_successful_response, letta_batch_job, + event_loop, ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5191,10 +5203,11 @@ def test_bulk_update_batch_items_results_by_agent( assert updated.batch_request_result == dummy_successful_response -def test_bulk_update_batch_items_step_status_by_agent( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_bulk_update_batch_items_step_status_by_agent( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5216,10 +5229,11 @@ def test_bulk_update_batch_items_step_status_by_agent( assert updated.step_status == AgentStepStatus.resumed -def test_list_batch_items_limit_and_filter( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_list_batch_items_limit_and_filter( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5235,18 +5249,19 @@ def test_list_batch_items_limit_and_filter( actor=default_user, ) - all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user) - limited_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, limit=2, actor=default_user) + all_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=batch.id, actor=default_user) + limited_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=batch.id, limit=2, actor=default_user) assert len(all_items) >= 3 assert len(limited_items) == 2 -def test_list_batch_items_pagination( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_list_batch_items_pagination( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): # Create a batch job. - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5266,7 +5281,7 @@ def test_list_batch_items_pagination( created_items.append(item) # Retrieve all items (without pagination). - all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user) + all_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=batch.id, actor=default_user) assert len(all_items) >= 10, f"Expected at least 10 items, got {len(all_items)}" # Verify the items are ordered ascending by id (based on our implementation). @@ -5278,7 +5293,7 @@ def test_list_batch_items_pagination( cursor = all_items[4].id # Retrieve items after the cursor. - paged_items = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor) + paged_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=batch.id, actor=default_user, after=cursor) # All returned items should have an id greater than the cursor. for item in paged_items: @@ -5292,7 +5307,9 @@ def test_list_batch_items_pagination( # Test pagination with a limit. limit = 3 - limited_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=cursor, limit=limit) + limited_page = await server.batch_manager.list_llm_batch_items_async( + llm_batch_id=batch.id, actor=default_user, after=cursor, limit=limit + ) # If more than 'limit' items remain, we should only get exactly 'limit' items. assert len(limited_page) == min( limit, expected_remaining @@ -5300,15 +5317,16 @@ def test_list_batch_items_pagination( # Optional: Test with a cursor beyond the last item returns an empty list. last_cursor = sorted_ids[-1] - empty_page = server.batch_manager.list_llm_batch_items(llm_batch_id=batch.id, actor=default_user, after=last_cursor) + empty_page = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=batch.id, actor=default_user, after=last_cursor) assert empty_page == [], "Expected an empty list when cursor is after the last item" -def test_bulk_update_batch_items_request_status_by_agent( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_bulk_update_batch_items_request_status_by_agent( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): # Create a batch job - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5334,15 +5352,17 @@ def test_bulk_update_batch_items_request_status_by_agent( assert updated.request_status == JobStatus.expired -def test_bulk_update_nonexistent_items_should_error( +@pytest.mark.asyncio +async def test_bulk_update_nonexistent_items_should_error( server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job, + event_loop, ): # Create a batch job - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5375,9 +5395,12 @@ def test_bulk_update_nonexistent_items_should_error( ) -def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job): +@pytest.mark.asyncio +async def test_bulk_update_nonexistent_items( + server, default_user, dummy_beta_message_batch, dummy_successful_response, letta_batch_job, event_loop +): # Create a batch job - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5410,11 +5433,12 @@ def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_ ) -def test_create_batch_items_bulk( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_create_batch_items_bulk( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): # Create a batch job - llm_batch_job = server.batch_manager.create_llm_batch_job( + llm_batch_job = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, create_batch_response=dummy_beta_message_batch, actor=default_user, @@ -5437,7 +5461,7 @@ def test_create_batch_items_bulk( batch_items.append(batch_item) # Call the bulk create function - created_items = server.batch_manager.create_llm_batch_items_bulk(batch_items, actor=default_user) + created_items = await server.batch_manager.create_llm_batch_items_bulk_async(batch_items, actor=default_user) # Verify the correct number of items were created assert len(created_items) == len(agent_ids) @@ -5453,7 +5477,7 @@ def test_create_batch_items_bulk( assert item.step_state == dummy_step_state # Verify items can be retrieved from the database - all_items = server.batch_manager.list_llm_batch_items(llm_batch_id=llm_batch_job.id, actor=default_user) + all_items = await server.batch_manager.list_llm_batch_items_async(llm_batch_id=llm_batch_job.id, actor=default_user) assert len(all_items) >= len(agent_ids) # Verify the IDs of created items match what's in the database @@ -5463,11 +5487,12 @@ def test_create_batch_items_bulk( assert fetched.id in created_ids -def test_count_batch_items( - server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job +@pytest.mark.asyncio +async def test_count_batch_items( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, letta_batch_job, event_loop ): # Create a batch job first. - batch = server.batch_manager.create_llm_batch_job( + batch = await server.batch_manager.create_llm_batch_job_async( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch,