From 13c916afaa9902992e16afd2234ddae4d98fc0da Mon Sep 17 00:00:00 2001 From: cthomas Date: Wed, 6 Aug 2025 16:34:28 -0700 Subject: [PATCH] feat: offload jinja to threadpool LET-3615 (#3787) --- letta/agents/base_agent.py | 2 +- letta/schemas/memory.py | 12 +++++++++--- letta/services/agent_manager.py | 4 ++-- letta/services/helpers/agent_manager_helper.py | 2 +- .../services/tool_executor/sandbox_tool_executor.py | 4 ++-- letta/services/tool_sandbox/base.py | 4 ++-- letta/templates/template_helper.py | 8 ++++++++ 7 files changed, 25 insertions(+), 11 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index f26c6047..d351eb10 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -122,7 +122,7 @@ class BaseAgent(ABC): curr_dynamic_section = extract_dynamic_section(curr_system_message_text) # generate just the memory string with current state for comparison - curr_memory_str = await agent_state.memory.compile_async( + curr_memory_str = await agent_state.memory.compile_in_thread_async( tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open ) new_dynamic_section = extract_dynamic_section(curr_memory_str) diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index 3952a647..802f2292 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -1,3 +1,4 @@ +import asyncio import logging from typing import TYPE_CHECKING, List, Optional @@ -142,11 +143,11 @@ class Memory(BaseModel, validate_assignment=True): """ try: # Validate Jinja2 syntax with async enabled - Template(prompt_template, enable_async=True) + Template(prompt_template) # Validate compatibility with current memory structure - use async rendering - template = Template(prompt_template, enable_async=True) - await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None) + template = Template(prompt_template) + await asyncio.to_thread(template.render, blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None) # If we get here, the template is valid and compatible self.prompt_template = prompt_template @@ -189,6 +190,11 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + @trace_method + async def compile_in_thread_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str: + """Compile the memory in a thread""" + return await asyncio.to_thread(self.compile, tool_usage_rules=tool_usage_rules, sources=sources, max_files_open=max_files_open) + def list_block_labels(self) -> List[str]: """Return a list of the block names held inside the memory object""" # return list(self.memory.keys()) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 9ad5e41a..d7eea3a0 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -1741,7 +1741,7 @@ class AgentManager: # note: we only update the system prompt if the core memory is changed # this means that the archival/recall memory statistics may be someout out of date - curr_memory_str = await agent_state.memory.compile_async( + curr_memory_str = await agent_state.memory.compile_in_thread_async( sources=agent_state.sources, tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(), max_files_open=agent_state.max_files_open, @@ -1928,7 +1928,7 @@ class AgentManager: agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"]) system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor) temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules) - new_memory_str = await new_memory.compile_async( + new_memory_str = await new_memory.compile_in_thread_async( sources=agent_state.sources, tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(), max_files_open=agent_state.max_files_open, diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 092b34ff..a97509f4 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -443,7 +443,7 @@ async def compile_system_message_async( timezone=timezone, ) - memory_with_sources = await in_context_memory.compile_async( + memory_with_sources = await in_context_memory.compile_in_thread_async( tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open ) full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string diff --git a/letta/services/tool_executor/sandbox_tool_executor.py b/letta/services/tool_executor/sandbox_tool_executor.py index 5189b5e1..816b832c 100644 --- a/letta/services/tool_executor/sandbox_tool_executor.py +++ b/letta/services/tool_executor/sandbox_tool_executor.py @@ -42,7 +42,7 @@ class SandboxToolExecutor(ToolExecutor): # Store original memory state if agent_state: - orig_memory_str = await agent_state.memory.compile_async() + orig_memory_str = await agent_state.memory.compile_in_thread_async() else: orig_memory_str = None @@ -73,7 +73,7 @@ class SandboxToolExecutor(ToolExecutor): # Verify memory integrity if agent_state: - new_memory_str = await agent_state.memory.compile_async() + new_memory_str = await agent_state.memory.compile_in_thread_async() assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool" # Update agent memory if needed diff --git a/letta/services/tool_sandbox/base.py b/letta/services/tool_sandbox/base.py index fb953772..97f8514f 100644 --- a/letta/services/tool_sandbox/base.py +++ b/letta/services/tool_sandbox/base.py @@ -80,7 +80,7 @@ class AsyncToolSandboxBase(ABC): Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool, then base64-encode/pickle the result. Runs a jinja2 template constructing the python file. """ - from letta.templates.template_helper import render_template_async + from letta.templates.template_helper import render_template_in_thread # Select the appropriate template based on whether the function is async TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2" @@ -107,7 +107,7 @@ class AsyncToolSandboxBase(ABC): agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None - return await render_template_async( + return await render_template_in_thread( TEMPLATE_NAME, future_import=future_import, inject_agent_state=self.inject_agent_state, diff --git a/letta/templates/template_helper.py b/letta/templates/template_helper.py index 54de2b5c..af4463fc 100644 --- a/letta/templates/template_helper.py +++ b/letta/templates/template_helper.py @@ -1,3 +1,4 @@ +import asyncio import os from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template @@ -38,6 +39,13 @@ async def render_template_async(template_name: str, **kwargs): return await template.render_async(**kwargs) +@trace_method +async def render_template_in_thread(template_name: str, **kwargs): + """Asynchronously render a template from a string""" + template = jinja_env.get_template(template_name) + return await asyncio.to_thread(template.render, **kwargs) + + @trace_method async def render_string_async(template_string: str, **kwargs): """Asynchronously render a template from a string"""