feat: Write tests for search messages [LET-4212] (#4447)

* Adjust naming

* Add testing and improve message search

* Adjust comments

* Change query text to query

* Fern autogen
This commit is contained in:
Matthew Zhou
2025-09-05 17:52:13 -07:00
committed by GitHub
parent cb7296c81d
commit af302e8ac8
10 changed files with 2678 additions and 2659 deletions

1
.gitignore vendored
View File

@@ -6,6 +6,7 @@ openapi_letta.json
openapi_openai.json
CLAUDE.md
AGENTS.md
### Eclipse ###
.metadata

View File

@@ -75,17 +75,22 @@ class TurbopufferClient:
return await self.archive_manager.get_or_set_vector_db_namespace_async(archive_id)
@trace_method
async def _get_message_namespace_name(self, agent_id: str, organization_id: str) -> str:
async def _get_message_namespace_name(self, organization_id: str) -> str:
"""Get namespace name for messages (org-scoped).
Args:
agent_id: Agent ID (stored for future sharding)
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name for messages
"""
return await self.agent_manager.get_or_set_vector_db_namespace_async(agent_id, organization_id)
environment = settings.environment
if environment:
namespace_name = f"messages_{organization_id}_{environment.lower()}"
else:
namespace_name = f"messages_{organization_id}"
return namespace_name
@trace_method
async def insert_archival_memories(
@@ -236,7 +241,7 @@ class TurbopufferClient:
# generate embeddings using the default config
embeddings = await self._generate_embeddings(message_texts, actor)
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
namespace_name = await self._get_message_namespace_name(organization_id)
# validation checks
if not message_ids:
@@ -606,7 +611,7 @@ class TurbopufferClient:
# Fallback to retrieving most recent messages when no search query is provided
search_mode = "timestamp"
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
namespace_name = await self._get_message_namespace_name(organization_id)
# build agent_id filter
agent_filter = ("agent_id", "Eq", agent_id)
@@ -744,7 +749,7 @@ class TurbopufferClient:
embeddings = await self._generate_embeddings([query_text], actor)
query_embedding = embeddings[0]
# namespace is org-scoped
namespace_name = f"letta_messages_{organization_id}"
namespace_name = await self._get_message_namespace_name(organization_id)
# build filters
all_filters = []
@@ -1054,7 +1059,7 @@ class TurbopufferClient:
if not message_ids:
return True
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
namespace_name = await self._get_message_namespace_name(organization_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:
@@ -1072,7 +1077,7 @@ class TurbopufferClient:
"""Delete all messages for an agent from Turbopuffer."""
from turbopuffer import AsyncTurbopuffer
namespace_name = await self._get_message_namespace_name(agent_id, organization_id)
namespace_name = await self._get_message_namespace_name(organization_id)
try:
async with AsyncTurbopuffer(api_key=self.api_key, region=self.region) as client:

View File

@@ -1192,7 +1192,7 @@ class ToolReturn(BaseModel):
class MessageSearchRequest(BaseModel):
"""Request model for searching messages across the organization"""
query_text: Optional[str] = Field(None, description="Text query for full-text search")
query: Optional[str] = Field(None, description="Text query for full-text search")
search_mode: Literal["vector", "fts", "hybrid"] = Field("hybrid", description="Search mode to use")
roles: Optional[List[MessageRole]] = Field(None, description="Filter messages by role")
project_id: Optional[str] = Field(None, description="Filter messages by project ID")
@@ -1204,9 +1204,8 @@ class MessageSearchRequest(BaseModel):
class MessageSearchResult(BaseModel):
"""Result from a message search operation with scoring details."""
message: Message = Field(..., description="The message content and metadata")
fts_score: Optional[float] = Field(None, description="Full-text search (BM25) score if FTS was used")
embedded_text: str = Field(..., description="The embedded content (LLM-friendly)")
message: Message = Field(..., description="The raw message object")
fts_rank: Optional[int] = Field(None, description="Full-text search rank position if FTS was used")
vector_score: Optional[float] = Field(None, description="Vector similarity score if vector search was used")
vector_rank: Optional[int] = Field(None, description="Vector search rank position if vector search was used")
rrf_score: float = Field(..., description="Reciprocal Rank Fusion combined score")

View File

@@ -1505,10 +1505,9 @@ async def search_messages(
actor_id: str | None = Header(None, alias="user_id"),
):
"""
Search messages across the entire organization with optional project filtering.
Returns messages with FTS/vector ranks and total RRF score.
Search messages across the entire organization with optional project filtering. Returns messages with FTS/vector ranks and total RRF score.
Requires message embedding and Turbopuffer to be enabled.
This is a cloud-only feature.
"""
actor = await server.user_manager.get_actor_or_default_async(actor_id=actor_id)
@@ -1521,7 +1520,7 @@ async def search_messages(
try:
results = await server.message_manager.search_messages_org_async(
actor=actor,
query_text=request.query_text,
query_text=request.query,
search_mode=request.search_mode,
roles=request.roles,
project_id=request.project_id,

View File

@@ -3716,45 +3716,3 @@ class AgentManager:
num_archival_memories=num_archival_memories,
num_messages=num_messages,
)
async def get_or_set_vector_db_namespace_async(
self,
agent_id: str,
organization_id: str,
) -> str:
"""Get the vector database namespace for an agent, creating it if it doesn't exist.
Args:
agent_id: Agent ID to check/store namespace
organization_id: Organization ID for namespace generation
Returns:
The org-scoped namespace name
"""
from sqlalchemy import update
from letta.settings import settings
async with db_registry.async_session() as session:
# check if namespace already exists
result = await session.execute(select(AgentModel._vector_db_namespace).where(AgentModel.id == agent_id))
row = result.fetchone()
if row and row[0]:
return row[0]
# TODO: In the future, we might use agent_id for sharding the namespace
# For now, all messages in an org share the same namespace
# generate org-scoped namespace name
environment = settings.environment
if environment:
namespace_name = f"messages_{organization_id}_{environment.lower()}"
else:
namespace_name = f"messages_{organization_id}"
# update the agent with the namespace (keeps agent-level tracking for future sharding)
await session.execute(update(AgentModel).where(AgentModel.id == agent_id).values(_vector_db_namespace=namespace_name))
await session.commit()
return namespace_name

View File

@@ -1260,19 +1260,23 @@ class MessageManager:
return []
# create message mapping
message_ids = [msg_dict["id"] for msg_dict, _, _ in results]
message_ids = []
embedded_text = {}
for msg_dict, _, _ in results:
message_ids.append(msg_dict["id"])
embedded_text[msg_dict["id"]] = msg_dict["text"]
messages = await self.get_messages_by_ids_async(message_ids=message_ids, actor=actor)
message_mapping = {message.id: message for message in messages}
# create search results using list comprehension
return [
MessageSearchResult(
message=message_mapping.get(msg_dict["id"]),
fts_score=metadata.get("fts_score"),
embedded_text=embedded_text[msg_id],
message=message_mapping[msg_id],
fts_rank=metadata.get("fts_rank"),
vector_score=metadata.get("vector_score"),
vector_rank=metadata.get("vector_rank"),
rrf_score=rrf_score,
)
for msg_dict, rrf_score, metadata in results
if (msg_id := msg_dict.get("id")) in message_mapping
]

View File

@@ -45,7 +45,7 @@ dependencies = [
"llama-index>=0.12.2",
"llama-index-embeddings-openai>=0.3.1",
"anthropic>=0.49.0",
"letta-client==0.1.314",
"letta-client>=0.1.314",
"openai>=1.99.9",
"opentelemetry-api==1.30.0",
"opentelemetry-sdk==1.30.0",

View File

@@ -1937,7 +1937,7 @@ class TestNamespaceTracking:
async def test_agent_namespace_tracking(self, server, default_user, sarah_agent, enable_message_embedding):
"""Test that agent message namespaces are properly tracked in database"""
# Get namespace - should be generated and stored
namespace = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id, default_user.organization_id)
namespace = await server.agent_manager.get_or_set_vector_db_namespace_async(default_user.organization_id)
# Should have messages_org_ prefix and environment suffix
expected_prefix = "messages_"
@@ -1947,7 +1947,7 @@ class TestNamespaceTracking:
assert settings.environment.lower() in namespace
# Call again - should return same namespace from database
namespace2 = await server.agent_manager.get_or_set_vector_db_namespace_async(sarah_agent.id, default_user.organization_id)
namespace2 = await server.agent_manager.get_or_set_vector_db_namespace_async(default_user.organization_id)
assert namespace == namespace2
@pytest.mark.asyncio

View File

@@ -54,6 +54,7 @@ def client() -> LettaSDKClient:
thread = threading.Thread(target=run_server, daemon=True)
thread.start()
time.sleep(5)
print("Running client tests with server:", server_url)
client = LettaSDKClient(base_url=server_url, token=None, timeout=300.0)
yield client
@@ -105,6 +106,63 @@ def fibonacci_tool(client: LettaSDKClient):
client.tools.delete(tool.id)
def test_messages_search(client: LettaSDKClient, agent: AgentState):
"""Exercise org-wide message search with query and filters.
Skips when Turbopuffer/OpenAI are not configured or unavailable in this environment.
"""
from datetime import timezone
from letta.settings import model_settings, settings
# Require TPUF + OpenAI to be configured; otherwise this is a cloud-only feature
if not getattr(settings, "tpuf_api_key", None) or not getattr(model_settings, "openai_api_key", None):
pytest.skip("Message search requires Turbopuffer and OpenAI; skipping.")
original_use_tpuf = settings.use_tpuf
original_embed_all = settings.embed_all_messages
try:
# Enable TPUF + message embedding for this test run
settings.use_tpuf = True
settings.embed_all_messages = True
unique_term = f"kitten-cats-{uuid.uuid4().hex[:8]}"
# Create a couple of messages to search over
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"I love {unique_term} dearly")],
)
client.agents.messages.create(
agent_id=agent.id,
messages=[MessageCreate(role="user", content=f"Recorded preference: {unique_term}")],
)
# Allow brief time for background indexing (if enabled)
time.sleep(2)
# Call the SDK using the OpenAPI fields
results = client.agents.messages.search(
query=unique_term,
search_mode="hybrid",
roles=["user"],
project_id=agent.project_id,
limit=10,
start_date=None,
end_date=None,
)
# Validate shape of response
assert isinstance(results, list) and len(results) >= 1
top = results[0]
assert getattr(top, "message", None) is not None
assert top.message.role == "user" # role filter applied
assert hasattr(top, "rrf_score") and top.rrf_score is not None
finally:
settings.use_tpuf = original_use_tpuf
settings.embed_all_messages = original_embed_all
@pytest.fixture(scope="function")
def preferences_tool(client: LettaSDKClient):
"""Fixture providing user preferences tool."""

5183
uv.lock generated

File diff suppressed because it is too large Load Diff