feat: Add project_id to message filtering/inserting to Turbopuffer [LET-4252] (#4466)

* Add project_id

* Fern autogen
This commit is contained in:
Matthew Zhou
2025-09-08 14:35:15 -07:00
committed by GitHub
parent 74e08f038e
commit acaf820009
9 changed files with 1906 additions and 1850 deletions

View File

@@ -498,6 +498,7 @@ class LettaAgent(BaseAgent):
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -828,6 +829,7 @@ class LettaAgent(BaseAgent):
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -1269,6 +1271,7 @@ class LettaAgent(BaseAgent):
initial_messages,
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
elif step_progression <= StepProgression.LOGGED_TRACE:
if stop_reason is None:
@@ -1676,7 +1679,7 @@ class LettaAgent(BaseAgent):
)
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
)
return persisted_messages, continue_stepping, stop_reason
@@ -1787,7 +1790,7 @@ class LettaAgent(BaseAgent):
messages_to_persist = (initial_messages or []) + tool_call_messages
persisted_messages = await self.message_manager.create_many_messages_async(
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
)
if run_id:

View File

@@ -220,6 +220,7 @@ class TurbopufferClient:
roles: List[MessageRole],
created_ats: List[datetime],
project_id: Optional[str] = None,
template_id: Optional[str] = None,
) -> bool:
"""Insert messages into Turbopuffer.
@@ -232,6 +233,7 @@ class TurbopufferClient:
roles: List of message roles corresponding to each message
created_ats: List of creation timestamps for each message
project_id: Optional project ID for all messages
template_id: Optional template ID for all messages
Returns:
True if successful
@@ -262,6 +264,7 @@ class TurbopufferClient:
message_roles = []
created_at_timestamps = []
project_ids = []
template_ids = []
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
message_id = message_ids[idx]
@@ -283,6 +286,7 @@ class TurbopufferClient:
message_roles.append(role.value)
created_at_timestamps.append(timestamp)
project_ids.append(project_id)
template_ids.append(template_id)
# build column-based upsert data
upsert_columns = {
@@ -299,6 +303,10 @@ class TurbopufferClient:
if project_id is not None:
upsert_columns["project_id"] = project_ids
# only include template_id if it's provided
if template_id is not None:
upsert_columns["template_id"] = template_ids
try:
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -573,6 +581,7 @@ class TurbopufferClient:
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -589,6 +598,7 @@ class TurbopufferClient:
top_k: Number of results to return
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -644,12 +654,19 @@ class TurbopufferClient:
if project_id:
project_filter = ("project_id", "Eq", project_id)
# build template_id filter if provided
template_filter = None
if template_id:
template_filter = ("template_id", "Eq", template_id)
# combine all filters
all_filters = [agent_filter] # always include agent_id filter
if role_filter:
all_filters.append(role_filter)
if project_filter:
all_filters.append(project_filter)
if template_filter:
all_filters.append(template_filter)
if date_filters:
all_filters.extend(date_filters)
@@ -717,6 +734,7 @@ class TurbopufferClient:
top_k: int = 10,
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
vector_weight: float = 0.5,
fts_weight: float = 0.5,
start_date: Optional[datetime] = None,
@@ -732,6 +750,7 @@ class TurbopufferClient:
top_k: Number of results to return
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
start_date: Optional datetime to filter messages created after this date
@@ -766,6 +785,10 @@ class TurbopufferClient:
if project_id:
all_filters.append(("project_id", "Eq", project_id))
# template filter
if template_id:
all_filters.append(("template_id", "Eq", template_id))
# date filters
if start_date:
all_filters.append(("created_at", "Gte", start_date))

View File

@@ -1196,6 +1196,7 @@ class MessageSearchRequest(BaseModel):
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role")
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")

View File

@@ -1524,7 +1524,7 @@ async def search_messages(
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Search messages across the entire organization with optional project filtering. Returns messages with FTS/vector ranks and total RRF score.
Search messages across the entire organization with optional project and template filtering. Returns messages with FTS/vector ranks and total RRF score.
This is a cloud-only feature.
"""
@@ -1543,6 +1543,7 @@ async def search_messages(
search_mode=request.search_mode,
roles=request.roles,
project_id=request.project_id,
template_id=request.template_id,
limit=request.limit,
start_date=request.start_date,
end_date=request.end_date,

View File

@@ -719,7 +719,9 @@ class AgentManager:
# Only create messages if we initialized with messages
if not _init_with_no_messages:
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor, project_id=result.project_id)
await self.message_manager.create_many_messages_async(
pydantic_msgs=init_messages, actor=actor, project_id=result.project_id, template_id=result.template_id
)
return result
@enforce_types
@@ -1886,7 +1888,9 @@ class AgentManager:
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
) -> PydanticAgentState:
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, project_id=agent.project_id)
messages = await self.message_manager.create_many_messages_async(
messages, actor=actor, project_id=agent.project_id, template_id=agent.template_id
)
message_ids = agent.message_ids or []
message_ids += [m.id for m in messages]
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)

View File

