feat: Move message embedding to background task [LET-4189] (#4430)
* Test background message embedding * Change to content
This commit is contained in:
@@ -21,7 +21,7 @@ from letta.server.db import db_registry
|
||||
from letta.services.file_manager import FileManager
|
||||
from letta.services.helpers.agent_manager_helper import validate_agent_exists_async
|
||||
from letta.settings import DatabaseChoice, settings
|
||||
from letta.utils import enforce_types
|
||||
from letta.utils import enforce_types, fire_and_forget
|
||||
|
||||
logger = get_logger(__name__)
|
||||
|
||||
@@ -101,7 +101,7 @@ class MessageManager:
|
||||
args = json.loads(tool_call.function.arguments)
|
||||
actual_message = args.get(DEFAULT_MESSAGE_TOOL_KWARG, "")
|
||||
|
||||
return json.dumps({"thinking": content_str, "message": actual_message})
|
||||
return json.dumps({"thinking": content_str, "content": actual_message})
|
||||
except (json.JSONDecodeError, KeyError):
|
||||
# fallback if parsing fails
|
||||
pass
|
||||
@@ -324,6 +324,7 @@ class MessageManager:
|
||||
pydantic_msgs: List of Pydantic message models to create
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
|
||||
strict_mode: If True, wait for embedding to complete; if False, run in background
|
||||
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
@@ -363,59 +364,80 @@ class MessageManager:
|
||||
await session.commit()
|
||||
|
||||
# embed messages in turbopuffer if enabled and embedding_config provided
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and result:
|
||||
try:
|
||||
# extract agent_id from the first message (all should have same agent_id)
|
||||
agent_id = result[0].agent_id
|
||||
if agent_id:
|
||||
# extract text content from each message
|
||||
message_texts = []
|
||||
message_ids = []
|
||||
roles = []
|
||||
created_ats = []
|
||||
# combine assistant+tool messages before embedding
|
||||
combined_messages = self._combine_assistant_tool_messages(result)
|
||||
|
||||
for msg in combined_messages:
|
||||
text = self._extract_message_text(msg).strip()
|
||||
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
|
||||
message_texts.append(text)
|
||||
message_ids.append(msg.id)
|
||||
roles.append(msg.role)
|
||||
created_ats.append(msg.created_at)
|
||||
|
||||
if message_texts:
|
||||
# generate embeddings using provided config
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||||
|
||||
# insert to turbopuffer
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=actor.organization_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed messages in Turbopuffer: {e}")
|
||||
|
||||
# extract agent_id from the first message (all should have same agent_id)
|
||||
agent_id = result[0].agent_id
|
||||
if agent_id:
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
# wait for embedding to complete
|
||||
await self._embed_messages_background(result, embedding_config, actor, agent_id)
|
||||
else:
|
||||
# fire and forget - run embedding in background
|
||||
fire_and_forget(
|
||||
self._embed_messages_background(result, embedding_config, actor, agent_id),
|
||||
task_name=f"embed_messages_for_agent_{agent_id}",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _embed_messages_background(
|
||||
self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str
|
||||
) -> None:
|
||||
"""Background task to embed and store messages in Turbopuffer.
|
||||
|
||||
Args:
|
||||
messages: List of messages to embed
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
agent_id: Agent ID for the messages
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
# extract text content from each message
|
||||
message_texts = []
|
||||
message_ids = []
|
||||
roles = []
|
||||
created_ats = []
|
||||
|
||||
# combine assistant+tool messages before embedding
|
||||
combined_messages = self._combine_assistant_tool_messages(messages)
|
||||
|
||||
for msg in combined_messages:
|
||||
text = self._extract_message_text(msg).strip()
|
||||
if text: # only embed messages with text content (role filtering is handled in _extract_message_text)
|
||||
message_texts.append(text)
|
||||
message_ids.append(msg.id)
|
||||
roles.append(msg.role)
|
||||
created_ats.append(msg.created_at)
|
||||
|
||||
if message_texts:
|
||||
# generate embeddings using provided config
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(message_texts, embedding_config)
|
||||
|
||||
# insert to turbopuffer
|
||||
tpuf_client = TurbopufferClient()
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=agent_id,
|
||||
message_texts=message_texts,
|
||||
embeddings=embeddings,
|
||||
message_ids=message_ids,
|
||||
organization_id=actor.organization_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to embed messages in Turbopuffer for agent {agent_id}: {e}")
|
||||
# don't re-raise the exception in background mode - just log it
|
||||
|
||||
@enforce_types
|
||||
@trace_method
|
||||
def update_message_by_letta_message(
|
||||
@@ -525,6 +547,13 @@ class MessageManager:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
Async version of the function above.
|
||||
|
||||
Args:
|
||||
message_id: ID of the message to update
|
||||
message_update: Update data for the message
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration for Turbopuffer
|
||||
strict_mode: If True, wait for embedding update to complete; if False, run in background
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch existing message from database
|
||||
@@ -540,49 +569,68 @@ class MessageManager:
|
||||
await session.commit()
|
||||
|
||||
# update message in turbopuffer if enabled (delete and re-insert)
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
from letta.helpers.tpuf_client import should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id:
|
||||
try:
|
||||
# extract text content from updated message
|
||||
text = self._extract_message_text(pydantic_message)
|
||||
# extract text content from updated message
|
||||
text = self._extract_message_text(pydantic_message)
|
||||
|
||||
# only update in turbopuffer if there's text content (role filtering is handled in _extract_message_text)
|
||||
if text:
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# delete old message from turbopuffer
|
||||
await tpuf_client.delete_messages(
|
||||
agent_id=pydantic_message.agent_id, organization_id=actor.organization_id, message_ids=[message_id]
|
||||
)
|
||||
|
||||
# generate new embedding
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], embedding_config)
|
||||
|
||||
# re-insert with updated content
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=pydantic_message.agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=embeddings,
|
||||
message_ids=[message_id],
|
||||
organization_id=actor.organization_id,
|
||||
roles=[pydantic_message.role],
|
||||
created_ats=[pydantic_message.created_at],
|
||||
)
|
||||
logger.info(f"Successfully updated message {message_id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message in Turbopuffer: {e}")
|
||||
# only update in turbopuffer if there's text content
|
||||
if text:
|
||||
if strict_mode:
|
||||
raise # Re-raise the exception in strict mode
|
||||
# wait for embedding update to complete
|
||||
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor)
|
||||
else:
|
||||
# fire and forget - run embedding update in background
|
||||
fire_and_forget(
|
||||
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor),
|
||||
task_name=f"update_message_embedding_{message_id}",
|
||||
)
|
||||
|
||||
return pydantic_message
|
||||
|
||||
async def _update_message_embedding_background(
|
||||
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser
|
||||
) -> None:
|
||||
"""Background task to update a message's embedding in Turbopuffer.
|
||||
|
||||
Args:
|
||||
message: The updated message
|
||||
text: Extracted text content from the message
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# delete old message from turbopuffer
|
||||
await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id])
|
||||
|
||||
# generate new embedding
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=embedding_config.embedding_endpoint_type,
|
||||
actor=actor,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings([text], embedding_config)
|
||||
|
||||
# re-insert with updated content
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=message.agent_id,
|
||||
message_texts=[text],
|
||||
embeddings=embeddings,
|
||||
message_ids=[message.id],
|
||||
organization_id=actor.organization_id,
|
||||
roles=[message.role],
|
||||
created_ats=[message.created_at],
|
||||
)
|
||||
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
logger.error(f"Failed to update message {message.id} in Turbopuffer: {e}")
|
||||
# don't re-raise the exception in background mode - just log it
|
||||
|
||||
def _update_message_by_id_impl(
|
||||
self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel
|
||||
) -> MessageModel:
|
||||
|
||||
@@ -17,7 +17,7 @@ from contextlib import contextmanager
|
||||
from datetime import datetime, timezone
|
||||
from functools import wraps
|
||||
from logging import Logger
|
||||
from typing import Any, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from typing import Any, Callable, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints
|
||||
from urllib.parse import urljoin, urlparse
|
||||
|
||||
import demjson3 as demjson
|
||||
@@ -1271,3 +1271,36 @@ def truncate_file_visible_content(visible_content: str, is_open: bool, per_file_
|
||||
visible_content += truncated_warning
|
||||
|
||||
return visible_content
|
||||
|
||||
|
||||
def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optional[Callable[[Exception], None]] = None) -> asyncio.Task:
|
||||
"""
|
||||
Execute an async coroutine in the background without waiting for completion.
|
||||
|
||||
Args:
|
||||
coro: The coroutine to execute
|
||||
task_name: Optional name for logging purposes
|
||||
error_callback: Optional callback to execute if the task fails
|
||||
|
||||
Returns:
|
||||
The created asyncio Task object
|
||||
"""
|
||||
import traceback
|
||||
|
||||
task = asyncio.create_task(coro)
|
||||
|
||||
def callback(t):
|
||||
try:
|
||||
t.result() # this re-raises exceptions from the task
|
||||
except Exception as e:
|
||||
task_desc = f"Background task {task_name}" if task_name else "Background task"
|
||||
logger.error(f"{task_desc} failed: {str(e)}\n{traceback.format_exc()}")
|
||||
|
||||
if error_callback:
|
||||
try:
|
||||
error_callback(e)
|
||||
except Exception as callback_error:
|
||||
logger.error(f"Error callback failed: {callback_error}")
|
||||
|
||||
task.add_done_callback(callback)
|
||||
return task
|
||||
|
||||
@@ -1,3 +1,4 @@
|
||||
import asyncio
|
||||
import uuid
|
||||
from datetime import datetime, timezone
|
||||
|
||||
@@ -1739,6 +1740,193 @@ class TestTurbopufferMessagesIntegration:
|
||||
# Clean up remaining message (use strict_mode=False since turbopuffer might be mocked)
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False)
|
||||
|
||||
async def wait_for_embedding(
|
||||
self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5
|
||||
) -> bool:
|
||||
"""Poll Turbopuffer directly to check if a message has been embedded.
|
||||
|
||||
Args:
|
||||
agent_id: Agent ID for the message
|
||||
message_id: ID of the message to find
|
||||
organization_id: Organization ID
|
||||
max_wait: Maximum time to wait in seconds
|
||||
poll_interval: Time between polls in seconds
|
||||
|
||||
Returns:
|
||||
True if message was found in Turbopuffer within timeout, False otherwise
|
||||
"""
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
|
||||
client = TurbopufferClient()
|
||||
start_time = asyncio.get_event_loop().time()
|
||||
|
||||
while asyncio.get_event_loop().time() - start_time < max_wait:
|
||||
try:
|
||||
# Query Turbopuffer directly using timestamp mode to get all messages
|
||||
results = await client.query_messages(
|
||||
agent_id=agent_id,
|
||||
organization_id=organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=100, # Get more messages to ensure we find it
|
||||
)
|
||||
|
||||
# Check if our message ID is in the results
|
||||
if any(msg["id"] == message_id for msg, _, _ in results):
|
||||
return True
|
||||
|
||||
except Exception as e:
|
||||
# Log but don't fail - Turbopuffer might still be processing
|
||||
pass
|
||||
|
||||
await asyncio.sleep(poll_interval)
|
||||
|
||||
return False
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_creation_background_mode(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that messages are embedded in background when strict_mode=False"""
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create message in background mode
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Background test message about Python programming")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Message should be in PostgreSQL immediately
|
||||
sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user)
|
||||
assert sql_message is not None
|
||||
assert sql_message.id == message_id
|
||||
|
||||
# Poll for embedding completion by querying Turbopuffer directly
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
)
|
||||
assert embedded, "Message was not embedded in Turbopuffer within timeout"
|
||||
|
||||
# Now verify it's also searchable through the search API
|
||||
search_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="Python programming",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert len(search_results) > 0
|
||||
assert any(msg.id == message_id for msg, _ in search_results)
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_update_background_mode(self, server, default_user, sarah_agent, enable_message_embedding):
|
||||
"""Test that message updates work in background mode"""
|
||||
from letta.schemas.message import MessageUpdate
|
||||
|
||||
embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai")
|
||||
|
||||
# Create initial message with strict_mode=True to ensure it's embedded
|
||||
messages = await server.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[
|
||||
PydanticMessage(
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Original content about databases")],
|
||||
agent_id=sarah_agent.id,
|
||||
)
|
||||
],
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=True, # Ensure initial embedding
|
||||
)
|
||||
|
||||
assert len(messages) == 1
|
||||
message_id = messages[0].id
|
||||
|
||||
# Verify initial content is searchable
|
||||
initial_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in initial_results)
|
||||
|
||||
# Update message in background mode
|
||||
updated_message = await server.message_manager.update_message_by_id_async(
|
||||
message_id=message_id,
|
||||
message_update=MessageUpdate(content="Updated content about machine learning"),
|
||||
actor=default_user,
|
||||
embedding_config=embedding_config,
|
||||
strict_mode=False, # Background mode
|
||||
)
|
||||
|
||||
assert updated_message.id == message_id
|
||||
|
||||
# PostgreSQL should be updated immediately
|
||||
sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user)
|
||||
assert "machine learning" in server.message_manager._extract_message_text(sql_message)
|
||||
|
||||
# Wait a bit for the background update to process
|
||||
await asyncio.sleep(1.0)
|
||||
|
||||
# Poll for the update to be reflected in Turbopuffer
|
||||
# We check by searching for the new content
|
||||
embedded = await self.wait_for_embedding(
|
||||
agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5
|
||||
)
|
||||
assert embedded, "Updated message was not re-embedded within timeout"
|
||||
|
||||
# Now verify the new content is searchable
|
||||
new_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="machine learning",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
assert any(msg.id == message_id for msg, _ in new_results)
|
||||
|
||||
# Old content should eventually no longer be searchable
|
||||
# (may take a moment for the delete to process)
|
||||
await asyncio.sleep(2.0)
|
||||
old_results = await server.message_manager.search_messages_async(
|
||||
agent_id=sarah_agent.id,
|
||||
actor=default_user,
|
||||
query_text="databases",
|
||||
search_mode="fts",
|
||||
limit=10,
|
||||
embedding_config=embedding_config,
|
||||
)
|
||||
# The message shouldn't match the old search term anymore
|
||||
if len(old_results) > 0:
|
||||
# If we find results, verify our message doesn't contain the old content
|
||||
for msg, _ in old_results:
|
||||
if msg.id == message_id:
|
||||
text = server.message_manager._extract_message_text(msg)
|
||||
assert "databases" not in text.lower()
|
||||
|
||||
# Clean up
|
||||
await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True)
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):
|
||||
|
||||
Reference in New Issue
Block a user