feat: Add project id to message schema [LET-4166] (#4433)
* Add project id * Propogate through update message by id async * Add project id testing
This commit is contained in:
@@ -175,7 +175,11 @@ class BaseAgent(ABC):
|
||||
|
||||
# [DB Call] Update Messages
|
||||
new_system_message = await self.message_manager.update_message_by_id_async(
|
||||
curr_system_message.id, message_update=MessageUpdate(content=new_system_message_str), actor=self.actor
|
||||
curr_system_message.id,
|
||||
message_update=MessageUpdate(content=new_system_message_str),
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
return [new_system_message] + in_context_messages[1:]
|
||||
|
||||
|
||||
@@ -118,6 +118,7 @@ async def _prepare_in_context_messages_async(
|
||||
create_input_messages(input_messages=input_messages, agent_id=agent_state.id, timezone=agent_state.timezone, actor=actor),
|
||||
actor=actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
|
||||
return current_in_context_messages, new_in_context_messages
|
||||
|
||||
@@ -495,7 +495,10 @@ class LettaAgent(BaseAgent):
|
||||
message.is_err = True
|
||||
message.step_id = effective_step_id
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -823,7 +826,10 @@ class LettaAgent(BaseAgent):
|
||||
message.is_err = True
|
||||
message.step_id = effective_step_id
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -1259,7 +1265,10 @@ class LettaAgent(BaseAgent):
|
||||
message.is_err = True
|
||||
message.step_id = effective_step_id
|
||||
await self.message_manager.create_many_messages_async(
|
||||
initial_messages, actor=self.actor, embedding_config=agent_state.embedding_config
|
||||
initial_messages,
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
elif step_progression <= StepProgression.LOGGED_TRACE:
|
||||
if stop_reason is None:
|
||||
@@ -1667,7 +1676,7 @@ class LettaAgent(BaseAgent):
|
||||
)
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
|
||||
)
|
||||
return persisted_messages, continue_stepping, stop_reason
|
||||
|
||||
@@ -1779,7 +1788,7 @@ class LettaAgent(BaseAgent):
|
||||
messages_to_persist = (initial_messages or []) + tool_call_messages
|
||||
|
||||
persisted_messages = await self.message_manager.create_many_messages_async(
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config
|
||||
messages_to_persist, actor=self.actor, embedding_config=agent_state.embedding_config, project_id=agent_state.project_id
|
||||
)
|
||||
|
||||
if run_id:
|
||||
|
||||
@@ -182,6 +182,7 @@ class TurbopufferClient:
|
||||
organization_id: str,
|
||||
roles: List[MessageRole],
|
||||
created_ats: List[datetime],
|
||||
project_id: Optional[str] = None,
|
||||
) -> bool:
|
||||
"""Insert messages into Turbopuffer.
|
||||
|
||||
@@ -193,6 +194,7 @@ class TurbopufferClient:
|
||||
organization_id: Organization ID for the messages
|
||||
roles: List of message roles corresponding to each message
|
||||
created_ats: List of creation timestamps for each message
|
||||
project_id: Optional project ID for all messages
|
||||
|
||||
Returns:
|
||||
True if successful
|
||||
@@ -221,6 +223,7 @@ class TurbopufferClient:
|
||||
agent_ids = []
|
||||
message_roles = []
|
||||
created_at_timestamps = []
|
||||
project_ids = []
|
||||
|
||||
for idx, (text, embedding, role, created_at) in enumerate(zip(message_texts, embeddings, roles, created_ats)):
|
||||
message_id = message_ids[idx]
|
||||
@@ -241,6 +244,7 @@ class TurbopufferClient:
|
||||
agent_ids.append(agent_id)
|
||||
message_roles.append(role.value)
|
||||
created_at_timestamps.append(timestamp)
|
||||
project_ids.append(project_id)
|
||||
|
||||
# build column-based upsert data
|
||||
upsert_columns = {
|
||||
@@ -253,6 +257,10 @@ class TurbopufferClient:
|
||||
"created_at": created_at_timestamps,
|
||||
}
|
||||
|
||||
# only include project_id if it's provided
|
||||
if project_id is not None:
|
||||
upsert_columns["project_id"] = project_ids
|
||||
|
||||
try:
|
||||
# Use AsyncTurbopuffer as a context manager for proper resource cleanup
|
||||
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
|
||||
@@ -520,6 +528,7 @@ class TurbopufferClient:
|
||||
search_mode: str = "vector", # "vector", "fts", "hybrid", "timestamp"
|
||||
top_k: int = 10,
|
||||
roles: Optional[List[MessageRole]] = None,
|
||||
project_id: Optional[str] = None,
|
||||
vector_weight: float = 0.5,
|
||||
fts_weight: float = 0.5,
|
||||
start_date: Optional[datetime] = None,
|
||||
@@ -535,6 +544,7 @@ class TurbopufferClient:
|
||||
search_mode: Search mode - "vector", "fts", "hybrid", or "timestamp" (default: "vector")
|
||||
top_k: Number of results to return
|
||||
roles: Optional list of message roles to filter by
|
||||
project_id: Optional project ID to filter messages by
|
||||
vector_weight: Weight for vector search results in hybrid mode (default: 0.5)
|
||||
fts_weight: Weight for FTS results in hybrid mode (default: 0.5)
|
||||
start_date: Optional datetime to filter messages created after this date
|
||||
@@ -579,10 +589,17 @@ class TurbopufferClient:
|
||||
end_date = end_date + timedelta(days=1) - timedelta(microseconds=1)
|
||||
date_filters.append(("created_at", "Lte", end_date))
|
||||
|
||||
# build project_id filter if provided
|
||||
project_filter = None
|
||||
if project_id:
|
||||
project_filter = ("project_id", "Eq", project_id)
|
||||
|
||||
# combine all filters
|
||||
all_filters = [agent_filter] # always include agent_id filter
|
||||
if role_filter:
|
||||
all_filters.append(role_filter)
|
||||
if project_filter:
|
||||
all_filters.append(project_filter)
|
||||
if date_filters:
|
||||
all_filters.extend(date_filters)
|
||||
|
||||
|
||||
@@ -720,7 +720,7 @@ class AgentManager:
|
||||
# Only create messages if we initialized with messages
|
||||
if not _init_with_no_messages:
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config
|
||||
pydantic_msgs=init_messages, actor=actor, embedding_config=result.embedding_config, project_id=result.project_id
|
||||
)
|
||||
return result
|
||||
|
||||
@@ -1834,6 +1834,8 @@ class AgentManager:
|
||||
message_id=curr_system_message.id,
|
||||
message_update=MessageUpdate(**temp_message.model_dump()),
|
||||
actor=actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
else:
|
||||
curr_system_message = temp_message
|
||||
@@ -1887,7 +1889,9 @@ class AgentManager:
|
||||
self, messages: List[PydanticMessage], agent_id: str, actor: PydanticUser
|
||||
) -> PydanticAgentState:
|
||||
agent = await self.get_agent_by_id_async(agent_id=agent_id, actor=actor)
|
||||
messages = await self.message_manager.create_many_messages_async(messages, actor=actor, embedding_config=agent.embedding_config)
|
||||
messages = await self.message_manager.create_many_messages_async(
|
||||
messages, actor=actor, embedding_config=agent.embedding_config, project_id=agent.project_id
|
||||
)
|
||||
message_ids = agent.message_ids or []
|
||||
message_ids += [m.id for m in messages]
|
||||
return await self.set_in_context_messages_async(agent_id=agent_id, message_ids=message_ids, actor=actor)
|
||||
|
||||
@@ -675,7 +675,12 @@ class AgentSerializationManager:
|
||||
# Map file ID to the generated database ID immediately
|
||||
message_file_to_db_ids[message_schema.id] = message_obj.id
|
||||
|
||||
created_messages = await self.message_manager.create_many_messages_async(pydantic_msgs=messages, actor=actor)
|
||||
created_messages = await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=messages,
|
||||
actor=actor,
|
||||
embedding_config=created_agent.embedding_config,
|
||||
project_id=created_agent.project_id,
|
||||
)
|
||||
imported_count += len(created_messages)
|
||||
|
||||
# Remap in_context_message_ids from file IDs to database IDs
|
||||
|
||||
@@ -316,6 +316,7 @@ class MessageManager:
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
) -> List[PydanticMessage]:
|
||||
"""
|
||||
Create multiple messages in a single database transaction asynchronously.
|
||||
@@ -325,6 +326,7 @@ class MessageManager:
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration to enable message embedding in Turbopuffer
|
||||
strict_mode: If True, wait for embedding to complete; if False, run in background
|
||||
project_id: Optional project ID for the messages (for Turbopuffer indexing)
|
||||
|
||||
Returns:
|
||||
List of created Pydantic message models
|
||||
@@ -372,18 +374,23 @@ class MessageManager:
|
||||
if agent_id:
|
||||
if strict_mode:
|
||||
# wait for embedding to complete
|
||||
await self._embed_messages_background(result, embedding_config, actor, agent_id)
|
||||
await self._embed_messages_background(result, embedding_config, actor, agent_id, project_id)
|
||||
else:
|
||||
# fire and forget - run embedding in background
|
||||
fire_and_forget(
|
||||
self._embed_messages_background(result, embedding_config, actor, agent_id),
|
||||
self._embed_messages_background(result, embedding_config, actor, agent_id, project_id),
|
||||
task_name=f"embed_messages_for_agent_{agent_id}",
|
||||
)
|
||||
|
||||
return result
|
||||
|
||||
async def _embed_messages_background(
|
||||
self, messages: List[PydanticMessage], embedding_config: EmbeddingConfig, actor: PydanticUser, agent_id: str
|
||||
self,
|
||||
messages: List[PydanticMessage],
|
||||
embedding_config: EmbeddingConfig,
|
||||
actor: PydanticUser,
|
||||
agent_id: str,
|
||||
project_id: Optional[str] = None,
|
||||
) -> None:
|
||||
"""Background task to embed and store messages in Turbopuffer.
|
||||
|
||||
@@ -392,6 +399,7 @@ class MessageManager:
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
agent_id: Agent ID for the messages
|
||||
project_id: Optional project ID for the messages
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
@@ -432,6 +440,7 @@ class MessageManager:
|
||||
organization_id=actor.organization_id,
|
||||
roles=roles,
|
||||
created_ats=created_ats,
|
||||
project_id=project_id,
|
||||
)
|
||||
logger.info(f"Successfully embedded {len(message_texts)} messages for agent {agent_id}")
|
||||
except Exception as e:
|
||||
@@ -543,6 +552,7 @@ class MessageManager:
|
||||
actor: PydanticUser,
|
||||
embedding_config: Optional[EmbeddingConfig] = None,
|
||||
strict_mode: bool = False,
|
||||
project_id: Optional[str] = None,
|
||||
) -> PydanticMessage:
|
||||
"""
|
||||
Updates an existing record in the database with values from the provided record object.
|
||||
@@ -554,6 +564,7 @@ class MessageManager:
|
||||
actor: User performing the action
|
||||
embedding_config: Optional embedding configuration for Turbopuffer
|
||||
strict_mode: If True, wait for embedding update to complete; if False, run in background
|
||||
project_id: Optional project ID for the message (for Turbopuffer indexing)
|
||||
"""
|
||||
async with db_registry.async_session() as session:
|
||||
# Fetch existing message from database
|
||||
@@ -579,18 +590,18 @@ class MessageManager:
|
||||
if text:
|
||||
if strict_mode:
|
||||
# wait for embedding update to complete
|
||||
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor)
|
||||
await self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id)
|
||||
else:
|
||||
# fire and forget - run embedding update in background
|
||||
fire_and_forget(
|
||||
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor),
|
||||
self._update_message_embedding_background(pydantic_message, text, embedding_config, actor, project_id),
|
||||
task_name=f"update_message_embedding_{message_id}",
|
||||
)
|
||||
|
||||
return pydantic_message
|
||||
|
||||
async def _update_message_embedding_background(
|
||||
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser
|
||||
self, message: PydanticMessage, text: str, embedding_config: EmbeddingConfig, actor: PydanticUser, project_id: Optional[str] = None
|
||||
) -> None:
|
||||
"""Background task to update a message's embedding in Turbopuffer.
|
||||
|
||||
@@ -599,6 +610,7 @@ class MessageManager:
|
||||
text: Extracted text content from the message
|
||||
embedding_config: Embedding configuration
|
||||
actor: User performing the action
|
||||
project_id: Optional project ID for the message
|
||||
"""
|
||||
try:
|
||||
from letta.helpers.tpuf_client import TurbopufferClient
|
||||
@@ -625,6 +637,7 @@ class MessageManager:
|
||||
organization_id=actor.organization_id,
|
||||
roles=[message.role],
|
||||
created_ats=[message.created_at],
|
||||
project_id=project_id,
|
||||
)
|
||||
logger.info(f"Successfully updated message {message.id} in Turbopuffer")
|
||||
except Exception as e:
|
||||
|
||||
@@ -195,6 +195,8 @@ class Summarizer:
|
||||
await self.message_manager.create_many_messages_async(
|
||||
pydantic_msgs=[summary_message_obj],
|
||||
actor=self.actor,
|
||||
embedding_config=agent_state.embedding_config,
|
||||
project_id=agent_state.project_id,
|
||||
)
|
||||
|
||||
updated_in_context_messages = all_in_context_messages[assistant_message_index:]
|
||||
|
||||
@@ -2097,3 +2097,118 @@ class TestNamespaceTracking:
|
||||
|
||||
finally:
|
||||
settings.environment = original_env
|
||||
|
||||
@pytest.mark.asyncio
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
async def test_message_project_id_filtering(self, server, sarah_agent, default_user, enable_turbopuffer, enable_message_embedding):
|
||||
"""Test that project_id filtering works correctly in query_messages"""
|
||||
from letta.schemas.letta_message_content import TextContent
|
||||
|
||||
# Create two project IDs
|
||||
project_a_id = str(uuid.uuid4())
|
||||
project_b_id = str(uuid.uuid4())
|
||||
|
||||
# Create messages with different project IDs
|
||||
message_a = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message for project A about Python")],
|
||||
)
|
||||
|
||||
message_b = PydanticMessage(
|
||||
agent_id=sarah_agent.id,
|
||||
role=MessageRole.user,
|
||||
content=[TextContent(text="Message for project B about JavaScript")],
|
||||
)
|
||||
|
||||
# Insert messages with their respective project IDs
|
||||
tpuf_client = TurbopufferClient()
|
||||
|
||||
# Generate embeddings
|
||||
from letta.llm_api.llm_client import LLMClient
|
||||
|
||||
embedding_client = LLMClient.create(
|
||||
provider_type=sarah_agent.embedding_config.embedding_endpoint_type,
|
||||
actor=default_user,
|
||||
)
|
||||
embeddings = await embedding_client.request_embeddings(
|
||||
[message_a.content[0].text, message_b.content[0].text], sarah_agent.embedding_config
|
||||
)
|
||||
|
||||
# Insert message A with project_a_id
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_a.content[0].text],
|
||||
embeddings=[embeddings[0]],
|
||||
message_ids=[message_a.id],
|
||||
organization_id=default_user.organization_id,
|
||||
roles=[message_a.role],
|
||||
created_ats=[message_a.created_at],
|
||||
project_id=project_a_id,
|
||||
)
|
||||
|
||||
# Insert message B with project_b_id
|
||||
await tpuf_client.insert_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
message_texts=[message_b.content[0].text],
|
||||
embeddings=[embeddings[1]],
|
||||
message_ids=[message_b.id],
|
||||
organization_id=default_user.organization_id,
|
||||
roles=[message_b.role],
|
||||
created_ats=[message_b.created_at],
|
||||
project_id=project_b_id,
|
||||
)
|
||||
|
||||
# Poll for message A with project_a_id filter
|
||||
max_retries = 10
|
||||
for i in range(max_retries):
|
||||
results_a = await tpuf_client.query_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp", # Simple timestamp retrieval
|
||||
top_k=10,
|
||||
project_id=project_a_id,
|
||||
)
|
||||
if len(results_a) == 1 and results_a[0][0]["id"] == message_a.id:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
pytest.fail(f"Message A not found after {max_retries} retries")
|
||||
|
||||
assert "Python" in results_a[0][0]["text"]
|
||||
|
||||
# Poll for message B with project_b_id filter
|
||||
for i in range(max_retries):
|
||||
results_b = await tpuf_client.query_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=project_b_id,
|
||||
)
|
||||
if len(results_b) == 1 and results_b[0][0]["id"] == message_b.id:
|
||||
break
|
||||
await asyncio.sleep(0.5)
|
||||
else:
|
||||
pytest.fail(f"Message B not found after {max_retries} retries")
|
||||
|
||||
assert "JavaScript" in results_b[0][0]["text"]
|
||||
|
||||
# Query without project filter - should find both
|
||||
results_all = await tpuf_client.query_messages(
|
||||
agent_id=sarah_agent.id,
|
||||
organization_id=default_user.organization_id,
|
||||
search_mode="timestamp",
|
||||
top_k=10,
|
||||
project_id=None, # No filter
|
||||
)
|
||||
|
||||
assert len(results_all) >= 2 # May have other messages from setup
|
||||
message_ids = [r[0]["id"] for r in results_all]
|
||||
assert message_a.id in message_ids
|
||||
assert message_b.id in message_ids
|
||||
|
||||
# Clean up
|
||||
await tpuf_client.delete_messages(
|
||||
agent_id=sarah_agent.id, organization_id=default_user.organization_id, message_ids=[message_a.id, message_b.id]
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user