@@ -679,6 +679,7 @@ class AgentSerializationManager:
pydantic_msgs=messages,
actor=actor,
project_id=created_agent.project_id,
template_id=created_agent.template_id,
)
imported_count += len(created_messages)

View File

@@ -315,6 +315,7 @@ class MessageManager:
actor: PydanticUser,
strict_mode: bool = False,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
) -> List[PydanticMessage]:
"""
Create multiple messages in a single database transaction asynchronously.
@@ -324,6 +325,7 @@ class MessageManager:
actor: User performing the action
strict_mode: If True, wait for embedding to complete; if False, run in background
project_id: Optional project ID for the messages (for Turbopuffer indexing)
template_id: Optional template ID for the messages (for Turbopuffer indexing)
Returns:
List of created Pydantic message models
@@ -371,11 +373,11 @@ class MessageManager:
if agent_id:
if strict_mode:
# wait for embedding to complete
await self._embed_messages_background(result, actor, agent_id, project_id)
await self._embed_messages_background(result, actor, agent_id, project_id, template_id)
else:
# fire and forget - run embedding in background
fire_and_forget(
self._embed_messages_background(result, actor, agent_id, project_id),
self._embed_messages_background(result, actor, agent_id, project_id, template_id),
task_name=f"embed_messages_for_agent_{agent_id}",
)
@@ -387,6 +389,7 @@ class MessageManager:
actor: PydanticUser,
agent_id: str,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
) -> None:
"""Background task to embed and store messages in Turbopuffer.
@@ -395,6 +398,7 @@ class MessageManager:
actor: User performing the action
agent_id: Agent ID for the messages
project_id: Optional project ID for the messages
template_id: Optional template ID for the messages
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
@@ -428,6 +432,7 @@ class MessageManager:
roles=roles,
created_ats=created_ats,
project_id=project_id,
template_id=template_id,
)
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
except Exception as e:
@@ -539,6 +544,7 @@ class MessageManager:
actor: PydanticUser,
strict_mode: bool = False,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
) -> PydanticMessage:
"""
Updates an existing record in the database with values from the provided record object.
@@ -550,6 +556,7 @@ class MessageManager:
actor: User performing the action
strict_mode: If True, wait for embedding update to complete; if False, run in background
project_id: Optional project ID for the message (for Turbopuffer indexing)
template_id: Optional template ID for the message (for Turbopuffer indexing)
"""
async with db_registry.async_session() as session:
# Fetch existing message from database
@@ -575,18 +582,18 @@ class MessageManager:
if text:
if strict_mode:
# wait for embedding update to complete
await self._update_message_embedding_background(pydantic_message, text, actor, project_id)
await self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id)
else:
# fire and forget - run embedding update in background
fire_and_forget(
self._update_message_embedding_background(pydantic_message, text, actor, project_id),
self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id),
task_name=f"update_message_embedding_{message_id}",
)
return pydantic_message
async def _update_message_embedding_background(
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None, template_id: Optional[str] = None
) -> None:
"""Background task to update a message's embedding in Turbopuffer.
@@ -595,6 +602,7 @@ class MessageManager:
text: Extracted text content from the message
actor: User performing the action
project_id: Optional project ID for the message
template_id: Optional template ID for the message
"""
try:
from letta.helpers.tpuf_client import TurbopufferClient
@@ -614,6 +622,7 @@ class MessageManager:
roles=[message.role],
created_ats=[message.created_at],
project_id=project_id,
template_id=template_id,
)
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
except Exception as e:
@@ -1097,6 +1106,8 @@ class MessageManager:
query_text: Optional[str] = None,
search_mode: str = "hybrid",
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
@@ -1110,6 +1121,8 @@ class MessageManager:
query_text: Text query (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created on or before this date (inclusive)
@@ -1132,6 +1145,8 @@ class MessageManager:
search_mode=search_mode,
top_k=limit,
roles=roles,
project_id=project_id,
template_id=template_id,
start_date=start_date,
end_date=end_date,
)
@@ -1211,6 +1226,7 @@ class MessageManager:
search_mode: str = "hybrid",
roles: Optional[List[MessageRole]] = None,
project_id: Optional[str] = None,
template_id: Optional[str] = None,
limit: int = 50,
start_date: Optional[datetime] = None,
end_date: Optional[datetime] = None,
@@ -1224,6 +1240,7 @@ class MessageManager:
search_mode: "vector", "fts", or "hybrid" (default: "hybrid")
roles: Optional list of message roles to filter by
project_id: Optional project ID to filter messages by
template_id: Optional template ID to filter messages by
limit: Maximum number of results to return
start_date: Optional filter for messages created after this date
end_date: Optional filter for messages created on or before this date (inclusive)
@@ -1251,6 +1268,7 @@ class MessageManager:
top_k=limit,
roles=roles,
project_id=project_id,
template_id=template_id,
start_date=start_date,
end_date=end_date,
)

View File

@@ -196,6 +196,7 @@ class Summarizer:
pydantic_msgs=[summary_message_obj],
actor=self.actor,
project_id=agent_state.project_id,
template_id=agent_state.template_id,
)
updated_in_context_messages = all_in_context_messages[assistant_message_index:]

File diff suppressed because it is too large Load Diff