feat: asyncify jinja templates (#3580)
This commit is contained in:
@@ -122,7 +122,7 @@ class BaseAgent(ABC):
|
||||
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
curr_memory_str = agent_state.memory.compile(
|
||||
curr_memory_str = await agent_state.memory.compile_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open
|
||||
)
|
||||
new_dynamic_section = extract_dynamic_section(curr_memory_str)
|
||||
|
||||
@@ -133,6 +133,25 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
async def set_prompt_template_async(self, prompt_template: str):
|
||||
"""
|
||||
Async version of set_prompt_template that doesn't block the event loop.
|
||||
"""
|
||||
try:
|
||||
# Validate Jinja2 syntax with async enabled
|
||||
Template(prompt_template, enable_async=True)
|
||||
|
||||
# Validate compatibility with current memory structure - use async rendering
|
||||
template = Template(prompt_template, enable_async=True)
|
||||
await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
|
||||
|
||||
# If we get here, the template is valid and compatible
|
||||
self.prompt_template = prompt_template
|
||||
except TemplateSyntaxError as e:
|
||||
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
|
||||
"""Generate a string representation of the memory in-context using the Jinja2 template"""
|
||||
try:
|
||||
@@ -149,6 +168,22 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
|
||||
"""Async version of compile that doesn't block the event loop"""
|
||||
try:
|
||||
template = Template(self.prompt_template, enable_async=True)
|
||||
return await template.render_async(
|
||||
blocks=self.blocks,
|
||||
file_blocks=self.file_blocks,
|
||||
tool_usage_rules=tool_usage_rules,
|
||||
sources=sources,
|
||||
max_files_open=max_files_open,
|
||||
)
|
||||
except TemplateSyntaxError as e:
|
||||
raise ValueError(f"Invalid Jinja2 template syntax: {str(e)}")
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
def list_block_labels(self) -> List[str]:
|
||||
"""Return a list of the block names held inside the memory object"""
|
||||
# return list(self.memory.keys())
|
||||
|
||||
@@ -1646,7 +1646,7 @@ class AgentManager:
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = agent_state.memory.compile(
|
||||
curr_memory_str = await agent_state.memory.compile_async(
|
||||
sources=agent_state.sources,
|
||||
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
|
||||
max_files_open=agent_state.max_files_open,
|
||||
@@ -1836,14 +1836,12 @@ class AgentManager:
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
|
||||
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
if (
|
||||
new_memory.compile(
|
||||
sources=agent_state.sources,
|
||||
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
not in system_message.content[0].text
|
||||
):
|
||||
new_memory_str = await new_memory.compile_async(
|
||||
sources=agent_state.sources,
|
||||
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
if new_memory_str not in system_message.content[0].text:
|
||||
# update the blocks (LRW) in the DB
|
||||
for label in agent_state.memory.list_block_labels():
|
||||
updated_value = new_memory.get_block(label).value
|
||||
|
||||
@@ -36,7 +36,10 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
) -> ToolExecutionResult:
|
||||
|
||||
# Store original memory state
|
||||
orig_memory_str = agent_state.memory.compile() if agent_state else None
|
||||
if agent_state:
|
||||
orig_memory_str = await agent_state.memory.compile_async()
|
||||
else:
|
||||
orig_memory_str = None
|
||||
|
||||
try:
|
||||
# Prepare function arguments
|
||||
@@ -58,7 +61,8 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
# Verify memory integrity
|
||||
if agent_state:
|
||||
assert orig_memory_str == agent_state.memory.compile(), "Memory should not be modified in a sandbox tool"
|
||||
new_memory_str = await agent_state.memory.compile_async()
|
||||
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
if tool_execution_result.agent_state is not None:
|
||||
|
||||
@@ -74,12 +74,12 @@ class AsyncToolSandboxBase(ABC):
|
||||
"""
|
||||
raise NotImplementedError
|
||||
|
||||
def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
|
||||
async def generate_execution_script(self, agent_state: Optional[AgentState], wrap_print_with_markers: bool = False) -> str:
|
||||
"""
|
||||
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
|
||||
then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
|
||||
"""
|
||||
from letta.templates.template_helper import render_template
|
||||
from letta.templates.template_helper import render_template_async
|
||||
|
||||
# Select the appropriate template based on whether the function is async
|
||||
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
|
||||
@@ -106,7 +106,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
|
||||
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
|
||||
|
||||
return render_template(
|
||||
return await render_template_async(
|
||||
TEMPLATE_NAME,
|
||||
future_import=future_import,
|
||||
inject_agent_state=self.inject_agent_state,
|
||||
|
||||
@@ -92,7 +92,7 @@ class AsyncToolSandboxE2B(AsyncToolSandboxBase):
|
||||
# Finally, get any that are passed explicitly into the `run` function call
|
||||
if additional_env_vars:
|
||||
env_vars.update(additional_env_vars)
|
||||
code = self.generate_execution_script(agent_state=agent_state)
|
||||
code = await self.generate_execution_script(agent_state=agent_state)
|
||||
|
||||
try:
|
||||
log_event(
|
||||
|
||||
@@ -99,8 +99,8 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
|
||||
# Make sure sandbox directory exists
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
||||
if not os.path.exists(sandbox_dir) or not os.path.isdir(sandbox_dir):
|
||||
os.makedirs(sandbox_dir)
|
||||
if not await asyncio.to_thread(lambda: os.path.exists(sandbox_dir) and os.path.isdir(sandbox_dir)):
|
||||
await asyncio.to_thread(os.makedirs, sandbox_dir)
|
||||
|
||||
# If using a virtual environment, ensure it's prepared in parallel
|
||||
venv_preparation_task = None
|
||||
@@ -109,11 +109,18 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
venv_preparation_task = asyncio.create_task(self._prepare_venv(local_configs, venv_path, env))
|
||||
|
||||
# Generate and write execution script (always with markers, since we rely on stdout)
|
||||
with tempfile.NamedTemporaryFile(mode="w", dir=sandbox_dir, suffix=".py", delete=False) as temp_file:
|
||||
code = self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
|
||||
temp_file.write(code)
|
||||
temp_file.flush()
|
||||
temp_file_path = temp_file.name
|
||||
code = await self.generate_execution_script(agent_state=agent_state, wrap_print_with_markers=True)
|
||||
|
||||
async def write_temp_file(dir, content):
|
||||
def _write():
|
||||
with tempfile.NamedTemporaryFile(mode="w", dir=dir, suffix=".py", delete=False) as temp_file:
|
||||
temp_file.write(content)
|
||||
temp_file.flush()
|
||||
return temp_file.name
|
||||
|
||||
return await asyncio.to_thread(_write)
|
||||
|
||||
temp_file_path = await write_temp_file(sandbox_dir, code)
|
||||
|
||||
try:
|
||||
# If we started a venv preparation task, wait for it to complete
|
||||
@@ -159,14 +166,14 @@ class AsyncToolSandboxLocal(AsyncToolSandboxBase):
|
||||
from letta.settings import settings
|
||||
|
||||
if not settings.debug:
|
||||
os.remove(temp_file_path)
|
||||
await asyncio.to_thread(os.remove, temp_file_path)
|
||||
|
||||
@trace_method
|
||||
async def _prepare_venv(self, local_configs, venv_path: str, env: Dict[str, str]):
|
||||
"""
|
||||
Prepare virtual environment asynchronously (in a background thread).
|
||||
"""
|
||||
if self.force_recreate_venv or not os.path.isdir(venv_path):
|
||||
if self.force_recreate_venv or not await asyncio.to_thread(os.path.isdir, venv_path):
|
||||
sandbox_dir = os.path.expanduser(local_configs.sandbox_dir)
|
||||
log_event(name="start create_venv_for_local_sandbox", attributes={"venv_path": venv_path})
|
||||
await asyncio.to_thread(
|
||||
|
||||
@@ -1,8 +1,10 @@
|
||||
import os
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
|
||||
|
||||
TEMPLATE_DIR = os.path.dirname(__file__)
|
||||
|
||||
# Synchronous environment (for backward compatibility)
|
||||
jinja_env = Environment(
|
||||
loader=FileSystemLoader(TEMPLATE_DIR),
|
||||
undefined=StrictUndefined,
|
||||
@@ -10,7 +12,29 @@ jinja_env = Environment(
|
||||
lstrip_blocks=True,
|
||||
)
|
||||
|
||||
# Async-enabled environment
|
||||
jinja_async_env = Environment(
|
||||
loader=FileSystemLoader(TEMPLATE_DIR),
|
||||
undefined=StrictUndefined,
|
||||
trim_blocks=True,
|
||||
lstrip_blocks=True,
|
||||
enable_async=True, # Enable async support
|
||||
)
|
||||
|
||||
|
||||
def render_template(template_name: str, **kwargs):
|
||||
"""Synchronous template rendering function (kept for backward compatibility)"""
|
||||
template = jinja_env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
async def render_template_async(template_name: str, **kwargs):
|
||||
"""Asynchronous template rendering function that doesn't block the event loop"""
|
||||
template = jinja_async_env.get_template(template_name)
|
||||
return await template.render_async(**kwargs)
|
||||
|
||||
|
||||
async def render_string_async(template_string: str, **kwargs):
|
||||
"""Asynchronously render a template from a string"""
|
||||
template = Template(template_string, enable_async=True)
|
||||
return await template.render_async(**kwargs)
|
||||
|
||||
@@ -908,11 +908,12 @@ def test_async_function_detection(add_integers_tool, async_add_integers_tool, te
|
||||
assert async_sandbox.is_async_function
|
||||
|
||||
|
||||
def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user):
|
||||
@pytest.mark.asyncio
|
||||
async def test_async_template_selection(add_integers_tool, async_add_integers_tool, test_user):
|
||||
"""Test that correct templates are selected for sync vs async functions"""
|
||||
# Test sync function uses regular template
|
||||
sync_sandbox = AsyncToolSandboxE2B(add_integers_tool.name, {}, test_user, tool_object=add_integers_tool)
|
||||
sync_script = sync_sandbox.generate_execution_script(agent_state=None)
|
||||
sync_script = await sync_sandbox.generate_execution_script(agent_state=None)
|
||||
print("=== SYNC SCRIPT ===")
|
||||
print(sync_script)
|
||||
print("=== END SYNC SCRIPT ===")
|
||||
@@ -921,7 +922,7 @@ def test_async_template_selection(add_integers_tool, async_add_integers_tool, te
|
||||
|
||||
# Test async function uses async template
|
||||
async_sandbox = AsyncToolSandboxE2B(async_add_integers_tool.name, {}, test_user, tool_object=async_add_integers_tool)
|
||||
async_script = async_sandbox.generate_execution_script(agent_state=None)
|
||||
async_script = await async_sandbox.generate_execution_script(agent_state=None)
|
||||
print("=== ASYNC SCRIPT ===")
|
||||
print(async_script)
|
||||
print("=== END ASYNC SCRIPT ===")
|
||||
|
||||
Reference in New Issue
Block a user