From cbf2e09e13d19c98e8a6b8c7c6db23c903ee57ed Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 4 Sep 2025 16:50:41 -0700 Subject: [PATCH] feat: Add project id to message schema [LET-4166] (#4433) * Add project id * Propogate through update message by id async * Add project id testing --- letta/agents/base_agent.py | 6 +- letta/agents/helpers.py | 1 + letta/agents/letta_agent.py | 19 ++- letta/helpers/tpuf_client.py | 17 +++ letta/services/agent_manager.py | 8 +- letta/services/agent_serialization_manager.py | 7 +- letta/services/message_manager.py | 25 +++- letta/services/summarizer/summarizer.py | 2 + tests/integration_test_turbopuffer.py | 115 ++++++++++++++++++ 9 files changed, 185 insertions(+), 15 deletions(-) diff --git a/letta/agents/base_agent.py b/letta/agents/base_agent.py index 6a03b216..4ada4a94 100644 --- a/letta/agents/base_agent.py +++ b/letta/agents/base_agent.py @@ -175,7 +175,11 @@ class BaseAgent(ABC): # [DB Call] Update Messages new_system_message = await self.message_manager.update_message_by_id_async( - curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor + curr_system_message.id, + message_update=MessageUpdate(content=new_system_message_str), + actor=self.actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) return [new_system_message] + in_context_messages[1:] diff --git a/letta/agents/helpers.py b/letta/agents/helpers.py index f4a58b65..011675eb 100644 --- a/letta/agents/helpers.py +++ b/letta/agents/helpers.py @@ -118,6 +118,7 @@ async def _prepare_in_context_messages_async( create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor), actor=actor, embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) return current_in_context_messages, new_in_context_messages diff --git a/letta/agents/letta_agent.py b/letta/agents/letta_agent.py index 76183c44..2d7baa8f 100644 --- a/letta/agents/letta_agent.py +++ b/letta/agents/letta_agent.py @@ -495,7 +495,10 @@ class LettaAgent(BaseAgent): message.is_err = True message.step_id = effective_step_id await self.message_manager.create_many_messages_async( - initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + initial_messages, + actor=self.actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: @@ -823,7 +826,10 @@ class LettaAgent(BaseAgent): message.is_err = True message.step_id = effective_step_id await self.message_manager.create_many_messages_async( - initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + initial_messages, + actor=self.actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: @@ -1259,7 +1265,10 @@ class LettaAgent(BaseAgent): message.is_err = True message.step_id = effective_step_id await self.message_manager.create_many_messages_async( - initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config + initial_messages, + actor=self.actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) elif step_progression <= StepProgression.LOGGED_TRACE: if stop_reason is None: @@ -1667,7 +1676,7 @@ class LettaAgent(BaseAgent): ) messages_to_persist = (initial_messages or []) + tool_call_messages persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config + messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id ) return persisted_messages, continue_stepping, stop_reason @@ -1779,7 +1788,7 @@ class LettaAgent(BaseAgent): messages_to_persist = (initial_messages or []) + tool_call_messages persisted_messages = await self.message_manager.create_many_messages_async( - messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config + messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id ) if run_id: diff --git a/letta/helpers/tpuf_client.py b/letta/helpers/tpuf_client.py index 0dc1a8be..2920ad62 100644 --- a/letta/helpers/tpuf_client.py +++ b/letta/helpers/tpuf_client.py @@ -182,6 +182,7 @@ class TurbopufferClient: organization_id: str, roles: List[MessageRole], created_ats: List[datetime], + project_id: Optional[str] = None, ) -> bool: """Insert messages into Turbopuffer. @@ -193,6 +194,7 @@ class TurbopufferClient: organization_id: Organization ID for the messages roles: List of message roles corresponding to each message created_ats: List of creation timestamps for each message + project_id: Optional project ID for all messages Returns: True if successful @@ -221,6 +223,7 @@ class TurbopufferClient: agent_ids = [] message_roles = [] created_at_timestamps = [] + project_ids = [] for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)): message_id = message_ids[idx] @@ -241,6 +244,7 @@ class TurbopufferClient: agent_ids.append(agent_id) message_roles.append(role.value) created_at_timestamps.append(timestamp) + project_ids.append(project_id) # build column-based upsert data upsert_columns = { @@ -253,6 +257,10 @@ class TurbopufferClient: "created_at": created_at_timestamps, } + # only include project_id if it's provided + if project_id is not None: + upsert_columns["project_id"] = project_ids + try: # Use AsyncTurbopuffer as a context manager for proper resource cleanup async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client: @@ -520,6 +528,7 @@ class TurbopufferClient: search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp" top_k: int = 10, roles: Optional[List[MessageRole]] = None, + project_id: Optional[str] = None, vector_weight: float = 0.5, fts_weight: float = 0.5, start_date: Optional[datetime] = None, @@ -535,6 +544,7 @@ class TurbopufferClient: search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector") top_k: Number of results to return roles: Optional list of message roles to filter by + project_id: Optional project ID to filter messages by vector_weight: Weight for vector search results in hybrid mode (default: 0.5) fts_weight: Weight for FTS results in hybrid mode (default: 0.5) start_date: Optional datetime to filter messages created after this date @@ -579,10 +589,17 @@ class TurbopufferClient: end_date = end_date + timedelta(days=1) - timedelta(microseconds=1) date_filters.append(("created_at", "Lte", end_date)) + # build project_id filter if provided + project_filter = None + if project_id: + project_filter = ("project_id", "Eq", project_id) + # combine all filters all_filters = [agent_filter] # always include agent_id filter if role_filter: all_filters.append(role_filter) + if project_filter: + all_filters.append(project_filter) if date_filters: all_filters.extend(date_filters) diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index 74d41a09..60272936 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -720,7 +720,7 @@ class AgentManager: # Only create messages if we initialized with messages if not _init_with_no_messages: await self.message_manager.create_many_messages_async( - pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config + pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id ) return result @@ -1834,6 +1834,8 @@ class AgentManager: message_id=curr_system_message.id, message_update=MessageUpdate(**temp_message.model_dump()), actor=actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) else: curr_system_message = temp_message @@ -1887,7 +1889,9 @@ class AgentManager: self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser ) -> PydanticAgentState: agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor) - messages = await self.message_manager.create_many_messages_async(messages, actor=actor, embedding_config=agent.embedding_config) + messages = await self.message_manager.create_many_messages_async( + messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id + ) message_ids = agent.message_ids or [] message_ids += [m.id for m in messages] return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor) diff --git a/letta/services/agent_serialization_manager.py b/letta/services/agent_serialization_manager.py index a0cca9b8..0bdcf5c6 100644 --- a/letta/services/agent_serialization_manager.py +++ b/letta/services/agent_serialization_manager.py @@ -675,7 +675,12 @@ class AgentSerializationManager: # Map file ID to the generated database ID immediately message_file_to_db_ids[message_schema.id] = message_obj.id - created_messages = await self.message_manager.create_many_messages_async(pydantic_msgs=messages, actor=actor) + created_messages = await self.message_manager.create_many_messages_async( + pydantic_msgs=messages, + actor=actor, + embedding_config=created_agent.embedding_config, + project_id=created_agent.project_id, + ) imported_count += len(created_messages) # Remap in_context_message_ids from file IDs to database IDs diff --git a/letta/services/message_manager.py b/letta/services/message_manager.py index 13df051d..c5a977b8 100644 --- a/letta/services/message_manager.py +++ b/letta/services/message_manager.py @@ -316,6 +316,7 @@ class MessageManager: actor: PydanticUser, embedding_config: Optional[EmbeddingConfig] = None, strict_mode: bool = False, + project_id: Optional[str] = None, ) -> List[PydanticMessage]: """ Create multiple messages in a single database transaction asynchronously. @@ -325,6 +326,7 @@ class MessageManager: actor: User performing the action embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer strict_mode: If True, wait for embedding to complete; if False, run in background + project_id: Optional project ID for the messages (for Turbopuffer indexing) Returns: List of created Pydantic message models @@ -372,18 +374,23 @@ class MessageManager: if agent_id: if strict_mode: # wait for embedding to complete - await self._embed_messages_background(result, embedding_config, actor, agent_id) + await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id) else: # fire and forget - run embedding in background fire_and_forget( - self._embed_messages_background(result, embedding_config, actor, agent_id), + self._embed_messages_background(result, embedding_config, actor, agent_id, project_id), task_name=f"embed_messages_for_agent_{agent_id}", ) return result async def _embed_messages_background( - self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str + self, + messages: List[PydanticMessage], + embedding_config: EmbeddingConfig, + actor: PydanticUser, + agent_id: str, + project_id: Optional[str] = None, ) -> None: """Background task to embed and store messages in Turbopuffer. @@ -392,6 +399,7 @@ class MessageManager: embedding_config: Embedding configuration actor: User performing the action agent_id: Agent ID for the messages + project_id: Optional project ID for the messages """ try: from letta.helpers.tpuf_client import TurbopufferClient @@ -432,6 +440,7 @@ class MessageManager: organization_id=actor.organization_id, roles=roles, created_ats=created_ats, + project_id=project_id, ) logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}") except Exception as e: @@ -543,6 +552,7 @@ class MessageManager: actor: PydanticUser, embedding_config: Optional[EmbeddingConfig] = None, strict_mode: bool = False, + project_id: Optional[str] = None, ) -> PydanticMessage: """ Updates an existing record in the database with values from the provided record object. @@ -554,6 +564,7 @@ class MessageManager: actor: User performing the action embedding_config: Optional embedding configuration for Turbopuffer strict_mode: If True, wait for embedding update to complete; if False, run in background + project_id: Optional project ID for the message (for Turbopuffer indexing) """ async with db_registry.async_session() as session: # Fetch existing message from database @@ -579,18 +590,18 @@ class MessageManager: if text: if strict_mode: # wait for embedding update to complete - await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor) + await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id) else: # fire and forget - run embedding update in background fire_and_forget( - self._update_message_embedding_background(pydantic_message, text, embedding_config, actor), + self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id), task_name=f"update_message_embedding_{message_id}", ) return pydantic_message async def _update_message_embedding_background( - self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser + self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None ) -> None: """Background task to update a message's embedding in Turbopuffer. @@ -599,6 +610,7 @@ class MessageManager: text: Extracted text content from the message embedding_config: Embedding configuration actor: User performing the action + project_id: Optional project ID for the message """ try: from letta.helpers.tpuf_client import TurbopufferClient @@ -625,6 +637,7 @@ class MessageManager: organization_id=actor.organization_id, roles=[message.role], created_ats=[message.created_at], + project_id=project_id, ) logger.info(f"Successfully updated message {message.id} in Turbopuffer") except Exception as e: diff --git a/letta/services/summarizer/summarizer.py b/letta/services/summarizer/summarizer.py index 3e4d040a..fdcf0327 100644 --- a/letta/services/summarizer/summarizer.py +++ b/letta/services/summarizer/summarizer.py @@ -195,6 +195,8 @@ class Summarizer: await self.message_manager.create_many_messages_async( pydantic_msgs=[summary_message_obj], actor=self.actor, + embedding_config=agent_state.embedding_config, + project_id=agent_state.project_id, ) updated_in_context_messages = all_in_context_messages[assistant_message_index:] diff --git a/tests/integration_test_turbopuffer.py b/tests/integration_test_turbopuffer.py index 3dc59b16..fe88522a 100644 --- a/tests/integration_test_turbopuffer.py +++ b/tests/integration_test_turbopuffer.py @@ -2097,3 +2097,118 @@ class TestNamespaceTracking: finally: settings.environment = original_env + + @pytest.mark.asyncio + @pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured") + async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding): + """Test that project_id filtering works correctly in query_messages""" + from letta.schemas.letta_message_content import TextContent + + # Create two project IDs + project_a_id = str(uuid.uuid4()) + project_b_id = str(uuid.uuid4()) + + # Create messages with different project IDs + message_a = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Message for project A about Python")], + ) + + message_b = PydanticMessage( + agent_id=sarah_agent.id, + role=MessageRole.user, + content=[TextContent(text="Message for project B about JavaScript")], + ) + + # Insert messages with their respective project IDs + tpuf_client = TurbopufferClient() + + # Generate embeddings + from letta.llm_api.llm_client import LLMClient + + embedding_client = LLMClient.create( + provider_type=sarah_agent.embedding_config.embedding_endpoint_type, + actor=default_user, + ) + embeddings = await embedding_client.request_embeddings( + [message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config + ) + + # Insert message A with project_a_id + await tpuf_client.insert_messages( + agent_id=sarah_agent.id, + message_texts=[message_a.content[0].text], + embeddings=[embeddings[0]], + message_ids=[message_a.id], + organization_id=default_user.organization_id, + roles=[message_a.role], + created_ats=[message_a.created_at], + project_id=project_a_id, + ) + + # Insert message B with project_b_id + await tpuf_client.insert_messages( + agent_id=sarah_agent.id, + message_texts=[message_b.content[0].text], + embeddings=[embeddings[1]], + message_ids=[message_b.id], + organization_id=default_user.organization_id, + roles=[message_b.role], + created_ats=[message_b.created_at], + project_id=project_b_id, + ) + + # Poll for message A with project_a_id filter + max_retries = 10 + for i in range(max_retries): + results_a = await tpuf_client.query_messages( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", # Simple timestamp retrieval + top_k=10, + project_id=project_a_id, + ) + if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id: + break + await asyncio.sleep(0.5) + else: + pytest.fail(f"Message A not found after {max_retries} retries") + + assert "Python" in results_a[0][0]["text"] + + # Poll for message B with project_b_id filter + for i in range(max_retries): + results_b = await tpuf_client.query_messages( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", + top_k=10, + project_id=project_b_id, + ) + if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id: + break + await asyncio.sleep(0.5) + else: + pytest.fail(f"Message B not found after {max_retries} retries") + + assert "JavaScript" in results_b[0][0]["text"] + + # Query without project filter - should find both + results_all = await tpuf_client.query_messages( + agent_id=sarah_agent.id, + organization_id=default_user.organization_id, + search_mode="timestamp", + top_k=10, + project_id=None, # No filter + ) + + assert len(results_all) >= 2 # May have other messages from setup + message_ids = [r[0]["id"] for r in results_all] + assert message_a.id in message_ids + assert message_b.id in message_ids + + # Clean up + await tpuf_client.delete_messages( + agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_a.id, message_b.id] + )