From 68fbcf33d8f9bdd1176f28c1effae328ffc840cc Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 10 Apr 2025 10:19:06 -0700 Subject: [PATCH] feat: Finish `step_until_request` in new batch agent loop (#1656) --- letta/agents/letta_agent_batch.py | 164 ++++++++++++++++++ letta/llm_api/llm_client_base.py | 8 +- letta/schemas/letta_request.py | 4 + letta/schemas/letta_response.py | 10 +- letta/services/agent_manager.py | 3 + letta/services/llm_batch_manager.py | 2 +- ...> integration_test_batch_api_cron_jobs.py} | 2 +- tests/test_managers.py | 10 +- 8 files changed, 193 insertions(+), 10 deletions(-) create mode 100644 letta/agents/letta_agent_batch.py rename tests/{integration_test_batch_api.py => integration_test_batch_api_cron_jobs.py} (99%) diff --git a/letta/agents/letta_agent_batch.py b/letta/agents/letta_agent_batch.py new file mode 100644 index 00000000..0dca824f --- /dev/null +++ b/letta/agents/letta_agent_batch.py @@ -0,0 +1,164 @@ +from typing import Dict, List + +from letta.agents.helpers import _prepare_in_context_messages +from letta.helpers import ToolRulesSolver +from letta.helpers.datetime_helpers import get_utc_time +from letta.helpers.tool_execution_helper import enable_strict_mode +from letta.llm_api.llm_client import LLMClient +from letta.log import get_logger +from letta.orm.enums import ToolType +from letta.schemas.agent import AgentState, AgentStepState +from letta.schemas.enums import JobStatus, ProviderType +from letta.schemas.letta_request import LettaBatchRequest +from letta.schemas.letta_response import LettaBatchResponse +from letta.schemas.message import Message, MessageCreate, MessageUpdate +from letta.schemas.user import User +from letta.services.agent_manager import AgentManager +from letta.services.block_manager import BlockManager +from letta.services.helpers.agent_manager_helper import compile_system_message +from letta.services.llm_batch_manager import LLMBatchManager +from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager +from letta.utils import united_diff + +logger = get_logger(__name__) + + +# TODO: Limitations -> +# TODO: Only works with anthropic for now +class LettaAgentBatch: + + def __init__( + self, + batch_id: str, + message_manager: MessageManager, + agent_manager: AgentManager, + block_manager: BlockManager, + passage_manager: PassageManager, + batch_manager: LLMBatchManager, + actor: User, + use_assistant_message: bool = True, + max_steps: int = 10, + ): + self.batch_id = batch_id + self.message_manager = message_manager + self.agent_manager = agent_manager + self.block_manager = block_manager + self.passage_manager = passage_manager + self.batch_manager = batch_manager + self.use_assistant_message = use_assistant_message + self.actor = actor + self.max_steps = max_steps + + async def step_until_request( + self, batch_requests: List[LettaBatchRequest], agent_step_state_mapping: Dict[str, AgentStepState] + ) -> LettaBatchResponse: + agent_messages_mapping: Dict[str, List[Message]] = {} + agent_tools_mapping: Dict[str, List[dict]] = {} + agent_states = [] + + for batch_request in batch_requests: + agent_id = batch_request.agent_id + agent_state = self.agent_manager.get_agent_by_id(agent_id) + agent_states.append(agent_state) + agent_messages_mapping[agent_id] = self.get_in_context_messages_per_agent( + agent_state=agent_state, input_messages=batch_request.messages + ) + agent_tools_mapping[agent_id] = self.prepare_tools_per_agent( + agent_state, agent_step_state_mapping.get(agent_id).tool_rules_solver + ) + + # TODO: This is a hack, this is because LLM client expects a LLM config + # TODO: But that doesn't really work in batch land + # TODO: @caren will factor this out + llm_client = LLMClient.create( + llm_config=agent_states[0].llm_config, + put_inner_thoughts_first=True, + ) + agent_llm_config_mapping = {agent_state.id: agent_state.llm_config for agent_state in agent_states} + batch_response = await llm_client.send_llm_batch_request_async( + agent_messages_mapping=agent_messages_mapping, + agent_tools_mapping=agent_tools_mapping, + agent_llm_config_mapping=agent_llm_config_mapping, + ) + + # Write the response into the jobs table, where it will get picked up by the next cron run + batch_job = self.batch_manager.create_batch_job( + llm_provider=ProviderType.anthropic, # TODO: Expand to more + create_batch_response=batch_response, + actor=self.actor, + status=JobStatus.running, + ) + + # TODO: Make this much more efficient by doing creates in bulk + for agent_state in agent_states: + agent_step_state = agent_step_state_mapping.get(agent_state.id) + self.batch_manager.create_batch_item( + batch_id=batch_job.id, + agent_id=agent_state.id, + llm_config=agent_state.llm_config, + actor=self.actor, + step_state=agent_step_state, + ) + + return LettaBatchResponse( + batch_id=batch_job.id, statue=batch_job.status, last_polled_at=batch_job.last_polled_at, created_at=batch_job.created_at + ) + + async def resume_step_after_request(self, batch_id: str): + pass + + def prepare_tools_per_agent(self, agent_state: AgentState, tool_rules_solver: ToolRulesSolver) -> List[dict]: + tools = [t for t in agent_state.tools if t.tool_type in {ToolType.CUSTOM, ToolType.LETTA_CORE, ToolType.LETTA_MEMORY_CORE}] + valid_tool_names = tool_rules_solver.get_allowed_tool_names(available_tools=set([t.name for t in tools])) + return [enable_strict_mode(t.json_schema) for t in tools if t.name in set(valid_tool_names)] + + def get_in_context_messages_per_agent(self, agent_state: AgentState, input_messages: List[MessageCreate]) -> List[Message]: + current_in_context_messages, new_in_context_messages = _prepare_in_context_messages( + input_messages, agent_state, self.message_manager, self.actor + ) + + in_context_messages = self._rebuild_memory(current_in_context_messages + new_in_context_messages, agent_state) + return in_context_messages + + # TODO: Make this a bullk function + def _rebuild_memory(self, in_context_messages: List[Message], agent_state: AgentState) -> List[Message]: + self.agent_manager.refresh_memory(agent_state=agent_state, actor=self.actor) + + # TODO: This is a pretty brittle pattern established all over our code, need to get rid of this + curr_system_message = in_context_messages[0] + curr_memory_str = agent_state.memory.compile() + curr_system_message_text = curr_system_message.content[0].text + if curr_memory_str in curr_system_message_text: + # NOTE: could this cause issues if a block is removed? (substring match would still work) + logger.debug( + f"Memory hasn't changed for agent id={agent_state.id} and actor=({self.actor.id}, {self.actor.name}), skipping system prompt rebuild" + ) + return in_context_messages + + memory_edit_timestamp = get_utc_time() + + num_messages = self.message_manager.size(actor=self.actor, agent_id=agent_state.id) + num_archival_memories = self.passage_manager.size(actor=self.actor, agent_id=agent_state.id) + + new_system_message_str = compile_system_message( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, + ) + + diff = united_diff(curr_system_message_text, new_system_message_str) + if len(diff) > 0: + logger.debug(f"Rebuilding system with new memory...\nDiff:\n{diff}") + + new_system_message = self.message_manager.update_message_by_id( + curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + ) + + # Skip pulling down the agent's memory again to save on a db call + return [new_system_message] + in_context_messages[1:] + + else: + return in_context_messages diff --git a/letta/llm_api/llm_client_base.py b/letta/llm_api/llm_client_base.py index bc6f5be5..12cf2fec 100644 --- a/letta/llm_api/llm_client_base.py +++ b/letta/llm_api/llm_client_base.py @@ -1,6 +1,7 @@ from abc import abstractmethod from typing import Dict, List, Optional, Union +from anthropic.types.beta.messages import BetaMessageBatch from openai import AsyncStream, Stream from openai.types.chat.chat_completion_chunk import ChatCompletionChunk @@ -80,8 +81,11 @@ class LLMClientBase: return self.convert_response_to_chat_completion(response_data, messages) async def send_llm_batch_request_async( - self, agent_messages_mapping: Dict[str, List[Message]], agent_tools_mapping: Dict[str, List[dict]] - ): + self, + agent_messages_mapping: Dict[str, List[Message]], + agent_tools_mapping: Dict[str, List[dict]], + agent_llm_config_mapping: Dict[str, LLMConfig], + ) -> Union[BetaMessageBatch]: raise NotImplementedError @abstractmethod diff --git a/letta/schemas/letta_request.py b/letta/schemas/letta_request.py index 2547fe68..4aa66e62 100644 --- a/letta/schemas/letta_request.py +++ b/letta/schemas/letta_request.py @@ -27,3 +27,7 @@ class LettaStreamingRequest(LettaRequest): default=False, description="Flag to determine if individual tokens should be streamed. Set to True for token streaming (requires stream_steps = True).", ) + + +class LettaBatchRequest(LettaRequest): + agent_id: str = Field(..., description="The ID of the agent to send this batch request for") diff --git a/letta/schemas/letta_response.py b/letta/schemas/letta_response.py index acdf1265..9a179261 100644 --- a/letta/schemas/letta_response.py +++ b/letta/schemas/letta_response.py @@ -1,12 +1,13 @@ import html import json import re +from datetime import datetime from typing import List, Union from pydantic import BaseModel, Field from letta.helpers.json_helpers import json_dumps -from letta.schemas.enums import MessageStreamStatus +from letta.schemas.enums import JobStatus, MessageStreamStatus from letta.schemas.letta_message import LettaMessage, LettaMessageUnion from letta.schemas.usage import LettaUsageStatistics @@ -165,3 +166,10 @@ class LettaResponse(BaseModel): # The streaming response is either [DONE], [DONE_STEP], [DONE], an error, or a LettaMessage LettaStreamingResponse = Union[LettaMessage, MessageStreamStatus, LettaUsageStatistics] + + +class LettaBatchResponse(BaseModel): + batch_id: str = Field(..., description="A unique identifier for this batch request.") + status: JobStatus = Field(..., description="The current status of the batch request.") + last_polled_at: datetime = Field(..., description="The timestamp when the batch was last polled for updates.") + created_at: datetime = Field(..., description="The timestamp when the batch request was created.") diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index ffbef850..ee320caf 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -658,6 +658,9 @@ class AgentManager: message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor) + # TODO: This is duplicated below + # TODO: This is legacy code and should be cleaned up + # TODO: A lot of the memory "compilation" should be offset to a separate class @enforce_types def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState: """Rebuilds the system message with the latest memory object and any shared memory block updates diff --git a/letta/services/llm_batch_manager.py b/letta/services/llm_batch_manager.py index 405fd36e..8de38d48 100644 --- a/letta/services/llm_batch_manager.py +++ b/letta/services/llm_batch_manager.py @@ -28,7 +28,7 @@ class LLMBatchManager: self.session_maker = db_context @enforce_types - def create_batch_request( + def create_batch_job( self, llm_provider: ProviderType, create_batch_response: BetaMessageBatch, diff --git a/tests/integration_test_batch_api.py b/tests/integration_test_batch_api_cron_jobs.py similarity index 99% rename from tests/integration_test_batch_api.py rename to tests/integration_test_batch_api_cron_jobs.py index 576a9a6c..7c0322e4 100644 --- a/tests/integration_test_batch_api.py +++ b/tests/integration_test_batch_api_cron_jobs.py @@ -147,7 +147,7 @@ def create_test_agent(client, name, model="anthropic/claude-3-5-sonnet-20241022" def create_test_batch_job(server, batch_response, default_user): """Create a test batch job with the given batch response.""" - return server.batch_manager.create_batch_request( + return server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, create_batch_response=batch_response, actor=default_user, diff --git a/tests/test_managers.py b/tests/test_managers.py index 0bd509aa..01835ce7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -4680,7 +4680,7 @@ def test_list_tags(server: SyncServer, default_user, default_organization): def test_create_and_get_batch_request(server, default_user, dummy_beta_message_batch): - batch = server.batch_manager.create_batch_request( + batch = server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -4693,7 +4693,7 @@ def test_create_and_get_batch_request(server, default_user, dummy_beta_message_b def test_update_batch_status(server, default_user, dummy_beta_message_batch): - batch = server.batch_manager.create_batch_request( + batch = server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -4715,7 +4715,7 @@ def test_update_batch_status(server, default_user, dummy_beta_message_batch): def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): - batch = server.batch_manager.create_batch_request( + batch = server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -4741,7 +4741,7 @@ def test_create_and_get_batch_item(server, default_user, sarah_agent, dummy_beta def test_update_batch_item( server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state, dummy_successful_response ): - batch = server.batch_manager.create_batch_request( + batch = server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch, @@ -4773,7 +4773,7 @@ def test_update_batch_item( def test_delete_batch_item(server, default_user, sarah_agent, dummy_beta_message_batch, dummy_llm_config, dummy_step_state): - batch = server.batch_manager.create_batch_request( + batch = server.batch_manager.create_batch_job( llm_provider=ProviderType.anthropic, status=JobStatus.created, create_batch_response=dummy_beta_message_batch,