From f305d3bfac81429c181a7a5d2262721452cea17a Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 4 Sep 2025 15:05:35 -0700 Subject: [PATCH] feat: Move message embedding to background task [LET-4189] (#4430) * Test background message embedding * Change to content --- letta/services/message_manager.py | 218 ++++++++++++++++---------- letta/utils.py | 35 ++++- tests/integration_test_turbopuffer.py | 188 ++++++++++++++++++++++ 3 files changed, 355 insertions(+), 86 deletions(-) diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 774eac69..57cf7cc5 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -21,7 +21,7 @@ from letta.server.db import db_registry from letta.services.file_manager import FileManager from letta.services.helpers.agent_manager_helper import validate_agent_exists_async from letta.settings import DatabaseChoice, settings -from letta.utils import enforce_types +from letta.utils import enforce_types, fire_and_forget logger = get_logger(__name__) @@ -101,7 +101,7 @@ class MessageManager: args = json.loads(tool_call.function.arguments) actual_message = args.get(DEFAULT_MESSAGE_TOOL_KWARG, "") - return json.dumps({"thinking": content_str, "message": actual_message}) + return json.dumps({"thinking": content_str, "content": actual_message}) except (json.JSONDecodeError, KeyError): # fallback if parsing fails pass @@ -324,6 +324,7 @@ class MessageManager: pydantic_msgs: List of Pydantic message models to create actor: User performing the action embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer + strict_mode: If True, wait for embedding to complete; if False, run in background Returns: List of created Pydantic message models @@ -363,59 +364,80 @@ class MessageManager: await session.commit() # embed messages in turbopuffer if enabled and embedding_config provided - from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + from letta.helpers.tpuf_client import should_use_tpuf_for_messages if should_use_tpuf_for_messages() and embedding_config and result: - try: - # extract agent_id from the first message (all should have same agent_id) - agent_id = result[0].agent_id - if agent_id: - # extract text content from each message - message_texts = [] - message_ids = [] - roles = [] - created_ats = [] - # combine assistant+tool messages before embedding - combined_messages = self._combine_assistant_tool_messages(result) - - for msg in combined_messages: - text = self._extract_message_text(msg).strip() - if text: # only embed messages with text content (role filtering is handled in _extract_message_text) - message_texts.append(text) - message_ids.append(msg.id) - roles.append(msg.role) - created_ats.append(msg.created_at) - - if message_texts: - # generate embeddings using provided config - from letta.llm_api.llm_client import LLMClient - - embedding_client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings(message_texts, embedding_config) - - # insert to turbopuffer - tpuf_client = TurbopufferClient() - await tpuf_client.insert_messages( - agent_id=agent_id, - message_texts=message_texts, - embeddings=embeddings, - message_ids=message_ids, - organization_id=actor.organization_id, - roles=roles, - created_ats=created_ats, - ) - logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") - except Exception as e: - logger.error(f"Failed to embed messages in Turbopuffer: {e}") - + # extract agent_id from the first message (all should have same agent_id) + agent_id = result[0].agent_id + if agent_id: if strict_mode: - raise # Re-raise the exception in strict mode + # wait for embedding to complete + await self._embed_messages_background(result, embedding_config, actor, agent_id) + else: + # fire and forget - run embedding in background + fire_and_forget( + self._embed_messages_background(result, embedding_config, actor, agent_id), + task_name=f"embed_messages_for_agent_{agent_id}", + ) return result + async def _embed_messages_background( + self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str + ) -> None: + """Background task to embed and store messages in Turbopuffer. + + Args: + messages: List of messages to embed + embedding_config: Embedding configuration + actor: User performing the action + agent_id: Agent ID for the messages + """ + try: + from letta.helpers.tpuf_client import TurbopufferClient + from letta.llm_api.llm_client import LLMClient + + # extract text content from each message + message_texts = [] + message_ids = [] + roles = [] + created_ats = [] + + # combine assistant+tool messages before embedding + combined_messages = self._combine_assistant_tool_messages(messages) + + for msg in combined_messages: + text = self._extract_message_text(msg).strip() + if text: # only embed messages with text content (role filtering is handled in _extract_message_text) + message_texts.append(text) + message_ids.append(msg.id) + roles.append(msg.role) + created_ats.append(msg.created_at) + + if message_texts: + # generate embeddings using provided config + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings(message_texts, embedding_config) + + # insert to turbopuffer + tpuf_client = TurbopufferClient() + await tpuf_client.insert_messages( + agent_id=agent_id, + message_texts=message_texts, + embeddings=embeddings, + message_ids=message_ids, + organization_id=actor.organization_id, + roles=roles, + created_ats=created_ats, + ) + logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") + except Exception as e: + logger.error(f"Failed to embed messages in Turbopuffer for agent {agent_id}: {e}") + # don't re-raise the exception in background mode - just log it + @enforce_types @trace_method def update_message_by_letta_message( @@ -525,6 +547,13 @@ class MessageManager: """ Updates an existing record in the database with values from the provided record object. Async version of the function above. + + Args: + message_id: ID of the message to update + message_update: Update data for the message + actor: User performing the action + embedding_config: Optional embedding configuration for Turbopuffer + strict_mode: If True, wait for embedding update to complete; if False, run in background """ async with db_registry.async_session() as session: # Fetch existing message from database @@ -540,49 +569,68 @@ class MessageManager: await session.commit() # update message in turbopuffer if enabled (delete and re-insert) - from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages + from letta.helpers.tpuf_client import should_use_tpuf_for_messages if should_use_tpuf_for_messages() and embedding_config and pydantic_message.agent_id: - try: - # extract text content from updated message - text = self._extract_message_text(pydantic_message) + # extract text content from updated message + text = self._extract_message_text(pydantic_message) - # only update in turbopuffer if there's text content (role filtering is handled in _extract_message_text) - if text: - tpuf_client = TurbopufferClient() - - # delete old message from turbopuffer - await tpuf_client.delete_messages( - agent_id=pydantic_message.agent_id, organization_id=actor.organization_id, message_ids=[message_id] - ) - - # generate new embedding - from letta.llm_api.llm_client import LLMClient - - embedding_client = LLMClient.create( - provider_type=embedding_config.embedding_endpoint_type, - actor=actor, - ) - embeddings = await embedding_client.request_embeddings([text], embedding_config) - - # re-insert with updated content - await tpuf_client.insert_messages( - agent_id=pydantic_message.agent_id, - message_texts=[text], - embeddings=embeddings, - message_ids=[message_id], - organization_id=actor.organization_id, - roles=[pydantic_message.role], - created_ats=[pydantic_message.created_at], - ) - logger.info(f"Successfully updated message {message_id} in Turbopuffer") - except Exception as e: - logger.error(f"Failed to update message in Turbopuffer: {e}") + # only update in turbopuffer if there's text content + if text: if strict_mode: - raise # Re-raise the exception in strict mode + # wait for embedding update to complete + await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor) + else: + # fire and forget - run embedding update in background + fire_and_forget( + self._update_message_embedding_background(pydantic_message, text, embedding_config, actor), + task_name=f"update_message_embedding_{message_id}", + ) return pydantic_message + async def _update_message_embedding_background( + self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser + ) -> None: + """Background task to update a message's embedding in Turbopuffer. + + Args: + message: The updated message + text: Extracted text content from the message + embedding_config: Embedding configuration + actor: User performing the action + """ + try: + from letta.helpers.tpuf_client import TurbopufferClient + from letta.llm_api.llm_client import LLMClient + + tpuf_client = TurbopufferClient() + + # delete old message from turbopuffer + await tpuf_client.delete_messages(agent_id=message.agent_id, organization_id=actor.organization_id, message_ids=[message.id]) + + # generate new embedding + embedding_client = LLMClient.create( + provider_type=embedding_config.embedding_endpoint_type, + actor=actor, + ) + embeddings = await embedding_client.request_embeddings([text], embedding_config) + + # re-insert with updated content + await tpuf_client.insert_messages( + agent_id=message.agent_id, + message_texts=[text], + embeddings=embeddings, + message_ids=[message.id], + organization_id=actor.organization_id, + roles=[message.role], + created_ats=[message.created_at], + ) + logger.info(f"Successfully updated message {message.id} in Turbopuffer") + except Exception as e: + logger.error(f"Failed to update message {message.id} in Turbopuffer: {e}") + # don't re-raise the exception in background mode - just log it + def _update_message_by_id_impl( self, message_id: str, message_update: MessageUpdate, actor: PydanticUser, message: MessageModel ) -> MessageModel: diff --git a/letta/utils.py b/letta/utils.py index d5aafc24..581b469e 100644 --- a/letta/utils.py +++ b/letta/utils.py @@ -17,7 +17,7 @@ from contextlib import contextmanager from datetime import datetime, timezone from functools import wraps from logging import Logger -from typing import Any, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints +from typing import Any, Callable, Coroutine, Optional, Union, _GenericAlias, get_args, get_origin, get_type_hints from urllib.parse import urljoin, urlparse import demjson3 as demjson @@ -1271,3 +1271,36 @@ def truncate_file_visible_content(visible_content: str, is_open: bool, per_file_ visible_content += truncated_warning return visible_content + + +def fire_and_forget(coro, task_name: Optional[str] = None, error_callback: Optional[Callable[[Exception], None]] = None) -> asyncio.Task: + """ + Execute an async coroutine in the background without waiting for completion. + + Args: + coro: The coroutine to execute + task_name: Optional name for logging purposes + error_callback: Optional callback to execute if the task fails + + Returns: + The created asyncio Task object + """ + import traceback + + task = asyncio.create_task(coro) + + def callback(t): + try: + t.result() # this re-raises exceptions from the task + except Exception as e: + task_desc = f"Background task {task_name}" if task_name else "Background task" + logger.error(f"{task_desc} failed: {str(e)}\n{traceback.format_exc()}") + + if error_callback: + try: + error_callback(e) + except Exception as callback_error: + logger.error(f"Error callback failed: {callback_error}") + + task.add_done_callback(callback) + return task diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 62951564..3dc59b16 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -1,3 +1,4 @@ +import asyncio import uuid from datetime import datetime, timezone @@ -1739,6 +1740,193 @@ class TestTurbopufferMessagesIntegration: # Clean up remaining message (use strict_mode=False since turbopuffer might be mocked) await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=False) + async def wait_for_embedding( + self, agent_id: str, message_id: str, organization_id: str, max_wait: float = 10.0, poll_interval: float = 0.5 + ) -> bool: + """Poll Turbopuffer directly to check if a message has been embedded. + + Args: + agent_id: Agent ID for the message + message_id: ID of the message to find + organization_id: Organization ID + max_wait: Maximum time to wait in seconds + poll_interval: Time between polls in seconds + + Returns: + True if message was found in Turbopuffer within timeout, False otherwise + """ + import asyncio + + from letta.helpers.tpuf_client import TurbopufferClient + + client = TurbopufferClient() + start_time = asyncio.get_event_loop().time() + + while asyncio.get_event_loop().time() - start_time < max_wait: + try: + # Query Turbopuffer directly using timestamp mode to get all messages + results = await client.query_messages( + agent_id=agent_id, + organization_id=organization_id, + search_mode="timestamp", + top_k=100, # Get more messages to ensure we find it + ) + + # Check if our message ID is in the results + if any(msg["id"] == message_id for msg, _, _ in results): + return True + + except Exception as e: + # Log but don't fail - Turbopuffer might still be processing + pass + + await asyncio.sleep(poll_interval) + + return False + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_creation_background_mode(self, server, default_user, sarah_agent, enable_message_embedding): + """Test that messages are embedded in background when strict_mode=False""" + embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai") + + # Create message in background mode + messages = await server.message_manager.create_many_messages_async( + pydantic_msgs=[ + PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Background test message about Python programming")], + agent_id=sarah_agent.id, + ) + ], + actor=default_user, + embedding_config=embedding_config, + strict_mode=False, # Background mode + ) + + assert len(messages) == 1 + message_id = messages[0].id + + # Message should be in PostgreSQL immediately + sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user) + assert sql_message is not None + assert sql_message.id == message_id + + # Poll for embedding completion by querying Turbopuffer directly + embedded = await self.wait_for_embedding( + agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5 + ) + assert embedded, "Message was not embedded in Turbopuffer within timeout" + + # Now verify it's also searchable through the search API + search_results = await server.message_manager.search_messages_async( + agent_id=sarah_agent.id, + actor=default_user, + query_text="Python programming", + search_mode="fts", + limit=10, + embedding_config=embedding_config, + ) + assert len(search_results) > 0 + assert any(msg.id == message_id for msg, _ in search_results) + + # Clean up + await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True) + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_update_background_mode(self, server, default_user, sarah_agent, enable_message_embedding): + """Test that message updates work in background mode""" + from letta.schemas.message import MessageUpdate + + embedding_config = sarah_agent.embedding_config or EmbeddingConfig.default_config(provider="openai") + + # Create initial message with strict_mode=True to ensure it's embedded + messages = await server.message_manager.create_many_messages_async( + pydantic_msgs=[ + PydanticMessage( + role=MessageRole.user, + content=[TextContent(text="Original content about databases")], + agent_id=sarah_agent.id, + ) + ], + actor=default_user, + embedding_config=embedding_config, + strict_mode=True, # Ensure initial embedding + ) + + assert len(messages) == 1 + message_id = messages[0].id + + # Verify initial content is searchable + initial_results = await server.message_manager.search_messages_async( + agent_id=sarah_agent.id, + actor=default_user, + query_text="databases", + search_mode="fts", + limit=10, + embedding_config=embedding_config, + ) + assert any(msg.id == message_id for msg, _ in initial_results) + + # Update message in background mode + updated_message = await server.message_manager.update_message_by_id_async( + message_id=message_id, + message_update=MessageUpdate(content="Updated content about machine learning"), + actor=default_user, + embedding_config=embedding_config, + strict_mode=False, # Background mode + ) + + assert updated_message.id == message_id + + # PostgreSQL should be updated immediately + sql_message = await server.message_manager.get_message_by_id_async(message_id, default_user) + assert "machine learning" in server.message_manager._extract_message_text(sql_message) + + # Wait a bit for the background update to process + await asyncio.sleep(1.0) + + # Poll for the update to be reflected in Turbopuffer + # We check by searching for the new content + embedded = await self.wait_for_embedding( + agent_id=sarah_agent.id, message_id=message_id, organization_id=default_user.organization_id, max_wait=10.0, poll_interval=0.5 + ) + assert embedded, "Updated message was not re-embedded within timeout" + + # Now verify the new content is searchable + new_results = await server.message_manager.search_messages_async( + agent_id=sarah_agent.id, + actor=default_user, + query_text="machine learning", + search_mode="fts", + limit=10, + embedding_config=embedding_config, + ) + assert any(msg.id == message_id for msg, _ in new_results) + + # Old content should eventually no longer be searchable + # (may take a moment for the delete to process) + await asyncio.sleep(2.0) + old_results = await server.message_manager.search_messages_async( + agent_id=sarah_agent.id, + actor=default_user, + query_text="databases", + search_mode="fts", + limit=10, + embedding_config=embedding_config, + ) + # The message shouldn't match the old search term anymore + if len(old_results) > 0: + # If we find results, verify our message doesn't contain the old content + for msg, _ in old_results: + if msg.id == message_id: + text = server.message_manager._extract_message_text(msg) + assert "databases" not in text.lower() + + # Clean up + await server.message_manager.delete_messages_by_ids_async([message_id], default_user, strict_mode=True) + @pytest.mark.asyncio @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") async def test_message_date_filtering_with_real_tpuf(self, enable_message_embedding):