feat: Add project id to message schema [LET-4166] (#4433)

* Add project id

* Propogate through update message by id async

* Add project id testing
This commit is contained in:
Matthew Zhou
2025-09-04 16:50:41 -07:00
committed by GitHub
parent 415ae5a928
commit cbf2e09e13
9 changed files with 185 additions and 15 deletions

View File

@@ -175,7 +175,11 @@ class BaseAgent(ABC):
# [DB Call] Update Messages
new_system_message = await self.message_manager.update_message_by_id_async(
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
curr_system_message.id,
message_update=MessageUpdate(content=new_system_message_str),
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
return [new_system_message] + in_context_messages[1:]

View File

@@ -118,6 +118,7 @@ async def _prepare_in_context_messages_async(
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
actor=actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
return current_in_context_messages, new_in_context_messages

View File

@@ -495,7 +495,10 @@ class LettaAgent(BaseAgent):
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -823,7 +826,10 @@ class LettaAgent(BaseAgent):
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -1259,7 +1265,10 @@ class LettaAgent(BaseAgent):
message.is_err = True
message.step_id = effective_step_id
await self.message_manager.create_many_messages_async(
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
initial_messages,
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -1667,7 +1676,7 @@ class LettaAgent(BaseAgent):
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
)
return persisted_messages, continue_stepping, stop_reason
@@ -1779,7 +1788,7 @@ class LettaAgent(BaseAgent):
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
)
if run_id:

View File

@@ -182,6 +182,7 @@ class TurbopufferClient:
organization_id: str,
roles: List[MessageRole],
created_ats: List[datetime],
project_id: Optional[str] = None,
) -> bool:
"""Insert messages into Turbopuffer.
@@ -193,6 +194,7 @@ class TurbopufferClient:
organization_id: Organization ID for the messages
roles: List of message roles corresponding to each message
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
Returns:
True if successful
@@ -221,6 +223,7 @@ class TurbopufferClient:
agent_ids = []
message_roles = []
created_at_timestamps = []
project_ids = []
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
message_id = message_ids[idx]
@@ -241,6 +244,7 @@ class TurbopufferClient:
agent_ids.append(agent_id)
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids.append(project_id)
# build column-based upsert data
upsert_columns = {
@@ -253,6 +257,10 @@ class TurbopufferClient:
"created_at": created_at_timestamps,
}
# only include project_id if it's provided
if project_id is not None:
upsert_columns["project_id"] = project_ids
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -520,6 +528,7 @@ class TurbopufferClient:
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -535,6 +544,7 @@ class TurbopufferClient:
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
top_k: Number of results to return
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -579,10 +589,17 @@ class TurbopufferClient:
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
date_filters.append(("created_at", "Lte", end_date))
# build project_id filter if provided
project_filter = None
if project_id:
project_filter = ("project_id", "Eq", project_id)
# combine all filters
all_filters = [agent_filter] # always include agent_id filter
if role_filter:
all_filters.append(role_filter)
if project_filter:
all_filters.append(project_filter)
if date_filters:
all_filters.extend(date_filters)

View File

@@ -720,7 +720,7 @@ class AgentManager:
# Only create messages if we initialized with messages
if not _init_with_no_messages:
await self.message_manager.create_many_messages_async(
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id
)
return result
@@ -1834,6 +1834,8 @@ class AgentManager:
message_id=curr_system_message.id,
message_update=MessageUpdate(**temp_message.model_dump()),
actor=actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
else:
curr_system_message = temp_message
@@ -1887,7 +1889,9 @@ class AgentManager:
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
) -> PydanticAgentState:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, embedding_config=agent.embedding_config)
messages = await self.message_manager.create_many_messages_async(
messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id
)
message_ids = agent.message_ids or []
message_ids += [m.id for m in messages]
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)

View File

@@ -675,7 +675,12 @@ class AgentSerializationManager:
# Map file ID to the generated database ID immediately
message_file_to_db_ids[message_schema.id] = message_obj.id
created_messages = await self.message_manager.create_many_messages_async(pydantic_msgs=messages, actor=actor)
created_messages = await self.message_manager.create_many_messages_async(
pydantic_msgs=messages,
actor=actor,
embedding_config=created_agent.embedding_config,
project_id=created_agent.project_id,
)
imported_count += len(created_messages)
# Remap in_context_message_ids from file IDs to database IDs

View File

@@ -316,6 +316,7 @@ class MessageManager:
actor: PydanticUser,
embedding_config: Optional[EmbeddingConfig] = None,
strict_mode: bool = False,
project_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction asynchronously.
@@ -325,6 +326,7 @@ class MessageManager:
actor: User performing the action
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
strict_mode: If True, wait for embedding to complete; if False, run in background
project_id: Optional project ID for the messages (for Turbopuffer indexing)
Returns:
List of created Pydantic message models
@@ -372,18 +374,23 @@ class MessageManager:
if agent_id:
if strict_mode:
# wait for embedding to complete
await self._embed_messages_background(result, embedding_config, actor, agent_id)
await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id)
else:
# fire and forget - run embedding in background
fire_and_forget(
self._embed_messages_background(result, embedding_config, actor, agent_id),
self._embed_messages_background(result, embedding_config, actor, agent_id, project_id),
task_name=f"embed_messages_for_agent_{agent_id}",
)
return result
async def _embed_messages_background(
self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str
self,
messages: List[PydanticMessage],
embedding_config: EmbeddingConfig,
actor: PydanticUser,
agent_id: str,
project_id: Optional[str] = None,
) -> None:
"""Background task to embed and store messages in Turbopuffer.
@@ -392,6 +399,7 @@ class MessageManager:
embedding_config: Embedding configuration
actor: User performing the action
agent_id: Agent ID for the messages
project_id: Optional project ID for the messages
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
@@ -432,6 +440,7 @@ class MessageManager:
organization_id=actor.organization_id,
roles=roles,
created_ats=created_ats,
project_id=project_id,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
@@ -543,6 +552,7 @@ class MessageManager:
actor: PydanticUser,
embedding_config: Optional[EmbeddingConfig] = None,
strict_mode: bool = False,
project_id: Optional[str] = None,
) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
@@ -554,6 +564,7 @@ class MessageManager:
actor: User performing the action
embedding_config: Optional embedding configuration for Turbopuffer
strict_mode: If True, wait for embedding update to complete; if False, run in background
project_id: Optional project ID for the message (for Turbopuffer indexing)
"""
async with db_registry.async_session() as session:
# Fetch existing message from database
@@ -579,18 +590,18 @@ class MessageManager:
if text:
if strict_mode:
# wait for embedding update to complete
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor)
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id)
else:
# fire and forget - run embedding update in background
fire_and_forget(
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor),
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id),
task_name=f"update_message_embedding_{message_id}",
)
return pydantic_message
async def _update_message_embedding_background(
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None
) -> None:
"""Background task to update a message's embedding in Turbopuffer.
@@ -599,6 +610,7 @@ class MessageManager:
text: Extracted text content from the message
embedding_config: Embedding configuration
actor: User performing the action
project_id: Optional project ID for the message
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
@@ -625,6 +637,7 @@ class MessageManager:
organization_id=actor.organization_id,
roles=[message.role],
created_ats=[message.created_at],
project_id=project_id,
)
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
except Exception as e:

