From f77a259d07e4d88b28f7ce3caac7a50014d13522 Mon Sep 17 00:00:00 2001 From: cthomas Date: Sat, 26 Jul 2025 23:17:24 -0700 Subject: [PATCH] feat: asyncify jinja templates (#3580) --- letta/agents/base_agent.py | 2 +- letta/schemas/memory.py | 35 +++++++++++++++++++ letta/services/agent_manager.py | 16 ++++----- letta/services/tool_executor/tool_executor.py | 8 +++-- letta/services/tool_sandbox/base.py | 6 ++-- letta/services/tool_sandbox/e2b_sandbox.py | 2 +- letta/services/tool_sandbox/local_sandbox.py | 25 ++++++++----- letta/templates/template_helper.py | 26 +++++++++++++- tests/integration_test_async_tool_sandbox.py | 7 ++-- 9 files changed, 98 insertions(+), 29 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 1e5401b2..325090ec 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -122,7 +122,7 @@ class BaseAgent(ABC): curr_dynamic_section = extract_dynamic_section(curr_system_message_text) # generate just the memory string with current state for comparison - curr_memory_str = agent_state.memory.compile( + curr_memory_str = await agent_state.memory.compile_async( tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open ) new_dynamic_section = extract_dynamic_section(curr_memory_str) diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index c240bded..e59a6437 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -133,6 +133,25 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + async def set_prompt_template_async(self, prompt_template: str): + """ + Async version of set_prompt_template that doesn't block the event loop. + """ + try: + # Validate Jinja2 syntax with async enabled + Template(prompt_template, enable_async=True) + + # Validate compatibility with current memory structure - use async rendering + template = Template(prompt_template, enable_async=True) + await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None) + + # If we get here, the template is valid and compatible + self.prompt_template = prompt_template + except TemplateSyntaxError as e: + raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}") + except Exception as e: + raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" try: @@ -149,6 +168,22 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str: + """Async version of compile that doesn't block the event loop""" + try: + template = Template(self.prompt_template, enable_async=True) + return await template.render_async( + blocks=self.blocks, + file_blocks=self.file_blocks, + tool_usage_rules=tool_usage_rules, + sources=sources, + max_files_open=max_files_open, + ) + except TemplateSyntaxError as e: + raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}") + except Exception as e: + raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" # return list(self.memory.keys()) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index be12684f..fed81a5b 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1646,7 +1646,7 @@ class AgentManager: # note: we only update the system prompt if the core memory is changed # this means that the archival/recall memory statistics may be someout out of date - curr_memory_str = agent_state.memory.compile( + curr_memory_str = await agent_state.memory.compile_async( sources=agent_state.sources, tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(), max_files_open=agent_state.max_files_open, @@ -1836,14 +1836,12 @@ class AgentManager: agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"]) system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) - if ( - new_memory.compile( - sources=agent_state.sources, - tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(), - max_files_open=agent_state.max_files_open, - ) - not in system_message.content[0].text - ): + new_memory_str = await new_memory.compile_async( + sources=agent_state.sources, + tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(), + max_files_open=agent_state.max_files_open, + ) + if new_memory_str not in system_message.content[0].text: # update the blocks (LRW) in the DB for label in agent_state.memory.list_block_labels(): updated_value = new_memory.get_block(label).value diff --git a/letta/services/tool_executor/tool_executor.py b/letta/services/tool_executor/tool_executor.py index b57d87cb..9476e498 100644 --- a/letta/services/tool_executor/tool_executor.py +++ b/letta/services/tool_executor/tool_executor.py @@ -36,7 +36,10 @@ class SandboxToolExecutor(ToolExecutor): ) -> ToolExecutionResult: # Store original memory state - orig_memory_str = agent_state.memory.compile() if agent_state else None + if agent_state: + orig_memory_str = await agent_state.memory.compile_async() + else: + orig_memory_str = None try: # Prepare function arguments @@ -58,7 +61,8 @@ class SandboxToolExecutor(ToolExecutor): # Verify memory integrity if agent_state: - assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool" + new_memory_str = await agent_state.memory.compile_async() + assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool" # Update agent memory if needed if tool_execution_result.agent_state is not None: diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index 03083981..68e4d109 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -74,12 +74,12 @@ class AsyncToolSandboxBase(ABC): """ raise NotImplementedError - def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str: + async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str: """ Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool, then base64-encode/pickle the result. Runs a jinja2 template constructing the python file. """ - from letta.templates.template_helper import render_template + from letta.templates.template_helper import render_template_async # Select the appropriate template based on whether the function is async TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2" @@ -106,7 +106,7 @@ class AsyncToolSandboxBase(ABC): agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None - return render_template( + return await render_template_async( TEMPLATE_NAME, future_import=future_import, inject_agent_state=self.inject_agent_state, diff --git a/letta/services/tool_sandbox/e2b_sandbox.py b/letta/services/tool_sandbox/e2b_sandbox.py index ca7ca907..bdc65e5c 100644 --- a/letta/services/tool_sandbox/e2b_sandbox.py +++ b/letta/services/tool_sandbox/e2b_sandbox.py @@ -92,7 +92,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase): # Finally, get any that are passed explicitly into the `run` function call if additional_env_vars: env_vars.update(additional_env_vars) - code = self.generate_execution_script(agent_state=agent_state) + code = await self.generate_execution_script(agent_state=agent_state) try: log_event( diff --git a/letta/services/tool_sandbox/local_sandbox.py b/letta/services/tool_sandbox/local_sandbox.py index 5056adde..8f2b2871 100644 --- a/letta/services/tool_sandbox/local_sandbox.py +++ b/letta/services/tool_sandbox/local_sandbox.py @@ -99,8 +99,8 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): # Make sure sandbox directory exists sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) - if not os.path.exists(sandbox_dir) or not os.path.isdir(sandbox_dir): - os.makedirs(sandbox_dir) + if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)): + await asyncio.to_thread(os.makedirs, sandbox_dir) # If using a virtual environment, ensure it's prepared in parallel venv_preparation_task = None @@ -109,11 +109,18 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env)) # Generate and write execution script (always with markers, since we rely on stdout) - with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file: - code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True) - temp_file.write(code) - temp_file.flush() - temp_file_path = temp_file.name + code = await self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True) + + async def write_temp_file(dir, content): + def _write(): + with tempfile.NamedTemporaryFile(mode="w", dir=dir, suffix=".py", delete=False) as temp_file: + temp_file.write(content) + temp_file.flush() + return temp_file.name + + return await asyncio.to_thread(_write) + + temp_file_path = await write_temp_file(sandbox_dir, code) try: # If we started a venv preparation task, wait for it to complete @@ -159,14 +166,14 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase): from letta.settings import settings if not settings.debug: - os.remove(temp_file_path) + await asyncio.to_thread(os.remove, temp_file_path) @trace_method async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]): """ Prepare virtual environment asynchronously (in a background thread). """ - if self.force_recreate_venv or not os.path.isdir(venv_path): + if self.force_recreate_venv or not await asyncio.to_thread(os.path.isdir, venv_path): sandbox_dir = os.path.expanduser(local_configs.sandbox_dir) log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path}) await asyncio.to_thread( diff --git a/letta/templates/template_helper.py b/letta/templates/template_helper.py index 0d2359ce..428e2bd2 100644 --- a/letta/templates/template_helper.py +++ b/letta/templates/template_helper.py @@ -1,8 +1,10 @@ import os -from jinja2 import Environment, FileSystemLoader, StrictUndefined +from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template TEMPLATE_DIR = os.path.dirname(__file__) + +# Synchronous environment (for backward compatibility) jinja_env = Environment( loader=FileSystemLoader(TEMPLATE_DIR), undefined=StrictUndefined, @@ -10,7 +12,29 @@ jinja_env = Environment( lstrip_blocks=True, ) +# Async-enabled environment +jinja_async_env = Environment( + loader=FileSystemLoader(TEMPLATE_DIR), + undefined=StrictUndefined, + trim_blocks=True, + lstrip_blocks=True, + enable_async=True, # Enable async support +) + def render_template(template_name: str, **kwargs): + """Synchronous template rendering function (kept for backward compatibility)""" template = jinja_env.get_template(template_name) return template.render(**kwargs) + + +async def render_template_async(template_name: str, **kwargs): + """Asynchronous template rendering function that doesn't block the event loop""" + template = jinja_async_env.get_template(template_name) + return await template.render_async(**kwargs) + + +async def render_string_async(template_string: str, **kwargs): + """Asynchronously render a template from a string""" + template = Template(template_string, enable_async=True) + return await template.render_async(**kwargs) diff --git a/tests/integration_test_async_tool_sandbox.py b/tests/integration_test_async_tool_sandbox.py index d1f599aa..bd01460e 100644 --- a/tests/integration_test_async_tool_sandbox.py +++ b/tests/integration_test_async_tool_sandbox.py @@ -908,11 +908,12 @@ def test_async_function_detection(add_integers_tool, async_add_integers_tool, te assert async_sandbox.is_async_function -def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user): +@pytest.mark.asyncio +async def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user): """Test that correct templates are selected for sync vs async functions""" # Test sync function uses regular template sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool) - sync_script = sync_sandbox.generate_execution_script(agent_state=None) + sync_script = await sync_sandbox.generate_execution_script(agent_state=None) print("=== SYNC SCRIPT ===") print(sync_script) print("=== END SYNC SCRIPT ===") @@ -921,7 +922,7 @@ def test_async_template_selection(add_integers_tool, async_add_integers_tool, te # Test async function uses async template async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool) - async_script = async_sandbox.generate_execution_script(agent_state=None) + async_script = await async_sandbox.generate_execution_script(agent_state=None) print("=== ASYNC SCRIPT ===") print(async_script) print("=== END ASYNC SCRIPT ===")