fix: Remove in-memory _messages field on Agent (#2295)
This commit is contained in:
@@ -20,6 +20,7 @@ from letta.schemas.agent import AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message as PydanticMessage
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source as PydanticSource
|
||||
from letta.schemas.tool_rule import ToolRule as PydanticToolRule
|
||||
@@ -28,12 +29,17 @@ from letta.services.block_manager import BlockManager
|
||||
from letta.services.helpers.agent_manager_helper import (
|
||||
_process_relationship,
|
||||
_process_tags,
|
||||
check_supports_structured_output,
|
||||
compile_system_message,
|
||||
derive_system_message,
|
||||
initialize_message_sequence,
|
||||
package_initial_message_sequence,
|
||||
)
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_manager import ToolManager
|
||||
from letta.settings import settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import enforce_types, get_utc_time, united_diff
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -49,6 +55,7 @@ class AgentManager:
|
||||
self.block_manager = BlockManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.source_manager = SourceManager()
|
||||
self.message_manager = MessageManager()
|
||||
|
||||
# ======================================================================================================================
|
||||
# Basic CRUD operations
|
||||
@@ -64,6 +71,10 @@ class AgentManager:
|
||||
if not agent_create.llm_config or not agent_create.embedding_config:
|
||||
raise ValueError("llm_config and embedding_config are required")
|
||||
|
||||
# Check tool rules are valid
|
||||
if agent_create.tool_rules:
|
||||
check_supports_structured_output(model=agent_create.llm_config.model, tool_rules=agent_create.tool_rules)
|
||||
|
||||
# create blocks (note: cannot be linked into the agent_id is created)
|
||||
block_ids = list(agent_create.block_ids or []) # Create a local copy to avoid modifying the original
|
||||
for create_block in agent_create.memory_blocks:
|
||||
@@ -88,7 +99,8 @@ class AgentManager:
|
||||
# Remove duplicates
|
||||
tool_ids = list(set(tool_ids))
|
||||
|
||||
return self._create_agent(
|
||||
# Create the agent
|
||||
agent_state = self._create_agent(
|
||||
name=agent_create.name,
|
||||
system=system,
|
||||
agent_type=agent_create.agent_type,
|
||||
@@ -104,6 +116,35 @@ class AgentManager:
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# TODO: See if we can merge this into the above SQL create call for performance reasons
|
||||
# Generate a sequence of initial messages to put in the buffer
|
||||
init_messages = initialize_message_sequence(
|
||||
agent_state=agent_state, memory_edit_timestamp=get_utc_time(), include_initial_boot_message=True
|
||||
)
|
||||
|
||||
if agent_create.initial_message_sequence is not None:
|
||||
# We always need the system prompt up front
|
||||
system_message_obj = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id,
|
||||
user_id=agent_state.created_by_id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict=init_messages[0],
|
||||
)
|
||||
# Don't use anything else in the pregen sequence, instead use the provided sequence
|
||||
init_messages = [system_message_obj]
|
||||
init_messages.extend(
|
||||
package_initial_message_sequence(agent_state.id, agent_create.initial_message_sequence, agent_state.llm_config.model, actor)
|
||||
)
|
||||
else:
|
||||
init_messages = [
|
||||
PydanticMessage.dict_to_message(
|
||||
agent_id=agent_state.id, user_id=agent_state.created_by_id, model=agent_state.llm_config.model, openai_message_dict=msg
|
||||
)
|
||||
for msg in init_messages
|
||||
]
|
||||
|
||||
return self.append_to_in_context_messages(init_messages, agent_id=agent_state.id, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def _create_agent(
|
||||
self,
|
||||
@@ -149,6 +190,16 @@ class AgentManager:
|
||||
|
||||
@enforce_types
|
||||
def update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
agent_state = self._update_agent(agent_id=agent_id, agent_update=agent_update, actor=actor)
|
||||
|
||||
# Rebuild the system prompt if it's different
|
||||
if agent_update.system and agent_update.system != agent_state.system:
|
||||
agent_state = self.rebuild_system_prompt(agent_id=agent_state.id, actor=actor, force=True, update_timestamp=False)
|
||||
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def _update_agent(self, agent_id: str, agent_update: UpdateAgent, actor: PydanticUser) -> PydanticAgentState:
|
||||
"""
|
||||
Update an existing agent.
|
||||
|
||||
@@ -247,6 +298,105 @@ class AgentManager:
|
||||
agent.hard_delete(session)
|
||||
return agent_state
|
||||
|
||||
# ======================================================================================================================
|
||||
# In Context Messages Management
|
||||
# ======================================================================================================================
|
||||
# TODO: There are several assumptions here that are not explicitly checked
|
||||
# TODO: 1) These message ids are valid
|
||||
# TODO: 2) These messages are ordered from oldest to newest
|
||||
# TODO: This can be fixed by having an actual relationship in the ORM for message_ids
|
||||
# TODO: This can also be made more efficient, instead of getting, setting, we can do it all in one db session for one query.
|
||||
@enforce_types
|
||||
def get_in_context_messages(self, agent_id: str, actor: PydanticUser) -> List[PydanticMessage]:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_messages_by_ids(message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_system_message(self, agent_id: str, actor: PydanticUser) -> PydanticMessage:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
return self.message_manager.get_message_by_id(message_id=message_ids[0], actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def rebuild_system_prompt(self, agent_id: str, actor: PydanticUser, force=False, update_timestamp=True) -> PydanticAgentState:
|
||||
"""Rebuilds the system message with the latest memory object and any shared memory block updates
|
||||
|
||||
Updates to core memory blocks should trigger a "rebuild", which itself will create a new message object
|
||||
|
||||
Updates to the memory header should *not* trigger a rebuild, since that will simply flood recall storage with excess messages
|
||||
"""
|
||||
agent_state = self.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
|
||||
curr_system_message = self.get_system_message(
|
||||
agent_id=agent_id, actor=actor
|
||||
) # this is the system + memory bank, not just the system prompt
|
||||
curr_system_message_openai = curr_system_message.to_openai_dict()
|
||||
|
||||
# note: we only update the system prompt if the core memory is changed
|
||||
# this means that the archival/recall memory statistics may be someout out of date
|
||||
curr_memory_str = agent_state.memory.compile()
|
||||
if curr_memory_str in curr_system_message_openai["content"] and not force:
|
||||
# NOTE: could this cause issues if a block is removed? (substring match would still work)
|
||||
logger.info(
|
||||
f"Memory hasn't changed for agent id={agent_id} and actor=({actor.id}, {actor.name}), skipping system prompt rebuild"
|
||||
)
|
||||
return agent_state
|
||||
|
||||
# If the memory didn't update, we probably don't want to update the timestamp inside
|
||||
# For example, if we're doing a system prompt swap, this should probably be False
|
||||
if update_timestamp:
|
||||
memory_edit_timestamp = get_utc_time()
|
||||
else:
|
||||
# NOTE: a bit of a hack - we pull the timestamp from the message created_by
|
||||
memory_edit_timestamp = curr_system_message.created_at
|
||||
|
||||
# update memory (TODO: potentially update recall/archival stats separately)
|
||||
new_system_message_str = compile_system_message(
|
||||
system_prompt=agent_state.system,
|
||||
in_context_memory=agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
)
|
||||
|
||||
diff = united_diff(curr_system_message_openai["content"], new_system_message_str)
|
||||
if len(diff) > 0: # there was a diff
|
||||
logger.info(f"Rebuilding system with new memory...\nDiff:\n{diff}")
|
||||
|
||||
# Swap the system message out (only if there is a diff)
|
||||
message = PydanticMessage.dict_to_message(
|
||||
agent_id=agent_id,
|
||||
user_id=actor.id,
|
||||
model=agent_state.llm_config.model,
|
||||
openai_message_dict={"role": "system", "content": new_system_message_str},
|
||||
)
|
||||
message = self.message_manager.create_message(message, actor=actor)
|
||||
message_ids = [message.id] + agent_state.message_ids[1:] # swap index 0 (system)
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
else:
|
||||
return agent_state
|
||||
|
||||
@enforce_types
|
||||
def set_in_context_messages(self, agent_id: str, message_ids: List[str], actor: PydanticUser) -> PydanticAgentState:
|
||||
return self.update_agent(agent_id=agent_id, agent_update=UpdateAgent(message_ids=message_ids), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def trim_older_in_context_messages(self, num: int, agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = [message_ids[0]] + message_ids[num:] # 0 is system message
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=new_messages, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def prepend_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids
|
||||
new_messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = [message_ids[0]] + [m.id for m in new_messages] + message_ids[1:]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def append_to_in_context_messages(self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser) -> PydanticAgentState:
|
||||
messages = self.message_manager.create_many_messages(messages, actor=actor)
|
||||
message_ids = self.get_agent_by_id(agent_id=agent_id, actor=actor).message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return self.set_in_context_messages(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
# ======================================================================================================================
|
||||
# Source Management
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user