feat: orm passage migration (#2180)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-10 18:09:35 -08:00
committed by GitHub
parent a187488f4f
commit 9deacbd89e
31 changed files with 1282 additions and 531 deletions

View File

@@ -28,7 +28,7 @@ from letta.interface import AgentInterface
from letta.llm_api.helpers import is_context_overflow_error
from letta.llm_api.llm_api_tools import create
from letta.local_llm.utils import num_tokens_from_functions, num_tokens_from_messages
from letta.memory import ArchivalMemory, EmbeddingArchivalMemory, summarize_messages
from letta.memory import summarize_messages
from letta.metadata import MetadataStore
from letta.orm import User
from letta.schemas.agent import AgentState, AgentStepResponse
@@ -52,6 +52,7 @@ from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User as PydanticUser
from letta.services.block_manager import BlockManager
from letta.services.message_manager import MessageManager
from letta.services.passage_manager import PassageManager
from letta.services.source_manager import SourceManager
from letta.services.tool_execution_sandbox import ToolExecutionSandbox
from letta.services.user_manager import UserManager
@@ -85,7 +86,7 @@ def compile_memory_metadata_block(
actor: PydanticUser,
agent_id: str,
memory_edit_timestamp: datetime.datetime,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
) -> str:
# Put the timestamp in the local timezone (mimicking get_local_time())
@@ -96,7 +97,7 @@ def compile_memory_metadata_block(
[
f"### Memory [last modified: {timestamp_str}]",
f"{message_manager.size(actor=actor, agent_id=agent_id) if message_manager else 0} previous messages between you and the user are stored in recall memory (use functions to access them)",
f"{archival_memory.count() if archival_memory else 0} total memories you created are stored in archival memory (use functions to access them)",
f"{passage_manager.size(actor=actor, agent_id=agent_id) if passage_manager else 0} total memories you created are stored in archival memory (use functions to access them)",
"\nCore memory shown below (limited in size, additional information stored in archival / recall memory):",
]
)
@@ -109,7 +110,7 @@ def compile_system_message(
in_context_memory: Memory,
in_context_memory_last_edit: datetime.datetime, # TODO move this inside of BaseMemory?
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
user_defined_variables: Optional[dict] = None,
append_icm_if_missing: bool = True,
@@ -138,7 +139,7 @@ def compile_system_message(
actor=actor,
agent_id=agent_id,
memory_edit_timestamp=in_context_memory_last_edit,
archival_memory=archival_memory,
passage_manager=passage_manager,
message_manager=message_manager,
)
full_memory_string = memory_metadata_string + "\n" + in_context_memory.compile()
@@ -175,7 +176,7 @@ def initialize_message_sequence(
agent_id: str,
memory: Memory,
actor: PydanticUser,
archival_memory: Optional[ArchivalMemory] = None,
passage_manager: Optional[PassageManager] = None,
message_manager: Optional[MessageManager] = None,
memory_edit_timestamp: Optional[datetime.datetime] = None,
include_initial_boot_message: bool = True,
@@ -184,7 +185,7 @@ def initialize_message_sequence(
memory_edit_timestamp = get_local_time()
# full_system_message = construct_system_with_memory(
# system, memory, memory_edit_timestamp, archival_memory=archival_memory, recall_memory=recall_memory
# system, memory, memory_edit_timestamp, passage_manager=passage_manager, recall_memory=recall_memory
# )
full_system_message = compile_system_message(
agent_id=agent_id,
@@ -192,7 +193,7 @@ def initialize_message_sequence(
in_context_memory=memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=actor,
archival_memory=archival_memory,
passage_manager=passage_manager,
message_manager=message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
@@ -294,7 +295,7 @@ class Agent(BaseAgent):
self.interface = interface
# Create the persistence manager object based on the AgentState info
self.archival_memory = EmbeddingArchivalMemory(agent_state)
self.passage_manager = PassageManager()
self.message_manager = MessageManager()
# State needed for heartbeat pausing
@@ -325,7 +326,7 @@ class Agent(BaseAgent):
agent_id=self.agent_state.id,
memory=self.agent_state.memory,
actor=self.user,
archival_memory=None,
passage_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
@@ -350,7 +351,7 @@ class Agent(BaseAgent):
memory=self.agent_state.memory,
agent_id=self.agent_state.id,
actor=self.user,
archival_memory=None,
passage_manager=None,
message_manager=None,
memory_edit_timestamp=get_utc_time(),
include_initial_boot_message=True,
@@ -1306,7 +1307,7 @@ class Agent(BaseAgent):
in_context_memory=self.agent_state.memory,
in_context_memory_last_edit=memory_edit_timestamp,
actor=self.user,
archival_memory=self.archival_memory,
passage_manager=self.passage_manager,
message_manager=self.message_manager,
user_defined_variables=None,
append_icm_if_missing=True,
@@ -1371,45 +1372,33 @@ class Agent(BaseAgent):
# TODO: recall memory
raise NotImplementedError()
def attach_source(self, source_id: str, source_connector: StorageConnector, ms: MetadataStore):
def attach_source(self, user: PydanticUser, source_id: str, source_manager: SourceManager, ms: MetadataStore):
"""Attach data with name `source_name` to the agent from source_connector."""
# TODO: eventually, adding a data source should just give access to the retriever the source table, rather than modifying archival memory
user = UserManager().get_user_by_id(self.agent_state.user_id)
filters = {"user_id": self.agent_state.user_id, "source_id": source_id}
size = source_connector.size(filters)
page_size = 100
generator = source_connector.get_all_paginated(filters=filters, page_size=page_size) # yields List[Passage]
all_passages = []
for i in tqdm(range(0, size, page_size)):
passages = next(generator)
passages = self.passage_manager.list_passages(actor=user, source_id=source_id, limit=page_size)
# need to associated passage with agent (for filtering)
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
for passage in passages:
assert isinstance(passage, Passage), f"Generate yielded bad non-Passage type: {type(passage)}"
passage.agent_id = self.agent_state.id
self.passage_manager.update_passage_by_id(passage_id=passage.id, passage=passage, actor=user)
# regenerate passage ID (avoid duplicates)
# TODO: need to find another solution to the text duplication issue
# passage.id = create_uuid_from_string(f"{source_id}_{str(passage.agent_id)}_{passage.text}")
# insert into agent archival memory
self.archival_memory.storage.insert_many(passages)
all_passages += passages
assert size == len(all_passages), f"Expected {size} passages, but only got {len(all_passages)}"
# save destination storage
self.archival_memory.storage.save()
agents_passages = self.passage_manager.list_passages(actor=user, agent_id=self.agent_state.id, source_id=source_id, limit=page_size)
passage_size = self.passage_manager.size(actor=user, agent_id=self.agent_state.id, source_id=source_id)
assert all([p.agent_id == self.agent_state.id for p in agents_passages])
assert len(agents_passages) == passage_size # sanity check
assert passage_size == len(passages), f"Expected {len(passages)} passages, got {passage_size}"
# attach to agent
source = SourceManager().get_source_by_id(source_id=source_id, actor=user)
source = source_manager.get_source_by_id(source_id=source_id, actor=user)
assert source is not None, f"Source {source_id} not found in metadata store"
# NOTE: need this redundant line here because we haven't migrated agent to ORM yet
# TODO: delete @matt and remove
ms.attach_source(agent_id=self.agent_state.id, source_id=source_id, user_id=self.agent_state.user_id)
total_agent_passages = self.archival_memory.storage.size()
printd(
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(all_passages)}. Agent now has {total_agent_passages} embeddings in archival memory.",
f"Attached data source {source.name} to agent {self.agent_state.name}, consisting of {len(passages)}. Agent now has {passage_size} embeddings in archival memory.",
)
def update_message(self, message_id: str, request: MessageUpdate) -> Message:
@@ -1565,13 +1554,13 @@ class Agent(BaseAgent):
num_tokens_from_messages(messages=messages_openai_format[1:], model=self.model) if len(messages_openai_format) > 1 else 0
)
num_archival_memory = self.archival_memory.storage.size()
passage_manager_size = self.passage_manager.size(actor=self.user, agent_id=self.agent_state.id)
message_manager_size = self.message_manager.size(actor=self.user, agent_id=self.agent_state.id)
external_memory_summary = compile_memory_metadata_block(
actor=self.user,
agent_id=self.agent_state.id,
memory_edit_timestamp=get_utc_time(), # dummy timestamp
archival_memory=self.archival_memory,
passage_manager=self.passage_manager,
message_manager=self.message_manager,
)
num_tokens_external_memory_summary = count_tokens(external_memory_summary)
@@ -1597,7 +1586,7 @@ class Agent(BaseAgent):
return ContextWindowOverview(
# context window breakdown (in messages)
num_messages=len(self._messages),
num_archival_memory=num_archival_memory,
num_archival_memory=passage_manager_size,
num_recall_memory=message_manager_size,
num_tokens_external_memory_summary=num_tokens_external_memory_summary,
# top-level information