From 8919f06b673033bd4fb06ef6aea6fa64011ffd47 Mon Sep 17 00:00:00 2001 From: cthomas Date: Thu, 31 Jul 2025 15:49:59 -0700 Subject: [PATCH] feat: convert compile system prompt to async (#3685) --- letta/agents/base_agent.py | 4 +- letta/agents/voice_agent.py | 4 +- letta/schemas/memory.py | 5 + letta/services/agent_manager.py | 39 ++++- .../services/helpers/agent_manager_helper.py | 152 ++++++++++++++++++ letta/templates/template_helper.py | 5 + 6 files changed, 201 insertions(+), 8 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 325090ec..420f78d4 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -17,7 +17,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agent_manager import AgentManager -from letta.services.helpers.agent_manager_helper import compile_system_message +from letta.services.helpers.agent_manager_helper import compile_system_message_async from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.utils import united_diff @@ -142,7 +142,7 @@ class BaseAgent(ABC): if num_archival_memories is None: num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) - new_system_message_str = compile_system_message( + new_system_message_str = await compile_system_message_async( system_prompt=agent_state.system, in_context_memory=agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, diff --git a/letta/agents/voice_agent.py b/letta/agents/voice_agent.py index e28fb17f..a75e7e8d 100644 --- a/letta/agents/voice_agent.py +++ b/letta/agents/voice_agent.py @@ -36,7 +36,7 @@ from letta.server.rest_api.utils import ( ) from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager -from letta.services.helpers.agent_manager_helper import compile_system_message +from letta.services.helpers.agent_manager_helper import compile_system_message_async from letta.services.job_manager import JobManager from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager @@ -145,7 +145,7 @@ class VoiceAgent(BaseAgent): in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor) memory_edit_timestamp = get_utc_time() - in_context_messages[0].content[0].text = compile_system_message( + in_context_messages[0].content[0].text = await compile_system_message_async( system_prompt=agent_state.system, in_context_memory=agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, diff --git a/letta/schemas/memory.py b/letta/schemas/memory.py index e59a6437..3952a647 100644 --- a/letta/schemas/memory.py +++ b/letta/schemas/memory.py @@ -11,6 +11,7 @@ if TYPE_CHECKING: from openai.types.beta.function_tool import FunctionTool as OpenAITool from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT +from letta.otel.tracing import trace_method from letta.schemas.block import Block, FileBlock from letta.schemas.message import Message @@ -114,6 +115,7 @@ class Memory(BaseModel, validate_assignment=True): """Return the current Jinja2 template string.""" return str(self.prompt_template) + @trace_method def set_prompt_template(self, prompt_template: str): """ Set a new Jinja2 template string. @@ -133,6 +135,7 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + @trace_method async def set_prompt_template_async(self, prompt_template: str): """ Async version of set_prompt_template that doesn't block the event loop. @@ -152,6 +155,7 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + @trace_method def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str: """Generate a string representation of the memory in-context using the Jinja2 template""" try: @@ -168,6 +172,7 @@ class Memory(BaseModel, validate_assignment=True): except Exception as e: raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}") + @trace_method async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str: """Async version of compile that doesn't block the event loop""" try: diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 7220af38..5696a11c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -86,8 +86,10 @@ from letta.services.helpers.agent_manager_helper import ( calculate_multi_agent_tools, check_supports_structured_output, compile_system_message, + compile_system_message_async, derive_system_message, initialize_message_sequence, + initialize_message_sequence_async, package_initial_message_sequence, validate_agent_exists_async, ) @@ -621,7 +623,7 @@ class AgentManager: # initial message sequence (skip if _init_with_no_messages is True) if not _init_with_no_messages: - init_messages = self._generate_initial_message_sequence( + init_messages = await self._generate_initial_message_sequence_async( actor, agent_state=result, supplied_initial_message_sequence=agent_create.initial_message_sequence, @@ -666,6 +668,35 @@ class AgentManager: return init_messages + @enforce_types + async def _generate_initial_message_sequence_async( + self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None + ) -> List[Message]: + init_messages = await initialize_message_sequence_async( + agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True + ) + if supplied_initial_message_sequence is not None: + # We always need the system prompt up front + system_message_obj = PydanticMessage.dict_to_message( + agent_id=agent_state.id, + model=agent_state.llm_config.model, + openai_message_dict=init_messages[0], + ) + # Don't use anything else in the pregen sequence, instead use the provided sequence + init_messages = [system_message_obj] + init_messages.extend( + package_initial_message_sequence( + agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor + ) + ) + else: + init_messages = [ + PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg) + for msg in init_messages + ] + + return init_messages + @enforce_types @trace_method def append_initial_message_sequence_to_in_context_messages( @@ -679,7 +710,7 @@ class AgentManager: async def append_initial_message_sequence_to_in_context_messages_async( self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None ) -> PydanticAgentState: - init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence) + init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence) return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor) @enforce_types @@ -1674,7 +1705,7 @@ class AgentManager: # update memory (TODO: potentially update recall/archival stats separately) - new_system_message_str = compile_system_message( + new_system_message_str = await compile_system_message_async( system_prompt=agent_state.system, in_context_memory=agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, @@ -1809,7 +1840,7 @@ class AgentManager: # Optionally add default initial messages after the system message if add_default_initial_messages: - init_messages = initialize_message_sequence( + init_messages = await initialize_message_sequence_async( agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True ) # Skip index 0 (system message) since we preserved the original diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index ccebf35b..62fd1de1 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -248,6 +248,7 @@ def safe_format(template: str, variables: dict) -> str: return escaped.format_map(PreserveMapping(variables)) +@trace_method def compile_system_message( system_prompt: str, in_context_memory: Memory, @@ -327,6 +328,87 @@ def compile_system_message( return formatted_prompt +@trace_method +async def compile_system_message_async( + system_prompt: str, + in_context_memory: Memory, + in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory? + timezone: str, + user_defined_variables: Optional[dict] = None, + append_icm_if_missing: bool = True, + template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + previous_message_count: int = 0, + archival_memory_size: int = 0, + tool_rules_solver: Optional[ToolRulesSolver] = None, + sources: Optional[List] = None, + max_files_open: Optional[int] = None, +) -> str: + """Prepare the final/full system message that will be fed into the LLM API + + The base system message may be templated, in which case we need to render the variables. + + The following are reserved variables: + - CORE_MEMORY: the in-context memory of the LLM + """ + + # Add tool rule constraints if available + tool_constraint_block = None + if tool_rules_solver is not None: + tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts() + + if user_defined_variables is not None: + # TODO eventually support the user defining their own variables to inject + raise NotImplementedError + else: + variables = {} + + # Add the protected memory variable + if IN_CONTEXT_MEMORY_KEYWORD in variables: + raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}") + else: + # TODO should this all put into the memory.__repr__ function? + memory_metadata_string = compile_memory_metadata_block( + memory_edit_timestamp=in_context_memory_last_edit, + previous_message_count=previous_message_count, + archival_memory_size=archival_memory_size, + timezone=timezone, + ) + + memory_with_sources = await in_context_memory.compile_async( + tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open + ) + full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string + + # Add to the variables list to inject + variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string + + if template_format == "f-string": + memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}" + + # Catch the special case where the system prompt is unformatted + if append_icm_if_missing: + if memory_variable_string not in system_prompt: + # In this case, append it to the end to make sure memory is still injected + # warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead") + system_prompt += "\n\n" + memory_variable_string + + # render the variables using the built-in templater + try: + if user_defined_variables: + formatted_prompt = safe_format(system_prompt, variables) + else: + formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string) + except Exception as e: + raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}") + + else: + # TODO support for mustache and jinja2 + raise NotImplementedError(template_format) + + return formatted_prompt + + +@trace_method def initialize_message_sequence( agent_state: AgentState, memory_edit_timestamp: Optional[datetime] = None, @@ -396,6 +478,76 @@ def initialize_message_sequence( return messages +@trace_method +async def initialize_message_sequence_async( + agent_state: AgentState, + memory_edit_timestamp: Optional[datetime] = None, + include_initial_boot_message: bool = True, + previous_message_count: int = 0, + archival_memory_size: int = 0, +) -> List[dict]: + if memory_edit_timestamp is None: + memory_edit_timestamp = get_local_time() + + full_system_message = await compile_system_message_async( + system_prompt=agent_state.system, + in_context_memory=agent_state.memory, + in_context_memory_last_edit=memory_edit_timestamp, + timezone=agent_state.timezone, + user_defined_variables=None, + append_icm_if_missing=True, + previous_message_count=previous_message_count, + archival_memory_size=archival_memory_size, + sources=agent_state.sources, + max_files_open=agent_state.max_files_open, + ) + first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in + + if include_initial_boot_message: + llm_config = agent_state.llm_config + uuid_str = str(uuid.uuid4()) + + # Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters + tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str + + if agent_state.agent_type == AgentType.sleeptime_agent: + initial_boot_messages = [] + elif llm_config.model is not None and "gpt-3.5" in llm_config.model: + initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id) + else: + initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id) + + # Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls + if llm_config.provider_name == "lmstudio_openai": + messages = ( + [ + {"role": "system", "content": full_system_message}, + ] + + [ + {"role": "user", "content": first_user_message}, + ] + + initial_boot_messages + ) + else: + messages = ( + [ + {"role": "system", "content": full_system_message}, + ] + + initial_boot_messages + + [ + {"role": "user", "content": first_user_message}, + ] + ) + + else: + messages = [ + {"role": "system", "content": full_system_message}, + {"role": "user", "content": first_user_message}, + ] + + return messages + + def package_initial_message_sequence( agent_id: str, initial_message_sequence: List[MessageCreate], model: str, timezone: str, actor: User ) -> List[Message]: diff --git a/letta/templates/template_helper.py b/letta/templates/template_helper.py index 428e2bd2..54de2b5c 100644 --- a/letta/templates/template_helper.py +++ b/letta/templates/template_helper.py @@ -2,6 +2,8 @@ import os from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template +from letta.otel.tracing import trace_method + TEMPLATE_DIR = os.path.dirname(__file__) # Synchronous environment (for backward compatibility) @@ -22,18 +24,21 @@ jinja_async_env = Environment( ) +@trace_method def render_template(template_name: str, **kwargs): """Synchronous template rendering function (kept for backward compatibility)""" template = jinja_env.get_template(template_name) return template.render(**kwargs) +@trace_method async def render_template_async(template_name: str, **kwargs): """Asynchronous template rendering function that doesn't block the event loop""" template = jinja_async_env.get_template(template_name) return await template.render_async(**kwargs) +@trace_method async def render_string_async(template_string: str, **kwargs): """Asynchronously render a template from a string""" template = Template(template_string, enable_async=True)