diff --git a/letta/agents/low_latency_agent.py b/letta/agents/low_latency_agent.py index 5d43b1e7..8fe3b7c0 100644 --- a/letta/agents/low_latency_agent.py +++ b/letta/agents/low_latency_agent.py @@ -40,6 +40,7 @@ from letta.services.agent_manager import AgentManager from letta.services.block_manager import BlockManager from letta.services.helpers.agent_manager_helper import compile_system_message from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.services.summarizer.enums import SummarizationMode from letta.services.summarizer.summarizer import Summarizer from letta.utils import united_diff @@ -75,6 +76,7 @@ class LowLatencyAgent(BaseAgent): # TODO: Make this more general, factorable # Summarizer settings self.block_manager = block_manager + self.passage_manager = PassageManager() # TODO: pass this in # TODO: This is not guaranteed to exist! self.summary_block_label = "human" self.summarizer = Summarizer( @@ -246,10 +248,16 @@ class LowLatencyAgent(BaseAgent): return in_context_messages memory_edit_timestamp = get_utc_time() + + num_messages = self.message_manager.size(actor=actor, agent_id=agent_id) + num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id) + new_system_message_str = compile_system_message( system_prompt=agent_state.system, in_context_memory=agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, ) diff = united_diff(curr_system_message_text, new_system_message_str) diff --git a/letta/functions/function_sets/base.py b/letta/functions/function_sets/base.py index ef32e65a..0ec88c96 100644 --- a/letta/functions/function_sets/base.py +++ b/letta/functions/function_sets/base.py @@ -77,6 +77,7 @@ def archival_memory_insert(self: "Agent", content: str) -> Optional[str]: text=content, actor=self.user, ) + self.agent_manager.rebuild_system_prompt(agent_id=self.agent_state.id, actor=self.user, force=True) return None diff --git a/letta/server/server.py b/letta/server/server.py index 0a7de95f..af6adbfb 100644 --- a/letta/server/server.py +++ b/letta/server/server.py @@ -850,6 +850,9 @@ class SyncServer(Server): # TODO: @mindy look at moving this to agent_manager to avoid above extra call passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor) + # rebuild agent system prompt - force since no archival change + self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True) + return passages def modify_archival_memory(self, agent_id: str, memory_id: str, passage: PassageUpdate, actor: User) -> List[Passage]: @@ -859,10 +862,14 @@ class SyncServer(Server): def delete_archival_memory(self, memory_id: str, actor: User): # TODO check if it exists first, and throw error if not - # TODO: @mindy make this return the deleted passage instead + # TODO: need to also rebuild the prompt here + passage = self.passage_manager.get_passage_by_id(passage_id=memory_id, actor=actor) + + # delete the passage self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor) - # TODO: return archival memory + # rebuild system prompt and force + self.agent_manager.rebuild_system_prompt(agent_id=passage.agent_id, actor=actor, force=True) def get_agent_recall( self, @@ -981,6 +988,9 @@ class SyncServer(Server): new_passage_size = self.agent_manager.passage_size(actor=actor, agent_id=agent_id) assert new_passage_size >= curr_passage_size # in case empty files are added + # rebuild system prompt and force + self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor, force=True) + return job def load_data( diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 51a35b14..d751544c 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -55,6 +55,7 @@ from letta.services.helpers.agent_manager_helper import ( ) from letta.services.identity_manager import IdentityManager from letta.services.message_manager import MessageManager +from letta.services.passage_manager import PassageManager from letta.services.source_manager import SourceManager from letta.services.tool_manager import ToolManager from letta.settings import settings @@ -76,6 +77,7 @@ class AgentManager: self.tool_manager = ToolManager() self.source_manager = SourceManager() self.message_manager = MessageManager() + self.passage_manager = PassageManager() self.identity_manager = IdentityManager() # ====================================================================================================================== @@ -625,12 +627,17 @@ class AgentManager: # NOTE: a bit of a hack - we pull the timestamp from the message created_by memory_edit_timestamp = curr_system_message.created_at + num_messages = self.message_manager.size(actor=actor, agent_id=agent_id) + num_archival_memories = self.passage_manager.size(actor=actor, agent_id=agent_id) + # update memory (TODO: potentially update recall/archival stats separately) new_system_message_str = compile_system_message( system_prompt=agent_state.system, in_context_memory=agent_state.memory, in_context_memory_last_edit=memory_edit_timestamp, recent_passages=self.list_passages(actor=actor, agent_id=agent_id, ascending=False, limit=10), + previous_message_count=num_messages, + archival_memory_size=num_archival_memories, ) diff = united_diff(curr_system_message_openai["content"], new_system_message_str) diff --git a/letta/services/passage_manager.py b/letta/services/passage_manager.py index 1758287b..8c02655b 100644 --- a/letta/services/passage_manager.py +++ b/letta/services/passage_manager.py @@ -203,3 +203,18 @@ class PassageManager: for passage in passages: self.delete_passage_by_id(passage_id=passage.id, actor=actor) return True + + @enforce_types + def size( + self, + actor: PydanticUser, + agent_id: Optional[str] = None, + ) -> int: + """Get the total count of messages with optional filters. + + Args: + actor: The user requesting the count + agent_id: The agent ID of the messages + """ + with self.session_maker() as session: + return AgentPassage.size(db_session=session, actor=actor, agent_id=agent_id)