feat: offload jinja to threadpool LET-3615 (#3787)
This commit is contained in:
@@ -122,7 +122,7 @@ class BaseAgent(ABC):
|
||||
curr_dynamic_section = extract_dynamic_section(curr_system_message_text)
|
||||
|
||||
# generate just the memory string with current state for comparison
|
||||
curr_memory_str = await agent_state.memory.compile_async(
|
||||
curr_memory_str = await agent_state.memory.compile_in_thread_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=agent_state.sources, max_files_open=agent_state.max_files_open
|
||||
)
|
||||
new_dynamic_section = extract_dynamic_section(curr_memory_str)
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import logging
|
||||
from typing import TYPE_CHECKING, List, Optional
|
||||
|
||||
@@ -142,11 +143,11 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
"""
|
||||
try:
|
||||
# Validate Jinja2 syntax with async enabled
|
||||
Template(prompt_template, enable_async=True)
|
||||
Template(prompt_template)
|
||||
|
||||
# Validate compatibility with current memory structure - use async rendering
|
||||
template = Template(prompt_template, enable_async=True)
|
||||
await template.render_async(blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
|
||||
template = Template(prompt_template)
|
||||
await asyncio.to_thread(template.render, blocks=self.blocks, file_blocks=self.file_blocks, sources=[], max_files_open=None)
|
||||
|
||||
# If we get here, the template is valid and compatible
|
||||
self.prompt_template = prompt_template
|
||||
@@ -189,6 +190,11 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
@trace_method
|
||||
async def compile_in_thread_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
|
||||
"""Compile the memory in a thread"""
|
||||
return await asyncio.to_thread(self.compile, tool_usage_rules=tool_usage_rules, sources=sources, max_files_open=max_files_open)
|
||||
|
||||
def list_block_labels(self) -> List[str]:
|
||||
"""Return a list of the block names held inside the memory object"""
|
||||
# return list(self.memory.keys())
|
||||
|
||||
@@ -1741,7 +1741,7 @@ class AgentManager:
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = await agent_state.memory.compile_async(
|
||||
curr_memory_str = await agent_state.memory.compile_in_thread_async(
|
||||
sources=agent_state.sources,
|
||||
tool_usage_rules=tool_rules_solver.compile_tool_rule_prompts(),
|
||||
max_files_open=agent_state.max_files_open,
|
||||
@@ -1928,7 +1928,7 @@ class AgentManager:
|
||||
agent_state = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor, include_relationships=["memory", "sources"])
|
||||
system_message = await self.message_manager.get_message_by_id_async(message_id=agent_state.message_ids[0], actor=actor)
|
||||
temp_tool_rules_solver = ToolRulesSolver(agent_state.tool_rules)
|
||||
new_memory_str = await new_memory.compile_async(
|
||||
new_memory_str = await new_memory.compile_in_thread_async(
|
||||
sources=agent_state.sources,
|
||||
tool_usage_rules=temp_tool_rules_solver.compile_tool_rule_prompts(),
|
||||
max_files_open=agent_state.max_files_open,
|
||||
|
||||
@@ -443,7 +443,7 @@ async def compile_system_message_async(
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
memory_with_sources = await in_context_memory.compile_async(
|
||||
memory_with_sources = await in_context_memory.compile_in_thread_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
|
||||
)
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
@@ -42,7 +42,7 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
# Store original memory state
|
||||
if agent_state:
|
||||
orig_memory_str = await agent_state.memory.compile_async()
|
||||
orig_memory_str = await agent_state.memory.compile_in_thread_async()
|
||||
else:
|
||||
orig_memory_str = None
|
||||
|
||||
@@ -73,7 +73,7 @@ class SandboxToolExecutor(ToolExecutor):
|
||||
|
||||
# Verify memory integrity
|
||||
if agent_state:
|
||||
new_memory_str = await agent_state.memory.compile_async()
|
||||
new_memory_str = await agent_state.memory.compile_in_thread_async()
|
||||
assert orig_memory_str == new_memory_str, "Memory should not be modified in a sandbox tool"
|
||||
|
||||
# Update agent memory if needed
|
||||
|
||||
@@ -80,7 +80,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
Generate code to run inside of execution sandbox. Serialize the agent state and arguments, call the tool,
|
||||
then base64-encode/pickle the result. Runs a jinja2 template constructing the python file.
|
||||
"""
|
||||
from letta.templates.template_helper import render_template_async
|
||||
from letta.templates.template_helper import render_template_in_thread
|
||||
|
||||
# Select the appropriate template based on whether the function is async
|
||||
TEMPLATE_NAME = "sandbox_code_file_async.py.j2" if self.is_async_function else "sandbox_code_file.py.j2"
|
||||
@@ -107,7 +107,7 @@ class AsyncToolSandboxBase(ABC):
|
||||
|
||||
agent_state_pickle = pickle.dumps(agent_state) if self.inject_agent_state else None
|
||||
|
||||
return await render_template_async(
|
||||
return await render_template_in_thread(
|
||||
TEMPLATE_NAME,
|
||||
future_import=future_import,
|
||||
inject_agent_state=self.inject_agent_state,
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import os
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
|
||||
@@ -38,6 +39,13 @@ async def render_template_async(template_name: str, **kwargs):
|
||||
return await template.render_async(**kwargs)
|
||||
|
||||
|
||||
@trace_method
|
||||
async def render_template_in_thread(template_name: str, **kwargs):
|
||||
"""Asynchronously render a template from a string"""
|
||||
template = jinja_env.get_template(template_name)
|
||||
return await asyncio.to_thread(template.render, **kwargs)
|
||||
|
||||
|
||||
@trace_method
|
||||
async def render_string_async(template_string: str, **kwargs):
|
||||
"""Asynchronously render a template from a string"""
|
||||
|
||||
Reference in New Issue
Block a user