fix: Remove in-memory _messages field on Agent (#2295)
This commit is contained in:
@@ -40,14 +40,14 @@ from letta.providers import (
|
||||
VLLMChatCompletionsProvider,
|
||||
VLLMCompletionsProvider,
|
||||
)
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent, UpdateAgent
|
||||
from letta.schemas.agent import AgentState, AgentType, CreateAgent
|
||||
from letta.schemas.block import BlockUpdate
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
|
||||
# openai schemas
|
||||
from letta.schemas.enums import JobStatus
|
||||
from letta.schemas.job import Job, JobUpdate
|
||||
from letta.schemas.letta_message import ToolReturnMessage, LettaMessage
|
||||
from letta.schemas.letta_message import LettaMessage, ToolReturnMessage
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.memory import (
|
||||
ArchivalMemorySummary,
|
||||
@@ -376,25 +376,6 @@ class SyncServer(Server):
|
||||
)
|
||||
)
|
||||
|
||||
def initialize_agent(self, agent_id, actor, interface: Union[AgentInterface, None] = None, initial_message_sequence=None) -> Agent:
|
||||
"""Initialize an agent from the database"""
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
|
||||
interface = interface or self.default_interface_factory()
|
||||
if agent_state.agent_type == AgentType.memgpt_agent:
|
||||
agent = Agent(agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence)
|
||||
elif agent_state.agent_type == AgentType.offline_memory_agent:
|
||||
agent = OfflineMemoryAgent(
|
||||
agent_state=agent_state, interface=interface, user=actor, initial_message_sequence=initial_message_sequence
|
||||
)
|
||||
else:
|
||||
assert initial_message_sequence is None, f"Initial message sequence is not supported for O1Agents"
|
||||
agent = O1Agent(agent_state=agent_state, interface=interface, user=actor)
|
||||
|
||||
# Persist to agent
|
||||
save_agent(agent)
|
||||
return agent
|
||||
|
||||
def load_agent(self, agent_id: str, actor: User, interface: Union[AgentInterface, None] = None) -> Agent:
|
||||
"""Updated method to load agents from persisted storage"""
|
||||
agent_lock = self.per_agent_lock_manager.get_lock(agent_id)
|
||||
@@ -413,11 +394,6 @@ class SyncServer(Server):
|
||||
else:
|
||||
raise ValueError(f"Invalid agent type {agent_state.agent_type}")
|
||||
|
||||
# Rebuild the system prompt - may be linked to new blocks now
|
||||
agent.rebuild_system_prompt()
|
||||
|
||||
# Persist to agent
|
||||
save_agent(agent)
|
||||
return agent
|
||||
|
||||
def _step(
|
||||
@@ -456,7 +432,7 @@ class SyncServer(Server):
|
||||
)
|
||||
|
||||
# save agent after step
|
||||
save_agent(letta_agent)
|
||||
# save_agent(letta_agent)
|
||||
|
||||
except Exception as e:
|
||||
logger.error(f"Error in server._step: {e}")
|
||||
@@ -790,129 +766,23 @@ class SyncServer(Server):
|
||||
|
||||
"""Create a new agent using a config"""
|
||||
# Invoke manager
|
||||
agent_state = self.agent_manager.create_agent(
|
||||
return self.agent_manager.create_agent(
|
||||
agent_create=request,
|
||||
actor=actor,
|
||||
)
|
||||
|
||||
# create the agent object
|
||||
if request.initial_message_sequence is not None:
|
||||
# init_messages = [Message(user_id=user_id, agent_id=agent_state.id, role=message.role, text=message.text) for message in request.initial_message_sequence]
|
||||
init_messages = []
|
||||
for message in request.initial_message_sequence:
|
||||
|
||||
if message.role == MessageRole.user:
|
||||
packed_message = system.package_user_message(
|
||||
user_message=message.text,
|
||||
)
|
||||
elif message.role == MessageRole.system:
|
||||
packed_message = system.package_system_message(
|
||||
system_message=message.text,
|
||||
)
|
||||
else:
|
||||
raise ValueError(f"Invalid message role: {message.role}")
|
||||
|
||||
init_messages.append(Message(role=message.role, text=packed_message, agent_id=agent_state.id))
|
||||
# init_messages = [Message.dict_to_message(user_id=user_id, agent_id=agent_state.id, openai_message_dict=message.model_dump()) for message in request.initial_message_sequence]
|
||||
else:
|
||||
init_messages = None
|
||||
|
||||
# initialize the agent (generates initial message list with system prompt)
|
||||
if interface is None:
|
||||
interface = self.default_interface_factory()
|
||||
self.initialize_agent(agent_id=agent_state.id, interface=interface, initial_message_sequence=init_messages, actor=actor)
|
||||
|
||||
in_memory_agent_state = self.agent_manager.get_agent_by_id(agent_state.id, actor=actor)
|
||||
return in_memory_agent_state
|
||||
|
||||
# TODO: This is not good!
|
||||
# TODO: Ideally, this should ALL be handled by the ORM
|
||||
# TODO: The main blocker here IS the _message updates
|
||||
def update_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
request: UpdateAgent,
|
||||
actor: User,
|
||||
) -> AgentState:
|
||||
"""Update the agents core memory block, return the new state"""
|
||||
# Update agent state in the db first
|
||||
agent_state = self.agent_manager.update_agent(agent_id=agent_id, agent_update=request, actor=actor)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# TODO: Everything below needs to get removed, no updating anything in memory
|
||||
# update the system prompt
|
||||
if request.system:
|
||||
letta_agent.update_system_prompt(request.system)
|
||||
|
||||
# update in-context messages
|
||||
if request.message_ids:
|
||||
# This means the user is trying to change what messages are in the message buffer
|
||||
# Internally this requires (1) pulling from recall,
|
||||
# then (2) setting the attributes ._messages and .state.message_ids
|
||||
letta_agent.set_message_buffer(message_ids=request.message_ids)
|
||||
|
||||
letta_agent.update_state()
|
||||
|
||||
return agent_state
|
||||
|
||||
def get_tools_from_agent(self, agent_id: str, user_id: Optional[str]) -> List[Tool]:
|
||||
"""Get tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.agent_state.tools
|
||||
|
||||
def add_tool_to_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Add tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
agent_state = self.agent_manager.attach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
def remove_tool_from_agent(
|
||||
self,
|
||||
agent_id: str,
|
||||
tool_id: str,
|
||||
user_id: str,
|
||||
):
|
||||
"""Remove tools from an existing agent"""
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
agent_state = self.agent_manager.detach_tool(agent_id=agent_id, tool_id=tool_id, actor=actor)
|
||||
|
||||
return agent_state
|
||||
|
||||
# convert name->id
|
||||
|
||||
# TODO: These can be moved to agent_manager
|
||||
def get_agent_memory(self, agent_id: str, actor: User) -> Memory:
|
||||
"""Return the memory of an agent (core memory)"""
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return agent.agent_state.memory
|
||||
return self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor).memory
|
||||
|
||||
def get_archival_memory_summary(self, agent_id: str, actor: User) -> ArchivalMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return ArchivalMemorySummary(size=self.agent_manager.passage_size(actor=actor, agent_id=agent_id))
|
||||
|
||||
def get_recall_memory_summary(self, agent_id: str, actor: User) -> RecallMemorySummary:
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return RecallMemorySummary(size=len(agent.message_manager))
|
||||
|
||||
def get_in_context_messages(self, agent_id: str, actor: User) -> List[Message]:
|
||||
"""Get the in-context messages in the agent's memory"""
|
||||
# Get the agent object (loaded in memory)
|
||||
agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return agent._messages
|
||||
return RecallMemorySummary(size=self.message_manager.size(actor=actor, agent_id=agent_id))
|
||||
|
||||
def get_agent_archival(self, user_id: str, agent_id: str, cursor: Optional[str] = None, limit: int = 50) -> List[Passage]:
|
||||
"""Paginated query of all messages in agent archival memory"""
|
||||
@@ -947,24 +817,17 @@ class SyncServer(Server):
|
||||
|
||||
def insert_archival_memory(self, agent_id: str, memory_contents: str, actor: User) -> List[Passage]:
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
agent_state = self.agent_manager.get_agent_by_id(agent_id=agent_id, actor=actor)
|
||||
# Insert into archival memory
|
||||
passages = self.passage_manager.insert_passage(
|
||||
agent_state=letta_agent.agent_state, agent_id=agent_id, text=memory_contents, actor=actor
|
||||
)
|
||||
|
||||
save_agent(letta_agent)
|
||||
# TODO: @mindy look at moving this to agent_manager to avoid above extra call
|
||||
passages = self.passage_manager.insert_passage(agent_state=agent_state, agent_id=agent_id, text=memory_contents, actor=actor)
|
||||
|
||||
return passages
|
||||
|
||||
def delete_archival_memory(self, agent_id: str, memory_id: str, actor: User):
|
||||
# Get the agent object (loaded in memory)
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
|
||||
# Delete by ID
|
||||
def delete_archival_memory(self, memory_id: str, actor: User):
|
||||
# TODO check if it exists first, and throw error if not
|
||||
letta_agent.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
# TODO: @mindy make this return the deleted passage instead
|
||||
self.passage_manager.delete_passage_by_id(passage_id=memory_id, actor=actor)
|
||||
|
||||
# TODO: return archival memory
|
||||
|
||||
@@ -1042,9 +905,8 @@ class SyncServer(Server):
|
||||
# update the block
|
||||
self.block_manager.update_block(block_id=block.id, block_update=BlockUpdate(value=value), actor=actor)
|
||||
|
||||
# load agent
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.agent_state.memory
|
||||
# rebuild system prompt for agent, potentially changed
|
||||
return self.agent_manager.rebuild_system_prompt(agent_id=agent_id, actor=actor).memory
|
||||
|
||||
def delete_source(self, source_id: str, actor: User):
|
||||
"""Delete a data source"""
|
||||
@@ -1214,36 +1076,11 @@ class SyncServer(Server):
|
||||
|
||||
return success
|
||||
|
||||
def update_agent_message(self, agent_id: str, message_id: str, request: MessageUpdate, actor: User) -> Message:
|
||||
def update_agent_message(self, message_id: str, request: MessageUpdate, actor: User) -> Message:
|
||||
"""Update the details of a message associated with an agent"""
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.update_message(message_id=message_id, request=request)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def rewrite_agent_message(self, agent_id: str, new_text: str, actor: User) -> Message:
|
||||
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.rewrite_message(new_text=new_text)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def rethink_agent_message(self, agent_id: str, new_thought: str, actor: User) -> Message:
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.rethink_message(new_thought=new_thought)
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
|
||||
def retry_agent_message(self, agent_id: str, actor: User) -> List[Message]:
|
||||
# Get the current message
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
response = letta_agent.retry_message()
|
||||
save_agent(letta_agent)
|
||||
return response
|
||||
return self.message_manager.update_message_by_id(message_id=message_id, message_update=request, actor=actor)
|
||||
|
||||
def get_organization_or_default(self, org_id: Optional[str]) -> Organization:
|
||||
"""Get the organization object for org_id if it exists, otherwise return the default organization object"""
|
||||
@@ -1331,15 +1168,7 @@ class SyncServer(Server):
|
||||
def add_embedding_model(self, request: EmbeddingConfig) -> EmbeddingConfig:
|
||||
"""Add a new embedding model"""
|
||||
|
||||
def get_agent_context_window(
|
||||
self,
|
||||
user_id: str,
|
||||
agent_id: str,
|
||||
) -> ContextWindowOverview:
|
||||
# TODO: Thread actor directly through this function, since the top level caller most likely already retrieved the user
|
||||
actor = self.user_manager.get_user_or_default(user_id=user_id)
|
||||
|
||||
# Get the current message
|
||||
def get_agent_context_window(self, agent_id: str, actor: User) -> ContextWindowOverview:
|
||||
letta_agent = self.load_agent(agent_id=agent_id, actor=actor)
|
||||
return letta_agent.get_context_window()
|
||||
|
||||
|
||||
Reference in New Issue
Block a user