feat: Move message embedding to background task [LET-4189] (#4430)

* Test background message embedding

* Change to content
This commit is contained in:
Matthew Zhou
2025-09-04 15:05:35 -07:00
committed by GitHub
parent d23318b802
commit 5337e5bcac
3 changed files with 355 additions and 86 deletions

View File

@@ -21,7 +21,7 @@ from letta.server.db import db_registry
from letta.services.file_manager import FileManager
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
from letta.settings import DatabaseChoice, settings
from letta.utils import enforce_types
from letta.utils import enforce_types, fire_and_forget
logger = get_logger(__name__)
@@ -101,7 +101,7 @@ class MessageManager:
args = json.loads(tool_call.function.arguments)
actual_message = args.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
return json.dumps({"thinking": content_str, "message": actual_message})
return json.dumps({"thinking": content_str, "content": actual_message})
except (json.JSONDecodeError, KeyError):
# fallback if parsing fails
pass
@@ -324,6 +324,7 @@ class MessageManager:
pydantic_msgs: List of Pydantic message models to create
actor: User performing the action
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
strict_mode: If True, wait for embedding to complete; if False, run in background
Returns:
List of created Pydantic message models
@@ -363,59 +364,80 @@ class MessageManager:
await session.commit()
# embed messages in turbopuffer if enabled and embedding_config provided
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and embedding_config and result:
try:
# extract agent_id from the first message (all should have same agent_id)
agent_id = result[0].agent_id
if agent_id:
# extract text content from each message
message_texts = []
message_ids = []
roles = []
created_ats = []
# combine assistant+tool messages before embedding
combined_messages = self._combine_assistant_tool_messages(result)
for msg in combined_messages:
text = self._extract_message_text(msg).strip()
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
message_texts.append(text)
message_ids.append(msg.id)
roles.append(msg.role)
created_ats.append(msg.created_at)
if message_texts:
# generate embeddings using provided config
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
# insert to turbopuffer
tpuf_client = TurbopufferClient()
await tpuf_client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=actor.organization_id,
roles=roles,
created_ats=created_ats,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
# extract agent_id from the first message (all should have same agent_id)
agent_id = result[0].agent_id
if agent_id:
if strict_mode:
raise # Re-raise the exception in strict mode
# wait for embedding to complete
await self._embed_messages_background(result, embedding_config, actor, agent_id)
else:
# fire and forget - run embedding in background
fire_and_forget(
self._embed_messages_background(result, embedding_config, actor, agent_id),
task_name=f"embed_messages_for_agent_{agent_id}",
)
return result
async def _embed_messages_background(
self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str
) -> None:
"""Background task to embed and store messages in Turbopuffer.
Args:
messages: List of messages to embed
embedding_config: Embedding configuration
actor: User performing the action
agent_id: Agent ID for the messages
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
# extract text content from each message
message_texts = []
message_ids = []
roles = []
created_ats = []
# combine assistant+tool messages before embedding
combined_messages = self._combine_assistant_tool_messages(messages)
for msg in combined_messages:
text = self._extract_message_text(msg).strip()
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
message_texts.append(text)
message_ids.append(msg.id)
roles.append(msg.role)
created_ats.append(msg.created_at)
if message_texts:
# generate embeddings using provided config
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
# insert to turbopuffer
tpuf_client = TurbopufferClient()
await tpuf_client.insert_messages(
agent_id=agent_id,
message_texts=message_texts,
embeddings=embeddings,
message_ids=message_ids,
organization_id=actor.organization_id,
roles=roles,
created_ats=created_ats,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
logger.error(f"Failed to embed messages in Turbopuffer for agent {agent_id}: {e}")
# don't re-raise the exception in background mode - just log it
@enforce_types
@trace_method
def update_message_by_letta_message(
@@ -525,6 +547,13 @@ class MessageManager:
"""
Updates an existing record in the database with values from the provided record object.
Async version of the function above.
Args:
message_id: ID of the message to update
message_update: Update data for the message
actor: User performing the action
embedding_config: Optional embedding configuration for Turbopuffer
strict_mode: If True, wait for embedding update to complete; if False, run in background
"""
async with db_registry.async_session() as session:
# Fetch existing message from database
@@ -540,49 +569,68 @@ class MessageManager:
await session.commit()
# update message in turbopuffer if enabled (delete and re-insert)
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
try:
# extract text content from updated message
text = self._extract_message_text(pydantic_message)
# extract text content from updated message
text = self._extract_message_text(pydantic_message)
# only update in turbopuffer if there's text content (role filtering is handled in _extract_message_text)
if text:
tpuf_client = TurbopufferClient()
# delete old message from turbopuffer
await tpuf_client.delete_messages(
agent_id=pydantic_message.agent_id, organization_id=actor.organization_id, message_ids=[message_id]
)
# generate new embedding
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([text], embedding_config)
# re-insert with updated content
await tpuf_client.insert_messages(
agent_id=pydantic_message.agent_id,
message_texts=[text],
embeddings=embeddings,
message_ids=[message_id],
organization_id=actor.organization_id,
roles=[pydantic_message.role],
created_ats=[pydantic_message.created_at],
)
logger.info(f"Successfully updated message {message_id} in Turbopuffer")
except Exception as e:
logger.error(f"Failed to update message in Turbopuffer: {e}")
# only update in turbopuffer if there's text content
if text:
if strict_mode:
raise # Re-raise the exception in strict mode
# wait for embedding update to complete
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor)
else:
# fire and forget - run embedding update in background
fire_and_forget(
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor),
task_name=f"update_message_embedding_{message_id}",
)
return pydantic_message
async def _update_message_embedding_background(
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser
) -> None:
"""Background task to update a message's embedding in Turbopuffer.
Args:
message: The updated message
text: Extracted text content from the message
embedding_config: Embedding configuration
actor: User performing the action
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
from letta.llm_api.llm_client import LLMClient
tpuf_client = TurbopufferClient()
# delete old message from turbopuffer
await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id])
# generate new embedding
embedding_client = LLMClient.create(
provider_type=embedding_config.embedding_endpoint_type,
actor=actor,
)
embeddings = await embedding_client.request_embeddings([text], embedding_config)
# re-insert with updated content
await tpuf_client.insert_messages(
agent_id=message.agent_id,
message_texts=[text],
embeddings=embeddings,
message_ids=[message.id],
organization_id=actor.organization_id,
roles=[message.role],
created_ats=[message.created_at],
)
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
except Exception as e:
logger.error(f"Failed to update message {message.id} in Turbopuffer: {e}")
# don't re-raise the exception in background mode - just log it
def _update_message_by_id_impl(
self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel
) -> MessageModel:

