diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 420f78d4..f26c6047 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -17,7 +17,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate from letta.schemas.usage import LettaUsageStatistics from letta.schemas.user import User from letta.services.agent_manager import AgentManager -from letta.services.helpers.agent_manager_helper import compile_system_message_async +from letta.services.helpers.agent_manager_helper import get_system_message_from_compiled_memory from letta.services.message_manager import MessageManager from letta.services.passage_manager import PassageManager from letta.utils import united_diff @@ -142,16 +142,13 @@ class BaseAgent(ABC): if num_archival_memories is None: num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id) - new_system_message_str = await compile_system_message_async( + new_system_message_str = get_system_message_from_compiled_memory( system_prompt=agent_state.system, - in_context_memory=agent_state.memory, + memory_with_sources=curr_memory_str, in_context_memory_last_edit=memory_edit_timestamp, timezone=agent_state.timezone, previous_message_count=num_messages - len(in_context_messages), archival_memory_size=num_archival_memories, - tool_rules_solver=tool_rules_solver, - sources=agent_state.sources, - max_files_open=agent_state.max_files_open, ) diff = united_diff(curr_system_message_text, new_system_message_str) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 6d4d6404..c3e57c5c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -87,8 +87,8 @@ from letta.services.helpers.agent_manager_helper import ( calculate_multi_agent_tools, check_supports_structured_output, compile_system_message, - compile_system_message_async, derive_system_message, + get_system_message_from_compiled_memory, initialize_message_sequence, initialize_message_sequence_async, package_initial_message_sequence, @@ -1750,16 +1750,13 @@ class AgentManager: # update memory (TODO: potentially update recall/archival stats separately) - new_system_message_str = await compile_system_message_async( + new_system_message_str = get_system_message_from_compiled_memory( system_prompt=agent_state.system, - in_context_memory=agent_state.memory, + memory_with_sources=curr_memory_str, in_context_memory_last_edit=memory_edit_timestamp, timezone=agent_state.timezone, previous_message_count=num_messages - len(agent_state.message_ids), archival_memory_size=num_archival_memories, - tool_rules_solver=tool_rules_solver, - sources=agent_state.sources, - max_files_open=agent_state.max_files_open, ) diff = united_diff(curr_system_message_openai["content"], new_system_message_str) diff --git a/letta/services/helpers/agent_manager_helper.py b/letta/services/helpers/agent_manager_helper.py index 04793022..092b34ff 100644 --- a/letta/services/helpers/agent_manager_helper.py +++ b/letta/services/helpers/agent_manager_helper.py @@ -329,6 +329,74 @@ def compile_system_message( return formatted_prompt +@trace_method +def get_system_message_from_compiled_memory( + system_prompt: str, + memory_with_sources: str, + in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory? + timezone: str, + user_defined_variables: Optional[dict] = None, + append_icm_if_missing: bool = True, + template_format: Literal["f-string", "mustache", "jinja2"] = "f-string", + previous_message_count: int = 0, + archival_memory_size: int = 0, +) -> str: + """Prepare the final/full system message that will be fed into the LLM API + + The base system message may be templated, in which case we need to render the variables. + + The following are reserved variables: + - CORE_MEMORY: the in-context memory of the LLM + """ + if user_defined_variables is not None: + # TODO eventually support the user defining their own variables to inject + raise NotImplementedError + else: + variables = {} + + # Add the protected memory variable + if IN_CONTEXT_MEMORY_KEYWORD in variables: + raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}") + else: + # TODO should this all put into the memory.__repr__ function? + memory_metadata_string = compile_memory_metadata_block( + memory_edit_timestamp=in_context_memory_last_edit, + previous_message_count=previous_message_count, + archival_memory_size=archival_memory_size, + timezone=timezone, + ) + + full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string + + # Add to the variables list to inject + variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string + + if template_format == "f-string": + memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}" + + # Catch the special case where the system prompt is unformatted + if append_icm_if_missing: + if memory_variable_string not in system_prompt: + # In this case, append it to the end to make sure memory is still injected + # warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead") + system_prompt += "\n\n" + memory_variable_string + + # render the variables using the built-in templater + try: + if user_defined_variables: + formatted_prompt = safe_format(system_prompt, variables) + else: + formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string) + except Exception as e: + raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}") + + else: + # TODO support for mustache and jinja2 + raise NotImplementedError(template_format) + + return formatted_prompt + + @trace_method async def compile_system_message_async( system_prompt: str,