feat: Extend crud lifecycle of messages [LET-4158] (#4364)
Extend crud lifecycle of messages
This commit is contained in:
@@ -775,6 +775,27 @@ class TurbopufferClient:
|
||||
logger.error(f"Failed to delete all passages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_messages(self, agent_id: str, message_ids: List[str]) -> bool:
|
||||
"""Delete multiple messages from Turbopuffer."""
|
||||
from turbopuffer import AsyncTurbopuffer
|
||||
|
||||
if not message_ids:
|
||||
return True
|
||||
|
||||
namespace_name = await self._get_message_namespace_name(agent_id)
|
||||
|
||||
try:
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
namespace = client.namespace(namespace_name)
|
||||
# Use write API with deletes parameter as per Turbopuffer docs
|
||||
await namespace.write(deletes=message_ids)
|
||||
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer for agent {agent_id}")
|
||||
return True
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
||||
raise
|
||||
|
||||
@trace_method
|
||||
async def delete_all_messages(self, agent_id: str) -> bool:
|
||||
"""Delete all messages for an agent from Turbopuffer."""
|
||||
|
||||
@@ -10,6 +10,7 @@ from letta.orm.agent import Agent as AgentModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.message import Message as MessageModel
|
||||
from letta.otel.tracing import trace_method
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import MessageRole
|
||||
from letta.schemas.letta_message import LettaMessageUpdateUnion
|
||||
from letta.schemas.letta_message_content import ImageSourceType, LettaImage, MessageContentType, TextContent
|
||||
@@ -161,7 +162,8 @@ class MessageManager:
|
||||
self,
|
||||
pydantic_msgs: List[PydanticMessage],
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[dict] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Create multiple messages in a single database transaction asynchronously.
|
||||
@@ -253,8 +255,9 @@ class MessageManager:
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
# log error but don't fail the message creation
|
||||
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return result
|
||||
|
||||
@@ -356,7 +359,14 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def update_message_by_id_async(self, message_id: str, message_update: MessageUpdate, actor: PydanticUser) -> PydanticMessage:
|
||||
async def update_message_by_id_async(
|
||||
self,
|
||||
message_id: str,
|
||||
message_update: MessageUpdate,
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
Async version of the function above.
|
||||
@@ -373,6 +383,47 @@ class MessageManager:
|
||||
await message.update_async(db_session=session, actor=actor, no_commit=True, no_refresh=True)
|
||||
pydantic_message = message.to_pydantic()
|
||||
await session.commit()
|
||||
|
||||
# update message in turbopuffer if enabled (delete and re-insert)
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
|
||||
try:
|
||||
# extract text content from updated message
|
||||
text = self._extract_message_text(pydantic_message)
|
||||
|
||||
# only update in turbopuffer if there's text content (role filtering is handled in _extract_message_text)
|
||||
if text:
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# delete old message from turbopuffer
|
||||
await tpuf_client.delete_messages(agent_id=pydantic_message.agent_id, message_ids=[message_id])
|
||||
|
||||
# generate new embedding
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], embedding_config)
|
||||
|
||||
# re-insert with updated content
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=pydantic_message.agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=embeddings,
|
||||
message_ids=[message_id],
|
||||
organization_id=actor.organization_id,
|
||||
roles=[pydantic_message.role],
|
||||
created_ats=[pydantic_message.created_at],
|
||||
)
|
||||
logger.info(f"Successfully updated message {message_id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message in Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return pydantic_message
|
||||
|
||||
def _update_message_by_id_impl(
|
||||
@@ -412,6 +463,39 @@ class MessageManager:
|
||||
actor=actor,
|
||||
)
|
||||
msg.hard_delete(session, actor=actor)
|
||||
# Note: Turbopuffer deletion requires async, use delete_message_by_id_async for full deletion
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Message with id {message_id} not found.")
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_message_by_id_async(self, message_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
|
||||
"""Delete a message (async version with turbopuffer support)."""
|
||||
async with db_registry.async_session() as session:
|
||||
try:
|
||||
msg = await MessageModel.read_async(
|
||||
db_session=session,
|
||||
identifier=message_id,
|
||||
actor=actor,
|
||||
)
|
||||
agent_id = msg.agent_id
|
||||
await msg.hard_delete_async(session, actor=actor)
|
||||
|
||||
# delete from turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and agent_id:
|
||||
try:
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_messages(agent_id=agent_id, message_ids=[message_id])
|
||||
logger.info(f"Successfully deleted message {message_id} from Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete message from Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return True
|
||||
|
||||
except NoResultFound:
|
||||
raise ValueError(f"Message with id {message_id} not found.")
|
||||
|
||||
@@ -712,7 +796,9 @@ class MessageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_all_messages_for_agent_async(self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None) -> int:
|
||||
async def delete_all_messages_for_agent_async(
|
||||
self, agent_id: str, actor: PydanticUser, exclude_ids: Optional[List[str]] = None, strict_mode: bool = False
|
||||
) -> int:
|
||||
"""
|
||||
Efficiently deletes all messages associated with a given agent_id,
|
||||
while enforcing permission checks and avoiding any ORM‑level loads.
|
||||
@@ -736,12 +822,31 @@ class MessageManager:
|
||||
# 4) commit once
|
||||
await session.commit()
|
||||
|
||||
# 5) return the number of rows deleted
|
||||
# 5) delete from turbopuffer if enabled
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages():
|
||||
try:
|
||||
tpuf_client = TurbopufferClient()
|
||||
if exclude_ids:
|
||||
# if we're excluding some IDs, we can't use delete_all
|
||||
# would need to query all messages first then delete specific ones
|
||||
# for now, log a warning
|
||||
logger.warning(f"Turbopuffer deletion with exclude_ids not fully supported, using delete_all for agent {agent_id}")
|
||||
# delete all messages for the agent from turbopuffer
|
||||
await tpuf_client.delete_all_messages(agent_id)
|
||||
logger.info(f"Successfully deleted all messages for agent {agent_id} from Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
# 6) return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser) -> int:
|
||||
async def delete_messages_by_ids_async(self, message_ids: List[str], actor: PydanticUser, strict_mode: bool = False) -> int:
|
||||
"""
|
||||
Efficiently deletes messages by their specific IDs,
|
||||
while enforcing permission checks.
|
||||
@@ -750,6 +855,20 @@ class MessageManager:
|
||||
return 0
|
||||
|
||||
async with db_registry.async_session() as session:
|
||||
# get agent_ids BEFORE deleting (for turbopuffer)
|
||||
agent_ids = []
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages():
|
||||
agent_query = (
|
||||
select(MessageModel.agent_id)
|
||||
.where(MessageModel.id.in_(message_ids))
|
||||
.where(MessageModel.organization_id == actor.organization_id)
|
||||
.distinct()
|
||||
)
|
||||
agent_result = await session.execute(agent_query)
|
||||
agent_ids = [row[0] for row in agent_result.fetchall() if row[0]]
|
||||
|
||||
# issue a CORE DELETE against the mapped class for specific message IDs
|
||||
stmt = delete(MessageModel).where(MessageModel.id.in_(message_ids)).where(MessageModel.organization_id == actor.organization_id)
|
||||
result = await session.execute(stmt)
|
||||
@@ -757,6 +876,19 @@ class MessageManager:
|
||||
# commit once
|
||||
await session.commit()
|
||||
|
||||
# delete from turbopuffer if enabled
|
||||
if should_use_tpuf_for_messages() and agent_ids:
|
||||
try:
|
||||
tpuf_client = TurbopufferClient()
|
||||
# delete from each affected agent's namespace
|
||||
for agent_id in agent_ids:
|
||||
await tpuf_client.delete_messages(agent_id=agent_id, message_ids=message_ids)
|
||||
logger.info(f"Successfully deleted {len(message_ids)} messages from Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete messages from Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
# return the number of rows deleted
|
||||
return result.rowcount
|
||||
|
||||
@@ -773,7 +905,7 @@ class MessageManager:
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
embedding_config: Optional[dict] = None,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Search messages using Turbopuffer if enabled, otherwise fall back to SQL search.
|
||||
|
||||
@@ -11,6 +11,7 @@ from letta.constants import MAX_EMBEDDING_DIM
|
||||
from letta.embeddings import parse_and_chunk_text
|
||||
from letta.helpers.decorators import async_redis_cache
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
from letta.log import get_logger
|
||||
from letta.orm import ArchivesAgents
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.orm.passage import ArchivalPassage, SourcePassage
|
||||
@@ -25,6 +26,8 @@ from letta.server.db import db_registry
|
||||
from letta.services.archive_manager import ArchiveManager
|
||||
from letta.utils import enforce_types
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
|
||||
# TODO: Add redis-backed caching for backend
|
||||
@lru_cache(maxsize=8192)
|
||||
@@ -552,6 +555,7 @@ class PassageManager:
|
||||
actor: PydanticUser,
|
||||
tags: Optional[List[str]] = None,
|
||||
created_at: Optional[datetime] = None,
|
||||
strict_mode: bool = False,
|
||||
) -> List[PydanticPassage]:
|
||||
"""Insert passage(s) into archival memory
|
||||
|
||||
@@ -609,24 +613,29 @@ class PassageManager:
|
||||
|
||||
# If archive uses Turbopuffer, also write to Turbopuffer (dual-write)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Extract IDs and texts from the created passages
|
||||
passage_ids = [p.id for p in passages]
|
||||
passage_texts = [p.text for p in passages]
|
||||
# Extract IDs and texts from the created passages
|
||||
passage_ids = [p.id for p in passages]
|
||||
passage_texts = [p.text for p in passages]
|
||||
|
||||
# Insert to Turbopuffer with the same IDs as SQL
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=passage_texts,
|
||||
embeddings=embeddings,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
organization_id=actor.organization_id,
|
||||
tags=tags,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
)
|
||||
# Insert to Turbopuffer with the same IDs as SQL
|
||||
await tpuf_client.insert_archival_memories(
|
||||
archive_id=archive.id,
|
||||
text_chunks=passage_texts,
|
||||
embeddings=embeddings,
|
||||
passage_ids=passage_ids, # Use same IDs as SQL
|
||||
organization_id=actor.organization_id,
|
||||
tags=tags,
|
||||
created_at=passages[0].created_at if passages else None,
|
||||
)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to insert passages to Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return passages
|
||||
|
||||
@@ -801,7 +810,7 @@ class PassageManager:
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser) -> bool:
|
||||
async def delete_agent_passage_by_id_async(self, passage_id: str, actor: PydanticUser, strict_mode: bool = False) -> bool:
|
||||
"""Delete an agent passage."""
|
||||
if not passage_id:
|
||||
raise ValueError("Passage ID must be provided.")
|
||||
@@ -818,10 +827,15 @@ class PassageManager:
|
||||
if archive_id:
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passage(archive_id=archive_id, passage_id=passage_id)
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passage(archive_id=archive_id, passage_id=passage_id)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete passage from Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return True
|
||||
except NoResultFound:
|
||||
@@ -981,6 +995,7 @@ class PassageManager:
|
||||
self,
|
||||
passages: List[PydanticPassage],
|
||||
actor: PydanticUser,
|
||||
strict_mode: bool = False,
|
||||
) -> bool:
|
||||
"""Delete multiple agent passages."""
|
||||
if not passages:
|
||||
@@ -1002,10 +1017,15 @@ class PassageManager:
|
||||
for archive_id, passage_ids in passages_by_archive.items():
|
||||
archive = await self.archive_manager.get_archive_by_id_async(archive_id=archive_id, actor=actor)
|
||||
if archive.vector_db_provider == VectorDBProvider.TPUF:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passages(archive_id=archive_id, passage_ids=passage_ids)
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.delete_passages(archive_id=archive_id, passage_ids=passage_ids)
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to delete passages from Turbopuffer: {e}")
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
|
||||
return True
|
||||
|
||||
|
||||
@@ -178,7 +178,9 @@ class TestTurbopufferIntegration:
|
||||
]
|
||||
|
||||
for text in test_passages:
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text, actor=default_user)
|
||||
passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text=text, actor=default_user, strict_mode=True
|
||||
)
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
|
||||
@@ -208,7 +210,7 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Test deletion - should delete from both
|
||||
passage_to_delete = sql_passages[0]
|
||||
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user)
|
||||
await server.passage_manager.delete_agent_passages_async([passage_to_delete], default_user, strict_mode=True)
|
||||
|
||||
# Verify deleted from SQL
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
@@ -310,7 +312,9 @@ class TestTurbopufferIntegration:
|
||||
|
||||
# Insert passages - should only write to SQL
|
||||
text_content = "This is a test passage for native PostgreSQL only."
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=text_content, actor=default_user)
|
||||
passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text=text_content, actor=default_user, strict_mode=True
|
||||
)
|
||||
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
@@ -721,7 +725,9 @@ class TestTurbopufferParametrized:
|
||||
|
||||
# Test inserting a passage (should work in both modes)
|
||||
test_text = f"Test passage for {expected_provider} mode"
|
||||
passages = await server.passage_manager.insert_passage(agent_state=sarah_agent, text=test_text, actor=default_user)
|
||||
passages = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text=test_text, actor=default_user, strict_mode=True
|
||||
)
|
||||
|
||||
assert passages is not None
|
||||
assert len(passages) > 0
|
||||
@@ -732,7 +738,7 @@ class TestTurbopufferParametrized:
|
||||
assert any(p.text == test_text for p in listed)
|
||||
|
||||
# Delete should work in both modes
|
||||
await server.passage_manager.delete_agent_passages_async(passages, default_user)
|
||||
await server.passage_manager.delete_agent_passages_async(passages, default_user, strict_mode=True)
|
||||
|
||||
# Verify deletion
|
||||
remaining = await server.agent_manager.query_agent_passages_async(actor=default_user, agent_id=sarah_agent.id, limit=10)
|
||||
@@ -750,11 +756,11 @@ class TestTurbopufferParametrized:
|
||||
|
||||
# Insert passages with specific timestamps
|
||||
recent_passage = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text="Recent update from today", actor=default_user, created_at=now
|
||||
agent_state=sarah_agent, text="Recent update from today", actor=default_user, created_at=now, strict_mode=True
|
||||
)
|
||||
|
||||
old_passage = await server.passage_manager.insert_passage(
|
||||
agent_state=sarah_agent, text="Old update from last week", actor=default_user, created_at=last_week
|
||||
agent_state=sarah_agent, text="Old update from last week", actor=default_user, created_at=last_week, strict_mode=True
|
||||
)
|
||||
|
||||
# Query with date range that includes only recent passage
|
||||
@@ -785,8 +791,8 @@ class TestTurbopufferParametrized:
|
||||
assert not any("Recent update from today" in p.text for p in old_results)
|
||||
|
||||
# Clean up
|
||||
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user)
|
||||
await server.passage_manager.delete_agent_passages_async(old_passage, default_user)
|
||||
await server.passage_manager.delete_agent_passages_async(recent_passage, default_user, strict_mode=True)
|
||||
await server.passage_manager.delete_agent_passages_async(old_passage, default_user, strict_mode=True)
|
||||
|
||||
|
||||
class TestTurbopufferMessagesIntegration:
|
||||
@@ -1361,6 +1367,369 @@ class TestTurbopufferMessagesIntegration:
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_update_reindexes_in_turbopuffer(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that updating a message properly deletes and re-inserts with new embedding in Turbopuffer"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial message
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Original content about Python programming")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Search for "Python" - should find it
|
||||
python_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(python_results) > 0
|
||||
assert any(msg.id == message_id for msg in python_results)
|
||||
|
||||
# Update the message content
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about JavaScript development"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
assert updated_message.id == message_id # ID should remain the same
|
||||
|
||||
# Search for "Python" - should NOT find it anymore
|
||||
python_results_after = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="Python",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg in python_results_after)
|
||||
|
||||
# Search for "JavaScript" - should find the updated message
|
||||
js_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="JavaScript",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(js_results) > 0
|
||||
assert any(msg.id == message_id for msg in js_results)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_deletion_syncs_with_turbopuffer(self, server, default_user, enable_message_embedding):
|
||||
"""Test that all deletion methods properly sync with Turbopuffer"""
|
||||
from letta.schemas.agent import CreateAgent
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
|
||||
# Create two test agents
|
||||
agent_a = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Agent A",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent_b = await server.agent_manager.create_agent_async(
|
||||
agent_create=CreateAgent(
|
||||
name="Agent B",
|
||||
memory_blocks=[],
|
||||
llm_config=LLMConfig.default_config("gpt-4o-mini"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
include_base_tools=False,
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
embedding_config = agent_a.embedding_config
|
||||
|
||||
try:
|
||||
# Create 5 messages for agent A
|
||||
agent_a_messages = []
|
||||
for i in range(5):
|
||||
msgs = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Agent A message {i + 1}")],
|
||||
agent_id=agent_a.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_a_messages.extend(msgs)
|
||||
|
||||
# Create 3 messages for agent B
|
||||
agent_b_messages = []
|
||||
for i in range(3):
|
||||
msgs = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text=f"Agent B message {i + 1}")],
|
||||
agent_id=agent_b.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
agent_b_messages.extend(msgs)
|
||||
|
||||
# Verify initial state - all messages are searchable
|
||||
agent_a_search = await server.message_manager.search_messages_async(
|
||||
agent_id=agent_a.id,
|
||||
actor=default_user,
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_search) == 5
|
||||
|
||||
agent_b_search = await server.message_manager.search_messages_async(
|
||||
agent_id=agent_b.id,
|
||||
actor=default_user,
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_search) == 3
|
||||
|
||||
# Test 1: Delete single message from agent A
|
||||
await server.message_manager.delete_message_by_id_async(agent_a_messages[0].id, default_user, strict_mode=True)
|
||||
|
||||
# Test 2: Batch delete 2 messages from agent A
|
||||
await server.message_manager.delete_messages_by_ids_async(
|
||||
[agent_a_messages[1].id, agent_a_messages[2].id], default_user, strict_mode=True
|
||||
)
|
||||
|
||||
# Test 3: Delete all messages for agent B
|
||||
await server.message_manager.delete_all_messages_for_agent_async(agent_b.id, default_user, strict_mode=True)
|
||||
|
||||
# Verify final state
|
||||
# Agent A should have 2 messages left (5 - 1 - 2 = 2)
|
||||
agent_a_final = await server.message_manager.search_messages_async(
|
||||
agent_id=agent_a.id,
|
||||
actor=default_user,
|
||||
query_text="Agent A",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_a_final) == 2
|
||||
# Verify the remaining messages are the correct ones
|
||||
remaining_ids = {msg.id for msg in agent_a_final}
|
||||
assert agent_a_messages[3].id in remaining_ids
|
||||
assert agent_a_messages[4].id in remaining_ids
|
||||
|
||||
# Agent B should have 0 messages
|
||||
agent_b_final = await server.message_manager.search_messages_async(
|
||||
agent_id=agent_b.id,
|
||||
actor=default_user,
|
||||
query_text="Agent B",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(agent_b_final) == 0
|
||||
|
||||
finally:
|
||||
# Clean up agents
|
||||
await server.agent_manager.delete_agent_async(agent_a.id, default_user)
|
||||
await server.agent_manager.delete_agent_async(agent_b.id, default_user)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_crud_operations_without_embedding_config(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that CRUD operations handle missing embedding_config gracefully"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create message WITH embedding_config
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message with searchable content about databases")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Verify message is searchable initially
|
||||
initial_search = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(initial_search) > 0
|
||||
assert any(msg.id == message_id for msg in initial_search)
|
||||
|
||||
# Update message WITHOUT embedding_config - should update postgres but not turbopuffer
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about algorithms"),
|
||||
actor=default_user,
|
||||
embedding_config=None, # No config provided
|
||||
)
|
||||
|
||||
# Verify postgres was updated
|
||||
assert updated_message.id == message_id
|
||||
updated_text = server.message_manager._extract_message_text(updated_message)
|
||||
assert "algorithms" in updated_text
|
||||
assert "databases" not in updated_text
|
||||
|
||||
# Original search term should STILL find the message (turbopuffer wasn't updated)
|
||||
still_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(still_searchable) > 0
|
||||
assert any(msg.id == message_id for msg in still_searchable)
|
||||
|
||||
# New content should NOT be searchable (wasn't re-indexed)
|
||||
not_searchable = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="algorithms",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# Should either find no results or results that don't include our message
|
||||
assert not any(msg.id == message_id for msg in not_searchable)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_turbopuffer_failure_does_not_break_postgres(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that postgres operations succeed even if turbopuffer fails"""
|
||||
from unittest.mock import AsyncMock, patch
|
||||
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial messages
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Test message for error handling")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Mock turbopuffer client to raise exceptions
|
||||
with patch(
|
||||
"letta.helpers.tpuf_client.TurbopufferClient.delete_messages",
|
||||
new=AsyncMock(side_effect=Exception("Turbopuffer connection failed")),
|
||||
):
|
||||
with patch(
|
||||
"letta.helpers.tpuf_client.TurbopufferClient.insert_messages",
|
||||
new=AsyncMock(side_effect=Exception("Turbopuffer insert failed")),
|
||||
):
|
||||
# Test 1: Update should succeed in postgres despite turbopuffer failure
|
||||
# NOTE: strict_mode=False here because we're testing error resilience
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated despite turbopuffer failure"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Don't fail on turbopuffer errors - that's what we're testing!
|
||||
)
|
||||
|
||||
# Verify postgres was updated successfully
|
||||
assert updated_message.id == message_id
|
||||
updated_text = server.message_manager._extract_message_text(updated_message)
|
||||
assert "Updated despite turbopuffer failure" in updated_text
|
||||
|
||||
# Test 2: Delete should succeed in postgres despite turbopuffer failure
|
||||
# First create another message to delete
|
||||
messages2 = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message to delete")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=None, # Create without embedding to avoid mock issues
|
||||
)
|
||||
message_to_delete_id = messages2[0].id
|
||||
|
||||
# Delete with mocked turbopuffer failure
|
||||
# NOTE: strict_mode=False here because we're testing error resilience
|
||||
deletion_result = await server.message_manager.delete_message_by_id_async(
|
||||
message_to_delete_id, default_user, strict_mode=False
|
||||
)
|
||||
assert deletion_result == True
|
||||
|
||||
# Verify message is deleted from postgres
|
||||
deleted_msg = await server.message_manager.get_message_by_id_async(message_to_delete_id, default_user)
|
||||
assert deleted_msg is None
|
||||
|
||||
# Clean up remaining message (use strict_mode=False since turbopuffer might be mocked)
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
|
||||
Reference in New Issue
Block a user