fix: Refactor reset messages async to reduce idle in transaction [LET-4529] (#4958)

* Reduce comments

* Move turbopuffer outside of reset messages
This commit is contained in:
Matthew Zhou
2025-09-26 13:03:03 -07:00
committed by Caren Thomas
parent b43193bf53
commit 171f9b0b19
2 changed files with 6 additions and 34 deletions

View File

@@ -1303,10 +1303,8 @@ class AgentManager:
PydanticAgentState: The updated agent state with only the original system message preserved.
"""
async with db_registry.async_session() as session:
# Retrieve the existing agent (will raise NoResultFound if invalid)
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
# Ensure agent has message_ids with at least one message
if not agent.message_ids or len(agent.message_ids) == 0:
logger.error(
f"Agent {agent_id} has no message_ids. Agent details: "
@@ -1315,13 +1313,12 @@ class AgentManager:
)
raise ValueError(f"Agent {agent_id} has no message_ids - cannot preserve system message")
# Get the system message ID (first message)
system_message_id = agent.message_ids[0]
# Delete all messages for the agent except the system message
await self.message_manager.delete_all_messages_for_agent_async(agent_id=agent_id, actor=actor, exclude_ids=[system_message_id])
await self.message_manager.delete_all_messages_for_agent_async(agent_id=agent_id, actor=actor, exclude_ids=[system_message_id])
# Update agent to only keep the system message
async with db_registry.async_session() as session:
agent = await AgentModel.read_async(db_session=session, identifier=agent_id, actor=actor)
agent.message_ids = [system_message_id]
await agent.update_async(db_session=session, actor=actor)
agent_state = await agent.to_pydantic_async(include_relationships=["sources"])

View File

@@ -377,28 +377,21 @@ class MessageManager:
created_messages = await MessageModel.batch_create_async(orm_messages, session, actor=actor, no_commit=True, no_refresh=True)
result = [msg.to_pydantic() for msg in created_messages]
await session.commit()
# session is now fully closed
# embed messages in turbopuffer if enabled (outside of DB session)
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and result:
# extract agent_id from the first message (all should have same agent_id)
agent_id = result[0].agent_id
if agent_id:
if strict_mode:
# wait for embedding to complete
await self._embed_messages_background(result, actor, agent_id, project_id, template_id)
else:
# fire and forget - run embedding in background
fire_and_forget(
self._embed_messages_background(result, actor, agent_id, project_id, template_id),
task_name=f"embed_messages_for_agent_{agent_id}",
)
# if allow_partial, combine newly created with existing
if allow_partial and existing_messages:
# fetch the existing messages to return complete data
async with db_registry.async_session() as session:
existing_ids = [msg.id for msg in existing_messages if msg.id]
query = select(MessageModel).where(MessageModel.id.in_(existing_ids), MessageModel.organization_id == actor.organization_id)
@@ -538,22 +531,16 @@ class MessageManager:
await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
pydantic_message = message.to_pydantic()
await session.commit()
# session is now fully closed
# update message in turbopuffer if enabled (delete and re-insert) - outside of DB session
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and pydantic_message.agent_id:
# extract text content from updated message
text = self._extract_message_text(pydantic_message)
# only update in turbopuffer if there's text content
if text:
if strict_mode:
# wait for embedding update to complete
await self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id)
else:
# fire and forget - run embedding update in background
fire_and_forget(
self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id),
task_name=f"update_message_embedding_{message_id}",
@@ -640,9 +627,7 @@ class MessageManager:
await msg.hard_delete_async(session, actor=actor)
except NoResultFound:
raise ValueError(f"Message with id {message_id} not found.")
# session is now fully closed
# delete from turbopuffer if enabled (outside of DB session)
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and agent_id:
@@ -653,7 +638,7 @@ class MessageManager:
except Exception as e:
logger.error(f"Failed to delete message from Turbopuffer: {e}")
if strict_mode:
raise # Re-raise the exception in strict mode
raise
return True
@@ -838,7 +823,6 @@ class MessageManager:
# 4) commit once
await session.commit()
# session is now fully closed
# 5) delete from turbopuffer if enabled (outside of DB session)
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
@@ -847,17 +831,13 @@ class MessageManager:
try:
tpuf_client = TurbopufferClient()
if exclude_ids:
# if we're excluding some IDs, we can't use delete_all
# would need to query all messages first then delete specific ones
# for now, log a warning
logger.warning(f"Turbopuffer deletion with exclude_ids not fully supported, using delete_all for agent {agent_id}")
# delete all messages for the agent from turbopuffer
await tpuf_client.delete_all_messages(agent_id, actor.organization_id)
logger.info(f"Successfully deleted all messages for agent {agent_id} from Turbopuffer")
except Exception as e:
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
if strict_mode:
raise # Re-raise the exception in strict mode
raise
# 6) return the number of rows deleted
return rowcount
@@ -872,7 +852,6 @@ class MessageManager:
if not message_ids:
return 0
# get agent_ids BEFORE deleting (for turbopuffer)
agent_ids = []
rowcount = 0
@@ -896,22 +875,18 @@ class MessageManager:
# commit once
await session.commit()
# session is now fully closed
# delete from turbopuffer if enabled (outside of DB session)
if should_use_tpuf_for_messages() and agent_ids:
try:
tpuf_client = TurbopufferClient()
# delete from each affected agent's namespace
for agent_id in agent_ids:
await tpuf_client.delete_messages(agent_id=agent_id, organization_id=actor.organization_id, message_ids=message_ids)
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer")
except Exception as e:
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
if strict_mode:
raise # Re-raise the exception in strict mode
raise
# return the number of rows deleted
return rowcount
@enforce_types