feat: remove redundant memory compilation in agent step (#3785)
This commit is contained in:
@@ -17,7 +17,7 @@ from letta.schemas.message import Message, MessageCreate, MessageUpdate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.services.agent_manager import AgentManager
|
||||
from letta.services.helpers.agent_manager_helper import compile_system_message_async
|
||||
from letta.services.helpers.agent_manager_helper import get_system_message_from_compiled_memory
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.utils import united_diff
|
||||
@@ -142,16 +142,13 @@ class BaseAgent(ABC):
|
||||
if num_archival_memories is None:
|
||||
num_archival_memories = await self.passage_manager.agent_passage_size_async(actor=self.actor, agent_id=agent_state.id)
|
||||
|
||||
new_system_message_str = await compile_system_message_async(
|
||||
new_system_message_str = get_system_message_from_compiled_memory(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
memory_with_sources=curr_memory_str,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
timezone=agent_state.timezone,
|
||||
previous_message_count=num_messages - len(in_context_messages),
|
||||
archival_memory_size=num_archival_memories,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_text, new_system_message_str)
|
||||
|
||||
@@ -87,8 +87,8 @@ from letta.services.helpers.agent_manager_helper import (
|
||||
calculate_multi_agent_tools,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
compile_system_message_async,
|
||||
derive_system_message,
|
||||
get_system_message_from_compiled_memory,
|
||||
initialize_message_sequence,
|
||||
initialize_message_sequence_async,
|
||||
package_initial_message_sequence,
|
||||
@@ -1750,16 +1750,13 @@ class AgentManager:
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
|
||||
new_system_message_str = await compile_system_message_async(
|
||||
new_system_message_str = get_system_message_from_compiled_memory(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
memory_with_sources=curr_memory_str,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
timezone=agent_state.timezone,
|
||||
previous_message_count=num_messages - len(agent_state.message_ids),
|
||||
archival_memory_size=num_archival_memories,
|
||||
tool_rules_solver=tool_rules_solver,
|
||||
sources=agent_state.sources,
|
||||
max_files_open=agent_state.max_files_open,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
|
||||
@@ -329,6 +329,74 @@ def compile_system_message(
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
def get_system_message_from_compiled_memory(
|
||||
system_prompt: str,
|
||||
memory_with_sources: str,
|
||||
in_context_memory_last_edit: datetime, # TODO move this inside of BaseMemory?
|
||||
timezone: str,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
template_format: Literal["f-string", "mustache", "jinja2"] = "f-string",
|
||||
previous_message_count: int = 0,
|
||||
archival_memory_size: int = 0,
|
||||
) -> str:
|
||||
"""Prepare the final/full system message that will be fed into the LLM API
|
||||
|
||||
The base system message may be templated, in which case we need to render the variables.
|
||||
|
||||
The following are reserved variables:
|
||||
- CORE_MEMORY: the in-context memory of the LLM
|
||||
"""
|
||||
if user_defined_variables is not None:
|
||||
# TODO eventually support the user defining their own variables to inject
|
||||
raise NotImplementedError
|
||||
else:
|
||||
variables = {}
|
||||
|
||||
# Add the protected memory variable
|
||||
if IN_CONTEXT_MEMORY_KEYWORD in variables:
|
||||
raise ValueError(f"Found protected variable '{IN_CONTEXT_MEMORY_KEYWORD}' in user-defined vars: {str(user_defined_variables)}")
|
||||
else:
|
||||
# TODO should this all put into the memory.__repr__ function?
|
||||
memory_metadata_string = compile_memory_metadata_block(
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
previous_message_count=previous_message_count,
|
||||
archival_memory_size=archival_memory_size,
|
||||
timezone=timezone,
|
||||
)
|
||||
|
||||
full_memory_string = memory_with_sources + "\n\n" + memory_metadata_string
|
||||
|
||||
# Add to the variables list to inject
|
||||
variables[IN_CONTEXT_MEMORY_KEYWORD] = full_memory_string
|
||||
|
||||
if template_format == "f-string":
|
||||
memory_variable_string = "{" + IN_CONTEXT_MEMORY_KEYWORD + "}"
|
||||
|
||||
# Catch the special case where the system prompt is unformatted
|
||||
if append_icm_if_missing:
|
||||
if memory_variable_string not in system_prompt:
|
||||
# In this case, append it to the end to make sure memory is still injected
|
||||
# warnings.warn(f"{IN_CONTEXT_MEMORY_KEYWORD} variable was missing from system prompt, appending instead")
|
||||
system_prompt += "\n\n" + memory_variable_string
|
||||
|
||||
# render the variables using the built-in templater
|
||||
try:
|
||||
if user_defined_variables:
|
||||
formatted_prompt = safe_format(system_prompt, variables)
|
||||
else:
|
||||
formatted_prompt = system_prompt.replace(memory_variable_string, full_memory_string)
|
||||
except Exception as e:
|
||||
raise ValueError(f"Failed to format system prompt - {str(e)}. System prompt value:\n{system_prompt}")
|
||||
|
||||
else:
|
||||
# TODO support for mustache and jinja2
|
||||
raise NotImplementedError(template_format)
|
||||
|
||||
return formatted_prompt
|
||||
|
||||
|
||||
@trace_method
|
||||
async def compile_system_message_async(
|
||||
system_prompt: str,
|
||||
|
||||
Reference in New Issue
Block a user