feat: Add project_id to message filtering/inserting to Turbopuffer [LET-4252] (#4466)
* Add project_id * Fern autogen
This commit is contained in:
@@ -498,6 +498,7 @@ class LettaAgent(BaseAgent):
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -828,6 +829,7 @@ class LettaAgent(BaseAgent):
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -1269,6 +1271,7 @@ class LettaAgent(BaseAgent):
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -1676,7 +1679,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
@@ -1787,7 +1790,7 @@ class LettaAgent(BaseAgent):
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id
|
||||
messages_to_persist, actor=self.actor, project_id=agent_state.project_id, template_id=agent_state.template_id
|
||||
)
|
||||
|
||||
if run_id:
|
||||
|
||||
@@ -220,6 +220,7 @@ class TurbopufferClient:
|
||||
roles: List[MessageRole],
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert messages into Turbopuffer.
|
||||
|
||||
@@ -232,6 +233,7 @@ class TurbopufferClient:
|
||||
roles: List of message roles corresponding to each message
|
||||
created_ats: List of creation timestamps for each message
|
||||
project_id: Optional project ID for all messages
|
||||
template_id: Optional template ID for all messages
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
@@ -262,6 +264,7 @@ class TurbopufferClient:
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
project_ids = []
|
||||
template_ids = []
|
||||
|
||||
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
|
||||
message_id = message_ids[idx]
|
||||
@@ -283,6 +286,7 @@ class TurbopufferClient:
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
project_ids.append(project_id)
|
||||
template_ids.append(template_id)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
@@ -299,6 +303,10 @@ class TurbopufferClient:
|
||||
if project_id is not None:
|
||||
upsert_columns["project_id"] = project_ids
|
||||
|
||||
# only include template_id if it's provided
|
||||
if template_id is not None:
|
||||
upsert_columns["template_id"] = template_ids
|
||||
|
||||
try:
|
||||
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -573,6 +581,7 @@ class TurbopufferClient:
|
||||
top_k: int = 10,
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -589,6 +598,7 @@ class TurbopufferClient:
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -644,12 +654,19 @@ class TurbopufferClient:
|
||||
if project_id:
|
||||
project_filter = ("project_id", "Eq", project_id)
|
||||
|
||||
# build template_id filter if provided
|
||||
template_filter = None
|
||||
if template_id:
|
||||
template_filter = ("template_id", "Eq", template_id)
|
||||
|
||||
# combine all filters
|
||||
all_filters = [agent_filter] # always include agent_id filter
|
||||
if role_filter:
|
||||
all_filters.append(role_filter)
|
||||
if project_filter:
|
||||
all_filters.append(project_filter)
|
||||
if template_filter:
|
||||
all_filters.append(template_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
@@ -717,6 +734,7 @@ class TurbopufferClient:
|
||||
top_k: int = 10,
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -732,6 +750,7 @@ class TurbopufferClient:
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -766,6 +785,10 @@ class TurbopufferClient:
|
||||
if project_id:
|
||||
all_filters.append(("project_id", "Eq", project_id))
|
||||
|
||||
# template filter
|
||||
if template_id:
|
||||
all_filters.append(("template_id", "Eq", template_id))
|
||||
|
||||
# date filters
|
||||
if start_date:
|
||||
all_filters.append(("created_at", "Gte", start_date))
|
||||
|
||||
@@ -1196,6 +1196,7 @@ class MessageSearchRequest(BaseModel):
|
||||
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
|
||||
roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role")
|
||||
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
|
||||
template_id: Optional[str] = Field(None, description="Filter messages by template ID")
|
||||
limit: int = Field(50, description="Maximum number of results to return", ge=1, le=100)
|
||||
start_date: Optional[datetime] = Field(None, description="Filter messages created after this date")
|
||||
end_date: Optional[datetime] = Field(None, description="Filter messages created on or before this date")
|
||||
|
||||
@@ -1524,7 +1524,7 @@ async def search_messages(
|
||||
actor_id: str | None = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Search messages across the entire organization with optional project filtering. Returns messages with FTS/vector ranks and total RRF score.
|
||||
Search messages across the entire organization with optional project and template filtering. Returns messages with FTS/vector ranks and total RRF score.
|
||||
|
||||
This is a cloud-only feature.
|
||||
"""
|
||||
@@ -1543,6 +1543,7 @@ async def search_messages(
|
||||
search_mode=request.search_mode,
|
||||
roles=request.roles,
|
||||
project_id=request.project_id,
|
||||
template_id=request.template_id,
|
||||
limit=request.limit,
|
||||
start_date=request.start_date,
|
||||
end_date=request.end_date,
|
||||
|
||||
@@ -719,7 +719,9 @@ class AgentManager:
|
||||
|
||||
# Only create messages if we initialized with messages
|
||||
if not _init_with_no_messages:
|
||||
await self.message_manager.create_many_messages_async(pydantic_msgs=init_messages, actor=actor, project_id=result.project_id)
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=init_messages, actor=actor, project_id=result.project_id, template_id=result.template_id
|
||||
)
|
||||
return result
|
||||
|
||||
@enforce_types
|
||||
@@ -1886,7 +1888,9 @@ class AgentManager:
|
||||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||||
) -> PydanticAgentState:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, project_id=agent.project_id)
|
||||
messages = await self.message_manager.create_many_messages_async(
|
||||
messages, actor=actor, project_id=agent.project_id, template_id=agent.template_id
|
||||
)
|
||||
message_ids = agent.message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@@ -679,6 +679,7 @@ class AgentSerializationManager:
|
||||
pydantic_msgs=messages,
|
||||
actor=actor,
|
||||
project_id=created_agent.project_id,
|
||||
template_id=created_agent.template_id,
|
||||
)
|
||||
imported_count += len(created_messages)
|
||||
|
||||
|
||||
@@ -315,6 +315,7 @@ class MessageManager:
|
||||
actor: PydanticUser,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Create multiple messages in a single database transaction asynchronously.
|
||||
@@ -324,6 +325,7 @@ class MessageManager:
|
||||
actor: User performing the action
|
||||
strict_mode: If True, wait for embedding to complete; if False, run in background
|
||||
project_id: Optional project ID for the messages (for Turbopuffer indexing)
|
||||
template_id: Optional template ID for the messages (for Turbopuffer indexing)
|
||||
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
@@ -371,11 +373,11 @@ class MessageManager:
|
||||
if agent_id:
|
||||
if strict_mode:
|
||||
# wait for embedding to complete
|
||||
await self._embed_messages_background(result, actor, agent_id, project_id)
|
||||
await self._embed_messages_background(result, actor, agent_id, project_id, template_id)
|
||||
else:
|
||||
# fire and forget - run embedding in background
|
||||
fire_and_forget(
|
||||
self._embed_messages_background(result, actor, agent_id, project_id),
|
||||
self._embed_messages_background(result, actor, agent_id, project_id, template_id),
|
||||
task_name=f"embed_messages_for_agent_{agent_id}",
|
||||
)
|
||||
|
||||
@@ -387,6 +389,7 @@ class MessageManager:
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Background task to embed and store messages in Turbopuffer.
|
||||
|
||||
@@ -395,6 +398,7 @@ class MessageManager:
|
||||
actor: User performing the action
|
||||
agent_id: Agent ID for the messages
|
||||
project_id: Optional project ID for the messages
|
||||
template_id: Optional template ID for the messages
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
@@ -428,6 +432,7 @@ class MessageManager:
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
@@ -539,6 +544,7 @@ class MessageManager:
|
||||
actor: PydanticUser,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
@@ -550,6 +556,7 @@ class MessageManager:
|
||||
actor: User performing the action
|
||||
strict_mode: If True, wait for embedding update to complete; if False, run in background
|
||||
project_id: Optional project ID for the message (for Turbopuffer indexing)
|
||||
template_id: Optional template ID for the message (for Turbopuffer indexing)
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch existing message from database
|
||||
@@ -575,18 +582,18 @@ class MessageManager:
|
||||
if text:
|
||||
if strict_mode:
|
||||
# wait for embedding update to complete
|
||||
await self._update_message_embedding_background(pydantic_message, text, actor, project_id)
|
||||
await self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id)
|
||||
else:
|
||||
# fire and forget - run embedding update in background
|
||||
fire_and_forget(
|
||||
self._update_message_embedding_background(pydantic_message, text, actor, project_id),
|
||||
self._update_message_embedding_background(pydantic_message, text, actor, project_id, template_id),
|
||||
task_name=f"update_message_embedding_{message_id}",
|
||||
)
|
||||
|
||||
return pydantic_message
|
||||
|
||||
async def _update_message_embedding_background(
|
||||
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None
|
||||
self, message: PydanticMessage, text: str, actor: PydanticUser, project_id: Optional[str] = None, template_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Background task to update a message's embedding in Turbopuffer.
|
||||
|
||||
@@ -595,6 +602,7 @@ class MessageManager:
|
||||
text: Extracted text content from the message
|
||||
actor: User performing the action
|
||||
project_id: Optional project ID for the message
|
||||
template_id: Optional template ID for the message
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
@@ -614,6 +622,7 @@ class MessageManager:
|
||||
roles=[message.role],
|
||||
created_ats=[message.created_at],
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
)
|
||||
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
@@ -1097,6 +1106,8 @@ class MessageManager:
|
||||
query_text: Optional[str] = None,
|
||||
search_mode: str = "hybrid",
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
@@ -1110,6 +1121,8 @@ class MessageManager:
|
||||
query_text: Text query (used for embedding in vector/hybrid modes, and FTS in fts/hybrid modes)
|
||||
search_mode: "vector", "fts", "hybrid", or "timestamp" (default: "hybrid")
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created on or before this date (inclusive)
|
||||
@@ -1132,6 +1145,8 @@ class MessageManager:
|
||||
search_mode=search_mode,
|
||||
top_k=limit,
|
||||
roles=roles,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
@@ -1211,6 +1226,7 @@ class MessageManager:
|
||||
search_mode: str = "hybrid",
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
template_id: Optional[str] = None,
|
||||
limit: int = 50,
|
||||
start_date: Optional[datetime] = None,
|
||||
end_date: Optional[datetime] = None,
|
||||
@@ -1224,6 +1240,7 @@ class MessageManager:
|
||||
search_mode: "vector", "fts", or "hybrid" (default: "hybrid")
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
template_id: Optional template ID to filter messages by
|
||||
limit: Maximum number of results to return
|
||||
start_date: Optional filter for messages created after this date
|
||||
end_date: Optional filter for messages created on or before this date (inclusive)
|
||||
@@ -1251,6 +1268,7 @@ class MessageManager:
|
||||
top_k=limit,
|
||||
roles=roles,
|
||||
project_id=project_id,
|
||||
template_id=template_id,
|
||||
start_date=start_date,
|
||||
end_date=end_date,
|
||||
)
|
||||
|
||||
@@ -196,6 +196,7 @@ class Summarizer:
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.actor,
|
||||
project_id=agent_state.project_id,
|
||||
template_id=agent_state.template_id,
|
||||
)
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
|
||||
File diff suppressed because it is too large
Load Diff
Reference in New Issue
Block a user