View File

@@ -195,6 +195,8 @@ class Summarizer:
await self.message_manager.create_many_messages_async(
pydantic_msgs=[summary_message_obj],
actor=self.actor,
embedding_config=agent_state.embedding_config,
project_id=agent_state.project_id,
)
updated_in_context_messages = all_in_context_messages[assistant_message_index:]

View File

@@ -2097,3 +2097,118 @@ class TestNamespaceTracking:
finally:
settings.environment = original_env
@pytest.mark.asyncio
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
"""Test that project_id filtering works correctly in query_messages"""
from letta.schemas.letta_message_content import TextContent
# Create two project IDs
project_a_id = str(uuid.uuid4())
project_b_id = str(uuid.uuid4())
# Create messages with different project IDs
message_a = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Message for project A about Python")],
)
message_b = PydanticMessage(
agent_id=sarah_agent.id,
role=MessageRole.user,
content=[TextContent(text="Message for project B about JavaScript")],
)
# Insert messages with their respective project IDs
tpuf_client = TurbopufferClient()
# Generate embeddings
from letta.llm_api.llm_client import LLMClient
embedding_client = LLMClient.create(
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
actor=default_user,
)
embeddings = await embedding_client.request_embeddings(
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
)
# Insert message A with project_a_id
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_a.content[0].text],
embeddings=[embeddings[0]],
message_ids=[message_a.id],
organization_id=default_user.organization_id,
roles=[message_a.role],
created_ats=[message_a.created_at],
project_id=project_a_id,
)
# Insert message B with project_b_id
await tpuf_client.insert_messages(
agent_id=sarah_agent.id,
message_texts=[message_b.content[0].text],
embeddings=[embeddings[1]],
message_ids=[message_b.id],
organization_id=default_user.organization_id,
roles=[message_b.role],
created_ats=[message_b.created_at],
project_id=project_b_id,
)
# Poll for message A with project_a_id filter
max_retries = 10
for i in range(max_retries):
results_a = await tpuf_client.query_messages(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp", # Simple timestamp retrieval
top_k=10,
project_id=project_a_id,
)
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
break
await asyncio.sleep(0.5)
else:
pytest.fail(f"Message A not found after {max_retries} retries")
assert "Python" in results_a[0][0]["text"]
# Poll for message B with project_b_id filter
for i in range(max_retries):
results_b = await tpuf_client.query_messages(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=project_b_id,
)
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
break
await asyncio.sleep(0.5)
else:
pytest.fail(f"Message B not found after {max_retries} retries")
assert "JavaScript" in results_b[0][0]["text"]
# Query without project filter - should find both
results_all = await tpuf_client.query_messages(
agent_id=sarah_agent.id,
organization_id=default_user.organization_id,
search_mode="timestamp",
top_k=10,
project_id=None, # No filter
)
assert len(results_all) >= 2 # May have other messages from setup
message_ids = [r[0]["id"] for r in results_all]
assert message_a.id in message_ids
assert message_b.id in message_ids
# Clean up
await tpuf_client.delete_messages(
agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_a.id, message_b.id]
)