feat: asyncify jinja templates (#3580)

This commit is contained in:
cthomas
2025-07-26 23:17:24 -07:00
committed by GitHub
parent 29bc80486d
commit f77a259d07
9 changed files with 98 additions and 29 deletions

View File

@@ -122,7 +122,7 @@ class BaseAgent(ABC):
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
# generate just the memory string with current state for comparison
curr_memory_str = agent_state.memory.compile(
curr_memory_str = await agent_state.memory.compile_async(
tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open
)
new_dynamic_section = extract_dynamic_section(curr_memory_str)

View File

@@ -133,6 +133,25 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
async def set_prompt_template_async(self, prompt_template: str):
"""
Async version of set_prompt_template that doesn't block the event loop.
"""
try:
# Validate Jinja2 syntax with async enabled
Template(prompt_template, enable_async=True)
# Validate compatibility with current memory structure - use async rendering
template = Template(prompt_template, enable_async=True)
await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
# If we get here, the template is valid and compatible
self.prompt_template = prompt_template
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
"""Generate a string representation of the memory in-context using the Jinja2 template"""
try:
@@ -149,6 +168,22 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
"""Async version of compile that doesn't block the event loop"""
try:
template = Template(self.prompt_template, enable_async=True)
return await template.render_async(
blocks=self.blocks,
file_blocks=self.file_blocks,
tool_usage_rules=tool_usage_rules,
sources=sources,
max_files_open=max_files_open,
)
except TemplateSyntaxError as e:
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
def list_block_labels(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""
# return list(self.memory.keys())

View File

@@ -1646,7 +1646,7 @@ class AgentManager:
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = agent_state.memory.compile(
curr_memory_str = await agent_state.memory.compile_async(
sources=agent_state.sources,
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
max_files_open=agent_state.max_files_open,
@@ -1836,14 +1836,12 @@ class AgentManager:
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
if (
new_memory.compile(
sources=agent_state.sources,
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
max_files_open=agent_state.max_files_open,
)
not in system_message.content[0].text
):
new_memory_str = await new_memory.compile_async(
sources=agent_state.sources,
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
max_files_open=agent_state.max_files_open,
)
if new_memory_str not in system_message.content[0].text:
# update the blocks (LRW) in the DB
for label in agent_state.memory.list_block_labels():
updated_value = new_memory.get_block(label).value

View File

@@ -36,7 +36,10 @@ class SandboxToolExecutor(ToolExecutor):
) -> ToolExecutionResult:
# Store original memory state
orig_memory_str = agent_state.memory.compile() if agent_state else None
if agent_state:
orig_memory_str = await agent_state.memory.compile_async()
else:
orig_memory_str = None
try:
# Prepare function arguments
@@ -58,7 +61,8 @@ class SandboxToolExecutor(ToolExecutor):
# Verify memory integrity
if agent_state:
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
new_memory_str = await agent_state.memory.compile_async()
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
# Update agent memory if needed
if tool_execution_result.agent_state is not None:

View File

@@ -74,12 +74,12 @@ class AsyncToolSandboxBase(ABC):
"""
raise NotImplementedError
def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
"""
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
"""
from letta.templates.template_helper import render_template
from letta.templates.template_helper import render_template_async
# Select the appropriate template based on whether the function is async
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
@@ -106,7 +106,7 @@ class AsyncToolSandboxBase(ABC):
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
return render_template(
return await render_template_async(
TEMPLATE_NAME,
future_import=future_import,
inject_agent_state=self.inject_agent_state,

View File

@@ -92,7 +92,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
# Finally, get any that are passed explicitly into the `run` function call
if additional_env_vars:
env_vars.update(additional_env_vars)
code = self.generate_execution_script(agent_state=agent_state)
code = await self.generate_execution_script(agent_state=agent_state)
try:
log_event(

View File

@@ -99,8 +99,8 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
# Make sure sandbox directory exists
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
if not os.path.exists(sandbox_dir) or not os.path.isdir(sandbox_dir):
os.makedirs(sandbox_dir)
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):
await asyncio.to_thread(os.makedirs, sandbox_dir)
# If using a virtual environment, ensure it's prepared in parallel
venv_preparation_task = None
@@ -109,11 +109,18 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env))
# Generate and write execution script (always with markers, since we rely on stdout)
with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file:
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
temp_file.write(code)
temp_file.flush()
temp_file_path = temp_file.name
code = await self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
async def write_temp_file(dir, content):
def _write():
with tempfile.NamedTemporaryFile(mode="w", dir=dir, suffix=".py", delete=False) as temp_file:
temp_file.write(content)
temp_file.flush()
return temp_file.name
return await asyncio.to_thread(_write)
temp_file_path = await write_temp_file(sandbox_dir, code)
try:
# If we started a venv preparation task, wait for it to complete
@@ -159,14 +166,14 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
from letta.settings import settings
if not settings.debug:
os.remove(temp_file_path)
await asyncio.to_thread(os.remove, temp_file_path)
@trace_method
async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]):
"""
Prepare virtual environment asynchronously (in a background thread).
"""
if self.force_recreate_venv or not os.path.isdir(venv_path):
if self.force_recreate_venv or not await asyncio.to_thread(os.path.isdir, venv_path):
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
await asyncio.to_thread(

View File

@@ -1,8 +1,10 @@
import os
from jinja2 import Environment, FileSystemLoader, StrictUndefined
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
TEMPLATE_DIR = os.path.dirname(__file__)
# Synchronous environment (for backward compatibility)
jinja_env = Environment(
loader=FileSystemLoader(TEMPLATE_DIR),
undefined=StrictUndefined,
@@ -10,7 +12,29 @@ jinja_env = Environment(
lstrip_blocks=True,
)
# Async-enabled environment
jinja_async_env = Environment(
loader=FileSystemLoader(TEMPLATE_DIR),
undefined=StrictUndefined,
trim_blocks=True,
lstrip_blocks=True,
enable_async=True, # Enable async support
)
def render_template(template_name: str, **kwargs):
"""Synchronous template rendering function (kept for backward compatibility)"""
template = jinja_env.get_template(template_name)
return template.render(**kwargs)
async def render_template_async(template_name: str, **kwargs):
"""Asynchronous template rendering function that doesn't block the event loop"""
template = jinja_async_env.get_template(template_name)
return await template.render_async(**kwargs)
async def render_string_async(template_string: str, **kwargs):
"""Asynchronously render a template from a string"""
template = Template(template_string, enable_async=True)
return await template.render_async(**kwargs)

View File

@@ -908,11 +908,12 @@ def test_async_function_detection(add_integers_tool, async_add_integers_tool, te
assert async_sandbox.is_async_function
def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user):
@pytest.mark.asyncio
async def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user):
"""Test that correct templates are selected for sync vs async functions"""
# Test sync function uses regular template
sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
sync_script = sync_sandbox.generate_execution_script(agent_state=None)
sync_script = await sync_sandbox.generate_execution_script(agent_state=None)
print("=== SYNC SCRIPT ===")
print(sync_script)
print("=== END SYNC SCRIPT ===")
@@ -921,7 +922,7 @@ def test_async_template_selection(add_integers_tool, async_add_integers_tool, te
# Test async function uses async template
async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
async_script = async_sandbox.generate_execution_script(agent_state=None)
async_script = await async_sandbox.generate_execution_script(agent_state=None)
print("=== ASYNC SCRIPT ===")
print(async_script)
print("=== END ASYNC SCRIPT ===")