feat: orm passage migration (#2180)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2024-12-10 18:09:35 -08:00
committed by GitHub
parent a187488f4f
commit 9deacbd89e
31 changed files with 1282 additions and 531 deletions

View File

@@ -16,7 +16,6 @@ import letta.constants as constants
import letta.server.utils as server_utils
import letta.system as system
from letta.agent import Agent, save_agent
from letta.agent_store.db import attach_base
from letta.agent_store.storage import StorageConnector, TableType
from letta.chat_only_agent import ChatOnlyAgent
from letta.credentials import LettaCredentials
@@ -70,17 +69,18 @@ from letta.schemas.memory import (
)
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
from letta.schemas.organization import Organization
from letta.schemas.passage import Passage
from letta.schemas.passage import Passage as PydanticPassage
from letta.schemas.source import Source
from letta.schemas.tool import Tool, ToolCreate
from letta.schemas.usage import LettaUsageStatistics
from letta.schemas.user import User
from letta.schemas.user import User as PydanticUser
from letta.services.agents_tags_manager import AgentsTagsManager
from letta.services.block_manager import BlockManager
from letta.services.blocks_agents_manager import BlocksAgentsManager
from letta.services.job_manager import JobManager
from letta.services.message_manager import MessageManager
from letta.services.organization_manager import OrganizationManager
from letta.services.passage_manager import PassageManager
from letta.services.per_agent_lock_manager import PerAgentLockManager
from letta.services.sandbox_config_manager import SandboxConfigManager
from letta.services.source_manager import SourceManager
@@ -125,7 +125,7 @@ class Server(object):
def create_agent(
self,
request: CreateAgent,
actor: User,
actor: PydanticUser,
# interface
interface: Union[AgentInterface, None] = None,
) -> AgentState:
@@ -166,8 +166,6 @@ from letta.settings import model_settings, settings, tool_settings
config = LettaConfig.load()
attach_base()
if settings.letta_pg_uri_no_default:
config.recall_storage_type = "postgres"
config.recall_storage_uri = settings.letta_pg_uri_no_default
@@ -245,6 +243,7 @@ class SyncServer(Server):
# Managers that interface with data models
self.organization_manager = OrganizationManager()
self.passage_manager = PassageManager()
self.user_manager = UserManager()
self.tool_manager = ToolManager()
self.block_manager = BlockManager()
@@ -498,7 +497,12 @@ class SyncServer(Server):
# attach data to agent from source
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
letta_agent.attach_source(data_source, source_connector, self.ms)
letta_agent.attach_source(
user=self.user_manager.get_user_by_id(user_id=user_id),
source_id=data_source,
source_manager=letta_agent.source_manager,
ms=self.ms
)
elif command.lower() == "dump" or command.lower().startswith("dump "):
# Check if there's an additional argument that's an integer
@@ -513,7 +517,7 @@ class SyncServer(Server):
letta_agent.interface.print_messages_raw(letta_agent.messages)
elif command.lower() == "memory":
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}"
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}"
return ret_str
elif command.lower() == "pop" or command.lower().startswith("pop "):
@@ -769,7 +773,7 @@ class SyncServer(Server):
def create_agent(
self,
request: CreateAgent,
actor: User,
actor: PydanticUser,
# interface
interface: Union[AgentInterface, None] = None,
) -> AgentState:
@@ -921,6 +925,7 @@ class SyncServer(Server):
# get `Tool` objects
tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names]
tools = [tool for tool in tools if tool is not None]
# get `Source` objects
sources = self.list_attached_sources(agent_id=agent_id)
@@ -934,7 +939,7 @@ class SyncServer(Server):
def update_agent(
self,
request: UpdateAgentState,
actor: User,
actor: PydanticUser,
) -> AgentState:
"""Update the agents core memory block, return the new state"""
try:
@@ -1151,7 +1156,7 @@ class SyncServer(Server):
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
agent = self.load_agent(agent_id=agent_id)
return ArchivalMemorySummary(size=len(agent.archival_memory))
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
agent = self.load_agent(agent_id=agent_id)
@@ -1176,7 +1181,56 @@ class SyncServer(Server):
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
return message
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
def get_agent_messages(
self,
agent_id: str,
start: int,
count: int,
) -> Union[List[Message], List[LettaMessage]]:
"""Paginated query of all messages in agent message queue"""
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id)
if start < 0 or count < 0:
raise ValueError("Start and count values should be non-negative")
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
# Reverse the list to make it in reverse chronological order
reversed_messages = letta_agent._messages[::-1]
# Check if start is within the range of the list
if start >= len(reversed_messages):
raise IndexError("Start index is out of range")
# Calculate the end index, ensuring it does not exceed the list length
end_index = min(start + count, len(reversed_messages))
# Slice the list for pagination
messages = reversed_messages[start:end_index]
else:
# need to access persistence manager for additional messages
# get messages using message manager
page = letta_agent.message_manager.list_user_messages_for_agent(
agent_id=agent_id,
actor=self.default_user,
cursor=start,
limit=count,
)
messages = page
assert all(isinstance(m, Message) for m in messages)
## Convert to json
## Add a tag indicating in-context or not
# json_messages = [record.to_json() for record in messages]
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
# for d in json_messages:
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
return messages
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[PydanticPassage]:
"""Paginated query of all messages in agent archival memory"""
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise ValueError(f"User user_id={user_id} does not exist")
@@ -1187,22 +1241,22 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over records
db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
records = letta_agent.passage_manager.list_passages(
actor=self.default_user,
agent_id=agent_id,
cursor=cursor,
limit=limit,
)
# get a single page of messages
page = next(db_iterator, [])
return page
return records
def get_agent_archival_cursor(
self,
user_id: str,
agent_id: str,
after: Optional[str] = None,
before: Optional[str] = None,
cursor: Optional[str] = None,
limit: Optional[int] = 100,
order_by: Optional[str] = "created_at",
reverse: Optional[bool] = False,
) -> List[Passage]:
) -> List[PydanticPassage]:
if self.user_manager.get_user_by_id(user_id=user_id) is None:
raise LettaUserNotFoundError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
@@ -1211,14 +1265,15 @@ class SyncServer(Server):
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id)
# iterate over recorde
cursor, records = letta_agent.archival_memory.storage.get_all_cursor(
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
# iterate over records
records = letta_agent.passage_manager.list_passages(
actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit,
)
return records
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]:
if self.user_manager.get_user_by_id(user_id=user_id) is None:
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[PydanticPassage]:
actor = self.user_manager.get_user_by_id(user_id=user_id)
if actor is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
@@ -1227,17 +1282,20 @@ class SyncServer(Server):
letta_agent = self.load_agent(agent_id=agent_id)
# Insert into archival memory
passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True)
passage_ids = self.passage_manager.insert_passage(
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor, return_ids=True
)
# Update the agent
# TODO: should this update the system prompt?
save_agent(letta_agent, self.ms)
# TODO: this is gross, fix
return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
return [self.passage_manager.get_passage_by_id(passage_id=passage_id, actor=actor) for passage_id in passage_ids]
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
if self.user_manager.get_user_by_id(user_id=user_id) is None:
actor = self.user_manager.get_user_by_id(user_id=user_id)
if actor is None:
raise ValueError(f"User user_id={user_id} does not exist")
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
raise ValueError(f"Agent agent_id={agent_id} does not exist")
@@ -1249,7 +1307,7 @@ class SyncServer(Server):
# Delete by ID
# TODO check if it exists first, and throw error if not
letta_agent.archival_memory.storage.delete({"id": memory_id})
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
# TODO: return archival memory
@@ -1395,6 +1453,12 @@ class SyncServer(Server):
except NoResultFound:
logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}")
# delete all passages associated with this agent
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM
passages = self.passage_manager.list_passages(actor=actor, agent_id=agent_state.id)
for passage in passages:
self.passage_manager.delete_passage_by_id(passage.id, actor=actor)
# First, if the agent is in the in-memory cache we should remove it
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
try:
@@ -1437,7 +1501,7 @@ class SyncServer(Server):
self.ms.delete_api_key(api_key=api_key)
return api_key_obj
def delete_source(self, source_id: str, actor: User):
def delete_source(self, source_id: str, actor: PydanticUser):
"""Delete a data source"""
self.source_manager.delete_source(source_id=source_id, actor=actor)
@@ -1447,7 +1511,7 @@ class SyncServer(Server):
# TODO: delete data from agent passage stores (?)
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: PydanticUser) -> Job:
# update job
job = self.job_manager.get_job_by_id(job_id, actor=actor)
@@ -1474,6 +1538,7 @@ class SyncServer(Server):
user_id: str,
connector: DataConnector,
source_name: str,
agent_id: Optional[str] = None,
) -> Tuple[int, int]:
"""Load data from a DataConnector into a source for a specified user_id"""
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
@@ -1488,14 +1553,13 @@ class SyncServer(Server):
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
# load data into the document store
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user)
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user, agent_id=agent_id)
return passage_count, document_count
def attach_source_to_agent(
self,
user_id: str,
agent_id: str,
# source_id: str,
source_id: Optional[str] = None,
source_name: Optional[str] = None,
) -> Source:
@@ -1507,15 +1571,14 @@ class SyncServer(Server):
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
else:
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
# get connection to data source storage
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
# load agent
agent = self.load_agent(agent_id=agent_id)
# attach source to agent
agent.attach_source(data_source.id, source_connector, self.ms)
agent.attach_source(user=user, source_id=data_source.id, source_manager=self.source_manager, ms=self.ms)
return data_source
@@ -1538,8 +1601,7 @@ class SyncServer(Server):
# delete all Passage objects with source_id==source_id from agent's archival memory
agent = self.load_agent(agent_id=agent_id)
archival_memory = agent.archival_memory
archival_memory.storage.delete({"source_id": source_id})
agent.passage_manager.delete_passages(actor=user, limit=100, source_id=source_id)
# delete agent-source mapping
self.ms.detach_source(agent_id=agent_id, source_id=source_id)
@@ -1553,11 +1615,11 @@ class SyncServer(Server):
return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids]
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
def list_data_source_passages(self, user_id: str, source_id: str) -> List[PydanticPassage]:
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
return []
def list_all_sources(self, actor: User) -> List[Source]:
def list_all_sources(self, actor: PydanticUser) -> List[Source]:
"""List all sources (w/ extra metadata) belonging to a user"""
sources = self.source_manager.list_sources(actor=actor)
@@ -1597,7 +1659,7 @@ class SyncServer(Server):
return sources_with_metadata
def add_default_external_tools(self, actor: User) -> bool:
def add_default_external_tools(self, actor: PydanticUser) -> bool:
"""Add default langchain tools. Return true if successful, false otherwise."""
success = True
tool_creates = ToolCreate.load_default_langchain_tools()
@@ -1654,7 +1716,7 @@ class SyncServer(Server):
save_agent(letta_agent, self.ms)
return response
def get_user_or_default(self, user_id: Optional[str]) -> User:
def get_user_or_default(self, user_id: Optional[str]) -> PydanticUser:
"""Get the user object for user_id if it exists, otherwise return the default user object"""
if user_id is None:
user_id = self.user_manager.DEFAULT_USER_ID