feat: offload jinja to threadpool LET-3615 (#3787)

This commit is contained in:
cthomas
2025-08-06 16:34:28 -07:00
committed by GitHub
parent fd9dc0c1f8
commit 9a2caeb0bc
7 changed files with 25 additions and 11 deletions

View File

@@ -122,7 +122,7 @@ class BaseAgent(ABC):
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
# generate just the memory string with current state for comparison
curr_memory_str = await agent_state.memory.compile_async(
curr_memory_str = await agent_state.memory.compile_in_thread_async(
tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open
)
new_dynamic_section = extract_dynamic_section(curr_memory_str)

View File

@@ -1,3 +1,4 @@
import asyncio
import logging
from typing import TYPE_CHECKING, List, Optional
@@ -142,11 +143,11 @@ class Memory(BaseModel, validate_assignment=True):
"""
try:
# Validate Jinja2 syntax with async enabled
Template(prompt_template, enable_async=True)
Template(prompt_template)
# Validate compatibility with current memory structure - use async rendering
template = Template(prompt_template, enable_async=True)
await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
template = Template(prompt_template)
await asyncio.to_thread(template.render, blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
# If we get here, the template is valid and compatible
self.prompt_template = prompt_template
@@ -189,6 +190,11 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
@trace_method
async def compile_in_thread_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
"""Compile the memory in a thread"""
return await asyncio.to_thread(self.compile, tool_usage_rules=tool_usage_rules, sources=sources, max_files_open=max_files_open)
def list_block_labels(self) -> List[str]:
"""Return a list of the block names held inside the memory object"""
# return list(self.memory.keys())

View File

@@ -1741,7 +1741,7 @@ class AgentManager:
# note: we only update the system prompt if the core memory is changed
# this means that the archival/recall memory statistics may be someout out of date
curr_memory_str = await agent_state.memory.compile_async(
curr_memory_str = await agent_state.memory.compile_in_thread_async(
sources=agent_state.sources,
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
max_files_open=agent_state.max_files_open,
@@ -1928,7 +1928,7 @@ class AgentManager:
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
new_memory_str = await new_memory.compile_async(
new_memory_str = await new_memory.compile_in_thread_async(
sources=agent_state.sources,
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
max_files_open=agent_state.max_files_open,

View File

@@ -443,7 +443,7 @@ async def compile_system_message_async(
timezone=timezone,
)
memory_with_sources = await in_context_memory.compile_async(
memory_with_sources = await in_context_memory.compile_in_thread_async(
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string

View File

@@ -42,7 +42,7 @@ class SandboxToolExecutor(ToolExecutor):
# Store original memory state
if agent_state:
orig_memory_str = await agent_state.memory.compile_async()
orig_memory_str = await agent_state.memory.compile_in_thread_async()
else:
orig_memory_str = None
@@ -73,7 +73,7 @@ class SandboxToolExecutor(ToolExecutor):
# Verify memory integrity
if agent_state:
new_memory_str = await agent_state.memory.compile_async()
new_memory_str = await agent_state.memory.compile_in_thread_async()
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
# Update agent memory if needed

View File

@@ -80,7 +80,7 @@ class AsyncToolSandboxBase(ABC):
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
"""
from letta.templates.template_helper import render_template_async
from letta.templates.template_helper import render_template_in_thread
# Select the appropriate template based on whether the function is async
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
@@ -107,7 +107,7 @@ class AsyncToolSandboxBase(ABC):
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
return await render_template_async(
return await render_template_in_thread(
TEMPLATE_NAME,
future_import=future_import,
inject_agent_state=self.inject_agent_state,

View File

@@ -1,3 +1,4 @@
import asyncio
import os
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
@@ -38,6 +39,13 @@ async def render_template_async(template_name: str, **kwargs):
return await template.render_async(**kwargs)
@trace_method
async def render_template_in_thread(template_name: str, **kwargs):
"""Asynchronously render a template from a string"""
template = jinja_env.get_template(template_name)
return await asyncio.to_thread(template.render, **kwargs)
@trace_method
async def render_string_async(template_string: str, **kwargs):
"""Asynchronously render a template from a string"""