feat: orm passage migration (#2180)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -16,7 +16,6 @@ import letta.constants as constants
|
||||
import letta.server.utils as server_utils
|
||||
import letta.system as system
|
||||
from letta.agent import Agent, save_agent
|
||||
from letta.agent_store.db import attach_base
|
||||
from letta.agent_store.storage import StorageConnector, TableType
|
||||
from letta.chat_only_agent import ChatOnlyAgent
|
||||
from letta.credentials import LettaCredentials
|
||||
@@ -70,17 +69,18 @@ from letta.schemas.memory import (
|
||||
)
|
||||
from letta.schemas.message import Message, MessageCreate, MessageRole, MessageUpdate
|
||||
from letta.schemas.organization import Organization
|
||||
from letta.schemas.passage import Passage
|
||||
from letta.schemas.passage import Passage as PydanticPassage
|
||||
from letta.schemas.source import Source
|
||||
from letta.schemas.tool import Tool, ToolCreate
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from letta.schemas.user import User
|
||||
from letta.schemas.user import User as PydanticUser
|
||||
from letta.services.agents_tags_manager import AgentsTagsManager
|
||||
from letta.services.block_manager import BlockManager
|
||||
from letta.services.blocks_agents_manager import BlocksAgentsManager
|
||||
from letta.services.job_manager import JobManager
|
||||
from letta.services.message_manager import MessageManager
|
||||
from letta.services.organization_manager import OrganizationManager
|
||||
from letta.services.passage_manager import PassageManager
|
||||
from letta.services.per_agent_lock_manager import PerAgentLockManager
|
||||
from letta.services.sandbox_config_manager import SandboxConfigManager
|
||||
from letta.services.source_manager import SourceManager
|
||||
@@ -125,7 +125,7 @@ class Server(object):
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
@@ -166,8 +166,6 @@ from letta.settings import model_settings, settings, tool_settings
|
||||
|
||||
config = LettaConfig.load()
|
||||
|
||||
attach_base()
|
||||
|
||||
if settings.letta_pg_uri_no_default:
|
||||
config.recall_storage_type = "postgres"
|
||||
config.recall_storage_uri = settings.letta_pg_uri_no_default
|
||||
@@ -245,6 +243,7 @@ class SyncServer(Server):
|
||||
|
||||
# Managers that interface with data models
|
||||
self.organization_manager = OrganizationManager()
|
||||
self.passage_manager = PassageManager()
|
||||
self.user_manager = UserManager()
|
||||
self.tool_manager = ToolManager()
|
||||
self.block_manager = BlockManager()
|
||||
@@ -498,7 +497,12 @@ class SyncServer(Server):
|
||||
|
||||
# attach data to agent from source
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
letta_agent.attach_source(data_source, source_connector, self.ms)
|
||||
letta_agent.attach_source(
|
||||
user=self.user_manager.get_user_by_id(user_id=user_id),
|
||||
source_id=data_source,
|
||||
source_manager=letta_agent.source_manager,
|
||||
ms=self.ms
|
||||
)
|
||||
|
||||
elif command.lower() == "dump" or command.lower().startswith("dump "):
|
||||
# Check if there's an additional argument that's an integer
|
||||
@@ -513,7 +517,7 @@ class SyncServer(Server):
|
||||
letta_agent.interface.print_messages_raw(letta_agent.messages)
|
||||
|
||||
elif command.lower() == "memory":
|
||||
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.archival_memory)}"
|
||||
ret_str = f"\nDumping memory contents:\n" + f"\n{str(letta_agent.agent_state.memory)}" + f"\n{str(letta_agent.passage_manager)}"
|
||||
return ret_str
|
||||
|
||||
elif command.lower() == "pop" or command.lower().startswith("pop "):
|
||||
@@ -769,7 +773,7 @@ class SyncServer(Server):
|
||||
def create_agent(
|
||||
self,
|
||||
request: CreateAgent,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
# interface
|
||||
interface: Union[AgentInterface, None] = None,
|
||||
) -> AgentState:
|
||||
@@ -921,6 +925,7 @@ class SyncServer(Server):
|
||||
|
||||
# get `Tool` objects
|
||||
tools = [self.tool_manager.get_tool_by_name(tool_name=tool_name, actor=user) for tool_name in agent_state.tool_names]
|
||||
tools = [tool for tool in tools if tool is not None]
|
||||
|
||||
# get `Source` objects
|
||||
sources = self.list_attached_sources(agent_id=agent_id)
|
||||
@@ -934,7 +939,7 @@ class SyncServer(Server):
|
||||
def update_agent(
|
||||
self,
|
||||
request: UpdateAgentState,
|
||||
actor: User,
|
||||
actor: PydanticUser,
|
||||
) -> AgentState:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
try:
|
||||
@@ -1151,7 +1156,7 @@ class SyncServer(Server):
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
return ArchivalMemorySummary(size=len(agent.archival_memory))
|
||||
return ArchivalMemorySummary(size=agent.passage_manager.size(actor=self.default_user))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
@@ -1176,7 +1181,56 @@ class SyncServer(Server):
|
||||
message = agent.message_manager.get_message_by_id(id=message_id, actor=self.default_user)
|
||||
return message
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, start: int, count: int) -> List[Passage]:
|
||||
def get_agent_messages(
|
||||
self,
|
||||
agent_id: str,
|
||||
start: int,
|
||||
count: int,
|
||||
) -> Union[List[Message], List[LettaMessage]]:
|
||||
"""Paginated query of all messages in agent message queue"""
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
if start < 0 or count < 0:
|
||||
raise ValueError("Start and count values should be non-negative")
|
||||
|
||||
if start + count < len(letta_agent._messages): # messages can be returned from whats in memory
|
||||
# Reverse the list to make it in reverse chronological order
|
||||
reversed_messages = letta_agent._messages[::-1]
|
||||
# Check if start is within the range of the list
|
||||
if start >= len(reversed_messages):
|
||||
raise IndexError("Start index is out of range")
|
||||
|
||||
# Calculate the end index, ensuring it does not exceed the list length
|
||||
end_index = min(start + count, len(reversed_messages))
|
||||
|
||||
# Slice the list for pagination
|
||||
messages = reversed_messages[start:end_index]
|
||||
|
||||
else:
|
||||
# need to access persistence manager for additional messages
|
||||
|
||||
# get messages using message manager
|
||||
page = letta_agent.message_manager.list_user_messages_for_agent(
|
||||
agent_id=agent_id,
|
||||
actor=self.default_user,
|
||||
cursor=start,
|
||||
limit=count,
|
||||
)
|
||||
|
||||
messages = page
|
||||
assert all(isinstance(m, Message) for m in messages)
|
||||
|
||||
## Convert to json
|
||||
## Add a tag indicating in-context or not
|
||||
# json_messages = [record.to_json() for record in messages]
|
||||
# in_context_message_ids = [str(m.id) for m in letta_agent._messages]
|
||||
# for d in json_messages:
|
||||
# d["in_context"] = True if str(d["id"]) in in_context_message_ids else False
|
||||
|
||||
return messages
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[PydanticPassage]:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
@@ -1187,22 +1241,22 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over records
|
||||
db_iterator = letta_agent.archival_memory.storage.get_all_paginated(page_size=count, offset=start)
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=self.default_user,
|
||||
agent_id=agent_id,
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
)
|
||||
|
||||
# get a single page of messages
|
||||
page = next(db_iterator, [])
|
||||
return page
|
||||
return records
|
||||
|
||||
def get_agent_archival_cursor(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
after: Optional[str] = None,
|
||||
before: Optional[str] = None,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 100,
|
||||
order_by: Optional[str] = "created_at",
|
||||
reverse: Optional[bool] = False,
|
||||
) -> List[Passage]:
|
||||
) -> List[PydanticPassage]:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
raise LettaUserNotFoundError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
@@ -1211,14 +1265,15 @@ class SyncServer(Server):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# iterate over recorde
|
||||
cursor, records = letta_agent.archival_memory.storage.get_all_cursor(
|
||||
after=after, before=before, limit=limit, order_by=order_by, reverse=reverse
|
||||
# iterate over records
|
||||
records = letta_agent.passage_manager.list_passages(
|
||||
actor=self.default_user, agent_id=agent_id, cursor=cursor, limit=limit,
|
||||
)
|
||||
return records
|
||||
|
||||
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[Passage]:
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
def insert_archival_memory(self, user_id: str, agent_id: str, memory_contents: str) -> List[PydanticPassage]:
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if actor is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
@@ -1227,17 +1282,20 @@ class SyncServer(Server):
|
||||
letta_agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# Insert into archival memory
|
||||
passage_ids = letta_agent.archival_memory.insert(memory_string=memory_contents, return_ids=True)
|
||||
passage_ids = self.passage_manager.insert_passage(
|
||||
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor, return_ids=True
|
||||
)
|
||||
|
||||
# Update the agent
|
||||
# TODO: should this update the system prompt?
|
||||
save_agent(letta_agent, self.ms)
|
||||
|
||||
# TODO: this is gross, fix
|
||||
return [letta_agent.archival_memory.storage.get(id=passage_id) for passage_id in passage_ids]
|
||||
return [self.passage_manager.get_passage_by_id(passage_id=passage_id, actor=actor) for passage_id in passage_ids]
|
||||
|
||||
def delete_archival_memory(self, user_id: str, agent_id: str, memory_id: str):
|
||||
if self.user_manager.get_user_by_id(user_id=user_id) is None:
|
||||
actor = self.user_manager.get_user_by_id(user_id=user_id)
|
||||
if actor is None:
|
||||
raise ValueError(f"User user_id={user_id} does not exist")
|
||||
if self.ms.get_agent(agent_id=agent_id, user_id=user_id) is None:
|
||||
raise ValueError(f"Agent agent_id={agent_id} does not exist")
|
||||
@@ -1249,7 +1307,7 @@ class SyncServer(Server):
|
||||
|
||||
# Delete by ID
|
||||
# TODO check if it exists first, and throw error if not
|
||||
letta_agent.archival_memory.storage.delete({"id": memory_id})
|
||||
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# TODO: return archival memory
|
||||
|
||||
@@ -1395,6 +1453,12 @@ class SyncServer(Server):
|
||||
except NoResultFound:
|
||||
logger.error(f"Agent with id {agent_state.id} has nonexistent user {agent_state.user_id}")
|
||||
|
||||
# delete all passages associated with this agent
|
||||
# TODO: REMOVE THIS ONCE WE MIGRATE AGENTMODEL TO ORM
|
||||
passages = self.passage_manager.list_passages(actor=actor, agent_id=agent_state.id)
|
||||
for passage in passages:
|
||||
self.passage_manager.delete_passage_by_id(passage.id, actor=actor)
|
||||
|
||||
# First, if the agent is in the in-memory cache we should remove it
|
||||
# List of {'user_id': user_id, 'agent_id': agent_id, 'agent': agent_obj} dicts
|
||||
try:
|
||||
@@ -1437,7 +1501,7 @@ class SyncServer(Server):
|
||||
self.ms.delete_api_key(api_key=api_key)
|
||||
return api_key_obj
|
||||
|
||||
def delete_source(self, source_id: str, actor: User):
|
||||
def delete_source(self, source_id: str, actor: PydanticUser):
|
||||
"""Delete a data source"""
|
||||
self.source_manager.delete_source(source_id=source_id, actor=actor)
|
||||
|
||||
@@ -1447,7 +1511,7 @@ class SyncServer(Server):
|
||||
|
||||
# TODO: delete data from agent passage stores (?)
|
||||
|
||||
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: User) -> Job:
|
||||
def load_file_to_source(self, source_id: str, file_path: str, job_id: str, actor: PydanticUser) -> Job:
|
||||
|
||||
# update job
|
||||
job = self.job_manager.get_job_by_id(job_id, actor=actor)
|
||||
@@ -1474,6 +1538,7 @@ class SyncServer(Server):
|
||||
user_id: str,
|
||||
connector: DataConnector,
|
||||
source_name: str,
|
||||
agent_id: Optional[str] = None,
|
||||
) -> Tuple[int, int]:
|
||||
"""Load data from a DataConnector into a source for a specified user_id"""
|
||||
# TODO: this should be implemented as a batch job or at least async, since it may take a long time
|
||||
@@ -1488,14 +1553,13 @@ class SyncServer(Server):
|
||||
passage_store = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
|
||||
# load data into the document store
|
||||
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user)
|
||||
passage_count, document_count = load_data(connector, source, passage_store, self.source_manager, actor=user, agent_id=agent_id)
|
||||
return passage_count, document_count
|
||||
|
||||
def attach_source_to_agent(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
# source_id: str,
|
||||
source_id: Optional[str] = None,
|
||||
source_name: Optional[str] = None,
|
||||
) -> Source:
|
||||
@@ -1507,15 +1571,14 @@ class SyncServer(Server):
|
||||
data_source = self.source_manager.get_source_by_name(source_name=source_name, actor=user)
|
||||
else:
|
||||
raise ValueError(f"Need to provide at least source_id or source_name to find the source.")
|
||||
# get connection to data source storage
|
||||
source_connector = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id)
|
||||
|
||||
assert data_source, f"Data source with id={source_id} or name={source_name} does not exist"
|
||||
|
||||
# load agent
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
|
||||
# attach source to agent
|
||||
agent.attach_source(data_source.id, source_connector, self.ms)
|
||||
agent.attach_source(user=user, source_id=data_source.id, source_manager=self.source_manager, ms=self.ms)
|
||||
|
||||
return data_source
|
||||
|
||||
@@ -1538,8 +1601,7 @@ class SyncServer(Server):
|
||||
|
||||
# delete all Passage objects with source_id==source_id from agent's archival memory
|
||||
agent = self.load_agent(agent_id=agent_id)
|
||||
archival_memory = agent.archival_memory
|
||||
archival_memory.storage.delete({"source_id": source_id})
|
||||
agent.passage_manager.delete_passages(actor=user, limit=100, source_id=source_id)
|
||||
|
||||
# delete agent-source mapping
|
||||
self.ms.detach_source(agent_id=agent_id, source_id=source_id)
|
||||
@@ -1553,11 +1615,11 @@ class SyncServer(Server):
|
||||
|
||||
return [self.source_manager.get_source_by_id(source_id=id) for id in source_ids]
|
||||
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[Passage]:
|
||||
def list_data_source_passages(self, user_id: str, source_id: str) -> List[PydanticPassage]:
|
||||
warnings.warn("list_data_source_passages is not yet implemented, returning empty list.", category=UserWarning)
|
||||
return []
|
||||
|
||||
def list_all_sources(self, actor: User) -> List[Source]:
|
||||
def list_all_sources(self, actor: PydanticUser) -> List[Source]:
|
||||
"""List all sources (w/ extra metadata) belonging to a user"""
|
||||
|
||||
sources = self.source_manager.list_sources(actor=actor)
|
||||
@@ -1597,7 +1659,7 @@ class SyncServer(Server):
|
||||
|
||||
return sources_with_metadata
|
||||
|
||||
def add_default_external_tools(self, actor: User) -> bool:
|
||||
def add_default_external_tools(self, actor: PydanticUser) -> bool:
|
||||
"""Add default langchain tools. Return true if successful, false otherwise."""
|
||||
success = True
|
||||
tool_creates = ToolCreate.load_default_langchain_tools()
|
||||
@@ -1654,7 +1716,7 @@ class SyncServer(Server):
|
||||
save_agent(letta_agent, self.ms)
|
||||
return response
|
||||
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> User:
|
||||
def get_user_or_default(self, user_id: Optional[str]) -> PydanticUser:
|
||||
"""Get the user object for user_id if it exists, otherwise return the default user object"""
|
||||
if user_id is None:
|
||||
user_id = self.user_manager.DEFAULT_USER_ID
|
||||
|
||||
Reference in New Issue
Block a user