fix: Remove in-memory _messages field on Agent (#2295)

This commit is contained in:
Matthew Zhou
2024-12-20 15:52:04 -08:00
committed by GitHub
parent e9239cf1bf
commit 5bb4888cea
18 changed files with 650 additions and 1164 deletions

View File

@@ -40,14 +40,14 @@ from letta.providers import (
VLLMChatCompletionsProvider,
VLLMCompletionsProvider,
)
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
from letta.schemas.agent import AgentState, AgentType, CreateAgent
from letta.schemas.block import BlockUpdate
from letta.schemas.embedding_config import EmbeddingConfig
# openai schemas
from letta.schemas.enums import JobStatus
from letta.schemas.job import Job, JobUpdate
from letta.schemas.letta_message import ToolReturnMessage, LettaMessage
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
from letta.schemas.llm_config import LLMConfig
from letta.schemas.memory import (
ArchivalMemorySummary,
@@ -376,25 +376,6 @@ class SyncServer(Server):
)
)
def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent:
"""Initialize an agent from the database"""
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
interface = interface or self.default_interface_factory()
if agent_state.agent_type == AgentType.memgpt_agent:
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
elif agent_state.agent_type == AgentType.offline_memory_agent:
agent = OfflineMemoryAgent(
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
)
else:
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
# Persist to agent
save_agent(agent)
return agent
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
"""Updated method to load agents from persisted storage"""
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
@@ -413,11 +394,6 @@ class SyncServer(Server):
else:
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
# Rebuild the system prompt - may be linked to new blocks now
agent.rebuild_system_prompt()
# Persist to agent
save_agent(agent)
return agent
def _step(
@@ -456,7 +432,7 @@ class SyncServer(Server):
)
# save agent after step
save_agent(letta_agent)
# save_agent(letta_agent)
except Exception as e:
logger.error(f"Error in server._step: {e}")
@@ -790,129 +766,23 @@ class SyncServer(Server):
"""Create a new agent using a config"""
# Invoke manager
agent_state = self.agent_manager.create_agent(
return self.agent_manager.create_agent(
agent_create=request,
actor=actor,
)
# create the agent object
if request.initial_message_sequence is not None:
# init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence]
init_messages = []
for message in request.initial_message_sequence:
if message.role == MessageRole.user:
packed_message = system.package_user_message(
user_message=message.text,
)
elif message.role == MessageRole.system:
packed_message = system.package_system_message(
system_message=message.text,
)
else:
raise ValueError(f"Invalid message role: {message.role}")
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
else:
init_messages = None
# initialize the agent (generates initial message list with system prompt)
if interface is None:
interface = self.default_interface_factory()
self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor)
in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor)
return in_memory_agent_state
# TODO: This is not good!
# TODO: Ideally, this should ALL be handled by the ORM
# TODO: The main blocker here IS the _message updates
def update_agent(
self,
agent_id: str,
request: UpdateAgent,
actor: User,
) -> AgentState:
"""Update the agents core memory block, return the new state"""
# Update agent state in the db first
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# TODO: Everything below needs to get removed, no updating anything in memory
# update the system prompt
if request.system:
letta_agent.update_system_prompt(request.system)
# update in-context messages
if request.message_ids:
# This means the user is trying to change what messages are in the message buffer
# Internally this requires (1) pulling from recall,
# then (2) setting the attributes ._messages and .state.message_ids
letta_agent.set_message_buffer(message_ids=request.message_ids)
letta_agent.update_state()
return agent_state
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
"""Get tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.agent_state.tools
def add_tool_to_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Add tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
return agent_state
def remove_tool_from_agent(
self,
agent_id: str,
tool_id: str,
user_id: str,
):
"""Remove tools from an existing agent"""
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
return agent_state
# convert name->id
# TODO: These can be moved to agent_manager
def get_agent_memory(self, agent_id: str, actor: User) -> Memory:
"""Return the memory of an agent (core memory)"""
agent = self.load_agent(agent_id=agent_id, actor=actor)
return agent.agent_state.memory
return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
agent = self.load_agent(agent_id=agent_id, actor=actor)
return RecallMemorySummary(size=len(agent.message_manager))
def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]:
"""Get the in-context messages in the agent's memory"""
# Get the agent object (loaded in memory)
agent = self.load_agent(agent_id=agent_id, actor=actor)
return agent._messages
return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id))
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]:
"""Paginated query of all messages in agent archival memory"""
@@ -947,24 +817,17 @@ class SyncServer(Server):
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
# Insert into archival memory
passages = self.passage_manager.insert_passage(
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor
)
save_agent(letta_agent)
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor)
return passages
def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User):
# Get the agent object (loaded in memory)
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
# Delete by ID
def delete_archival_memory(self, memory_id: str, actor: User):
# TODO check if it exists first, and throw error if not
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
# TODO: @mindy make this return the deleted passage instead
self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
# TODO: return archival memory
@@ -1042,9 +905,8 @@ class SyncServer(Server):
# update the block
self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor)
# load agent
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.agent_state.memory
# rebuild system prompt for agent, potentially changed
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
def delete_source(self, source_id: str, actor: User):
"""Delete a data source"""
@@ -1214,36 +1076,11 @@ class SyncServer(Server):
return success
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message:
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
"""Update the details of a message associated with an agent"""
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.update_message(message_id=message_id, request=request)
save_agent(letta_agent)
return response
def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.rewrite_message(new_text=new_text)
save_agent(letta_agent)
return response
def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.rethink_message(new_thought=new_thought)
save_agent(letta_agent)
return response
def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]:
# Get the current message
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
response = letta_agent.retry_message()
save_agent(letta_agent)
return response
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
@@ -1331,15 +1168,7 @@ class SyncServer(Server):
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
"""Add a new embedding model"""
def get_agent_context_window(
self,
user_id: str,
agent_id: str,
) -> ContextWindowOverview:
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
actor = self.user_manager.get_user_or_default(user_id=user_id)
# Get the current message
def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview:
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
return letta_agent.get_context_window()