diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index e2355ab5..10a50a58 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -233,7 +233,7 @@ class LettaAgentBatch(BaseAgent): ctx = await self._collect_resume_context(llm_batch_id) log_event(name="update_statuses") - self._update_request_statuses(ctx.request_status_updates) + await self._update_request_statuses_async(ctx.request_status_updates) log_event(name="exec_tools") exec_results = await self._execute_tools(ctx) @@ -242,7 +242,7 @@ class LettaAgentBatch(BaseAgent): msg_map = await self._persist_tool_messages(exec_results, ctx) log_event(name="mark_steps_done") - self._mark_steps_complete(llm_batch_id, ctx.agent_ids) + await self._mark_steps_complete_async(llm_batch_id, ctx.agent_ids) log_event(name="prepare_next") next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map) @@ -382,9 +382,9 @@ class LettaAgentBatch(BaseAgent): return self._extract_tool_call_and_decide_continue(tool_call, item.step_state) - def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: + async def _update_request_statuses_async(self, updates: List[RequestStatusUpdateInfo]) -> None: if updates: - self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent(updates=updates) + await self.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async(updates=updates) def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]: sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL @@ -474,11 +474,11 @@ class LettaAgentBatch(BaseAgent): await self.message_manager.create_many_messages_async([m for msgs in msg_map.values() for m in msgs], actor=self.actor) return msg_map - def _mark_steps_complete(self, llm_batch_id: str, agent_ids: List[str]) -> None: + async def _mark_steps_complete_async(self, llm_batch_id: str, agent_ids: List[str]) -> None: updates = [ StepStatusUpdateInfo(llm_batch_id=llm_batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids ] - self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent(updates) + await self.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async(updates) def _prepare_next_iteration( self, diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index e0f51dd5..401860e8 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -106,7 +106,7 @@ async def poll_batch_updates(server: SyncServer, batch_jobs: List[LLMBatchJob], results: List[BatchPollingResult] = await asyncio.gather(*coros) # Update the server with batch status changes - server.batch_manager.bulk_update_llm_batch_statuses(updates=results) + await server.batch_manager.bulk_update_llm_batch_statuses_async(updates=results) logger.info(f"[Poll BatchJob] Bulk-updated {len(results)} LLM batch(es) in the DB at job level.") return results @@ -197,13 +197,13 @@ async def poll_running_llm_batches(server: "SyncServer") -> List[LettaBatchRespo # 6. Bulk update all items for newly completed batch(es) if item_updates: metrics.updated_items_count = len(item_updates) - server.batch_manager.bulk_update_batch_llm_items_results_by_agent(item_updates) + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async(item_updates) # ─── Kick off post‑processing for each batch that just completed ─── completed = [r for r in batch_results if r.request_status == JobStatus.completed] async def _resume(batch_row: LLMBatchJob) -> LettaBatchResponse: - actor: User = server.user_manager.get_user_by_id(batch_row.created_by_id) + actor: User = await server.user_manager.get_actor_by_id_async(batch_row.created_by_id) runner = LettaAgentBatch( message_manager=server.message_manager, agent_manager=server.agent_manager, diff --git a/letta/jobs/scheduler.py b/letta/jobs/scheduler.py index 6e7dad00..7a6b105f 100644 --- a/letta/jobs/scheduler.py +++ b/letta/jobs/scheduler.py @@ -7,7 +7,7 @@ from apscheduler.triggers.interval import IntervalTrigger from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.log import get_logger -from letta.server.db import db_context +from letta.server.db import db_registry from letta.server.server import SyncServer from letta.settings import settings @@ -34,18 +34,15 @@ async def _try_acquire_lock_and_start_scheduler(server: SyncServer) -> bool: acquired_lock = False try: # Use a temporary connection context for the attempt initially - with db_context() as session: - engine = session.get_bind() - # Get raw connection - MUST be kept open if lock is acquired - raw_conn = engine.raw_connection() - cur = raw_conn.cursor() + async with db_registry.async_session() as session: + raw_conn = await session.connection() - cur.execute("SELECT pg_try_advisory_lock(CAST(%s AS bigint))", (ADVISORY_LOCK_KEY,)) - acquired_lock = cur.fetchone()[0] + # Try to acquire the advisory lock + result = await session.execute(f"SELECT pg_try_advisory_lock(CAST({ADVISORY_LOCK_KEY} AS bigint))") + acquired_lock = result.scalar_one() if not acquired_lock: - cur.close() - raw_conn.close() + await raw_conn.close() logger.info("Scheduler lock held by another instance.") return False diff --git a/letta/server/rest_api/routers/v1/messages.py b/letta/server/rest_api/routers/v1/messages.py index 4d7d3588..e156d05d 100644 --- a/letta/server/rest_api/routers/v1/messages.py +++ b/letta/server/rest_api/routers/v1/messages.py @@ -161,7 +161,7 @@ async def list_batch_messages( # Get messages directly using our efficient method # We'll need to update the underlying implementation to use message_id as cursor - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=batch_id, limit=limit, actor=actor, agent_id=agent_id, sort_descending=sort_descending, cursor=cursor ) @@ -184,7 +184,7 @@ async def cancel_batch_run( job = await server.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor) # Get related llm batch jobs - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=job.id, actor=actor) for llm_batch_job in llm_batch_jobs: if llm_batch_job.status in {JobStatus.running, JobStatus.created}: # TODO: Extend to providers beyond anthropic @@ -194,6 +194,8 @@ async def cancel_batch_run( await server.anthropic_async_client.messages.batches.cancel(anthropic_batch_id) # Update all the batch_job statuses - server.batch_manager.update_llm_batch_status(llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor) + await server.batch_manager.update_llm_batch_status_async( + llm_batch_id=llm_batch_job.id, status=JobStatus.cancelled, actor=actor + ) except NoResultFound: raise HTTPException(status_code=404, detail="Run not found") diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 7ac6d9de..c296a64e 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -58,7 +58,7 @@ class LLMBatchManager: @enforce_types @trace_method - def update_llm_batch_status( + async def update_llm_batch_status_async( self, llm_batch_id: str, status: JobStatus, @@ -66,15 +66,15 @@ class LLMBatchManager: latest_polling_response: Optional[BetaMessageBatch] = None, ) -> PydanticLLMBatchJob: """Update a batch job’s status and optionally its polling response.""" - with db_registry.session() as session: - batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) + async with db_registry.async_session() as session: + batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor) batch.status = status batch.latest_polling_response = latest_polling_response batch.last_polled_at = datetime.datetime.now(datetime.timezone.utc) - batch = batch.update(db_session=session, actor=actor) + batch = await batch.update_async(db_session=session, actor=actor) return batch.to_pydantic() - def bulk_update_llm_batch_statuses( + async def bulk_update_llm_batch_statuses_async( self, updates: List[BatchPollingResult], ) -> None: @@ -85,7 +85,7 @@ class LLMBatchManager: """ now = datetime.datetime.now(datetime.timezone.utc) - with db_registry.session() as session: + async with db_registry.async_session() as session: mappings = [] for llm_batch_id, status, response in updates: mappings.append( @@ -97,18 +97,18 @@ class LLMBatchManager: } ) - session.bulk_update_mappings(LLMBatchJob, mappings) - session.commit() + await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchJob, mappings)) + await session.commit() @enforce_types @trace_method - def list_llm_batch_jobs( + async def list_llm_batch_jobs_async( self, letta_batch_id: str, limit: Optional[int] = None, actor: Optional[PydanticUser] = None, after: Optional[str] = None, - ) -> List[PydanticLLMBatchItem]: + ) -> List[PydanticLLMBatchJob]: """ List all batch items for a given llm_batch_id, optionally filtered by additional criteria and limited in count. @@ -120,35 +120,35 @@ class LLMBatchManager: The results are ordered by their id in ascending order. """ - with db_registry.session() as session: - query = session.query(LLMBatchJob).filter(LLMBatchJob.letta_batch_job_id == letta_batch_id) + async with db_registry.async_session() as session: + query = select(LLMBatchJob).where(LLMBatchJob.letta_batch_job_id == letta_batch_id) if actor is not None: - query = query.filter(LLMBatchJob.organization_id == actor.organization_id) + query = query.where(LLMBatchJob.organization_id == actor.organization_id) # Additional optional filters if after is not None: - query = query.filter(LLMBatchJob.id > after) + query = query.where(LLMBatchJob.id > after) query = query.order_by(LLMBatchJob.id.asc()) if limit is not None: query = query.limit(limit) - results = query.all() - return [item.to_pydantic() for item in results] + results = await session.execute(query) + return [item.to_pydantic() for item in results.scalars().all()] @enforce_types @trace_method - def delete_llm_batch_request(self, llm_batch_id: str, actor: PydanticUser) -> None: + async def delete_llm_batch_request_async(self, llm_batch_id: str, actor: PydanticUser) -> None: """Hard delete a batch job by ID.""" - with db_registry.session() as session: - batch = LLMBatchJob.read(db_session=session, identifier=llm_batch_id, actor=actor) - batch.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + batch = await LLMBatchJob.read_async(db_session=session, identifier=llm_batch_id, actor=actor) + await batch.hard_delete_async(db_session=session, actor=actor) @enforce_types @trace_method - def get_messages_for_letta_batch( + async def get_messages_for_letta_batch_async( self, letta_batch_job_id: str, limit: int = 100, @@ -161,12 +161,12 @@ class LLMBatchManager: Retrieve messages across all LLM batch jobs associated with a Letta batch job. Optimized for PostgreSQL performance using ID-based keyset pagination. """ - with db_registry.session() as session: + async with db_registry.async_session() as session: # If cursor is provided, get sequence_id for that message cursor_sequence_id = None if cursor: - cursor_query = session.query(MessageModel.sequence_id).filter(MessageModel.id == cursor).limit(1) - cursor_result = cursor_query.first() + cursor_query = select(MessageModel.sequence_id).where(MessageModel.id == cursor).limit(1) + cursor_result = await session.execute(cursor_query) if cursor_result: cursor_sequence_id = cursor_result[0] else: @@ -174,24 +174,24 @@ class LLMBatchManager: pass query = ( - session.query(MessageModel) + select(MessageModel) .join(LLMBatchItem, MessageModel.batch_item_id == LLMBatchItem.id) .join(LLMBatchJob, LLMBatchItem.llm_batch_id == LLMBatchJob.id) - .filter(LLMBatchJob.letta_batch_job_id == letta_batch_job_id) + .where(LLMBatchJob.letta_batch_job_id == letta_batch_job_id) ) if actor is not None: - query = query.filter(MessageModel.organization_id == actor.organization_id) + query = query.where(MessageModel.organization_id == actor.organization_id) if agent_id is not None: - query = query.filter(MessageModel.agent_id == agent_id) + query = query.where(MessageModel.agent_id == agent_id) # Apply cursor-based pagination if cursor exists if cursor_sequence_id is not None: if sort_descending: - query = query.filter(MessageModel.sequence_id < cursor_sequence_id) + query = query.where(MessageModel.sequence_id < cursor_sequence_id) else: - query = query.filter(MessageModel.sequence_id > cursor_sequence_id) + query = query.where(MessageModel.sequence_id > cursor_sequence_id) if sort_descending: query = query.order_by(desc(MessageModel.sequence_id)) @@ -200,8 +200,8 @@ class LLMBatchManager: query = query.limit(limit) - results = query.all() - return [message.to_pydantic() for message in results] + results = await session.execute(query) + return [message.to_pydantic() for message in results.scalars().all()] @enforce_types @trace_method @@ -218,7 +218,7 @@ class LLMBatchManager: @enforce_types @trace_method - def create_llm_batch_item( + async def create_llm_batch_item_async( self, llm_batch_id: str, agent_id: str, @@ -229,7 +229,7 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Create a new batch item.""" - with db_registry.session() as session: + async with db_registry.async_session() as session: item = LLMBatchItem( llm_batch_id=llm_batch_id, agent_id=agent_id, @@ -239,7 +239,7 @@ class LLMBatchManager: step_state=step_state, organization_id=actor.organization_id, ) - item.create(session, actor=actor) + await item.create_async(session, actor=actor) return item.to_pydantic() @enforce_types @@ -280,15 +280,15 @@ class LLMBatchManager: @enforce_types @trace_method - def get_llm_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: + async def get_llm_batch_item_by_id_async(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: """Retrieve a single batch item by ID.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) return item.to_pydantic() @enforce_types @trace_method - def update_llm_batch_item( + async def update_llm_batch_item_async( self, item_id: str, actor: PydanticUser, @@ -298,8 +298,8 @@ class LLMBatchManager: step_state: Optional[AgentStepState] = None, ) -> PydanticLLMBatchItem: """Update fields on a batch item.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) if request_status: item.request_status = request_status @@ -310,7 +310,8 @@ class LLMBatchManager: if step_state: item.step_state = step_state - return item.update(db_session=session, actor=actor).to_pydantic() + result = await item.update_async(db_session=session, actor=actor) + return result.to_pydantic() @enforce_types @trace_method @@ -360,7 +361,7 @@ class LLMBatchManager: return [item.to_pydantic() for item in results.scalars()] @trace_method - def bulk_update_llm_batch_items( + async def bulk_update_llm_batch_items_async( self, llm_batch_id_agent_id_pairs: List[Tuple[str, str]], field_updates: List[Dict[str, Any]], strict: bool = True ) -> None: """ @@ -378,13 +379,13 @@ class LLMBatchManager: if len(llm_batch_id_agent_id_pairs) != len(field_updates): raise ValueError("llm_batch_id_agent_id_pairs and field_updates must have the same length") - with db_registry.session() as session: + async with db_registry.async_session() as session: # Lookup primary keys for all requested (batch_id, agent_id) pairs - items = ( - session.query(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id) - .filter(tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs)) - .all() + query = select(LLMBatchItem.id, LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).filter( + tuple_(LLMBatchItem.llm_batch_id, LLMBatchItem.agent_id).in_(llm_batch_id_agent_id_pairs) ) + result = await session.execute(query) + items = result.all() pair_to_pk = {(batch_id, agent_id): pk for pk, batch_id, agent_id in items} if strict: @@ -409,12 +410,12 @@ class LLMBatchManager: mappings.append(update_fields) if mappings: - session.bulk_update_mappings(LLMBatchItem, mappings) - session.commit() + await session.run_sync(lambda ses: ses.bulk_update_mappings(LLMBatchItem, mappings)) + await session.commit() @enforce_types @trace_method - def bulk_update_batch_llm_items_results_by_agent(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None: + async def bulk_update_batch_llm_items_results_by_agent_async(self, updates: List[ItemUpdateInfo], strict: bool = True) -> None: """Update request status and batch results for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [ @@ -425,37 +426,41 @@ class LLMBatchManager: for update in updates ] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types @trace_method - def bulk_update_llm_batch_items_step_status_by_agent(self, updates: List[StepStatusUpdateInfo], strict: bool = True) -> None: + async def bulk_update_llm_batch_items_step_status_by_agent_async( + self, updates: List[StepStatusUpdateInfo], strict: bool = True + ) -> None: """Update step status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"step_status": update.step_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types @trace_method - def bulk_update_llm_batch_items_request_status_by_agent(self, updates: List[RequestStatusUpdateInfo], strict: bool = True) -> None: + async def bulk_update_llm_batch_items_request_status_by_agent_async( + self, updates: List[RequestStatusUpdateInfo], strict: bool = True + ) -> None: """Update request status for multiple batch items.""" batch_id_agent_id_pairs = [(update.llm_batch_id, update.agent_id) for update in updates] field_updates = [{"request_status": update.request_status} for update in updates] - self.bulk_update_llm_batch_items(batch_id_agent_id_pairs, field_updates, strict=strict) + await self.bulk_update_llm_batch_items_async(batch_id_agent_id_pairs, field_updates, strict=strict) @enforce_types @trace_method - def delete_llm_batch_item(self, item_id: str, actor: PydanticUser) -> None: + async def delete_llm_batch_item_async(self, item_id: str, actor: PydanticUser) -> None: """Hard delete a batch item by ID.""" - with db_registry.session() as session: - item = LLMBatchItem.read(db_session=session, identifier=item_id, actor=actor) - item.hard_delete(db_session=session, actor=actor) + async with db_registry.async_session() as session: + item = await LLMBatchItem.read_async(db_session=session, identifier=item_id, actor=actor) + await item.hard_delete_async(db_session=session, actor=actor) @enforce_types @trace_method - def count_llm_batch_items(self, llm_batch_id: str) -> int: + async def count_llm_batch_items_async(self, llm_batch_id: str) -> int: """ Efficiently count the number of batch items for a given llm_batch_id. @@ -465,6 +470,6 @@ class LLMBatchManager: Returns: int: The total number of batch items associated with the given llm_batch_id. """ - with db_registry.session() as session: - count = session.query(func.count(LLMBatchItem.id)).filter(LLMBatchItem.llm_batch_id == llm_batch_id).scalar() - return count or 0 + async with db_registry.async_session() as session: + count = await session.execute(select(func.count(LLMBatchItem.id)).where(LLMBatchItem.llm_batch_id == llm_batch_id)) + return count.scalar() or 0 diff --git a/tests/integration_test_batch_api_cron_jobs.py b/tests/integration_test_batch_api_cron_jobs.py index 406d06cd..8a07b9a4 100644 --- a/tests/integration_test_batch_api_cron_jobs.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -185,7 +185,7 @@ async def create_test_llm_batch_job_async(server, batch_response, default_user): ) -def create_test_batch_item(server, batch_id, agent_id, default_user): +async def create_test_batch_item(server, batch_id, agent_id, default_user): """Create a test batch item for the given batch and agent.""" dummy_llm_config = LLMConfig( model="claude-3-7-sonnet-latest", @@ -201,7 +201,7 @@ def create_test_batch_item(server, batch_id, agent_id, default_user): step_number=1, tool_rules_solver=ToolRulesSolver(tool_rules=[InitToolRule(tool_name="send_message")]) ) - return server.batch_manager.create_llm_batch_item( + return await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch_id, agent_id=agent_id, llm_config=dummy_llm_config, @@ -289,9 +289,9 @@ async def test_polling_mixed_batch_jobs(default_user, server): job_b = await create_test_llm_batch_job_async(server, batch_b_resp, default_user) # --- Step 3: Create batch items --- - item_a = create_test_batch_item(server, job_a.id, agent_a.id, default_user) - item_b = create_test_batch_item(server, job_b.id, agent_b.id, default_user) - item_c = create_test_batch_item(server, job_b.id, agent_c.id, default_user) + item_a = await create_test_batch_item(server, job_a.id, agent_a.id, default_user) + item_b = await create_test_batch_item(server, job_b.id, agent_b.id, default_user) + item_c = await create_test_batch_item(server, job_b.id, agent_c.id, default_user) # --- Step 4: Mock the Anthropic client --- mock_anthropic_client(server, batch_a_resp, batch_b_resp, agent_b.id, agent_c.id) @@ -316,17 +316,17 @@ async def test_polling_mixed_batch_jobs(default_user, server): # --- Step 7: Verify batch item status updates --- # Item A should remain unchanged - updated_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) + updated_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user) assert updated_item_a.request_status == JobStatus.created assert updated_item_a.batch_request_result is None # Item B should be marked as completed with a successful result - updated_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) + updated_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user) assert updated_item_b.request_status == JobStatus.completed assert updated_item_b.batch_request_result is not None # Item C should be marked as failed with an error result - updated_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + updated_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user) assert updated_item_c.request_status == JobStatus.failed assert updated_item_c.batch_request_result is not None @@ -352,9 +352,9 @@ async def test_polling_mixed_batch_jobs(default_user, server): # Refresh all objects final_job_a = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_a.id, actor=default_user) final_job_b = await server.batch_manager.get_llm_batch_job_by_id_async(llm_batch_id=job_b.id, actor=default_user) - final_item_a = server.batch_manager.get_llm_batch_item_by_id(item_a.id, actor=default_user) - final_item_b = server.batch_manager.get_llm_batch_item_by_id(item_b.id, actor=default_user) - final_item_c = server.batch_manager.get_llm_batch_item_by_id(item_c.id, actor=default_user) + final_item_a = await server.batch_manager.get_llm_batch_item_by_id_async(item_a.id, actor=default_user) + final_item_b = await server.batch_manager.get_llm_batch_item_by_id_async(item_b.id, actor=default_user) + final_item_c = await server.batch_manager.get_llm_batch_item_by_id_async(item_c.id, actor=default_user) # Job A should still be polling (last_polled_at should update) assert final_job_a.status == JobStatus.running diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index bc8f7e8e..3a14a856 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -458,7 +458,9 @@ async def test_partial_error_from_anthropic_batch( letta_batch_job_id=batch_job.id, ) - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) llm_batch_job = llm_batch_jobs[0] # 2. Invoke the polling job and mock responses from Anthropic @@ -571,7 +573,7 @@ async def test_partial_error_from_anthropic_batch( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == (len(agents) - 1) * 4 + 1 @@ -621,7 +623,9 @@ async def test_resume_step_some_stop( letta_batch_job_id=batch_job.id, ) - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) llm_batch_job = llm_batch_jobs[0] # 2. Invoke the polling job and mock responses from Anthropic @@ -723,7 +727,7 @@ async def test_resume_step_some_stop( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 3 + 1 @@ -789,7 +793,9 @@ async def test_resume_step_after_request_all_continue( # Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly` # Verify batch items - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async( + letta_batch_id=pre_resume_response.letta_batch_id, actor=default_user + ) assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] @@ -883,7 +889,7 @@ async def test_resume_step_after_request_all_continue( ), f"Agent's in-context messages have been extended, are length: {len(refreshed_agent.message_ids)}" # Check the total list of messages - messages = server.batch_manager.get_messages_for_letta_batch( + messages = await server.batch_manager.get_messages_for_letta_batch_async( letta_batch_job_id=pre_resume_response.letta_batch_id, limit=200, actor=default_user ) assert len(messages) == len(agents) * 4 @@ -987,7 +993,7 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( mock_send.assert_called_once() # Verify database records were created correctly - llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=response.letta_batch_id, actor=default_user) + llm_batch_jobs = await server.batch_manager.list_llm_batch_jobs_async(letta_batch_id=response.letta_batch_id, actor=default_user) assert len(llm_batch_jobs) == 1, f"Expected 1 llm_batch_jobs, got {len(llm_batch_jobs)}" llm_batch_job = llm_batch_jobs[0] diff --git a/tests/test_managers.py b/tests/test_managers.py index bddac2a2..c51ed0f9 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -5107,7 +5107,7 @@ async def test_update_batch_status(server, default_user, dummy_beta_message_batc ) before = datetime.now(timezone.utc) - server.batch_manager.update_llm_batch_status( + await server.batch_manager.update_llm_batch_status_async( llm_batch_id=batch.id, status=JobStatus.completed, latest_polling_response=dummy_beta_message_batch, @@ -5132,7 +5132,7 @@ async def test_create_and_get_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5144,7 +5144,7 @@ async def test_create_and_get_batch_item( assert item.agent_id == sarah_agent.id assert item.step_state == dummy_step_state - fetched = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert fetched.id == item.id @@ -5168,7 +5168,7 @@ async def test_update_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5178,7 +5178,7 @@ async def test_update_batch_item( updated_step_state = AgentStepState(step_number=2, tool_rules_solver=dummy_step_state.tool_rules_solver) - server.batch_manager.update_llm_batch_item( + await server.batch_manager.update_llm_batch_item_async( item_id=item.id, request_status=JobStatus.completed, step_status=AgentStepStatus.resumed, @@ -5187,7 +5187,7 @@ async def test_update_batch_item( actor=default_user, ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response @@ -5204,7 +5204,7 @@ async def test_delete_batch_item( letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5212,10 +5212,10 @@ async def test_delete_batch_item( actor=default_user, ) - server.batch_manager.delete_llm_batch_item(item_id=item.id, actor=default_user) + await server.batch_manager.delete_llm_batch_item_async(item_id=item.id, actor=default_user) with pytest.raises(NoResultFound): - server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) @pytest.mark.asyncio @@ -5243,7 +5243,7 @@ async def test_bulk_update_batch_statuses(server, default_user, dummy_beta_messa letta_batch_job_id=letta_batch_job.id, ) - server.batch_manager.bulk_update_llm_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) + await server.batch_manager.bulk_update_llm_batch_statuses_async([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) updated = await server.batch_manager.get_llm_batch_job_by_id_async(batch.id, actor=default_user) assert updated.status == JobStatus.completed @@ -5268,7 +5268,7 @@ async def test_bulk_update_batch_items_results_by_agent( actor=default_user, letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5276,11 +5276,11 @@ async def test_bulk_update_batch_items_results_by_agent( actor=default_user, ) - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)] ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.completed assert updated.batch_request_result == dummy_successful_response @@ -5295,7 +5295,7 @@ async def test_bulk_update_batch_items_step_status_by_agent( actor=default_user, letta_batch_job_id=letta_batch_job.id, ) - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5303,11 +5303,11 @@ async def test_bulk_update_batch_items_step_status_by_agent( actor=default_user, ) - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)] ) - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.step_status == AgentStepStatus.resumed @@ -5323,7 +5323,7 @@ async def test_list_batch_items_limit_and_filter( ) for _ in range(3): - server.batch_manager.create_llm_batch_item( + await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5353,7 +5353,7 @@ async def test_list_batch_items_pagination( # Create 10 batch items. created_items = [] for i in range(10): - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5416,7 +5416,7 @@ async def test_bulk_update_batch_items_request_status_by_agent( ) # Create a batch item - item = server.batch_manager.create_llm_batch_item( + item = await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5425,12 +5425,12 @@ async def test_bulk_update_batch_items_request_status_by_agent( ) # Update the request status using the bulk update method - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)] ) # Verify the update was applied - updated = server.batch_manager.get_llm_batch_item_by_id(item.id, actor=default_user) + updated = await server.batch_manager.get_llm_batch_item_by_id_async(item.id, actor=default_user) assert updated.request_status == JobStatus.expired @@ -5459,20 +5459,20 @@ async def test_bulk_update_nonexistent_items_should_error( ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates) + await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] ) with pytest.raises(ValueError, match=re.escape(expected_err_msg)): - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] ) @@ -5496,21 +5496,21 @@ async def test_bulk_update_nonexistent_items( nonexistent_updates = [{"request_status": JobStatus.expired}] # This should not raise an error, just silently skip non-existent items - server.batch_manager.bulk_update_llm_batch_items(nonexistent_pairs, nonexistent_updates, strict=False) + await server.batch_manager.bulk_update_llm_batch_items_async(nonexistent_pairs, nonexistent_updates, strict=False) # Test with higher-level methods # Results by agent - server.batch_manager.bulk_update_batch_llm_items_results_by_agent( + await server.batch_manager.bulk_update_batch_llm_items_results_by_agent_async( [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)], strict=False ) # Step status by agent - server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_step_status_by_agent_async( [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)], strict=False ) # Request status by agent - server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent( + await server.batch_manager.bulk_update_llm_batch_items_request_status_by_agent_async( [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)], strict=False ) @@ -5565,7 +5565,7 @@ async def test_create_batch_items_bulk( # Verify the IDs of created items match what's in the database created_ids = [item.id for item in created_items] for item_id in created_ids: - fetched = server.batch_manager.get_llm_batch_item_by_id(item_id, actor=default_user) + fetched = await server.batch_manager.get_llm_batch_item_by_id_async(item_id, actor=default_user) assert fetched.id in created_ids @@ -5585,7 +5585,7 @@ async def test_count_batch_items( # Create a specific number of batch items for this batch. num_items = 5 for _ in range(num_items): - server.batch_manager.create_llm_batch_item( + await server.batch_manager.create_llm_batch_item_async( llm_batch_id=batch.id, agent_id=sarah_agent.id, llm_config=dummy_llm_config, @@ -5594,7 +5594,7 @@ async def test_count_batch_items( ) # Use the count_llm_batch_items method to count the items. - count = server.batch_manager.count_llm_batch_items(llm_batch_id=batch.id) + count = await server.batch_manager.count_llm_batch_items_async(llm_batch_id=batch.id) # Assert that the count matches the expected number. assert count == num_items, f"Expected {num_items} items, got {count}"