From 983f75099050cc83be14cc18ffc5cf91453d1a97 Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Tue, 15 Apr 2025 13:56:22 -0700 Subject: [PATCH] feat: Implement resume step after request in new batch agent loop (#1676) --- letta/agents/letta_agent.py | 1 + letta/agents/letta_agent_batch.py | 348 +++++++++++++++++- letta/jobs/llm_batch_job_polling.py | 14 +- letta/jobs/types.py | 32 +- letta/llm_api/anthropic_client.py | 4 +- letta/orm/message.py | 2 +- letta/orm/sqlalchemy_base.py | 64 +++- letta/schemas/enums.py | 4 +- letta/schemas/letta_response.py | 1 + letta/schemas/llm_batch_job.py | 4 +- letta/server/rest_api/utils.py | 32 +- letta/services/agent_manager.py | 3 +- letta/services/llm_batch_manager.py | 126 +++++-- .../tool_executor/tool_execution_sandbox.py | 1 - letta/services/tool_sandbox/base.py | 3 - poetry.lock | 69 ++-- pyproject.toml | 1 + tests/integration_test_experimental.py | 4 - tests/test_letta_agent_batch.py | 278 +++++++++++++- tests/test_managers.py | 251 ++++++++++++- 20 files changed, 1059 insertions(+), 183 deletions(-) diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index a158c10e..7fb4ff57 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -76,6 +76,7 @@ class LettaAgent(BaseAgent): agent_state=agent_state, tool_rules_solver=tool_rules_solver, stream=False, + # TODO: also pass in reasoning content ) tool_call = response.choices[0].message.tool_calls[0] diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py index 2499a5ca..7d2f3c30 100644 --- a/letta/agents/letta_agent_batch.py +++ b/letta/agents/letta_agent_batch.py @@ -1,51 +1,117 @@ -from typing import Dict, List +import json +import uuid +from dataclasses import dataclass +from typing import Any, Dict, List, Optional, Sequence, Tuple, Union + +from aiomultiprocess import Pool +from anthropic.types.beta.messages import BetaMessageBatchCanceledResult, BetaMessageBatchErroredResult, BetaMessageBatchSucceededResult from letta.agents.helpers import _prepare_in_context_messages from letta.helpers import ToolRulesSolver from letta.helpers.datetime_helpers import get_utc_time from letta.helpers.tool_execution_helper import enable_strict_mode +from letta.jobs.types import RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.llm_api.llm_client import LLMClient +from letta.local_llm.constants import INNER_THOUGHTS_KWARG from letta.log import get_logger from letta.orm.enums import ToolType from letta.schemas.agent import AgentState, AgentStepState -from letta.schemas.enums import JobStatus, ProviderType +from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType +from letta.schemas.letta_message_content import OmittedReasoningContent, ReasoningContent, RedactedReasoningContent, TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.letta_response import LettaBatchResponse +from letta.schemas.llm_batch_job import LLMBatchItem from letta.schemas.message import Message, MessageCreate, MessageUpdate +from letta.schemas.openai.chat_completion_response import ToolCall as OpenAIToolCall +from letta.schemas.sandbox_config import SandboxConfig, SandboxType from letta.schemas.user import User +from letta.server.rest_api.utils import create_heartbeat_system_message, create_letta_messages_from_llm_response from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.llm_batch_manager import LLMBatchManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager +from letta.services.sandbox_config_manager import SandboxConfigManager +from letta.services.tool_executor.tool_execution_manager import ToolExecutionManager +from letta.settings import tool_settings from letta.utils import united_diff logger = get_logger(__name__) +@dataclass +class ToolExecutionParams: + agent_id: str + tool_call_name: str + tool_args: Dict[str, Any] + agent_state: AgentState + actor: User + sbx_config: SandboxConfig + sbx_env_vars: Dict[str, Any] + + +@dataclass +class _ResumeContext: + batch_items: List[LLMBatchItem] + agent_ids: List[str] + agent_state_map: Dict[str, AgentState] + provider_results: Dict[str, Any] + tool_call_name_map: Dict[str, str] + tool_call_args_map: Dict[str, Dict[str, Any]] + should_continue_map: Dict[str, bool] + request_status_updates: List[RequestStatusUpdateInfo] + + +async def execute_tool_wrapper(params: ToolExecutionParams): + """ + Executes the tool in an out‑of‑process worker and returns: + (agent_id, (tool_result:str, success_flag:bool)) + """ + # locate the tool on the agent + target_tool = next((t for t in params.agent_state.tools if t.name == params.tool_call_name), None) + if not target_tool: + return params.agent_id, (f"Tool not found: {params.tool_call_name}", False) + + try: + mgr = ToolExecutionManager( + agent_state=params.agent_state, + actor=params.actor, + sandbox_config=params.sbx_config, + sandbox_env_vars=params.sbx_env_vars, + ) + result, _ = await mgr.execute_tool_async( + function_name=params.tool_call_name, + function_args=params.tool_args, + tool=target_tool, + ) + return params.agent_id, (result, True) + except Exception as e: + return params.agent_id, (f"Failed to call tool. Error: {e}", False) + + # TODO: Limitations -> # TODO: Only works with anthropic for now class LettaAgentBatch: def __init__( self, - batch_id: str, message_manager: MessageManager, agent_manager: AgentManager, block_manager: BlockManager, passage_manager: PassageManager, batch_manager: LLMBatchManager, + sandbox_config_manager: SandboxConfigManager, actor: User, use_assistant_message: bool = True, max_steps: int = 10, ): - self.batch_id = batch_id self.message_manager = message_manager self.agent_manager = agent_manager self.block_manager = block_manager self.passage_manager = passage_manager self.batch_manager = batch_manager + self.sandbox_config_manager = sandbox_config_manager self.use_assistant_message = use_assistant_message self.actor = actor self.max_steps = max_steps @@ -53,6 +119,10 @@ class LettaAgentBatch: async def step_until_request( self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState] ) -> LettaBatchResponse: + # Basic checks + if not batch_requests: + raise ValueError("Empty list of batch_requests passed in!") + agent_messages_mapping: Dict[str, List[Message]] = {} agent_tools_mapping: Dict[str, List[dict]] = {} agent_states = [] @@ -61,10 +131,10 @@ class LettaAgentBatch: agent_id = batch_request.agent_id agent_state = self.agent_manager.get_agent_by_id(agent_id, actor=self.actor) agent_states.append(agent_state) - agent_messages_mapping[agent_id] = self.get_in_context_messages_per_agent( + agent_messages_mapping[agent_id] = self._get_in_context_messages_per_agent( agent_state=agent_state, input_messages=batch_request.messages ) - agent_tools_mapping[agent_id] = self.prepare_tools_per_agent( + agent_tools_mapping[agent_id] = self._prepare_tools_per_agent( agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver ) @@ -84,36 +154,284 @@ class LettaAgentBatch: # Write the response into the jobs table, where it will get picked up by the next cron run batch_job = self.batch_manager.create_batch_job( - llm_provider=ProviderType.anthropic, # TODO: Expand to more + llm_provider=ProviderType.anthropic, # TODO: Expand to more providers create_batch_response=batch_response, actor=self.actor, status=JobStatus.running, ) - # TODO: Make this much more efficient by doing creates in bulk + # Create batch items in bulk for all agents + batch_items = [] for agent_state in agent_states: agent_step_state = agent_step_state_mapping.get(agent_state.id) - self.batch_manager.create_batch_item( + batch_item = LLMBatchItem( batch_id=batch_job.id, agent_id=agent_state.id, llm_config=agent_state.llm_config, - actor=self.actor, + request_status=JobStatus.created, + step_status=AgentStepStatus.paused, step_state=agent_step_state, ) + batch_items.append(batch_item) + + # Create all batch items at once using the bulk operation + if batch_items: + self.batch_manager.create_batch_items_bulk(batch_items, actor=self.actor) return LettaBatchResponse( - batch_id=batch_job.id, status=batch_job.status, last_polled_at=get_utc_time(), created_at=batch_job.created_at + batch_id=batch_job.id, + status=batch_job.status, + agent_count=len(agent_states), + last_polled_at=get_utc_time(), + created_at=batch_job.created_at, ) - async def resume_step_after_request(self, batch_id: str): - pass + async def resume_step_after_request(self, batch_id: str) -> LettaBatchResponse: + # 1. gather everything we need + ctx = await self._collect_resume_context(batch_id) - def prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]: + # 2. persist request‑level status updates + self._update_request_statuses(ctx.request_status_updates) + + # 3. run the tools in parallel + exec_results = await self._execute_tools(ctx) + + # 4. create + save assistant/tool messages + msg_map = self._persist_tool_messages(exec_results, ctx) + + # 5. mark steps complete + self._mark_steps_complete(batch_id, ctx.agent_ids) + + # 6. build next‑round requests / step‑state map + next_reqs, next_step_state = self._prepare_next_iteration(exec_results, ctx, msg_map) + + # 7. recurse into the normal stepping pipeline + return await self.step_until_request( + batch_requests=next_reqs, + agent_step_state_mapping=next_step_state, + ) + + async def _collect_resume_context(self, batch_id: str) -> _ResumeContext: + batch_items = self.batch_manager.list_batch_items(batch_id=batch_id) + + agent_ids, agent_state_map = [], {} + provider_results, name_map, args_map, cont_map = {}, {}, {}, {} + request_status_updates: List[RequestStatusUpdateInfo] = [] + + for item in batch_items: + aid = item.agent_id + agent_ids.append(aid) + agent_state_map[aid] = self.agent_manager.get_agent_by_id(aid, actor=self.actor) + provider_results[aid] = item.batch_request_result.result + + # status bookkeeping + pr = provider_results[aid] + status = ( + JobStatus.completed + if isinstance(pr, BetaMessageBatchSucceededResult) + else ( + JobStatus.failed + if isinstance(pr, BetaMessageBatchErroredResult) + else JobStatus.cancelled if isinstance(pr, BetaMessageBatchCanceledResult) else JobStatus.expired + ) + ) + request_status_updates.append(RequestStatusUpdateInfo(batch_id=batch_id, agent_id=aid, request_status=status)) + + # translate provider‑specific response → OpenAI‑style tool call (unchanged) + llm_client = LLMClient.create(llm_config=item.llm_config, put_inner_thoughts_first=True) + tool_call = ( + llm_client.convert_response_to_chat_completion(response_data=pr.message.model_dump(), input_messages=[]) + .choices[0] + .message.tool_calls[0] + ) + + name, args, cont = self._extract_tool_call_and_decide_continue(tool_call, item.step_state) + name_map[aid], args_map[aid], cont_map[aid] = name, args, cont + + return _ResumeContext( + batch_items=batch_items, + agent_ids=agent_ids, + agent_state_map=agent_state_map, + provider_results=provider_results, + tool_call_name_map=name_map, + tool_call_args_map=args_map, + should_continue_map=cont_map, + request_status_updates=request_status_updates, + ) + + def _update_request_statuses(self, updates: List[RequestStatusUpdateInfo]) -> None: + if updates: + self.batch_manager.bulk_update_batch_items_request_status_by_agent(updates=updates) + + def _build_sandbox(self) -> Tuple[SandboxConfig, Dict[str, Any]]: + sbx_type = SandboxType.E2B if tool_settings.e2b_api_key else SandboxType.LOCAL + cfg = self.sandbox_config_manager.get_or_create_default_sandbox_config(sandbox_type=sbx_type, actor=self.actor) + env = self.sandbox_config_manager.get_sandbox_env_vars_as_dict(cfg.id, actor=self.actor, limit=100) + return cfg, env + + async def _execute_tools(self, ctx: _ResumeContext) -> Sequence[Tuple[str, Tuple[str, bool]]]: + sbx_cfg, sbx_env = self._build_sandbox() + params = [ + ToolExecutionParams( + agent_id=aid, + tool_call_name=ctx.tool_call_name_map[aid], + tool_args=ctx.tool_call_args_map[aid], + agent_state=ctx.agent_state_map[aid], + actor=self.actor, + sbx_config=sbx_cfg, + sbx_env_vars=sbx_env, + ) + for aid in ctx.agent_ids + ] + async with Pool() as pool: + return await pool.map(execute_tool_wrapper, params) + + def _persist_tool_messages( + self, + exec_results: Sequence[Tuple[str, Tuple[str, bool]]], + ctx: _ResumeContext, + ) -> Dict[str, List[Message]]: + msg_map: Dict[str, List[Message]] = {} + for aid, (tool_res, success) in exec_results: + msgs = self._create_tool_call_messages( + agent_state=ctx.agent_state_map[aid], + tool_call_name=ctx.tool_call_name_map[aid], + tool_call_args=ctx.tool_call_args_map[aid], + tool_exec_result=tool_res, + success_flag=success, + reasoning_content=None, + ) + msg_map[aid] = msgs + # flatten & persist + self.message_manager.create_many_messages([m for msgs in msg_map.values() for m in msgs], actor=self.actor) + return msg_map + + def _mark_steps_complete(self, batch_id: str, agent_ids: List[str]) -> None: + updates = [StepStatusUpdateInfo(batch_id=batch_id, agent_id=aid, step_status=AgentStepStatus.completed) for aid in agent_ids] + self.batch_manager.bulk_update_batch_items_step_status_by_agent(updates) + + def _prepare_next_iteration( + self, + exec_results: Sequence[Tuple[str, Tuple[str, bool]]], + ctx: _ResumeContext, + msg_map: Dict[str, List[Message]], + ) -> Tuple[List[LettaBatchRequest], Dict[str, AgentStepState]]: + # who continues? + continues = [aid for aid, cont in ctx.should_continue_map.items() if cont] + + success_flag_map = {aid: flag for aid, (_res, flag) in exec_results} + + batch_reqs: List[LettaBatchRequest] = [] + for aid in continues: + heartbeat = create_heartbeat_system_message( + agent_id=aid, + model=ctx.agent_state_map[aid].llm_config.model, + function_call_success=success_flag_map[aid], + actor=self.actor, + ) + batch_reqs.append( + LettaBatchRequest( + agent_id=aid, messages=[MessageCreate.model_validate(heartbeat.model_dump(include={"role", "content", "name", "otid"}))] + ) + ) + + # extend in‑context ids when necessary + for aid, new_msgs in msg_map.items(): + ast = ctx.agent_state_map[aid] + if not ast.message_buffer_autoclear: + self.agent_manager.set_in_context_messages( + agent_id=aid, + message_ids=ast.message_ids + [m.id for m in new_msgs], + actor=self.actor, + ) + + # bump step number + step_map = { + item.agent_id: item.step_state.model_copy(update={"step_number": item.step_state.step_number + 1}) for item in ctx.batch_items + } + return batch_reqs, step_map + + def _create_tool_call_messages( + self, + agent_state: AgentState, + tool_call_name: str, + tool_call_args: Dict[str, Any], + tool_exec_result: str, + success_flag: bool, + reasoning_content: Optional[List[Union[TextContent, ReasoningContent, RedactedReasoningContent, OmittedReasoningContent]]] = None, + ) -> List[Message]: + tool_call_id = f"call_{uuid.uuid4().hex[:8]}" + + tool_call_messages = create_letta_messages_from_llm_response( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + function_name=tool_call_name, + function_arguments=tool_call_args, + tool_call_id=tool_call_id, + function_call_success=success_flag, + function_response=tool_exec_result, + actor=self.actor, + add_heartbeat_request_system_message=False, + reasoning_content=reasoning_content, + pre_computed_assistant_message_id=None, + pre_computed_tool_message_id=None, + ) + + return tool_call_messages + + # TODO: This is doing a lot of dict passing + # TODO: Make the passing here typed + def _extract_tool_call_and_decide_continue( + self, tool_call: OpenAIToolCall, agent_step_state: AgentStepState + ) -> Tuple[str, Dict[str, Any], bool]: + """ + Now that streaming is done, handle the final AI response. + This might yield additional SSE tokens if we do stalling. + At the end, set self._continue_execution accordingly. + """ + tool_call_name = tool_call.function.name + tool_call_args_str = tool_call.function.arguments + + try: + tool_args = json.loads(tool_call_args_str) + except json.JSONDecodeError: + logger.warning(f"Failed to JSON decode tool call argument string: {tool_call_args_str}") + tool_args = {} + + # Get request heartbeats and coerce to bool + request_heartbeat = tool_args.pop("request_heartbeat", False) + # Pre-emptively pop out inner_thoughts + tool_args.pop(INNER_THOUGHTS_KWARG, "") + + # So this is necessary, because sometimes non-structured outputs makes mistakes + if isinstance(request_heartbeat, str): + request_heartbeat = request_heartbeat.lower() == "true" + else: + request_heartbeat = bool(request_heartbeat) + + continue_stepping = request_heartbeat + tool_rules_solver = agent_step_state.tool_rules_solver + tool_rules_solver.register_tool_call(tool_name=tool_call_name) + if tool_rules_solver.is_terminal_tool(tool_name=tool_call_name): + continue_stepping = False + elif tool_rules_solver.has_children_tools(tool_name=tool_call_name): + continue_stepping = True + elif tool_rules_solver.is_continue_tool(tool_name=tool_call_name): + continue_stepping = True + + step_count = agent_step_state.step_number + if step_count >= self.max_steps: + logger.warning("Hit max steps, stopping agent loop prematurely.") + continue_stepping = False + + return tool_call_name, tool_args, continue_stepping + + def _prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]: tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}] valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] - def get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: + def _get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( input_messages, agent_state, self.message_manager, self.actor ) diff --git a/letta/jobs/llm_batch_job_polling.py b/letta/jobs/llm_batch_job_polling.py index 79788aa5..479f33b2 100644 --- a/letta/jobs/llm_batch_job_polling.py +++ b/letta/jobs/llm_batch_job_polling.py @@ -3,7 +3,7 @@ import datetime from typing import List from letta.jobs.helpers import map_anthropic_batch_job_status_to_job_status, map_anthropic_individual_batch_item_status_to_job_status -from letta.jobs.types import BatchId, BatchPollingResult, ItemUpdateInfo +from letta.jobs.types import BatchPollingResult, ItemUpdateInfo from letta.log import get_logger from letta.schemas.enums import JobStatus, ProviderType from letta.schemas.llm_batch_job import LLMBatchJob @@ -49,14 +49,14 @@ async def fetch_batch_status(server: SyncServer, batch_job: LLMBatchJob) -> Batc response = await server.anthropic_async_client.beta.messages.batches.retrieve(batch_id_str) new_status = map_anthropic_batch_job_status_to_job_status(response.processing_status) logger.debug(f"[Poll BatchJob] Batch {batch_job.id}: provider={response.processing_status} → internal={new_status}") - return (batch_job.id, new_status, response) + return BatchPollingResult(batch_job.id, new_status, response) except Exception as e: - logger.warning(f"[Poll BatchJob] Batch {batch_job.id}: failed to retrieve {batch_id_str}: {e}") + logger.error(f"[Poll BatchJob] Batch {batch_job.id}: failed to retrieve {batch_id_str}: {e}") # We treat a retrieval error as still running to try again next cycle - return (batch_job.id, JobStatus.running, None) + return BatchPollingResult(batch_job.id, JobStatus.running, None) -async def fetch_batch_items(server: SyncServer, batch_id: BatchId, batch_resp_id: str) -> List[ItemUpdateInfo]: +async def fetch_batch_items(server: SyncServer, batch_id: str, batch_resp_id: str) -> List[ItemUpdateInfo]: """ Fetch individual item results for a completed batch. @@ -73,7 +73,7 @@ async def fetch_batch_items(server: SyncServer, batch_id: BatchId, batch_resp_id async for item_result in server.anthropic_async_client.beta.messages.batches.results(batch_resp_id): # Here, custom_id should be the agent_id item_status = map_anthropic_individual_batch_item_status_to_job_status(item_result) - updates.append((batch_id, item_result.custom_id, item_status, item_result)) + updates.append(ItemUpdateInfo(batch_id, item_result.custom_id, item_status, item_result)) logger.info(f"[Poll BatchJob] Fetched {len(updates)} item updates for batch {batch_id}.") except Exception as e: logger.error(f"[Poll BatchJob] Error fetching item updates for batch {batch_id}: {e}") @@ -193,7 +193,7 @@ async def poll_running_llm_batches(server: "SyncServer") -> None: # 6. Bulk update all items for newly completed batch(es) if item_updates: metrics.updated_items_count = len(item_updates) - server.batch_manager.bulk_update_batch_items_by_agent(item_updates) + server.batch_manager.bulk_update_batch_items_results_by_agent(item_updates) else: logger.info("[Poll BatchJob] No item-level updates needed.") diff --git a/letta/jobs/types.py b/letta/jobs/types.py index 217ee85a..854e0fef 100644 --- a/letta/jobs/types.py +++ b/letta/jobs/types.py @@ -1,10 +1,30 @@ -from typing import Optional, Tuple +from typing import NamedTuple, Optional from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse -from letta.schemas.enums import JobStatus +from letta.schemas.enums import AgentStepStatus, JobStatus -BatchId = str -AgentId = str -BatchPollingResult = Tuple[BatchId, JobStatus, Optional[BetaMessageBatch]] -ItemUpdateInfo = Tuple[BatchId, AgentId, JobStatus, BetaMessageBatchIndividualResponse] + +class BatchPollingResult(NamedTuple): + batch_id: str + request_status: JobStatus + batch_response: Optional[BetaMessageBatch] + + +class ItemUpdateInfo(NamedTuple): + batch_id: str + agent_id: str + request_status: JobStatus + batch_request_result: Optional[BetaMessageBatchIndividualResponse] + + +class StepStatusUpdateInfo(NamedTuple): + batch_id: str + agent_id: str + step_status: AgentStepStatus + + +class RequestStatusUpdateInfo(NamedTuple): + batch_id: str + agent_id: str + request_status: JobStatus diff --git a/letta/llm_api/anthropic_client.py b/letta/llm_api/anthropic_client.py index 89e437df..cd9c0815 100644 --- a/letta/llm_api/anthropic_client.py +++ b/letta/llm_api/anthropic_client.py @@ -4,7 +4,7 @@ from typing import Dict, List, Optional, Union import anthropic from anthropic import AsyncStream -from anthropic.types import Message as AnthropicMessage +from anthropic.types.beta import BetaMessage as AnthropicMessage from anthropic.types.beta import BetaRawMessageStreamEvent from anthropic.types.beta.message_create_params import MessageCreateParamsNonStreaming from anthropic.types.beta.messages import BetaMessageBatch @@ -304,6 +304,8 @@ class AnthropicClient(LLMClientBase): return super().handle_llm_error(e) + # TODO: Input messages doesn't get used here + # TODO: Clean up this interface def convert_response_to_chat_completion( self, response_data: dict, diff --git a/letta/orm/message.py b/letta/orm/message.py index 9f678bb1..cba74000 100644 --- a/letta/orm/message.py +++ b/letta/orm/message.py @@ -80,7 +80,7 @@ class Message(SqlalchemyBase, OrganizationMixin, AgentMixin): @event.listens_for(Message, "before_insert") def set_sequence_id_for_sqlite(mapper, connection, target): # TODO: Kind of hacky, used to detect if we are using sqlite or not - if not settings.pg_uri: + if not settings.letta_pg_uri_no_default: session = Session.object_session(target) if not hasattr(session, "_sequence_id_counter"): diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index ee85d988..92bb6965 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -390,7 +390,14 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): @classmethod @handle_db_timeout - def batch_create(cls, items: List["SqlalchemyBase"], db_session: "Session", actor: Optional["User"] = None) -> List["SqlalchemyBase"]: + def batch_create( + cls, + items: List["SqlalchemyBase"], + db_session: "Session", + actor: Optional["User"] = None, + batch_size: int = 1000, # TODO: Make this a configurable setting + requery: bool = True, + ) -> List["SqlalchemyBase"]: """ Create multiple records in a single transaction for better performance. @@ -398,6 +405,8 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): items: List of model instances to create db_session: SQLAlchemy session actor: Optional user performing the action + batch_size: Maximum number of items to process in a single batch + requery: Whether to requery the objects after creation Returns: List of created model instances @@ -407,30 +416,47 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): if not items: return [] - # Set created/updated by fields if actor is provided - if actor: - for item in items: - item._set_created_and_updated_by_fields(actor.id) + result_items = [] - try: - with db_session as session: - session.add_all(items) - session.flush() # Flush to generate IDs but don't commit yet + # Process in batches to avoid memory issues with very large sets + for i in range(0, len(items), batch_size): + batch = items[i : i + batch_size] - # Collect IDs to fetch the complete objects after commit - item_ids = [item.id for item in items] + # Set created/updated by fields if actor is provided + if actor: + for item in batch: + item._set_created_and_updated_by_fields(actor.id) - session.commit() + try: + with db_session as session: + session.add_all(batch) + session.flush() # Flush to generate IDs but don't commit yet - # Re-query the objects to get them with relationships loaded - query = select(cls).where(cls.id.in_(item_ids)) - if hasattr(cls, "created_at"): - query = query.order_by(cls.created_at) + # Collect IDs to fetch the complete objects after commit + item_ids = [item.id for item in batch] - return list(session.execute(query).scalars()) + session.commit() - except (DBAPIError, IntegrityError) as e: - cls._handle_dbapi_error(e) + if requery: + # Re-query the objects to get them with relationships loaded + query = select(cls).where(cls.id.in_(item_ids)) + if hasattr(cls, "created_at"): + query = query.order_by(cls.created_at) + + batch_result = list(session.execute(query).scalars()) + else: + # Use the objects we already have in memory + batch_result = batch + + result_items.extend(batch_result) + + except (DBAPIError, IntegrityError) as e: + logger.error(f"Database error during batch creation: {e}") + # Log which items we were processing when the error occurred + logger.error(f"Failed batch starting at index {i} of {len(items)}") + cls._handle_dbapi_error(e) + + return result_items @handle_db_timeout def delete(self, db_session: "Session", actor: Optional["User"] = None) -> "SqlalchemyBase": diff --git a/letta/schemas/enums.py b/letta/schemas/enums.py index f566908b..f4d1aef6 100644 --- a/letta/schemas/enums.py +++ b/letta/schemas/enums.py @@ -33,6 +33,7 @@ class JobStatus(str, Enum): failed = "failed" pending = "pending" cancelled = "cancelled" + expired = "expired" class AgentStepStatus(str, Enum): @@ -41,7 +42,8 @@ class AgentStepStatus(str, Enum): """ paused = "paused" - running = "running" + resumed = "resumed" + completed = "completed" class MessageStreamStatus(str, Enum): diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index 9a179261..662f0f8f 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -171,5 +171,6 @@ LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStat class LettaBatchResponse(BaseModel): batch_id: str = Field(..., description="A unique identifier for this batch request.") status: JobStatus = Field(..., description="The current status of the batch request.") + agent_count: int = Field(..., description="The number of agents in the batch request.") last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.") created_at: datetime = Field(..., description="The timestamp when the batch request was created.") diff --git a/letta/schemas/llm_batch_job.py b/letta/schemas/llm_batch_job.py index 7f21fc27..5e178d32 100644 --- a/letta/schemas/llm_batch_job.py +++ b/letta/schemas/llm_batch_job.py @@ -19,7 +19,7 @@ class LLMBatchItem(OrmMetadataBase, validate_assignment=True): __id_prefix__ = "batch_item" - id: str = Field(..., description="The id of the batch item. Assigned by the database.") + id: Optional[str] = Field(None, description="The id of the batch item. Assigned by the database.") batch_id: str = Field(..., description="The id of the parent LLM batch job this item belongs to.") agent_id: str = Field(..., description="The id of the agent associated with this LLM request.") @@ -42,7 +42,7 @@ class LLMBatchJob(OrmMetadataBase, validate_assignment=True): __id_prefix__ = "batch_req" - id: str = Field(..., description="The id of the batch job. Assigned by the database.") + id: Optional[str] = Field(None, description="The id of the batch job. Assigned by the database.") status: JobStatus = Field(..., description="The current status of the batch (e.g., created, in_progress, done).") llm_provider: ProviderType = Field(..., description="The LLM provider used for the batch (e.g., anthropic, openai).") diff --git a/letta/server/rest_api/utils.py b/letta/server/rest_api/utils.py index 2daa5d3e..7fd49e2e 100644 --- a/letta/server/rest_api/utils.py +++ b/letta/server/rest_api/utils.py @@ -225,22 +225,34 @@ def create_letta_messages_from_llm_response( messages.append(tool_message) if add_heartbeat_request_system_message: - text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE - heartbeat_system_message = Message( - role=MessageRole.user, - content=[TextContent(text=get_heartbeat(text_content))], - organization_id=actor.organization_id, - agent_id=agent_id, - model=model, - tool_calls=[], - tool_call_id=None, - created_at=get_utc_time(), + heartbeat_system_message = create_heartbeat_system_message( + agent_id=agent_id, model=model, function_call_success=function_call_success, actor=actor ) messages.append(heartbeat_system_message) return messages +def create_heartbeat_system_message( + agent_id: str, + model: str, + function_call_success: bool, + actor: User, +) -> Message: + text_content = REQ_HEARTBEAT_MESSAGE if function_call_success else FUNC_FAILED_HEARTBEAT_MESSAGE + heartbeat_system_message = Message( + role=MessageRole.user, + content=[TextContent(text=get_heartbeat(text_content))], + organization_id=actor.organization_id, + agent_id=agent_id, + model=model, + tool_calls=[], + tool_call_id=None, + created_at=get_utc_time(), + ) + return heartbeat_system_message + + def create_assistant_messages_from_openai_response( response_text: str, agent_id: str, diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index e158ba5c..bb56abb6 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -423,7 +423,8 @@ class AgentManager: query = _apply_tag_filter(query, tags, match_all_tags) query = _apply_pagination(query, before, after, session, ascending=ascending) - query = query.limit(limit) + if limit: + query = query.limit(limit) agents = session.execute(query).scalars().all() return [agent.to_pydantic(include_relationships=include_relationships) for agent in agents] diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index b903b955..a3ee9611 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -1,10 +1,10 @@ import datetime -from typing import List, Optional +from typing import Any, Dict, List, Optional, Tuple from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchIndividualResponse from sqlalchemy import tuple_ -from letta.jobs.types import BatchPollingResult, ItemUpdateInfo +from letta.jobs.types import BatchPollingResult, ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.log import get_logger from letta.orm.llm_batch_items import LLMBatchItem from letta.orm.llm_batch_job import LLMBatchJob @@ -140,6 +140,39 @@ class LLMBatchManager: item.create(session, actor=actor) return item.to_pydantic() + @enforce_types + def create_batch_items_bulk(self, llm_batch_items: List[PydanticLLMBatchItem], actor: PydanticUser) -> List[PydanticLLMBatchItem]: + """ + Create multiple batch items in bulk for better performance. + + Args: + llm_batch_items: List of batch items to create + actor: User performing the action + + Returns: + List of created batch items as Pydantic models + """ + with self.session_maker() as session: + # Convert Pydantic models to ORM objects + orm_items = [] + for item in llm_batch_items: + orm_item = LLMBatchItem( + batch_id=item.batch_id, + agent_id=item.agent_id, + llm_config=item.llm_config, + request_status=item.request_status, + step_status=item.step_status, + step_state=item.step_state, + organization_id=actor.organization_id, + ) + orm_items.append(orm_item) + + # Use the batch_create method to create all items at once + created_items = LLMBatchItem.batch_create(orm_items, session, actor=actor) + + # Convert back to Pydantic models + return [item.to_pydantic() for item in created_items] + @enforce_types def get_batch_item_by_id(self, item_id: str, actor: PydanticUser) -> PydanticLLMBatchItem: """Retrieve a single batch item by ID.""" @@ -172,6 +205,7 @@ class LLMBatchManager: return item.update(db_session=session, actor=actor).to_pydantic() + # TODO: Maybe make this paginated? @enforce_types def list_batch_items( self, @@ -192,56 +226,86 @@ class LLMBatchManager: results = query.all() return [item.to_pydantic() for item in results] - def bulk_update_batch_items_by_agent( + def bulk_update_batch_items( self, - updates: List[ItemUpdateInfo], + batch_id_agent_id_pairs: List[Tuple[str, str]], + field_updates: List[Dict[str, Any]], ) -> None: """ - Efficiently update LLMBatchItem rows by (batch_id, agent_id). + Efficiently update multiple LLMBatchItem rows by (batch_id, agent_id) pairs. Args: - updates: List of tuples: - (batch_id, agent_id, new_request_status, batch_request_result) + batch_id_agent_id_pairs: List of (batch_id, agent_id) tuples identifying items to update + field_updates: List of dictionaries containing the fields to update for each item """ + if not batch_id_agent_id_pairs or not field_updates: + return + + if len(batch_id_agent_id_pairs) != len(field_updates): + raise ValueError("batch_id_agent_id_pairs and field_updates must have the same length") + with self.session_maker() as session: - # For bulk_update_mappings, we need the primary key of each row - # So we must map (batch_id, agent_id) → actual PK (id) - # We'll do it in one DB query using the (batch_id, agent_id) sets - - # 1. Gather the pairs - key_pairs = [(b_id, a_id) for (b_id, a_id, *_rest) in updates] - - # 2. Query items in a single step + # Lookup primary keys items = ( session.query(LLMBatchItem.id, LLMBatchItem.batch_id, LLMBatchItem.agent_id) - .filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(key_pairs)) + .filter(tuple_(LLMBatchItem.batch_id, LLMBatchItem.agent_id).in_(batch_id_agent_id_pairs)) .all() ) + pair_to_pk = {(b, a): id for id, b, a in items} - # Build a map from (batch_id, agent_id) → PK id - pair_to_pk = {} - for row_id, row_batch_id, row_agent_id in items: - pair_to_pk[(row_batch_id, row_agent_id)] = row_id - - # 3. Construct mappings for the PK-based bulk update mappings = [] - for batch_id, agent_id, new_status, new_result in updates: + for (batch_id, agent_id), fields in zip(batch_id_agent_id_pairs, field_updates): pk_id = pair_to_pk.get((batch_id, agent_id)) if not pk_id: - # Nonexistent or mismatch → skip continue - mappings.append( - { - "id": pk_id, - "request_status": new_status, - "batch_request_result": new_result, - } - ) + + update_fields = fields.copy() + update_fields["id"] = pk_id + mappings.append(update_fields) if mappings: session.bulk_update_mappings(LLMBatchItem, mappings) session.commit() + @enforce_types + def bulk_update_batch_items_results_by_agent( + self, + updates: List[ItemUpdateInfo], + ) -> None: + """Update request status and batch results for multiple batch items.""" + batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + field_updates = [ + { + "request_status": update.request_status, + "batch_request_result": update.batch_request_result, + } + for update in updates + ] + + self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + + @enforce_types + def bulk_update_batch_items_step_status_by_agent( + self, + updates: List[StepStatusUpdateInfo], + ) -> None: + """Update step status for multiple batch items.""" + batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + field_updates = [{"step_status": update.step_status} for update in updates] + + self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + + @enforce_types + def bulk_update_batch_items_request_status_by_agent( + self, + updates: List[RequestStatusUpdateInfo], + ) -> None: + """Update request status for multiple batch items.""" + batch_id_agent_id_pairs = [(update.batch_id, update.agent_id) for update in updates] + field_updates = [{"request_status": update.request_status} for update in updates] + + self.bulk_update_batch_items(batch_id_agent_id_pairs, field_updates) + @enforce_types def delete_batch_item(self, item_id: str, actor: PydanticUser) -> None: """Hard delete a batch item by ID.""" diff --git a/letta/services/tool_executor/tool_execution_sandbox.py b/letta/services/tool_executor/tool_execution_sandbox.py index 4a82fc13..375f6b00 100644 --- a/letta/services/tool_executor/tool_execution_sandbox.py +++ b/letta/services/tool_executor/tool_execution_sandbox.py @@ -148,7 +148,6 @@ class ToolExecutionSandbox: temp_file.write(code) temp_file.flush() temp_file_path = temp_file.name - try: if local_configs.use_venv: return self.run_local_dir_sandbox_venv(sbx_config, env, temp_file_path) diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index 87f16c57..dbbd632a 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -10,7 +10,6 @@ from letta.schemas.agent import AgentState from letta.schemas.sandbox_config import SandboxConfig, SandboxRunResult from letta.schemas.tool import Tool from letta.services.helpers.tool_execution_helper import add_imports_and_pydantic_schemas_for_args -from letta.services.organization_manager import OrganizationManager from letta.services.sandbox_config_manager import SandboxConfigManager from letta.services.tool_manager import ToolManager @@ -33,8 +32,6 @@ class AsyncToolSandboxBase(ABC): self.tool_name = tool_name self.args = args self.user = user - self.organization = OrganizationManager().get_organization_by_id(self.user.organization_id) - self.privileged_tools = self.organization.privileged_tools self.tool = tool_object or ToolManager().get_tool_by_name(tool_name=tool_name, actor=self.user) if self.tool is None: diff --git a/poetry.lock b/poetry.lock index 1a8e8e6d..e94f5c25 100644 --- a/poetry.lock +++ b/poetry.lock @@ -1,4 +1,4 @@ -# This file is automatically @generated by Poetry 1.8.5 and should not be changed by hand. +# This file is automatically @generated by Poetry 1.8.3 and should not be changed by hand. [[package]] name = "aiohappyeyeballs" @@ -114,6 +114,21 @@ yarl = ">=1.17.0,<2.0" [package.extras] speedups = ["Brotli", "aiodns (>=3.2.0)", "brotlicffi"] +[[package]] +name = "aiomultiprocess" +version = "0.9.1" +description = "AsyncIO version of the standard multiprocessing module" +optional = false +python-versions = ">=3.8" +files = [ + {file = "aiomultiprocess-0.9.1-py3-none-any.whl", hash = "sha256:3a7b3bb3c38dbfb4d9d1194ece5934b6d32cf0280e8edbe64a7d215bba1322c6"}, + {file = "aiomultiprocess-0.9.1.tar.gz", hash = "sha256:f0231dbe0291e15325d7896ebeae0002d95a4f2675426ca05eb35f24c60e495b"}, +] + +[package.extras] +dev = ["attribution (==1.7.1)", "black (==24.4.0)", "coverage (==7.4.4)", "flake8 (==7.0.0)", "flake8-bugbear (==24.4.21)", "flit (==3.9.0)", "mypy (==1.9.0)", "usort (==1.0.8.post1)", "uvloop (==0.19.0)"] +docs = ["sphinx (==7.3.7)", "sphinx-mdinclude (==0.6.0)"] + [[package]] name = "aiosignal" version = "1.3.2" @@ -548,10 +563,6 @@ files = [ {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_i686.whl", hash = "sha256:a37b8f0391212d29b3a91a799c8e4a2855e0576911cdfb2515487e30e322253d"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_ppc64le.whl", hash = "sha256:e84799f09591700a4154154cab9787452925578841a94321d5ee8fb9a9a328f0"}, {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_1_x86_64.whl", hash = "sha256:f66b5337fa213f1da0d9000bc8dc0cb5b896b726eefd9c6046f699b169c41b9e"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_aarch64.whl", hash = "sha256:5dab0844f2cf82be357a0eb11a9087f70c5430b2c241493fc122bb6f2bb0917c"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_i686.whl", hash = "sha256:e4fe605b917c70283db7dfe5ada75e04561479075761a0b3866c081d035b01c1"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_ppc64le.whl", hash = "sha256:1e9a65b5736232e7a7f91ff3d02277f11d339bf34099a56cdab6a8b3410a02b2"}, - {file = "Brotli-1.1.0-cp310-cp310-musllinux_1_2_x86_64.whl", hash = "sha256:58d4b711689366d4a03ac7957ab8c28890415e267f9b6589969e74b6e42225ec"}, {file = "Brotli-1.1.0-cp310-cp310-win32.whl", hash = "sha256:be36e3d172dc816333f33520154d708a2657ea63762ec16b62ece02ab5e4daf2"}, {file = "Brotli-1.1.0-cp310-cp310-win_amd64.whl", hash = "sha256:0c6244521dda65ea562d5a69b9a26120769b7a9fb3db2fe9545935ed6735b128"}, {file = "Brotli-1.1.0-cp311-cp311-macosx_10_9_universal2.whl", hash = "sha256:a3daabb76a78f829cafc365531c972016e4aa8d5b4bf60660ad8ecee19df7ccc"}, @@ -564,14 +575,8 @@ files = [ {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_i686.whl", hash = "sha256:19c116e796420b0cee3da1ccec3b764ed2952ccfcc298b55a10e5610ad7885f9"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_ppc64le.whl", hash = "sha256:510b5b1bfbe20e1a7b3baf5fed9e9451873559a976c1a78eebaa3b86c57b4265"}, {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_1_x86_64.whl", hash = "sha256:a1fd8a29719ccce974d523580987b7f8229aeace506952fa9ce1d53a033873c8"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_aarch64.whl", hash = "sha256:c247dd99d39e0338a604f8c2b3bc7061d5c2e9e2ac7ba9cc1be5a69cb6cd832f"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_i686.whl", hash = "sha256:1b2c248cd517c222d89e74669a4adfa5577e06ab68771a529060cf5a156e9757"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_ppc64le.whl", hash = "sha256:2a24c50840d89ded6c9a8fdc7b6ed3692ed4e86f1c4a4a938e1e92def92933e0"}, - {file = "Brotli-1.1.0-cp311-cp311-musllinux_1_2_x86_64.whl", hash = "sha256:f31859074d57b4639318523d6ffdca586ace54271a73ad23ad021acd807eb14b"}, {file = "Brotli-1.1.0-cp311-cp311-win32.whl", hash = "sha256:39da8adedf6942d76dc3e46653e52df937a3c4d6d18fdc94a7c29d263b1f5b50"}, {file = "Brotli-1.1.0-cp311-cp311-win_amd64.whl", hash = "sha256:aac0411d20e345dc0920bdec5548e438e999ff68d77564d5e9463a7ca9d3e7b1"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_universal2.whl", hash = "sha256:32d95b80260d79926f5fab3c41701dbb818fde1c9da590e77e571eefd14abe28"}, - {file = "Brotli-1.1.0-cp312-cp312-macosx_10_13_x86_64.whl", hash = "sha256:b760c65308ff1e462f65d69c12e4ae085cff3b332d894637f6273a12a482d09f"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_universal2.whl", hash = "sha256:316cc9b17edf613ac76b1f1f305d2a748f1b976b033b049a6ecdfd5612c70409"}, {file = "Brotli-1.1.0-cp312-cp312-macosx_10_9_x86_64.whl", hash = "sha256:caf9ee9a5775f3111642d33b86237b05808dafcd6268faa492250e9b78046eb2"}, {file = "Brotli-1.1.0-cp312-cp312-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:70051525001750221daa10907c77830bc889cb6d865cc0b813d9db7fefc21451"}, @@ -582,24 +587,8 @@ files = [ {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_i686.whl", hash = "sha256:4093c631e96fdd49e0377a9c167bfd75b6d0bad2ace734c6eb20b348bc3ea180"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_ppc64le.whl", hash = "sha256:7e4c4629ddad63006efa0ef968c8e4751c5868ff0b1c5c40f76524e894c50248"}, {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_1_x86_64.whl", hash = "sha256:861bf317735688269936f755fa136a99d1ed526883859f86e41a5d43c61d8966"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_aarch64.whl", hash = "sha256:87a3044c3a35055527ac75e419dfa9f4f3667a1e887ee80360589eb8c90aabb9"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_i686.whl", hash = "sha256:c5529b34c1c9d937168297f2c1fde7ebe9ebdd5e121297ff9c043bdb2ae3d6fb"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_ppc64le.whl", hash = "sha256:ca63e1890ede90b2e4454f9a65135a4d387a4585ff8282bb72964fab893f2111"}, - {file = "Brotli-1.1.0-cp312-cp312-musllinux_1_2_x86_64.whl", hash = "sha256:e79e6520141d792237c70bcd7a3b122d00f2613769ae0cb61c52e89fd3443839"}, {file = "Brotli-1.1.0-cp312-cp312-win32.whl", hash = "sha256:5f4d5ea15c9382135076d2fb28dde923352fe02951e66935a9efaac8f10e81b0"}, {file = "Brotli-1.1.0-cp312-cp312-win_amd64.whl", hash = "sha256:906bc3a79de8c4ae5b86d3d75a8b77e44404b0f4261714306e3ad248d8ab0951"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_universal2.whl", hash = "sha256:8bf32b98b75c13ec7cf774164172683d6e7891088f6316e54425fde1efc276d5"}, - {file = "Brotli-1.1.0-cp313-cp313-macosx_10_13_x86_64.whl", hash = "sha256:7bc37c4d6b87fb1017ea28c9508b36bbcb0c3d18b4260fcdf08b200c74a6aee8"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:3c0ef38c7a7014ffac184db9e04debe495d317cc9c6fb10071f7fefd93100a4f"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:91d7cc2a76b5567591d12c01f019dd7afce6ba8cba6571187e21e2fc418ae648"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_17_x86_64.manylinux2014_x86_64.whl", hash = "sha256:a93dde851926f4f2678e704fadeb39e16c35d8baebd5252c9fd94ce8ce68c4a0"}, - {file = "Brotli-1.1.0-cp313-cp313-manylinux_2_5_i686.manylinux1_i686.manylinux_2_17_i686.manylinux2014_i686.whl", hash = "sha256:f0db75f47be8b8abc8d9e31bc7aad0547ca26f24a54e6fd10231d623f183d089"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_aarch64.whl", hash = "sha256:6967ced6730aed543b8673008b5a391c3b1076d834ca438bbd70635c73775368"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_i686.whl", hash = "sha256:7eedaa5d036d9336c95915035fb57422054014ebdeb6f3b42eac809928e40d0c"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_ppc64le.whl", hash = "sha256:d487f5432bf35b60ed625d7e1b448e2dc855422e87469e3f450aa5552b0eb284"}, - {file = "Brotli-1.1.0-cp313-cp313-musllinux_1_2_x86_64.whl", hash = "sha256:832436e59afb93e1836081a20f324cb185836c617659b07b129141a8426973c7"}, - {file = "Brotli-1.1.0-cp313-cp313-win32.whl", hash = "sha256:43395e90523f9c23a3d5bdf004733246fba087f2948f87ab28015f12359ca6a0"}, - {file = "Brotli-1.1.0-cp313-cp313-win_amd64.whl", hash = "sha256:9011560a466d2eb3f5a6e4929cf4a09be405c64154e12df0dd72713f6500e32b"}, {file = "Brotli-1.1.0-cp36-cp36m-macosx_10_9_x86_64.whl", hash = "sha256:a090ca607cbb6a34b0391776f0cb48062081f5f60ddcce5d11838e67a01928d1"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_aarch64.manylinux2014_aarch64.whl", hash = "sha256:2de9d02f5bda03d27ede52e8cfe7b865b066fa49258cbab568720aa5be80a47d"}, {file = "Brotli-1.1.0-cp36-cp36m-manylinux_2_17_ppc64le.manylinux2014_ppc64le.whl", hash = "sha256:2333e30a5e00fe0fe55903c8832e08ee9c3b1382aacf4db26664a16528d51b4b"}, @@ -609,10 +598,6 @@ files = [ {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_i686.whl", hash = "sha256:fd5f17ff8f14003595ab414e45fce13d073e0762394f957182e69035c9f3d7c2"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_ppc64le.whl", hash = "sha256:069a121ac97412d1fe506da790b3e69f52254b9df4eb665cd42460c837193354"}, {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_1_x86_64.whl", hash = "sha256:e93dfc1a1165e385cc8239fab7c036fb2cd8093728cbd85097b284d7b99249a2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_aarch64.whl", hash = "sha256:aea440a510e14e818e67bfc4027880e2fb500c2ccb20ab21c7a7c8b5b4703d75"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_i686.whl", hash = "sha256:6974f52a02321b36847cd19d1b8e381bf39939c21efd6ee2fc13a28b0d99348c"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_ppc64le.whl", hash = "sha256:a7e53012d2853a07a4a79c00643832161a910674a893d296c9f1259859a289d2"}, - {file = "Brotli-1.1.0-cp36-cp36m-musllinux_1_2_x86_64.whl", hash = "sha256:d7702622a8b40c49bffb46e1e3ba2e81268d5c04a34f460978c6b5517a34dd52"}, {file = "Brotli-1.1.0-cp36-cp36m-win32.whl", hash = "sha256:a599669fd7c47233438a56936988a2478685e74854088ef5293802123b5b2460"}, {file = "Brotli-1.1.0-cp36-cp36m-win_amd64.whl", hash = "sha256:d143fd47fad1db3d7c27a1b1d66162e855b5d50a89666af46e1679c496e8e579"}, {file = "Brotli-1.1.0-cp37-cp37m-macosx_10_9_x86_64.whl", hash = "sha256:11d00ed0a83fa22d29bc6b64ef636c4552ebafcef57154b4ddd132f5638fbd1c"}, @@ -624,10 +609,6 @@ files = [ {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_i686.whl", hash = "sha256:919e32f147ae93a09fe064d77d5ebf4e35502a8df75c29fb05788528e330fe74"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_ppc64le.whl", hash = "sha256:23032ae55523cc7bccb4f6a0bf368cd25ad9bcdcc1990b64a647e7bbcce9cb5b"}, {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_1_x86_64.whl", hash = "sha256:224e57f6eac61cc449f498cc5f0e1725ba2071a3d4f48d5d9dffba42db196438"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_aarch64.whl", hash = "sha256:cb1dac1770878ade83f2ccdf7d25e494f05c9165f5246b46a621cc849341dc01"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_i686.whl", hash = "sha256:3ee8a80d67a4334482d9712b8e83ca6b1d9bc7e351931252ebef5d8f7335a547"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_ppc64le.whl", hash = "sha256:5e55da2c8724191e5b557f8e18943b1b4839b8efc3ef60d65985bcf6f587dd38"}, - {file = "Brotli-1.1.0-cp37-cp37m-musllinux_1_2_x86_64.whl", hash = "sha256:d342778ef319e1026af243ed0a07c97acf3bad33b9f29e7ae6a1f68fd083e90c"}, {file = "Brotli-1.1.0-cp37-cp37m-win32.whl", hash = "sha256:587ca6d3cef6e4e868102672d3bd9dc9698c309ba56d41c2b9c85bbb903cdb95"}, {file = "Brotli-1.1.0-cp37-cp37m-win_amd64.whl", hash = "sha256:2954c1c23f81c2eaf0b0717d9380bd348578a94161a65b3a2afc62c86467dd68"}, {file = "Brotli-1.1.0-cp38-cp38-macosx_10_9_universal2.whl", hash = "sha256:efa8b278894b14d6da122a72fefcebc28445f2d3f880ac59d46c90f4c13be9a3"}, @@ -640,10 +621,6 @@ files = [ {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_i686.whl", hash = "sha256:1ab4fbee0b2d9098c74f3057b2bc055a8bd92ccf02f65944a241b4349229185a"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_ppc64le.whl", hash = "sha256:141bd4d93984070e097521ed07e2575b46f817d08f9fa42b16b9b5f27b5ac088"}, {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_1_x86_64.whl", hash = "sha256:fce1473f3ccc4187f75b4690cfc922628aed4d3dd013d047f95a9b3919a86596"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_aarch64.whl", hash = "sha256:d2b35ca2c7f81d173d2fadc2f4f31e88cc5f7a39ae5b6db5513cf3383b0e0ec7"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_i686.whl", hash = "sha256:af6fa6817889314555aede9a919612b23739395ce767fe7fcbea9a80bf140fe5"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_ppc64le.whl", hash = "sha256:2feb1d960f760a575dbc5ab3b1c00504b24caaf6986e2dc2b01c09c87866a943"}, - {file = "Brotli-1.1.0-cp38-cp38-musllinux_1_2_x86_64.whl", hash = "sha256:4410f84b33374409552ac9b6903507cdb31cd30d2501fc5ca13d18f73548444a"}, {file = "Brotli-1.1.0-cp38-cp38-win32.whl", hash = "sha256:db85ecf4e609a48f4b29055f1e144231b90edc90af7481aa731ba2d059226b1b"}, {file = "Brotli-1.1.0-cp38-cp38-win_amd64.whl", hash = "sha256:3d7954194c36e304e1523f55d7042c59dc53ec20dd4e9ea9d151f1b62b4415c0"}, {file = "Brotli-1.1.0-cp39-cp39-macosx_10_9_universal2.whl", hash = "sha256:5fb2ce4b8045c78ebbc7b8f3c15062e435d47e7393cc57c25115cfd49883747a"}, @@ -656,10 +633,6 @@ files = [ {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_i686.whl", hash = "sha256:949f3b7c29912693cee0afcf09acd6ebc04c57af949d9bf77d6101ebb61e388c"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_ppc64le.whl", hash = "sha256:89f4988c7203739d48c6f806f1e87a1d96e0806d44f0fba61dba81392c9e474d"}, {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_1_x86_64.whl", hash = "sha256:de6551e370ef19f8de1807d0a9aa2cdfdce2e85ce88b122fe9f6b2b076837e59"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_aarch64.whl", hash = "sha256:0737ddb3068957cf1b054899b0883830bb1fec522ec76b1098f9b6e0f02d9419"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_i686.whl", hash = "sha256:4f3607b129417e111e30637af1b56f24f7a49e64763253bbc275c75fa887d4b2"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_ppc64le.whl", hash = "sha256:6c6e0c425f22c1c719c42670d561ad682f7bfeeef918edea971a79ac5252437f"}, - {file = "Brotli-1.1.0-cp39-cp39-musllinux_1_2_x86_64.whl", hash = "sha256:494994f807ba0b92092a163a0a283961369a65f6cbe01e8891132b7a320e61eb"}, {file = "Brotli-1.1.0-cp39-cp39-win32.whl", hash = "sha256:f0d8a7a6b5983c2496e364b969f0e526647a06b075d034f3297dc66f3b360c64"}, {file = "Brotli-1.1.0-cp39-cp39-win_amd64.whl", hash = "sha256:cdad5b9014d83ca68c25d2e9444e28e967ef16e80f6b436918c700c117a85467"}, {file = "Brotli-1.1.0.tar.gz", hash = "sha256:81de08ac11bcb85841e440c13611c00b67d3bf82698314928d0b676362546724"}, @@ -1083,9 +1056,9 @@ isort = ">=4.3.21,<6.0" jinja2 = ">=2.10.1,<4.0" packaging = "*" pydantic = [ + {version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, {version = ">=1.10.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.11\" and python_version < \"3.12\""}, {version = ">=1.9.0,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.10\" and python_version < \"3.11\""}, - {version = ">=1.10.0,<2.0.0 || >2.0.0,<2.0.1 || >2.0.1,<2.4.0 || >2.4.0,<3.0", extras = ["email"], markers = "python_version >= \"3.12\" and python_version < \"4.0\""}, ] pyyaml = ">=6.0.1" toml = {version = ">=0.10.0,<1.0.0", markers = "python_version < \"3.11\""} @@ -3109,8 +3082,8 @@ psutil = ">=5.9.1" pywin32 = {version = "*", markers = "sys_platform == \"win32\""} pyzmq = ">=25.0.0" requests = [ - {version = ">=2.26.0", markers = "python_version <= \"3.11\""}, {version = ">=2.32.2", markers = "python_version > \"3.11\""}, + {version = ">=2.26.0", markers = "python_version <= \"3.11\""}, ] setuptools = ">=70.0.0" tomli = {version = ">=1.1.0", markers = "python_version < \"3.11\""} @@ -3975,9 +3948,9 @@ files = [ [package.dependencies] numpy = [ + {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, {version = ">=1.23.2", markers = "python_version == \"3.11\""}, {version = ">=1.22.4", markers = "python_version < \"3.11\""}, - {version = ">=1.26.0", markers = "python_version >= \"3.12\""}, ] python-dateutil = ">=2.8.2" pytz = ">=2020.1" @@ -5274,8 +5247,8 @@ grpcio = ">=1.41.0" grpcio-tools = ">=1.41.0" httpx = {version = ">=0.20.0", extras = ["http2"]} numpy = [ - {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, {version = ">=1.26", markers = "python_version == \"3.12\""}, + {version = ">=1.21", markers = "python_version >= \"3.10\" and python_version < \"3.12\""}, ] portalocker = ">=2.7.0,<3.0.0" pydantic = ">=1.10.8" @@ -6873,4 +6846,4 @@ tests = ["wikipedia"] [metadata] lock-version = "2.0" python-versions = "<3.14,>=3.10" -content-hash = "980cd598eaa6fb9f0c7f5f28587bb86cab36a17859a8211fd78c65c8c90755d6" +content-hash = "16f6a0c089d3eeca4107a9191201138340570cd40d52ca21a827ac0189fbf15d" diff --git a/pyproject.toml b/pyproject.toml index ce7d95a0..5b4d70db 100644 --- a/pyproject.toml +++ b/pyproject.toml @@ -89,6 +89,7 @@ datamodel-code-generator = {extras = ["http"], version = "^0.25.0"} mcp = "^1.3.0" firecrawl-py = "^1.15.0" apscheduler = "^3.11.0" +aiomultiprocess = "^0.9.1" [tool.poetry.extras] diff --git a/tests/integration_test_experimental.py b/tests/integration_test_experimental.py index 6ade107b..e690023c 100644 --- a/tests/integration_test_experimental.py +++ b/tests/integration_test_experimental.py @@ -108,12 +108,8 @@ def weather_tool(client): Raises: RuntimeError: If the request to fetch weather data fails. """ - import time - import requests - time.sleep(5) - url = f"https://wttr.in/{location}?format=%C+%t" response = requests.get(url) diff --git a/tests/test_letta_agent_batch.py b/tests/test_letta_agent_batch.py index 9987d583..1e13ec3b 100644 --- a/tests/test_letta_agent_batch.py +++ b/tests/test_letta_agent_batch.py @@ -10,19 +10,29 @@ import os import threading import time from datetime import datetime, timezone -from unittest.mock import patch +from typing import Tuple +from unittest.mock import AsyncMock, Mock, patch import pytest -from anthropic.types.beta.messages import BetaMessageBatch, BetaMessageBatchRequestCounts +from anthropic.types import BetaErrorResponse, BetaRateLimitError +from anthropic.types.beta import BetaMessage +from anthropic.types.beta.messages import ( + BetaMessageBatch, + BetaMessageBatchErroredResult, + BetaMessageBatchIndividualResponse, + BetaMessageBatchRequestCounts, + BetaMessageBatchSucceededResult, +) from dotenv import load_dotenv from letta_client import Letta from letta.agents.letta_agent_batch import LettaAgentBatch from letta.config import LettaConfig from letta.helpers import ToolRulesSolver +from letta.jobs.llm_batch_job_polling import poll_running_llm_batches from letta.orm import Base -from letta.schemas.agent import AgentStepState -from letta.schemas.enums import JobStatus, ProviderType +from letta.schemas.agent import AgentState, AgentStepState +from letta.schemas.enums import AgentStepStatus, JobStatus, ProviderType from letta.schemas.letta_message_content import TextContent from letta.schemas.letta_request import LettaBatchRequest from letta.schemas.message import MessageCreate @@ -50,8 +60,39 @@ EXPECTED_ROLES = ["system", "assistant", "tool", "user", "user"] # --------------------------------------------------------------------------- # +@pytest.fixture(scope="function") +def weather_tool(client): + def get_weather(location: str) -> str: + """ + Fetches the current weather for a given location. + + Parameters: + location (str): The location to get the weather for. + + Returns: + str: A formatted string describing the weather in the given location. + + Raises: + RuntimeError: If the request to fetch weather data fails. + """ + import requests + + url = f"https://wttr.in/{location}?format=%C+%t" + + response = requests.get(url) + if response.status_code == 200: + weather_data = response.text + return f"The weather in {location} is {weather_data}." + else: + raise RuntimeError(f"Failed to get weather data, status code: {response.status_code}") + + tool = client.tools.upsert_from_function(func=get_weather) + # Yield the created tool + yield tool + + @pytest.fixture -def agents(client): +def agents(client, weather_tool): """ Create three test agents with different models. @@ -66,6 +107,7 @@ def agents(client): model=model_name, tags=["test_agents"], embedding="letta/letta-free", + tool_ids=[weather_tool.id], ) return ( @@ -107,6 +149,90 @@ def step_state_map(agents): return {agent.id: AgentStepState(step_number=0, tool_rules_solver=solver) for agent in agents} +def create_batch_response(batch_id: str, processing_status: str = "in_progress") -> BetaMessageBatch: + """Create a dummy BetaMessageBatch with the specified ID and status.""" + now = datetime(2024, 8, 20, 18, 37, 24, 100435, tzinfo=timezone.utc) + return BetaMessageBatch( + id=batch_id, + archived_at=now, + cancel_initiated_at=now, + created_at=now, + ended_at=now, + expires_at=now, + processing_status=processing_status, + request_counts=BetaMessageBatchRequestCounts( + canceled=10, + errored=30, + expired=10, + processing=100, + succeeded=50, + ), + results_url=None, + type="message_batch", + ) + + +def create_successful_response(custom_id: str) -> BetaMessageBatchIndividualResponse: + """Create a dummy successful batch response.""" + return BetaMessageBatchIndividualResponse( + custom_id=custom_id, + result=BetaMessageBatchSucceededResult( + type="succeeded", + message=BetaMessage( + id="msg_abc123", + role="assistant", + type="message", + model="claude-3-5-sonnet-20240620", + content=[{"type": "text", "text": "hi!"}], + usage={"input_tokens": 5, "output_tokens": 7}, + stop_reason="end_turn", + ), + ), + ) + + +def create_complete_tool_response(custom_id: str, model: str, request_heartbeat: bool) -> BetaMessageBatchIndividualResponse: + """Create a dummy successful batch response with a tool call after user asks about weather.""" + return BetaMessageBatchIndividualResponse( + custom_id=custom_id, + result=BetaMessageBatchSucceededResult( + type="succeeded", + message=BetaMessage( + id="msg_abc123", + role="assistant", + type="message", + model=model, + content=[ + {"type": "text", "text": "Let me check the current weather in San Francisco for you."}, + { + "type": "tool_use", + "id": "tu_01234567890123456789012345", + "name": "get_weather", + "input": { + "location": "Las Vegas", + "inner_thoughts": "I should get the weather", + "request_heartbeat": request_heartbeat, + }, + }, + ], + usage={"input_tokens": 7, "output_tokens": 17}, + stop_reason="end_turn", + ), + ), + ) + + +def create_failed_response(custom_id: str) -> BetaMessageBatchIndividualResponse: + """Create a dummy failed batch response with a rate limit error.""" + return BetaMessageBatchIndividualResponse( + custom_id=custom_id, + result=BetaMessageBatchErroredResult( + type="errored", + error=BetaErrorResponse(type="error", error=BetaRateLimitError(type="rate_limit_error", message="Rate limit hit.")), + ), + ) + + @pytest.fixture def dummy_batch_response(): """ @@ -115,14 +241,8 @@ def dummy_batch_response(): Returns: BetaMessageBatch: A dummy batch response """ - now = datetime.now(timezone.utc) - return BetaMessageBatch( - id="msgbatch_test_12345", - created_at=now, - expires_at=now, - processing_status="in_progress", - request_counts=BetaMessageBatchRequestCounts(canceled=0, errored=0, expired=0, processing=3, succeeded=0), - type="message_batch", + return create_batch_response( + batch_id="msgbatch_test_12345", ) @@ -185,14 +305,142 @@ def client(server_url): return Letta(base_url=server_url) +class MockAsyncIterable: + def __init__(self, items): + self.items = items + + def __aiter__(self): + return self + + async def __anext__(self): + if not self.items: + raise StopAsyncIteration + return self.items.pop(0) + + # --------------------------------------------------------------------------- # # Test # --------------------------------------------------------------------------- # +@pytest.mark.asyncio +async def test_resume_step_after_request_happy_path( + disable_e2b_api_key, server, default_user, agents: Tuple[AgentState], batch_requests, step_state_map +): + anthropic_batch_id = "msgbatch_test_12345" + dummy_batch_response = create_batch_response( + batch_id=anthropic_batch_id, + ) + + # 1. Invoke `step_until_request` + with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): + # Create batch runner + batch_runner = LettaAgentBatch( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, + actor=default_user, + ) + + # Run the method under test + pre_resume_response = await batch_runner.step_until_request( + batch_requests=batch_requests, + agent_step_state_mapping=step_state_map, + ) + + # Basic sanity checks (This is tested more thoroughly in `test_step_until_request_prepares_and_submits_batch_correctly` + # Verify batch items + items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user) + assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" + + # 2. Invoke the polling job and mock responses from Anthropic + mock_retrieve = AsyncMock(return_value=create_batch_response(batch_id=pre_resume_response.batch_id, processing_status="ended")) + + with patch.object(server.anthropic_async_client.beta.messages.batches, "retrieve", mock_retrieve): + mock_items = [ + create_complete_tool_response(custom_id=agent.id, model=agent.llm_config.model, request_heartbeat=True) for agent in agents + ] + + # Create the mock for results + mock_results = Mock() + mock_results.return_value = MockAsyncIterable(mock_items.copy()) # Using copy to preserve the original list + + with patch.object(server.anthropic_async_client.beta.messages.batches, "results", mock_results): + await poll_running_llm_batches(server) + + # Verify database records were updated correctly + job = server.batch_manager.get_batch_job_by_id(pre_resume_response.batch_id, actor=default_user) + + # Verify job properties + assert job.status == JobStatus.completed, "Job status should be 'completed'" + + # Verify batch items + items = server.batch_manager.list_batch_items(batch_id=job.id, actor=default_user) + assert len(items) == 3, f"Expected 3 batch items, got {len(items)}" + assert all([item.request_status == JobStatus.completed for item in items]) + + # 3. Call resume_step_after_request + letta_batch_agent = LettaAgentBatch( + message_manager=server.message_manager, + agent_manager=server.agent_manager, + block_manager=server.block_manager, + passage_manager=server.passage_manager, + batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, + actor=default_user, + ) + with patch("letta.llm_api.anthropic_client.AnthropicClient.send_llm_batch_request_async", return_value=dummy_batch_response): + msg_counts_before = {agent.id: server.message_manager.size(actor=default_user, agent_id=agent.id) for agent in agents} + + post_resume_response = await letta_batch_agent.resume_step_after_request(batch_id=pre_resume_response.batch_id) + + # A *new* batch job should have been spawned + assert ( + post_resume_response.batch_id != pre_resume_response.batch_id + ), "resume_step_after_request is expected to enqueue a follow‑up batch job." + assert post_resume_response.status == JobStatus.running + assert post_resume_response.agent_count == 3 + + # New batch‑items should exist, initialised in (created, paused) state + new_items = server.batch_manager.list_batch_items(batch_id=post_resume_response.batch_id, actor=default_user) + assert len(new_items) == 3, f"Expected 3 new batch items, got {len(new_items)}" + assert {i.request_status for i in new_items} == {JobStatus.created} + assert {i.step_status for i in new_items} == {AgentStepStatus.paused} + + # Confirm that tool_rules_solver state was preserved correctly + # Assert every new item's step_state's tool_rules_solver has "get_weather" in the tool_call_history + assert all( + "get_weather" in item.step_state.tool_rules_solver.tool_call_history for item in new_items + ), "Expected 'get_weather' in tool_call_history for all new_items" + # Assert that each new item's step_number was incremented to 1 + assert all(item.step_state.step_number == 1 for item in new_items), "Expected step_number to be incremented to 1 for all new_items" + + # Old items must have been flipped to completed / finished earlier + # (sanity – we already asserted this above, but we keep it close for clarity) + old_items = server.batch_manager.list_batch_items(batch_id=pre_resume_response.batch_id, actor=default_user) + assert {i.request_status for i in old_items} == {JobStatus.completed} + assert {i.step_status for i in old_items} == {AgentStepStatus.completed} + + # Tool‑call side‑effects – each agent gets at least 2 extra messages + for agent in agents: + before = msg_counts_before[agent.id] # captured just before resume + after = server.message_manager.size(actor=default_user, agent_id=agent.id) + assert after - before >= 2, f"Agent {agent.id} should have an assistant tool‑call " f"and tool‑response message persisted." + + # Check that agent states have been properly modified to have extended in-context messages + for agent in agents: + refreshed_agent = server.agent_manager.get_agent_by_id(agent_id=agent.id, actor=default_user) + assert ( + len(refreshed_agent.message_ids) == 6 + ), f"Agent's in-context messages have not been extended, are length: {len(refreshed_agent.message_ids)}" + + @pytest.mark.asyncio async def test_step_until_request_prepares_and_submits_batch_correctly( - server, default_user, agents, batch_requests, step_state_map, dummy_batch_response + disable_e2b_api_key, server, default_user, agents, batch_requests, step_state_map, dummy_batch_response ): """ Test that step_until_request correctly: @@ -258,12 +506,12 @@ async def test_step_until_request_prepares_and_submits_batch_correctly( # Create batch runner batch_runner = LettaAgentBatch( - batch_id="test_batch", message_manager=server.message_manager, agent_manager=server.agent_manager, block_manager=server.block_manager, passage_manager=server.passage_manager, batch_manager=server.batch_manager, + sandbox_config_manager=server.sandbox_config_manager, actor=default_user, ) diff --git a/tests/test_managers.py b/tests/test_managers.py index 1734a1e9..5bca8587 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -26,6 +26,7 @@ from letta.embeddings import embedding_model from letta.functions.functions import derive_openai_json_schema, parse_source_code from letta.functions.mcp_client.types import MCPTool from letta.helpers import ToolRulesSolver +from letta.jobs.types import ItemUpdateInfo, RequestStatusUpdateInfo, StepStatusUpdateInfo from letta.orm import Base, Block from letta.orm.block_history import BlockHistory from letta.orm.enums import ActorType, JobType, ToolType @@ -42,6 +43,7 @@ from letta.schemas.job import Job as PydanticJob from letta.schemas.job import JobUpdate, LettaRequestConfig from letta.schemas.letta_message import UpdateAssistantMessage, UpdateReasoningMessage, UpdateSystemMessage, UpdateUserMessage from letta.schemas.letta_message_content import TextContent +from letta.schemas.llm_batch_job import LLMBatchItem from letta.schemas.llm_config import LLMConfig from letta.schemas.message import Message as PydanticMessage from letta.schemas.message import MessageCreate, MessageUpdate @@ -2713,7 +2715,7 @@ def test_multiple_checkpoints(server: SyncServer, default_user): block_manager.checkpoint_block(block_id=block.id, actor=default_user) # 2) Update block content - updated_block_data = PydanticBlock(**block.dict()) + updated_block_data = PydanticBlock(**block.model_dump()) updated_block_data.value = "v2" block_manager.create_or_update_block(updated_block_data, actor=default_user) @@ -2827,7 +2829,7 @@ def test_checkpoint_no_future_states(server: SyncServer, default_user): block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # 2) Create "v2" and checkpoint => seq=2 - updated_data = PydanticBlock(**block_v1.dict()) + updated_data = PydanticBlock(**block_v1.model_dump()) updated_data.value = "v2" block_manager.create_or_update_block(updated_data, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) @@ -2872,7 +2874,7 @@ def test_undo_checkpoint_block(server: SyncServer, default_user): block_manager.checkpoint_block(block_id=created_block.id, actor=default_user) # 3) Update block content to "Version 2" - updated_data = PydanticBlock(**created_block.dict()) + updated_data = PydanticBlock(**created_block.model_dump()) updated_data.value = "Version 2 content" block_manager.create_or_update_block(updated_data, actor=default_user) @@ -2901,13 +2903,13 @@ def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # 2) Update to "v2", checkpoint => seq=2 - block_v2 = PydanticBlock(**block_v1.dict()) + block_v2 = PydanticBlock(**block_v1.model_dump()) block_v2.value = "v2" block_manager.create_or_update_block(block_v2, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # 3) Update to "v3", checkpoint => seq=3 - block_v3 = PydanticBlock(**block_v1.dict()) + block_v3 = PydanticBlock(**block_v1.model_dump()) block_v3.value = "v3" block_manager.create_or_update_block(block_v3, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) @@ -2926,7 +2928,7 @@ def test_checkpoint_deletes_future_states_after_undo(server: SyncServer, default # because the new code truncates future states beyond seq=1. # Let's do a new edit: "v1.5" - block_v1_5 = PydanticBlock(**block_undo_2.dict()) + block_v1_5 = PydanticBlock(**block_undo_2.model_dump()) block_v1_5.value = "v1.5" block_manager.create_or_update_block(block_v1_5, actor=default_user) @@ -3000,13 +3002,13 @@ def test_undo_multiple_checkpoints(server: SyncServer, default_user): block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # Step 2: Update to v2, checkpoint => seq=2 - block_data_v2 = PydanticBlock(**block_v1.dict()) + block_data_v2 = PydanticBlock(**block_v1.model_dump()) block_data_v2.value = "v2" block_manager.create_or_update_block(block_data_v2, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # Step 3: Update to v3, checkpoint => seq=3 - block_data_v3 = PydanticBlock(**block_v1.dict()) + block_data_v3 = PydanticBlock(**block_v1.model_dump()) block_data_v3.value = "v3" block_manager.create_or_update_block(block_data_v3, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) @@ -3040,7 +3042,7 @@ def test_undo_concurrency_stale(server: SyncServer, default_user): block_manager.checkpoint_block(block_v1.id, actor=default_user) # 2) update to v2 - block_data_v2 = PydanticBlock(**block_v1.dict()) + block_data_v2 = PydanticBlock(**block_v1.model_dump()) block_data_v2.value = "v2" block_manager.create_or_update_block(block_data_v2, actor=default_user) # checkpoint => seq=2 @@ -3088,13 +3090,13 @@ def test_redo_checkpoint_block(server: SyncServer, default_user): block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # 2) Update to 'v2'; checkpoint => seq=2 - block_v2 = PydanticBlock(**block_v1.dict()) + block_v2 = PydanticBlock(**block_v1.model_dump()) block_v2.value = "v2" block_manager.create_or_update_block(block_v2, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) # 3) Update to 'v3'; checkpoint => seq=3 - block_v3 = PydanticBlock(**block_v1.dict()) + block_v3 = PydanticBlock(**block_v1.model_dump()) block_v3.value = "v3" block_manager.create_or_update_block(block_v3, actor=default_user) block_manager.checkpoint_block(block_id=block_v1.id, actor=default_user) @@ -3135,7 +3137,7 @@ def test_redo_at_highest_checkpoint(server: SyncServer, default_user): block_manager.checkpoint_block(b_init.id, actor=default_user) # 2) Another edit => seq=2 - b_next = PydanticBlock(**b_init.dict()) + b_next = PydanticBlock(**b_init.model_dump()) b_next.value = "v2" block_manager.create_or_update_block(b_next, actor=default_user) block_manager.checkpoint_block(b_init.id, actor=default_user) @@ -3159,19 +3161,19 @@ def test_redo_after_multiple_undo(server: SyncServer, default_user): block_manager.checkpoint_block(b_init.id, actor=default_user) # seq=2 - b_v2 = PydanticBlock(**b_init.dict()) + b_v2 = PydanticBlock(**b_init.model_dump()) b_v2.value = "v2" block_manager.create_or_update_block(b_v2, actor=default_user) block_manager.checkpoint_block(b_init.id, actor=default_user) # seq=3 - b_v3 = PydanticBlock(**b_init.dict()) + b_v3 = PydanticBlock(**b_init.model_dump()) b_v3.value = "v3" block_manager.create_or_update_block(b_v3, actor=default_user) block_manager.checkpoint_block(b_init.id, actor=default_user) # seq=4 - b_v4 = PydanticBlock(**b_init.dict()) + b_v4 = PydanticBlock(**b_init.model_dump()) b_v4.value = "v4" block_manager.create_or_update_block(b_v4, actor=default_user) block_manager.checkpoint_block(b_init.id, actor=default_user) @@ -3197,13 +3199,13 @@ def test_redo_concurrency_stale(server: SyncServer, default_user): block_manager.checkpoint_block(block.id, actor=default_user) # 2) Another edit => checkpoint => seq=2 - block_v2 = PydanticBlock(**block.dict()) + block_v2 = PydanticBlock(**block.model_dump()) block_v2.value = "v2" block_manager.create_or_update_block(block_v2, actor=default_user) block_manager.checkpoint_block(block.id, actor=default_user) # 3) Another edit => checkpoint => seq=3 - block_v3 = PydanticBlock(**block.dict()) + block_v3 = PydanticBlock(**block.model_dump()) block_v3.value = "v3" block_manager.create_or_update_block(block_v3, actor=default_user) block_manager.checkpoint_block(block.id, actor=default_user) @@ -4812,7 +4814,7 @@ def test_update_batch_item( server.batch_manager.update_batch_item( item_id=item.id, request_status=JobStatus.completed, - step_status=AgentStepStatus.running, + step_status=AgentStepStatus.resumed, llm_request_response=dummy_successful_response, step_state=updated_step_state, actor=default_user, @@ -4843,3 +4845,216 @@ def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message with pytest.raises(NoResultFound): server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + + +def test_list_running_batches(server, default_user, dummy_beta_message_batch): + server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + status=JobStatus.running, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + running_batches = server.batch_manager.list_running_batches(actor=default_user) + assert len(running_batches) >= 1 + assert all(batch.status == JobStatus.running for batch in running_batches) + + +def test_bulk_update_batch_statuses(server, default_user, dummy_beta_message_batch): + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + status=JobStatus.created, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + server.batch_manager.bulk_update_batch_statuses([(batch.id, JobStatus.completed, dummy_beta_message_batch)]) + + updated = server.batch_manager.get_batch_job_by_id(batch.id, actor=default_user) + assert updated.status == JobStatus.completed + assert updated.latest_polling_response == dummy_beta_message_batch + + +def test_bulk_update_batch_items_results_by_agent( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response +): + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + server.batch_manager.bulk_update_batch_items_results_by_agent( + [ItemUpdateInfo(batch.id, sarah_agent.id, JobStatus.completed, dummy_successful_response)] + ) + + updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + assert updated.request_status == JobStatus.completed + assert updated.batch_request_result == dummy_successful_response + + +def test_bulk_update_batch_items_step_status_by_agent( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state +): + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + server.batch_manager.bulk_update_batch_items_step_status_by_agent( + [StepStatusUpdateInfo(batch.id, sarah_agent.id, AgentStepStatus.resumed)] + ) + + updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + assert updated.step_status == AgentStepStatus.resumed + + +def test_list_batch_items_limit_and_filter(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + for _ in range(3): + server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) + limited_items = server.batch_manager.list_batch_items(batch_id=batch.id, limit=2, actor=default_user) + + assert len(all_items) >= 3 + assert len(limited_items) == 2 + + +def test_bulk_update_batch_items_request_status_by_agent( + server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state +): + # Create a batch job + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + # Create a batch item + item = server.batch_manager.create_batch_item( + batch_id=batch.id, + agent_id=sarah_agent.id, + llm_config=dummy_llm_config, + step_state=dummy_step_state, + actor=default_user, + ) + + # Update the request status using the bulk update method + server.batch_manager.bulk_update_batch_items_request_status_by_agent( + [RequestStatusUpdateInfo(batch.id, sarah_agent.id, JobStatus.expired)] + ) + + # Verify the update was applied + updated = server.batch_manager.get_batch_item_by_id(item.id, actor=default_user) + assert updated.request_status == JobStatus.expired + + +def test_bulk_update_nonexistent_items(server, default_user, dummy_beta_message_batch, dummy_successful_response): + # Create a batch job + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + # Attempt to update non-existent items should not raise errors + + # Test with the direct bulk_update_batch_items method + nonexistent_pairs = [(batch.id, "nonexistent-agent-id")] + nonexistent_updates = [{"request_status": JobStatus.expired}] + + # This should not raise an error, just silently skip non-existent items + server.batch_manager.bulk_update_batch_items(nonexistent_pairs, nonexistent_updates) + + # Test with higher-level methods + # Results by agent + server.batch_manager.bulk_update_batch_items_results_by_agent( + [ItemUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired, dummy_successful_response)] + ) + + # Step status by agent + server.batch_manager.bulk_update_batch_items_step_status_by_agent( + [StepStatusUpdateInfo(batch.id, "nonexistent-agent-id", AgentStepStatus.resumed)] + ) + + # Request status by agent + server.batch_manager.bulk_update_batch_items_request_status_by_agent( + [RequestStatusUpdateInfo(batch.id, "nonexistent-agent-id", JobStatus.expired)] + ) + + +def test_create_batch_items_bulk(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): + # Create a batch job + batch = server.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, + create_batch_response=dummy_beta_message_batch, + actor=default_user, + ) + + # Prepare data for multiple batch items + batch_items = [] + agent_ids = [sarah_agent.id, sarah_agent.id, sarah_agent.id] # Using the same agent for simplicity + + for agent_id in agent_ids: + batch_item = LLMBatchItem( + batch_id=batch.id, + agent_id=agent_id, + llm_config=dummy_llm_config, + request_status=JobStatus.created, + step_status=AgentStepStatus.paused, + step_state=dummy_step_state, + ) + batch_items.append(batch_item) + + # Call the bulk create function + created_items = server.batch_manager.create_batch_items_bulk(batch_items, actor=default_user) + + # Verify the correct number of items were created + assert len(created_items) == len(agent_ids) + + # Verify each item has expected properties + for item in created_items: + assert item.id.startswith("batch_item-") + assert item.batch_id == batch.id + assert item.agent_id in agent_ids + assert item.llm_config == dummy_llm_config + assert item.request_status == JobStatus.created + assert item.step_status == AgentStepStatus.paused + assert item.step_state == dummy_step_state + + # Verify items can be retrieved from the database + all_items = server.batch_manager.list_batch_items(batch_id=batch.id, actor=default_user) + assert len(all_items) >= len(agent_ids) + + # Verify the IDs of created items match what's in the database + created_ids = [item.id for item in created_items] + for item_id in created_ids: + fetched = server.batch_manager.get_batch_item_by_id(item_id, actor=default_user) + assert fetched.id in created_ids