feat: convert compile system prompt to async (#3685)
This commit is contained in:
@@ -17,7 +17,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message_async
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.utils import united_diff
|
||||
@@ -142,7 +142,7 @@ class BaseAgent(ABC):
|
||||
if num_archival_memories is None:
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
new_system_message_str = await compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -36,7 +36,7 @@ from letta.server.rest_api.utils import (
|
||||
)
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message_async
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
@@ -145,7 +145,7 @@ class VoiceAgent(BaseAgent):
|
||||
|
||||
in_context_messages = await self.message_manager.get_messages_by_ids_async(message_ids=agent_state.message_ids, actor=self.actor)
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
in_context_messages[0].content[0].text = compile_system_message(
|
||||
in_context_messages[0].content[0].text = await compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
|
||||
@@ -11,6 +11,7 @@ if TYPE_CHECKING:
|
||||
from openai.types.beta.function_tool import FunctionTool as OpenAITool
|
||||
|
||||
from letta.constants import CORE_MEMORY_BLOCK_CHAR_LIMIT
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.block import Block, FileBlock
|
||||
from letta.schemas.message import Message
|
||||
|
||||
@@ -114,6 +115,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
"""Return the current Jinja2 template string."""
|
||||
return str(self.prompt_template)
|
||||
|
||||
@trace_method
|
||||
def set_prompt_template(self, prompt_template: str):
|
||||
"""
|
||||
Set a new Jinja2 template string.
|
||||
@@ -133,6 +135,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
@trace_method
|
||||
async def set_prompt_template_async(self, prompt_template: str):
|
||||
"""
|
||||
Async version of set_prompt_template that doesn't block the event loop.
|
||||
@@ -152,6 +155,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
@trace_method
|
||||
def compile(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
|
||||
"""Generate a string representation of the memory in-context using the Jinja2 template"""
|
||||
try:
|
||||
@@ -168,6 +172,7 @@ class Memory(BaseModel, validate_assignment=True):
|
||||
except Exception as e:
|
||||
raise ValueError(f"Prompt template is not compatible with current memory structure: {str(e)}")
|
||||
|
||||
@trace_method
|
||||
async def compile_async(self, tool_usage_rules=None, sources=None, max_files_open=None) -> str:
|
||||
"""Async version of compile that doesn't block the event loop"""
|
||||
try:
|
||||
|
||||
@@ -86,8 +86,10 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
calculate_multi_agent_tools,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
compile_system_message_async,
|
||||
derive_system_message,
|
||||
initialize_message_sequence,
|
||||
initialize_message_sequence_async,
|
||||
package_initial_message_sequence,
|
||||
validate_agent_exists_async,
|
||||
)
|
||||
@@ -621,7 +623,7 @@ class AgentManager:
|
||||
|
||||
# initial message sequence (skip if _init_with_no_messages is True)
|
||||
if not _init_with_no_messages:
|
||||
init_messages = self._generate_initial_message_sequence(
|
||||
init_messages = await self._generate_initial_message_sequence_async(
|
||||
actor,
|
||||
agent_state=result,
|
||||
supplied_initial_message_sequence=agent_create.initial_message_sequence,
|
||||
@@ -666,6 +668,35 @@ class AgentManager:
|
||||
|
||||
return init_messages
|
||||
|
||||
@enforce_types
|
||||
async def _generate_initial_message_sequence_async(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, supplied_initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> List[Message]:
|
||||
init_messages = await initialize_message_sequence_async(
|
||||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||||
)
|
||||
if supplied_initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj]
|
||||
init_messages.extend(
|
||||
package_initial_message_sequence(
|
||||
agent_state.id, supplied_initial_message_sequence, agent_state.llm_config.model, agent_state.timezone, actor
|
||||
)
|
||||
)
|
||||
else:
|
||||
init_messages = [
|
||||
PydanticMessage.dict_to_message(agent_id=agent_state.id, model=agent_state.llm_config.model, openai_message_dict=msg)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
return init_messages
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def append_initial_message_sequence_to_in_context_messages(
|
||||
@@ -679,7 +710,7 @@ class AgentManager:
|
||||
async def append_initial_message_sequence_to_in_context_messages_async(
|
||||
self, actor: PydanticUser, agent_state: PydanticAgentState, initial_message_sequence: Optional[List[MessageCreate]] = None
|
||||
) -> PydanticAgentState:
|
||||
init_messages = self._generate_initial_message_sequence(actor, agent_state, initial_message_sequence)
|
||||
init_messages = await self._generate_initial_message_sequence_async(actor, agent_state, initial_message_sequence)
|
||||
return await self.append_to_in_context_messages_async(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
@@ -1674,7 +1705,7 @@ class AgentManager:
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
|
||||
new_system_message_str = compile_system_message(
|
||||
new_system_message_str = await compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
@@ -1809,7 +1840,7 @@ class AgentManager:
|
||||
|
||||
# Optionally add default initial messages after the system message
|
||||
if add_default_initial_messages:
|
||||
init_messages = initialize_message_sequence(
|
||||
init_messages = await initialize_message_sequence_async(
|
||||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||||
)
|
||||
# Skip index 0 (system message) since we preserved the original
|
||||
|
||||
@@ -248,6 +248,7 @@ def safe_format(template: str, variables: dict) -> str:
|
||||
return escaped.format_map(PreserveMapping(variables))
|
||||
|
||||
|
||||
@trace_method
|
||||
def compile_system_message(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
@@ -327,6 +328,87 @@ def compile_system_message(
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
async def compile_system_message_async(
|
||||
system_prompt: str,
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
tool_rules_solver: Optional[ToolRulesSolver] = None,
|
||||
sources: Optional[List] = None,
|
||||
max_files_open: Optional[int] = None,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
|
||||
# Add tool rule constraints if available
|
||||
tool_constraint_block = None
|
||||
if tool_rules_solver is not None:
|
||||
tool_constraint_block = tool_rules_solver.compile_tool_rule_prompts()
|
||||
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
memory_with_sources = await in_context_memory.compile_async(
|
||||
tool_usage_rules=tool_constraint_block, sources=sources, max_files_open=max_files_open
|
||||
)
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
if user_defined_variables:
|
||||
formatted_prompt = safe_format(system_prompt, variables)
|
||||
else:
|
||||
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
def initialize_message_sequence(
|
||||
agent_state: AgentState,
|
||||
memory_edit_timestamp: Optional[datetime] = None,
|
||||
@@ -396,6 +478,76 @@ def initialize_message_sequence(
|
||||
return messages
|
||||
|
||||
|
||||
@trace_method
|
||||
async def initialize_message_sequence_async(
|
||||
agent_state: AgentState,
|
||||
memory_edit_timestamp: Optional[datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> List[dict]:
|
||||
if memory_edit_timestamp is None:
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
full_system_message = await compile_system_message_async(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
timezone=agent_state.timezone,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
first_user_message = get_login_event(agent_state.timezone) # event letting Letta know the user just logged in
|
||||
|
||||
if include_initial_boot_message:
|
||||
llm_config = agent_state.llm_config
|
||||
uuid_str = str(uuid.uuid4())
|
||||
|
||||
# Some LMStudio models (e.g. ministral) require the tool call ID to be 9 alphanumeric characters
|
||||
tool_call_id = uuid_str[:9] if llm_config.provider_name == "lmstudio_openai" else uuid_str
|
||||
|
||||
if agent_state.agent_type == AgentType.sleeptime_agent:
|
||||
initial_boot_messages = []
|
||||
elif llm_config.model is not None and "gpt-3.5" in llm_config.model:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message_gpt35", agent_state.timezone, tool_call_id)
|
||||
else:
|
||||
initial_boot_messages = get_initial_boot_messages("startup_with_send_message", agent_state.timezone, tool_call_id)
|
||||
|
||||
# Some LMStudio models (e.g. meta-llama-3.1) require the user message before any tool calls
|
||||
if llm_config.provider_name == "lmstudio_openai":
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
)
|
||||
else:
|
||||
messages = (
|
||||
[
|
||||
{"role": "system", "content": full_system_message},
|
||||
]
|
||||
+ initial_boot_messages
|
||||
+ [
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
)
|
||||
|
||||
else:
|
||||
messages = [
|
||||
{"role": "system", "content": full_system_message},
|
||||
{"role": "user", "content": first_user_message},
|
||||
]
|
||||
|
||||
return messages
|
||||
|
||||
|
||||
def package_initial_message_sequence(
|
||||
agent_id: str, initial_message_sequence: List[MessageCreate], model: str, timezone: str, actor: User
|
||||
) -> List[Message]:
|
||||
|
||||
@@ -2,6 +2,8 @@ import os
|
||||
|
||||
from jinja2 import Environment, FileSystemLoader, StrictUndefined, Template
|
||||
|
||||
from letta.otel.tracing import trace_method
|
||||
|
||||
TEMPLATE_DIR = os.path.dirname(__file__)
|
||||
|
||||
# Synchronous environment (for backward compatibility)
|
||||
@@ -22,18 +24,21 @@ jinja_async_env = Environment(
|
||||
)
|
||||
|
||||
|
||||
@trace_method
|
||||
def render_template(template_name: str, **kwargs):
|
||||
"""Synchronous template rendering function (kept for backward compatibility)"""
|
||||
template = jinja_env.get_template(template_name)
|
||||
return template.render(**kwargs)
|
||||
|
||||
|
||||
@trace_method
|
||||
async def render_template_async(template_name: str, **kwargs):
|
||||
"""Asynchronous template rendering function that doesn't block the event loop"""
|
||||
template = jinja_async_env.get_template(template_name)
|
||||
return await template.render_async(**kwargs)
|
||||
|
||||
|
||||
@trace_method
|
||||
async def render_string_async(template_string: str, **kwargs):
|
||||
"""Asynchronously render a template from a string"""
|
||||
template = Template(template_string, enable_async=True)
|
||||
|
||||
Reference in New Issue
Block a user