feat: convert compile system prompt to async (#3685)

This commit is contained in:
cthomas
2025-07-31 15:49:59 -07:00
committed by GitHub
parent 774c4c1481
commit fb7615be0c
6 changed files with 201 additions and 8 deletions

View File

@@ -17,7 +17,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.services.agent_manager import AgentManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.helpers.agent_manager_helper import compile_system_message_async
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.utils import united_diff
@@ -142,7 +142,7 @@ class BaseAgent(ABC):
if num_archival_memories is None:
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
new_system_message_str = compile_system_message(
new_system_message_str = await compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -36,7 +36,7 @@ from letta.server.rest_api.utils import (
)
from letta.services.agent_manager import AgentManager
from letta.services.block_manager import BlockManager
from letta.services.helpers.agent_manager_helper import compile_system_message
from letta.services.helpers.agent_manager_helper import compile_system_message_async
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
@@ -145,7 +145,7 @@ class VoiceAgent(BaseAgent):
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
memory_edit_timestamp = get_utc_time()
in_context_messages[0].content[0].text = compile_system_message(
in_context_messages[0].content[0].text = await compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,

View File

@@ -11,6 +11,7 @@ if TYPE_CHECKING:
from openai.types.beta.function_tool import FunctionTool as OpenAITool
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
from letta.otel.tracing import trace_method
from letta.schemas.block import Block, FileBlock
from letta.schemas.message import Message
@@ -114,6 +115,7 @@ class Memory(BaseModel, validate_assignment=True):
"""Return the current Jinja2 template string."""
return str(self.prompt_template)
@trace_method
def set_prompt_template(self, prompt_template: str):
"""
Set a new Jinja2 template string.
@@ -133,6 +135,7 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
@trace_method
async def set_prompt_template_async(self, prompt_template: str):
"""
Async version of set_prompt_template that doesn't block the event loop.
@@ -152,6 +155,7 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
@trace_method
def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
"""Generate a string representation of the memory in-context using the Jinja2 template"""
try:
@@ -168,6 +172,7 @@ class Memory(BaseModel, validate_assignment=True):
except Exception as e:
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
@trace_method
async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
"""Async version of compile that doesn't block the event loop"""
try:

View File

@@ -86,8 +86,10 @@ from letta.services.helpers.agent_manager_helper import (
calculate_multi_agent_tools,
check_supports_structured_output,
compile_system_message,
compile_system_message_async,
derive_system_message,
initialize_message_sequence,
initialize_message_sequence_async,
package_initial_message_sequence,
validate_agent_exists_async,
)
@@ -621,7 +623,7 @@ class AgentManager:
# initial message sequence (skip if _init_with_no_messages is True)
if not _init_with_no_messages:
init_messages = self._generate_initial_message_sequence(
init_messages = await self._generate_initial_message_sequence_async(
actor,
agent_state=result,
supplied_initial_message_sequence=agent_create.initial_message_sequence,
@@ -666,6 +668,35 @@ class AgentManager:
return init_messages
@enforce_types
async def _generate_initial_message_sequence_async(
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
) -> List[Message]:
init_messages = await initialize_message_sequence_async(
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
)
if supplied_initial_message_sequence is not None:
# We always need the system prompt up front
system_message_obj = PydanticMessage.dict_to_message(
agent_id=agent_state.id,
model=agent_state.llm_config.model,
openai_message_dict=init_messages[0],
)
# Don't use anything else in the pregen sequence, instead use the provided sequence
init_messages = [system_message_obj]
init_messages.extend(
package_initial_message_sequence(
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
)
)
else:
init_messages = [
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
for msg in init_messages
]
return init_messages
@enforce_types
@trace_method
def append_initial_message_sequence_to_in_context_messages(
@@ -679,7 +710,7 @@ class AgentManager:
async def append_initial_message_sequence_to_in_context_messages_async(
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
) -> PydanticAgentState:
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence)
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
@enforce_types
@@ -1674,7 +1705,7 @@ class AgentManager:
# update memory (TODO: potentially update recall/archival stats separately)
new_system_message_str = compile_system_message(
new_system_message_str = await compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
@@ -1809,7 +1840,7 @@ class AgentManager:
# Optionally add default initial messages after the system message
if add_default_initial_messages:
init_messages = initialize_message_sequence(
init_messages = await initialize_message_sequence_async(
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
)
# Skip index 0 (system message) since we preserved the original

View File

@@ -248,6 +248,7 @@ def safe_format(template: str, variables: dict) -> str:
return escaped.format_map(PreserveMapping(variables))
@trace_method
def compile_system_message(
system_prompt: str,
in_context_memory: Memory,
@@ -327,6 +328,87 @@ def compile_system_message(
return formatted_prompt
@trace_method
async def compile_system_message_async(
system_prompt: str,
in_context_memory: Memory,
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
timezone: str,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
previous_message_count: int = 0,
archival_memory_size: int = 0,
tool_rules_solver: Optional[ToolRulesSolver] = None,
sources: Optional[List] = None,
max_files_open: Optional[int] = None,
) -> str:
"""Prepare the final/full system message that will be fed into the LLM API
The base system message may be templated, in which case we need to render the variables.
The following are reserved variables:
- CORE_MEMORY: the in-context memory of the LLM
"""
# Add tool rule constraints if available
tool_constraint_block = None
if tool_rules_solver is not None:
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
if user_defined_variables is not None:
# TODO eventually support the user defining their own variables to inject
raise NotImplementedError
else:
variables = {}
# Add the protected memory variable
if IN_CONTEXT_MEMORY_KEYWORD in variables:
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
else:
# TODO should this all put into the memory.__repr__ function?
memory_metadata_string = compile_memory_metadata_block(
memory_edit_timestamp=in_context_memory_last_edit,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
timezone=timezone,
)
memory_with_sources = await in_context_memory.compile_async(
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
)
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
# Add to the variables list to inject
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
if template_format == "f-string":
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
# Catch the special case where the system prompt is unformatted
if append_icm_if_missing:
if memory_variable_string not in system_prompt:
# In this case, append it to the end to make sure memory is still injected
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
system_prompt += "\n\n" + memory_variable_string
# render the variables using the built-in templater
try:
if user_defined_variables:
formatted_prompt = safe_format(system_prompt, variables)
else:
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
except Exception as e:
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
else:
# TODO support for mustache and jinja2
raise NotImplementedError(template_format)
return formatted_prompt
@trace_method
def initialize_message_sequence(
agent_state: AgentState,
memory_edit_timestamp: Optional[datetime] = None,
@@ -396,6 +478,76 @@ def initialize_message_sequence(
return messages
@trace_method
async def initialize_message_sequence_async(
agent_state: AgentState,
memory_edit_timestamp: Optional[datetime] = None,
include_initial_boot_message: bool = True,
previous_message_count: int = 0,
archival_memory_size: int = 0,
) -> List[dict]:
if memory_edit_timestamp is None:
memory_edit_timestamp = get_local_time()
full_system_message = await compile_system_message_async(
system_prompt=agent_state.system,
in_context_memory=agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
timezone=agent_state.timezone,
user_defined_variables=None,
append_icm_if_missing=True,
previous_message_count=previous_message_count,
archival_memory_size=archival_memory_size,
sources=agent_state.sources,
max_files_open=agent_state.max_files_open,
)
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
if include_initial_boot_message:
llm_config = agent_state.llm_config
uuid_str = str(uuid.uuid4())
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
if agent_state.agent_type == AgentType.sleeptime_agent:
initial_boot_messages = []
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
else:
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
if llm_config.provider_name == "lmstudio_openai":
messages = (
[
{"role": "system", "content": full_system_message},
]
+ [
{"role": "user", "content": first_user_message},
]
+ initial_boot_messages
)
else:
messages = (
[
{"role": "system", "content": full_system_message},
]
+ initial_boot_messages
+ [
{"role": "user", "content": first_user_message},
]
)
else:
messages = [
{"role": "system", "content": full_system_message},
{"role": "user", "content": first_user_message},
]
return messages
def package_initial_message_sequence(
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, timezone: str, actor: User
) -> List[Message]:

View File

@@ -2,6 +2,8 @@ import os
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
from letta.otel.tracing import trace_method
TEMPLATE_DIR = os.path.dirname(__file__)
# Synchronous environment (for backward compatibility)
@@ -22,18 +24,21 @@ jinja_async_env = Environment(
)
@trace_method
def render_template(template_name: str, **kwargs):
"""Synchronous template rendering function (kept for backward compatibility)"""
template = jinja_env.get_template(template_name)
return template.render(**kwargs)
@trace_method
async def render_template_async(template_name: str, **kwargs):
"""Asynchronous template rendering function that doesn't block the event loop"""
template = jinja_async_env.get_template(template_name)
return await template.render_async(**kwargs)
@trace_method
async def render_string_async(template_string: str, **kwargs):
"""Asynchronously render a template from a string"""
template = Template(template_string, enable_async=True)