View File

@@ -17,7 +17,7 @@ from contextlib import contextmanager
from datetime import datetime, timezone
from functools import wraps
from logging import Logger
from typing import Any, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints
from typing import Any, Callable, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints
from urllib.parse import urljoin, urlparse
import demjson3 as demjson
@@ -1271,3 +1271,36 @@ def truncate_file_visible_content(visible_content: str, is_open: bool, per_file_
visible_content += truncated_warning
return visible_content
def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optional[Callable[[Exception], None]] = None) -> asyncio.Task:
"""
Execute an async coroutine in the background without waiting for completion.
Args:
coro: The coroutine to execute
task_name: Optional name for logging purposes
error_callback: Optional callback to execute if the task fails
Returns:
The created asyncio Task object
"""
import traceback
task = asyncio.create_task(coro)
def callback(t):
try:
t.result() # this re-raises exceptions from the task
except Exception as e:
task_desc = f"Background task {task_name}" if task_name else "Background task"
logger.error(f"{task_desc} failed: {str(e)}\n{traceback.format_exc()}")
if error_callback:
try:
error_callback(e)
except Exception as callback_error:
logger.error(f"Error callback failed: {callback_error}")
task.add_done_callback(callback)
return task

View File

@@ -1,3 +1,4 @@
import asyncio
import uuid
from datetime import datetime, timezone
@@ -1739,6 +1740,193 @@ class TestTurbopufferMessagesIntegration:
# Clean up remaining message (use strict_mode=False since turbopuffer might be mocked)
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
async def wait_for_embedding(
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
) -> bool:
"""Poll Turbopuffer directly to check if a message has been embedded.
Args:
agent_id: Agent ID for the message
message_id: ID of the message to find
organization_id: Organization ID
max_wait: Maximum time to wait in seconds
poll_interval: Time between polls in seconds
Returns:
True if message was found in Turbopuffer within timeout, False otherwise
"""
import asyncio
from letta.helpers.tpuf_client import TurbopufferClient
client = TurbopufferClient()
start_time = asyncio.get_event_loop().time()
while asyncio.get_event_loop().time() - start_time < max_wait:
try:
# Query Turbopuffer directly using timestamp mode to get all messages
results = await client.query_messages(
agent_id=agent_id,
organization_id=organization_id,
search_mode="timestamp",
top_k=100, # Get more messages to ensure we find it
)
# Check if our message ID is in the results
if any(msg["id"] == message_id for msg, _, _ in results):
return True
except Exception as e:
# Log but don't fail - Turbopuffer might still be processing
pass
await asyncio.sleep(poll_interval)
return False
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_creation_background_mode(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that messages are embedded in background when strict_mode=False"""
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
# Create message in background mode
messages = await server.message_manager.create_many_messages_async(
pydantic_msgs=[
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Background test message about Python programming")],
agent_id=sarah_agent.id,
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
assert len(messages) == 1
message_id = messages[0].id
# Message should be in PostgreSQL immediately
sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user)
assert sql_message is not None
assert sql_message.id == message_id
# Poll for embedding completion by querying Turbopuffer directly
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
)
assert embedded, "Message was not embedded in Turbopuffer within timeout"
# Now verify it's also searchable through the search API
search_results = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="Python programming",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert len(search_results) > 0
assert any(msg.id == message_id for msg, _ in search_results)
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_update_background_mode(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that message updates work in background mode"""
from letta.schemas.message import MessageUpdate
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
# Create initial message with strict_mode=True to ensure it's embedded
messages = await server.message_manager.create_many_messages_async(
pydantic_msgs=[
PydanticMessage(
role=MessageRole.user,
content=[TextContent(text="Original content about databases")],
agent_id=sarah_agent.id,
)
],
actor=default_user,
embedding_config=embedding_config,
strict_mode=True, # Ensure initial embedding
)
assert len(messages) == 1
message_id = messages[0].id
# Verify initial content is searchable
initial_results = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in initial_results)
# Update message in background mode
updated_message = await server.message_manager.update_message_by_id_async(
message_id=message_id,
message_update=MessageUpdate(content="Updated content about machine learning"),
actor=default_user,
embedding_config=embedding_config,
strict_mode=False, # Background mode
)
assert updated_message.id == message_id
# PostgreSQL should be updated immediately
sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user)
assert "machine learning" in server.message_manager._extract_message_text(sql_message)
# Wait a bit for the background update to process
await asyncio.sleep(1.0)
# Poll for the update to be reflected in Turbopuffer
# We check by searching for the new content
embedded = await self.wait_for_embedding(
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
)
assert embedded, "Updated message was not re-embedded within timeout"
# Now verify the new content is searchable
new_results = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="machine learning",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
assert any(msg.id == message_id for msg, _ in new_results)
# Old content should eventually no longer be searchable
# (may take a moment for the delete to process)
await asyncio.sleep(2.0)
old_results = await server.message_manager.search_messages_async(
agent_id=sarah_agent.id,
actor=default_user,
query_text="databases",
search_mode="fts",
limit=10,
embedding_config=embedding_config,
)
# The message shouldn't match the old search term anymore
if len(old_results) > 0:
# If we find results, verify our message doesn't contain the old content
for msg, _ in old_results:
if msg.id == message_id:
text = server.message_manager._extract_message_text(msg)
assert "databases" not in text.lower()
# Clean up
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):