feat: orm passage migration (#2180)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -28,7 +28,7 @@ from letta.interface import AgentInterface
|
||||
from letta.llm_api.helpers import is_context_overflow_error
|
||||
from letta.llm_api.llm_api_tools import create
|
||||
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
|
||||
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
|
||||
from letta.memory import summarize_messages
|
||||
from letta.metadata import MetadataStore
|
||||
from letta.orm import User
|
||||
from letta.schemas.agent import AgentState, AgentStepResponse
|
||||
@@ -52,6 +52,7 @@ from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
|
||||
from letta.services.user_manager import UserManager
|
||||
@@ -85,7 +86,7 @@ def compile_memory_metadata_block(
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
memory_edit_timestamp: datetime.datetime,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
) -> str:
|
||||
# Put the timestamp in the local timezone (mimicking get_local_time())
|
||||
@@ -96,7 +97,7 @@ def compile_memory_metadata_block(
|
||||
[
|
||||
f"### Memory [last modified: {timestamp_str}]",
|
||||
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
|
||||
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
|
||||
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
|
||||
]
|
||||
)
|
||||
@@ -109,7 +110,7 @@ def compile_system_message(
|
||||
in_context_memory: Memory,
|
||||
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
user_defined_variables: Optional[dict] = None,
|
||||
append_icm_if_missing: bool = True,
|
||||
@@ -138,7 +139,7 @@ def compile_system_message(
|
||||
actor=actor,
|
||||
agent_id=agent_id,
|
||||
memory_edit_timestamp=in_context_memory_last_edit,
|
||||
archival_memory=archival_memory,
|
||||
passage_manager=passage_manager,
|
||||
message_manager=message_manager,
|
||||
)
|
||||
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
|
||||
@@ -175,7 +176,7 @@ def initialize_message_sequence(
|
||||
agent_id: str,
|
||||
memory: Memory,
|
||||
actor: PydanticUser,
|
||||
archival_memory: Optional[ArchivalMemory] = None,
|
||||
passage_manager: Optional[PassageManager] = None,
|
||||
message_manager: Optional[MessageManager] = None,
|
||||
memory_edit_timestamp: Optional[datetime.datetime] = None,
|
||||
include_initial_boot_message: bool = True,
|
||||
@@ -184,7 +185,7 @@ def initialize_message_sequence(
|
||||
memory_edit_timestamp = get_local_time()
|
||||
|
||||
# full_system_message = construct_system_with_memory(
|
||||
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
|
||||
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
|
||||
# )
|
||||
full_system_message = compile_system_message(
|
||||
agent_id=agent_id,
|
||||
@@ -192,7 +193,7 @@ def initialize_message_sequence(
|
||||
in_context_memory=memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=actor,
|
||||
archival_memory=archival_memory,
|
||||
passage_manager=passage_manager,
|
||||
message_manager=message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@@ -294,7 +295,7 @@ class Agent(BaseAgent):
|
||||
self.interface = interface
|
||||
|
||||
# Create the persistence manager object based on the AgentState info
|
||||
self.archival_memory = EmbeddingArchivalMemory(agent_state)
|
||||
self.passage_manager = PassageManager()
|
||||
self.message_manager = MessageManager()
|
||||
|
||||
# State needed for heartbeat pausing
|
||||
@@ -325,7 +326,7 @@ class Agent(BaseAgent):
|
||||
agent_id=self.agent_state.id,
|
||||
memory=self.agent_state.memory,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
passage_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@@ -350,7 +351,7 @@ class Agent(BaseAgent):
|
||||
memory=self.agent_state.memory,
|
||||
agent_id=self.agent_state.id,
|
||||
actor=self.user,
|
||||
archival_memory=None,
|
||||
passage_manager=None,
|
||||
message_manager=None,
|
||||
memory_edit_timestamp=get_utc_time(),
|
||||
include_initial_boot_message=True,
|
||||
@@ -1306,7 +1307,7 @@ class Agent(BaseAgent):
|
||||
in_context_memory=self.agent_state.memory,
|
||||
in_context_memory_last_edit=memory_edit_timestamp,
|
||||
actor=self.user,
|
||||
archival_memory=self.archival_memory,
|
||||
passage_manager=self.passage_manager,
|
||||
message_manager=self.message_manager,
|
||||
user_defined_variables=None,
|
||||
append_icm_if_missing=True,
|
||||
@@ -1371,45 +1372,33 @@ class Agent(BaseAgent):
|
||||
# TODO: recall memory
|
||||
raise NotImplementedError()
|
||||
|
||||
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
|
||||
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore):
|
||||
"""Attach data with name `source_name` to the agent from source_connector."""
|
||||
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
|
||||
user = UserManager().get_user_by_id(self.agent_state.user_id)
|
||||
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
|
||||
size = source_connector.size(filters)
|
||||
page_size = 100
|
||||
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
|
||||
all_passages = []
|
||||
for i in tqdm(range(0, size, page_size)):
|
||||
passages = next(generator)
|
||||
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
|
||||
|
||||
# need to associated passage with agent (for filtering)
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = self.agent_state.id
|
||||
for passage in passages:
|
||||
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
|
||||
passage.agent_id = self.agent_state.id
|
||||
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
|
||||
|
||||
# regenerate passage ID (avoid duplicates)
|
||||
# TODO: need to find another solution to the text duplication issue
|
||||
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
|
||||
|
||||
# insert into agent archival memory
|
||||
self.archival_memory.storage.insert_many(passages)
|
||||
all_passages += passages
|
||||
|
||||
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
|
||||
|
||||
# save destination storage
|
||||
self.archival_memory.storage.save()
|
||||
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
|
||||
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
|
||||
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
|
||||
assert len(agents_passages) == passage_size # sanity check
|
||||
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
|
||||
|
||||
# attach to agent
|
||||
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
|
||||
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
|
||||
assert source is not None, f"Source {source_id} not found in metadata store"
|
||||
|
||||
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
|
||||
# TODO: delete @matt and remove
|
||||
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
|
||||
|
||||
total_agent_passages = self.archival_memory.storage.size()
|
||||
|
||||
printd(
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
|
||||
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
|
||||
)
|
||||
|
||||
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
|
||||
@@ -1565,13 +1554,13 @@ class Agent(BaseAgent):
|
||||
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
|
||||
)
|
||||
|
||||
num_archival_memory = self.archival_memory.storage.size()
|
||||
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
|
||||
external_memory_summary = compile_memory_metadata_block(
|
||||
actor=self.user,
|
||||
agent_id=self.agent_state.id,
|
||||
memory_edit_timestamp=get_utc_time(), # dummy timestamp
|
||||
archival_memory=self.archival_memory,
|
||||
passage_manager=self.passage_manager,
|
||||
message_manager=self.message_manager,
|
||||
)
|
||||
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
|
||||
@@ -1597,7 +1586,7 @@ class Agent(BaseAgent):
|
||||
return ContextWindowOverview(
|
||||
# context window breakdown (in messages)
|
||||
num_messages=len(self._messages),
|
||||
num_archival_memory=num_archival_memory,
|
||||
num_archival_memory=passage_manager_size,
|
||||
num_recall_memory=message_manager_size,
|
||||
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
|
||||
# top-level information
|
||||
|
||||
Reference in New Issue
Block a user