feat: convert more methods to async (#2200)
This commit is contained in:
@@ -247,7 +247,9 @@ class LettaAgentBatch(BaseAgent):
|
||||
log_event(name="prepare_next")
|
||||
next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map)
|
||||
if len(next_reqs) == 0:
|
||||
self.job_manager.update_job_by_id(job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor)
|
||||
await self.job_manager.update_job_by_id_async(
|
||||
job_id=letta_batch_id, job_update=JobUpdate(status=JobStatus.completed), actor=self.actor
|
||||
)
|
||||
return LettaBatchResponse(
|
||||
letta_batch_id=llm_batch_job.letta_batch_job_id,
|
||||
last_llm_batch_id=llm_batch_job.id,
|
||||
@@ -358,14 +360,14 @@ class LettaAgentBatch(BaseAgent):
|
||||
tool_params.append(param)
|
||||
|
||||
if rethink_memory_params:
|
||||
return self._bulk_rethink_memory(rethink_memory_params)
|
||||
return await self._bulk_rethink_memory_async(rethink_memory_params)
|
||||
|
||||
if tool_params:
|
||||
async with Pool() as pool:
|
||||
return await pool.map(execute_tool_wrapper, tool_params)
|
||||
|
||||
@trace_method
|
||||
def _bulk_rethink_memory(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
async def _bulk_rethink_memory_async(self, params: List[ToolExecutionParams]) -> Sequence[Tuple[str, Tuple[str, bool]]]:
|
||||
updates = {}
|
||||
result = []
|
||||
for param in params:
|
||||
@@ -386,7 +388,7 @@ class LettaAgentBatch(BaseAgent):
|
||||
# TODO: This is quite ugly and confusing - this is mostly to align with the returns of other tools
|
||||
result.append((param.agent_id, ("", True)))
|
||||
|
||||
self.block_manager.bulk_update_block_values(updates=updates, actor=self.actor)
|
||||
await self.block_manager.bulk_update_block_values_async(updates=updates, actor=self.actor)
|
||||
|
||||
return result
|
||||
|
||||
|
||||
@@ -63,7 +63,7 @@ async def create_messages_batch(
|
||||
)
|
||||
|
||||
try:
|
||||
batch_job = server.job_manager.create_job(pydantic_job=batch_job, actor=actor)
|
||||
batch_job = await server.job_manager.create_job_async(pydantic_job=batch_job, actor=actor)
|
||||
|
||||
# create the batch runner
|
||||
batch_runner = LettaAgentBatch(
|
||||
@@ -86,7 +86,7 @@ async def create_messages_batch(
|
||||
traceback.print_exc()
|
||||
|
||||
# mark job as failed
|
||||
server.job_manager.update_job_by_id(job_id=batch_job.id, job=BatchJob(status=JobStatus.failed), actor=actor)
|
||||
await server.job_manager.update_job_by_id_async(job_id=batch_job.id, job_update=JobUpdate(status=JobStatus.failed), actor=actor)
|
||||
raise
|
||||
return batch_job
|
||||
|
||||
@@ -103,7 +103,7 @@ async def retrieve_batch_run(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
return BatchJob.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
@@ -154,7 +154,7 @@ async def list_batch_messages(
|
||||
|
||||
# First, verify the batch job exists and the user has access to it
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
BatchJob.from_job(job)
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Batch not found")
|
||||
@@ -180,8 +180,8 @@ async def cancel_batch_run(
|
||||
actor = server.user_manager.get_user_or_default(user_id=actor_id)
|
||||
|
||||
try:
|
||||
job = server.job_manager.get_job_by_id(job_id=batch_id, actor=actor)
|
||||
job = server.job_manager.update_job_by_id(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
||||
job = await server.job_manager.get_job_by_id_async(job_id=batch_id, actor=actor)
|
||||
job = await server.job_manager.update_job_by_id_async(job_id=job.id, job_update=JobUpdate(status=JobStatus.cancelled), actor=actor)
|
||||
|
||||
# Get related llm batch jobs
|
||||
llm_batch_jobs = server.batch_manager.list_llm_batch_jobs(letta_batch_id=job.id, actor=actor)
|
||||
|
||||
@@ -975,7 +975,6 @@ class AgentManager:
|
||||
@enforce_types
|
||||
async def get_agent_by_id_async(self, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""Fetch an agent by its ID."""
|
||||
print("ASYNC")
|
||||
async with db_registry.async_session() as session:
|
||||
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
|
||||
return agent.to_pydantic()
|
||||
|
||||
@@ -1,5 +1,6 @@
|
||||
from typing import Dict, List, Optional
|
||||
|
||||
from sqlalchemy import select
|
||||
from sqlalchemy.orm import Session
|
||||
|
||||
from letta.log import get_logger
|
||||
@@ -454,7 +455,7 @@ class BlockManager:
|
||||
return block.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def bulk_update_block_values(
|
||||
async def bulk_update_block_values_async(
|
||||
self, updates: Dict[str, str], actor: PydanticUser, return_hydrated: bool = False
|
||||
) -> Optional[List[PydanticBlock]]:
|
||||
"""
|
||||
@@ -469,12 +470,13 @@ class BlockManager:
|
||||
the updated Block objects as Pydantic schemas
|
||||
|
||||
Raises:
|
||||
NoResultFound if any block_id doesn’t exist or isn’t visible to this actor
|
||||
ValueError if any new value exceeds its block’s limit
|
||||
NoResultFound if any block_id doesn't exist or isn't visible to this actor
|
||||
ValueError if any new value exceeds its block's limit
|
||||
"""
|
||||
with db_registry.session() as session:
|
||||
q = session.query(BlockModel).filter(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
|
||||
blocks = q.all()
|
||||
async with db_registry.async_session() as session:
|
||||
query = select(BlockModel).where(BlockModel.id.in_(updates.keys()), BlockModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(query)
|
||||
blocks = result.scalars().all()
|
||||
|
||||
found_ids = {b.id for b in blocks}
|
||||
missing = set(updates.keys()) - found_ids
|
||||
@@ -488,8 +490,10 @@ class BlockManager:
|
||||
new_val = new_val[: block.limit]
|
||||
block.value = new_val
|
||||
|
||||
session.commit()
|
||||
await session.commit()
|
||||
|
||||
if return_hydrated:
|
||||
return [b.to_pydantic() for b in blocks]
|
||||
# TODO: implement for async
|
||||
pass
|
||||
|
||||
return None
|
||||
|
||||
@@ -81,6 +81,30 @@ class JobManager:
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def update_job_by_id_async(self, job_id: str, job_update: JobUpdate, actor: PydanticUser) -> PydanticJob:
|
||||
"""Update a job by its ID with the given JobUpdate object asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch the job by ID
|
||||
job = await self._verify_job_access_async(session=session, job_id=job_id, actor=actor, access=["write"])
|
||||
|
||||
# Update job attributes with only the fields that were explicitly set
|
||||
update_data = job_update.model_dump(to_orm=True, exclude_unset=True, exclude_none=True)
|
||||
|
||||
# Automatically update the completion timestamp if status is set to 'completed'
|
||||
for key, value in update_data.items():
|
||||
setattr(job, key, value)
|
||||
|
||||
if update_data.get("status") == JobStatus.completed and not job.completed_at:
|
||||
job.completed_at = get_utc_time()
|
||||
if job.callback_url:
|
||||
await self._dispatch_callback_async(session, job)
|
||||
|
||||
# Save the updated job to the database
|
||||
await job.update_async(db_session=session, actor=actor)
|
||||
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def get_job_by_id(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Fetch a job by its ID."""
|
||||
@@ -89,6 +113,14 @@ class JobManager:
|
||||
job = JobModel.read(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
async def get_job_by_id_async(self, job_id: str, actor: PydanticUser) -> PydanticJob:
|
||||
"""Fetch a job by its ID asynchronously."""
|
||||
async with db_registry.async_session() as session:
|
||||
# Retrieve job by ID using the Job model's read method
|
||||
job = await JobModel.read_async(db_session=session, identifier=job_id, actor=actor, access_type=AccessType.USER)
|
||||
return job.to_pydantic()
|
||||
|
||||
@enforce_types
|
||||
def list_jobs(
|
||||
self,
|
||||
@@ -451,6 +483,35 @@ class JobManager:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
async def _verify_job_access_async(
|
||||
self,
|
||||
session: Session,
|
||||
job_id: str,
|
||||
actor: PydanticUser,
|
||||
access: List[Literal["read", "write", "delete"]] = ["read"],
|
||||
) -> JobModel:
|
||||
"""
|
||||
Verify that a job exists and the user has the required access.
|
||||
|
||||
Args:
|
||||
session: The database session
|
||||
job_id: The ID of the job to verify
|
||||
actor: The user making the request
|
||||
|
||||
Returns:
|
||||
The job if it exists and the user has access
|
||||
|
||||
Raises:
|
||||
NoResultFound: If the job does not exist or user does not have access
|
||||
"""
|
||||
job_query = select(JobModel).where(JobModel.id == job_id)
|
||||
job_query = JobModel.apply_access_predicate(job_query, actor, access, AccessType.USER)
|
||||
result = await session.execute(job_query)
|
||||
job = result.scalar_one_or_none()
|
||||
if not job:
|
||||
raise NoResultFound(f"Job with id {job_id} does not exist or user does not have access")
|
||||
return job
|
||||
|
||||
def _get_run_request_config(self, run_id: str) -> LettaRequestConfig:
|
||||
"""
|
||||
Get the request config for a job.
|
||||
@@ -489,3 +550,28 @@ class JobManager:
|
||||
|
||||
session.add(job)
|
||||
session.commit()
|
||||
|
||||
async def _dispatch_callback_async(self, session, job: JobModel) -> None:
|
||||
"""
|
||||
POST a standard JSON payload to job.callback_url
|
||||
and record timestamp + HTTP status asynchronously.
|
||||
"""
|
||||
|
||||
payload = {
|
||||
"job_id": job.id,
|
||||
"status": job.status,
|
||||
"completed_at": job.completed_at.isoformat(),
|
||||
}
|
||||
try:
|
||||
import httpx
|
||||
|
||||
async with httpx.AsyncClient() as client:
|
||||
resp = await client.post(job.callback_url, json=payload, timeout=5.0)
|
||||
job.callback_sent_at = get_utc_time()
|
||||
job.callback_status_code = resp.status_code
|
||||
|
||||
except Exception:
|
||||
return
|
||||
|
||||
session.add(job)
|
||||
await session.commit()
|
||||
|
||||
@@ -2832,7 +2832,10 @@ async def test_batch_create_multiple_blocks(server: SyncServer, default_user, ev
|
||||
assert expected_labels.issubset(all_labels)
|
||||
|
||||
|
||||
def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncServer, default_user: PydanticUser, caplog):
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_skips_missing_and_truncates_then_returns_none(
|
||||
server: SyncServer, default_user: PydanticUser, caplog, event_loop
|
||||
):
|
||||
mgr = BlockManager()
|
||||
|
||||
# create one block with a small limit
|
||||
@@ -2849,7 +2852,7 @@ def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncS
|
||||
}
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
result = mgr.bulk_update_block_values(updates, actor=default_user)
|
||||
result = await mgr.bulk_update_block_values_async(updates, actor=default_user)
|
||||
# default return_hydrated=False → should be None
|
||||
assert result is None
|
||||
|
||||
@@ -2863,7 +2866,9 @@ def test_bulk_update_skips_missing_and_truncates_then_returns_none(server: SyncS
|
||||
assert reloaded.value == long_val[:5]
|
||||
|
||||
|
||||
def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: PydanticUser):
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skip(reason="TODO: implement for async")
|
||||
async def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: PydanticUser, event_loop):
|
||||
mgr = BlockManager()
|
||||
|
||||
# create a block
|
||||
@@ -2873,7 +2878,7 @@ def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: Pyda
|
||||
)
|
||||
|
||||
updates = {b.id: "new-val"}
|
||||
updated = mgr.bulk_update_block_values(updates, actor=default_user, return_hydrated=True)
|
||||
updated = await mgr.bulk_update_block_values_async(updates, actor=default_user, return_hydrated=True)
|
||||
|
||||
# with return_hydrated=True, we get back a list of schemas
|
||||
assert isinstance(updated, list) and len(updated) == 1
|
||||
@@ -2881,7 +2886,10 @@ def test_bulk_update_return_hydrated_true(server: SyncServer, default_user: Pyda
|
||||
assert updated[0].value == "new-val"
|
||||
|
||||
|
||||
def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: PydanticUser, other_user_different_org: PydanticUser, caplog):
|
||||
@pytest.mark.asyncio
|
||||
async def test_bulk_update_respects_org_scoping(
|
||||
server: SyncServer, default_user: PydanticUser, other_user_different_org: PydanticUser, caplog, event_loop
|
||||
):
|
||||
mgr = BlockManager()
|
||||
|
||||
# one block in each org
|
||||
@@ -2900,7 +2908,7 @@ def test_bulk_update_respects_org_scoping(server: SyncServer, default_user: Pyda
|
||||
}
|
||||
|
||||
caplog.set_level(logging.WARNING)
|
||||
mgr.bulk_update_block_values(updates, actor=default_user)
|
||||
await mgr.bulk_update_block_values_async(updates, actor=default_user)
|
||||
|
||||
# mine should be updated...
|
||||
reloaded_mine = mgr.get_block_by_id(actor=default_user, block_id=mine.id)
|
||||
|
||||
Reference in New Issue
Block a user