feat: cutover repo to 1.0 sdk client LET-6256 (#6361)
feat: cutover repo to 1.0 sdk client
This commit is contained in:
@@ -1,11 +1,42 @@
|
||||
from conftest import create_test_module
|
||||
|
||||
AGENTS_CREATE_PARAMS = [
|
||||
("caren_agent", {"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"}, {}, None),
|
||||
(
|
||||
"caren_agent",
|
||||
{"name": "caren", "model": "openai/gpt-4o-mini", "embedding": "openai/text-embedding-3-small"},
|
||||
{
|
||||
# Verify model_settings is populated with config values
|
||||
# Note: The 'model' field itself is separate from model_settings
|
||||
"model_settings": {
|
||||
"max_output_tokens": 4096,
|
||||
"parallel_tool_calls": False,
|
||||
"provider_type": "openai",
|
||||
"temperature": 0.7,
|
||||
"reasoning": {"reasoning_effort": "minimal"},
|
||||
"response_format": None,
|
||||
}
|
||||
},
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
AGENTS_MODIFY_PARAMS = [
|
||||
("caren_agent", {"name": "caren_updated"}, {}, None),
|
||||
AGENTS_UPDATE_PARAMS = [
|
||||
(
|
||||
"caren_agent",
|
||||
{"name": "caren_updated"},
|
||||
{
|
||||
# After updating just the name, model_settings should still be present
|
||||
"model_settings": {
|
||||
"max_output_tokens": 4096,
|
||||
"parallel_tool_calls": False,
|
||||
"provider_type": "openai",
|
||||
"temperature": 0.7,
|
||||
"reasoning": {"reasoning_effort": "minimal"},
|
||||
"response_format": None,
|
||||
}
|
||||
},
|
||||
None,
|
||||
),
|
||||
]
|
||||
|
||||
AGENTS_LIST_PARAMS = [
|
||||
@@ -19,7 +50,7 @@ globals().update(
|
||||
resource_name="agents",
|
||||
id_param_name="agent_id",
|
||||
create_params=AGENTS_CREATE_PARAMS,
|
||||
modify_params=AGENTS_MODIFY_PARAMS,
|
||||
update_params=AGENTS_UPDATE_PARAMS,
|
||||
list_params=AGENTS_LIST_PARAMS,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -1,5 +1,5 @@
|
||||
from conftest import create_test_module
|
||||
from letta_client.errors import UnprocessableEntityError
|
||||
from letta_client import UnprocessableEntityError
|
||||
|
||||
from letta.constants import CORE_MEMORY_HUMAN_CHAR_LIMIT, CORE_MEMORY_PERSONA_CHAR_LIMIT
|
||||
|
||||
@@ -8,7 +8,7 @@ BLOCKS_CREATE_PARAMS = [
|
||||
("persona_block", {"label": "persona", "value": "test1"}, {"limit": CORE_MEMORY_PERSONA_CHAR_LIMIT}, None),
|
||||
]
|
||||
|
||||
BLOCKS_MODIFY_PARAMS = [
|
||||
BLOCKS_UPDATE_PARAMS = [
|
||||
("human_block", {"value": "test2"}, {}, None),
|
||||
("persona_block", {"value": "testing testing testing", "limit": 10}, {}, UnprocessableEntityError),
|
||||
]
|
||||
@@ -25,7 +25,7 @@ globals().update(
|
||||
resource_name="blocks",
|
||||
id_param_name="block_id",
|
||||
create_params=BLOCKS_CREATE_PARAMS,
|
||||
modify_params=BLOCKS_MODIFY_PARAMS,
|
||||
update_params=BLOCKS_UPDATE_PARAMS,
|
||||
list_params=BLOCKS_LIST_PARAMS,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -48,14 +48,12 @@ def server_url() -> str:
|
||||
|
||||
# This fixture creates a client for each test module
|
||||
@pytest.fixture(scope="session")
|
||||
def client(server_url):
|
||||
print("Running client tests with server:", server_url)
|
||||
|
||||
# Overide the base_url if the LETTA_API_URL is set
|
||||
api_url = os.getenv("LETTA_API_URL")
|
||||
base_url = api_url if api_url else server_url
|
||||
# create the Letta client
|
||||
yield Letta(base_url=base_url, token=None, timeout=300.0)
|
||||
def client(server_url: str) -> Letta:
|
||||
"""
|
||||
Creates and returns a synchronous Letta REST client for testing.
|
||||
"""
|
||||
client_instance = Letta(base_url=server_url)
|
||||
yield client_instance
|
||||
|
||||
|
||||
def skip_test_if_not_implemented(handler, resource_name, test_name):
|
||||
@@ -68,7 +66,7 @@ def create_test_module(
|
||||
id_param_name: str,
|
||||
create_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
|
||||
upsert_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
|
||||
modify_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
|
||||
update_params: List[Tuple[str, Dict[str, Any], Dict[str, Any], Optional[Exception]]] = [],
|
||||
list_params: List[Tuple[Dict[str, Any], int]] = [],
|
||||
) -> Dict[str, Any]:
|
||||
"""Create a test module for a resource.
|
||||
@@ -80,7 +78,7 @@ def create_test_module(
|
||||
resource_name: Name of the resource (e.g., "blocks", "tools")
|
||||
id_param_name: Name of the ID parameter (e.g., "block_id", "tool_id")
|
||||
create_params: List of (name, params, expected_error) tuples for create tests
|
||||
modify_params: List of (name, params, expected_error) tuples for modify tests
|
||||
update_params: List of (name, params, expected_error) tuples for update tests
|
||||
list_params: List of (query_params, expected_count) tuples for list tests
|
||||
|
||||
Returns:
|
||||
@@ -138,11 +136,7 @@ def create_test_module(
|
||||
expected_values = processed_params | processed_extra_expected
|
||||
for key, value in expected_values.items():
|
||||
if hasattr(item, key):
|
||||
if key == "model" or key == "embedding":
|
||||
# NOTE: add back these tests after v1 migration
|
||||
continue
|
||||
print(f"item.{key}: {getattr(item, key)}")
|
||||
assert custom_model_dump(getattr(item, key)) == value, f"For key {key}, expected {value}, but got {getattr(item, key)}"
|
||||
assert custom_model_dump(getattr(item, key)) == value
|
||||
|
||||
@pytest.mark.order(1)
|
||||
def test_retrieve(handler):
|
||||
@@ -180,9 +174,9 @@ def create_test_module(
|
||||
assert custom_model_dump(getattr(item, key)) == value
|
||||
|
||||
@pytest.mark.order(3)
|
||||
def test_modify(handler, caren_agent, name, params, extra_expected_values, expected_error):
|
||||
"""Test modifying a resource."""
|
||||
skip_test_if_not_implemented(handler, resource_name, "modify")
|
||||
def test_update(handler, caren_agent, name, params, extra_expected_values, expected_error):
|
||||
"""Test updating a resource."""
|
||||
skip_test_if_not_implemented(handler, resource_name, "update")
|
||||
if name not in test_item_ids:
|
||||
pytest.skip(f"Item '{name}' not found in test_items")
|
||||
|
||||
@@ -192,7 +186,7 @@ def create_test_module(
|
||||
processed_extra_expected = preprocess_params(extra_expected_values, caren_agent)
|
||||
|
||||
try:
|
||||
item = handler.modify(**processed_params)
|
||||
item = handler.update(**processed_params)
|
||||
except Exception as e:
|
||||
if expected_error is not None:
|
||||
assert isinstance(e, expected_error), f"Expected error with type {expected_error}, but got {type(e)}: {e}"
|
||||
@@ -254,7 +248,7 @@ def create_test_module(
|
||||
"test_create": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", create_params)(test_create),
|
||||
"test_retrieve": test_retrieve,
|
||||
"test_upsert": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", upsert_params)(test_upsert),
|
||||
"test_modify": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", modify_params)(test_modify),
|
||||
"test_update": pytest.mark.parametrize("name, params, extra_expected_values, expected_error", update_params)(test_update),
|
||||
"test_delete": test_delete,
|
||||
"test_list": pytest.mark.parametrize("query_params, count", list_params)(test_list),
|
||||
}
|
||||
|
||||
@@ -10,7 +10,7 @@ GROUPS_CREATE_PARAMS = [
|
||||
),
|
||||
]
|
||||
|
||||
GROUPS_MODIFY_PARAMS = [
|
||||
GROUPS_UPDATE_PARAMS = [
|
||||
(
|
||||
"round_robin_group",
|
||||
{"manager_config": {"manager_type": "round_robin", "max_turns": 10}},
|
||||
@@ -30,7 +30,7 @@ globals().update(
|
||||
resource_name="groups",
|
||||
id_param_name="group_id",
|
||||
create_params=GROUPS_CREATE_PARAMS,
|
||||
modify_params=GROUPS_MODIFY_PARAMS,
|
||||
update_params=GROUPS_UPDATE_PARAMS,
|
||||
list_params=GROUPS_LIST_PARAMS,
|
||||
)
|
||||
)
|
||||
|
||||
@@ -5,7 +5,7 @@ IDENTITIES_CREATE_PARAMS = [
|
||||
("caren2", {"identifier_key": "456", "name": "caren", "identity_type": "user"}, {}, None),
|
||||
]
|
||||
|
||||
IDENTITIES_MODIFY_PARAMS = [
|
||||
IDENTITIES_UPDATE_PARAMS = [
|
||||
("caren1", {"properties": [{"key": "email", "value": "caren@letta.com", "type": "string"}]}, {}, None),
|
||||
("caren2", {"properties": [{"key": "email", "value": "caren@gmail.com", "type": "string"}]}, {}, None),
|
||||
]
|
||||
@@ -37,7 +37,7 @@ globals().update(
|
||||
id_param_name="identity_id",
|
||||
create_params=IDENTITIES_CREATE_PARAMS,
|
||||
upsert_params=IDENTITIES_UPSERT_PARAMS,
|
||||
modify_params=IDENTITIES_MODIFY_PARAMS,
|
||||
update_params=IDENTITIES_UPDATE_PARAMS,
|
||||
list_params=IDENTITIES_LIST_PARAMS,
|
||||
)
|
||||
)
|
||||
|
||||
1173
tests/sdk/mcp_servers_test.py
Normal file
1173
tests/sdk/mcp_servers_test.py
Normal file
File diff suppressed because it is too large
Load Diff
185
tests/sdk/mock_mcp_server.py
Executable file
185
tests/sdk/mock_mcp_server.py
Executable file
@@ -0,0 +1,185 @@
|
||||
#!/usr/bin/env python3
|
||||
"""
|
||||
Mock MCP server for testing.
|
||||
Implements a simple stdio-based MCP server with various test tools using FastMCP.
|
||||
"""
|
||||
|
||||
import argparse
|
||||
import json
|
||||
import logging
|
||||
import sys
|
||||
from typing import Any, Dict, List, Optional
|
||||
|
||||
try:
|
||||
from mcp.server.fastmcp import FastMCP
|
||||
from pydantic import BaseModel, Field
|
||||
except ImportError as e:
|
||||
print(f"Error importing required modules: {e}", file=sys.stderr)
|
||||
print("Please ensure mcp and pydantic are installed", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
# Parse command line arguments
|
||||
parser = argparse.ArgumentParser(description="Mock MCP server for testing")
|
||||
parser.add_argument("--no-tools", action="store_true", help="Start server with no tools")
|
||||
args = parser.parse_args()
|
||||
|
||||
# Configure logging to stderr (not stdout for STDIO servers)
|
||||
logging.basicConfig(level=logging.INFO)
|
||||
|
||||
# Initialize FastMCP server
|
||||
mcp = FastMCP("mock-mcp-server")
|
||||
|
||||
|
||||
# Pydantic models for complex tools
|
||||
class Address(BaseModel):
|
||||
"""An address with street, city, and zip code."""
|
||||
|
||||
street: Optional[str] = Field(None, description="Street address")
|
||||
city: Optional[str] = Field(None, description="City name")
|
||||
zip: Optional[str] = Field(None, description="ZIP code")
|
||||
|
||||
|
||||
class Instantiation(BaseModel):
|
||||
"""Instantiation object with optional node identifiers."""
|
||||
|
||||
doid: Optional[str] = Field(None, description="DOID identifier")
|
||||
nodeFamilyId: Optional[int] = Field(None, description="Node family ID")
|
||||
|
||||
|
||||
class InstantiationData(BaseModel):
|
||||
"""Instantiation data with abstract and multiplicity flags."""
|
||||
|
||||
isAbstract: Optional[bool] = Field(None, description="Whether the instantiation is abstract")
|
||||
isMultiplicity: Optional[bool] = Field(None, description="Whether the instantiation has multiplicity")
|
||||
instantiations: Optional[List[Instantiation]] = Field(None, description="List of instantiations")
|
||||
|
||||
|
||||
# Only register tools if --no-tools flag is not set
|
||||
if not args.no_tools:
|
||||
# Simple tools
|
||||
@mcp.tool()
|
||||
async def echo(message: str) -> str:
|
||||
"""Echo back a message.
|
||||
|
||||
Args:
|
||||
message: The message to echo
|
||||
"""
|
||||
return f"Echo: {message}"
|
||||
|
||||
@mcp.tool()
|
||||
async def add(a: float, b: float) -> str:
|
||||
"""Add two numbers.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
"""
|
||||
return f"Result: {a + b}"
|
||||
|
||||
@mcp.tool()
|
||||
async def multiply(a: float, b: float) -> str:
|
||||
"""Multiply two numbers.
|
||||
|
||||
Args:
|
||||
a: First number
|
||||
b: Second number
|
||||
"""
|
||||
return f"Result: {a * b}"
|
||||
|
||||
@mcp.tool()
|
||||
async def reverse_string(text: str) -> str:
|
||||
"""Reverse a string.
|
||||
|
||||
Args:
|
||||
text: The text to reverse
|
||||
"""
|
||||
return f"Reversed: {text[::-1]}"
|
||||
|
||||
# Complex tools
|
||||
@mcp.tool()
|
||||
async def create_person(name: str, age: Optional[int] = None, email: Optional[str] = None, address: Optional[Address] = None) -> str:
|
||||
"""Create a person object with details.
|
||||
|
||||
Args:
|
||||
name: Person's name
|
||||
age: Person's age
|
||||
email: Person's email
|
||||
address: Person's address
|
||||
"""
|
||||
person_data = {"name": name}
|
||||
if age is not None:
|
||||
person_data["age"] = age
|
||||
if email is not None:
|
||||
person_data["email"] = email
|
||||
if address is not None:
|
||||
person_data["address"] = address.model_dump(exclude_none=True)
|
||||
|
||||
return f"Created person: {json.dumps(person_data)}"
|
||||
|
||||
@mcp.tool()
|
||||
async def manage_tasks(action: str, task: Optional[str] = None) -> str:
|
||||
"""Manage a list of tasks.
|
||||
|
||||
Args:
|
||||
action: The action to perform (add, remove, list)
|
||||
task: The task to add or remove
|
||||
"""
|
||||
if action == "add":
|
||||
return f"Added task: {task}"
|
||||
elif action == "remove":
|
||||
return f"Removed task: {task}"
|
||||
else:
|
||||
return "Listed tasks: []"
|
||||
|
||||
@mcp.tool()
|
||||
async def search_with_filters(query: str, filters: Optional[Dict[str, Any]] = None) -> str:
|
||||
"""Search with various filters.
|
||||
|
||||
Args:
|
||||
query: Search query
|
||||
filters: Optional filters dictionary
|
||||
"""
|
||||
return f"Search results for '{query}' with filters {filters}"
|
||||
|
||||
@mcp.tool()
|
||||
async def process_nested_data(data: Dict[str, Any]) -> str:
|
||||
"""Process deeply nested data structures.
|
||||
|
||||
Args:
|
||||
data: The nested data to process
|
||||
"""
|
||||
return f"Processed nested data: {json.dumps(data)}"
|
||||
|
||||
@mcp.tool()
|
||||
async def get_parameter_type_description(
|
||||
preset: str, connected_service_descriptor: Optional[str] = None, instantiation_data: Optional[InstantiationData] = None
|
||||
) -> str:
|
||||
"""Get parameter type description with complex schema.
|
||||
|
||||
Args:
|
||||
preset: Preset configuration (a, b, c)
|
||||
connected_service_descriptor: Service descriptor
|
||||
instantiation_data: Instantiation data with nested structure
|
||||
"""
|
||||
result = f"Preset: {preset}"
|
||||
if connected_service_descriptor:
|
||||
result += f", Service: {connected_service_descriptor}"
|
||||
if instantiation_data:
|
||||
result += f", Instantiation data: {json.dumps(instantiation_data.model_dump(exclude_none=True))}"
|
||||
return result
|
||||
|
||||
|
||||
def main():
|
||||
"""Run the MCP server using stdio transport."""
|
||||
try:
|
||||
mcp.run(transport="stdio")
|
||||
except KeyboardInterrupt:
|
||||
# Clean exit on Ctrl+C
|
||||
sys.exit(0)
|
||||
except Exception as e:
|
||||
print(f"Server error: {e}", file=sys.stderr)
|
||||
sys.exit(1)
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
main()
|
||||
506
tests/sdk/search_test.py
Normal file
506
tests/sdk/search_test.py
Normal file
@@ -0,0 +1,506 @@
|
||||
"""
|
||||
End-to-end tests for passage and message search endpoints using the SDK client.
|
||||
|
||||
These tests verify that the /v1/passages/search and /v1/messages/search endpoints work correctly
|
||||
with Turbopuffer integration, including vector search, FTS, hybrid search, filtering, and pagination.
|
||||
"""
|
||||
|
||||
import time
|
||||
import uuid
|
||||
from datetime import datetime, timedelta, timezone
|
||||
from typing import Any
|
||||
|
||||
import pytest
|
||||
from letta_client import Letta
|
||||
from letta_client.types import CreateBlockParam, MessageCreateParam
|
||||
|
||||
from letta.config import LettaConfig
|
||||
from letta.server.rest_api.routers.v1.passages import PassageSearchResult
|
||||
from letta.server.server import SyncServer
|
||||
from letta.settings import settings
|
||||
|
||||
|
||||
def cleanup_agent_with_messages(client: Letta, agent_id: str):
|
||||
"""
|
||||
Helper function to properly clean up an agent by first deleting all its messages
|
||||
from Turbopuffer before deleting the agent itself.
|
||||
|
||||
Args:
|
||||
client: Letta SDK client
|
||||
agent_id: ID of the agent to clean up
|
||||
"""
|
||||
try:
|
||||
# First, delete all messages for this agent from Turbopuffer
|
||||
# This ensures no orphaned messages remain in Turbopuffer
|
||||
try:
|
||||
import asyncio
|
||||
|
||||
from letta.helpers.tpuf_client import TurbopufferClient, should_use_tpuf_for_messages
|
||||
|
||||
if should_use_tpuf_for_messages():
|
||||
tpuf_client = TurbopufferClient()
|
||||
# Delete all messages for this agent from Turbopuffer
|
||||
asyncio.run(tpuf_client.delete_all_messages(agent_id))
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up Turbopuffer messages for agent {agent_id}: {e}")
|
||||
|
||||
# Now delete the agent itself (which will delete SQL messages via cascade)
|
||||
client.agents.delete(agent_id=agent_id)
|
||||
except Exception as e:
|
||||
print(f"Warning: Failed to clean up agent {agent_id}: {e}")
|
||||
|
||||
|
||||
@pytest.fixture(scope="module")
|
||||
def server():
|
||||
"""Server fixture for testing"""
|
||||
config = LettaConfig.load()
|
||||
config.save()
|
||||
server = SyncServer(init_with_default_org_and_user=False)
|
||||
return server
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_turbopuffer():
|
||||
"""Temporarily enable Turbopuffer for testing"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_environment = settings.environment
|
||||
|
||||
# Enable Turbopuffer with test key
|
||||
settings.use_tpuf = True
|
||||
if not settings.tpuf_api_key:
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
# Restore original values
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def enable_message_embedding():
|
||||
"""Enable both Turbopuffer and message embedding"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_api_key = settings.tpuf_api_key
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
original_environment = settings.environment
|
||||
|
||||
settings.use_tpuf = True
|
||||
settings.tpuf_api_key = settings.tpuf_api_key or "test-key"
|
||||
settings.embed_all_messages = True
|
||||
settings.environment = "DEV"
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.tpuf_api_key = original_api_key
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
settings.environment = original_environment
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
def disable_turbopuffer():
|
||||
"""Ensure Turbopuffer is disabled for testing"""
|
||||
original_use_tpuf = settings.use_tpuf
|
||||
original_embed_messages = settings.embed_all_messages
|
||||
|
||||
settings.use_tpuf = False
|
||||
settings.embed_all_messages = False
|
||||
|
||||
yield
|
||||
|
||||
settings.use_tpuf = original_use_tpuf
|
||||
settings.embed_all_messages = original_embed_messages
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_basic(client: Letta, enable_turbopuffer):
|
||||
"""Test basic passage search functionality through the SDK"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_search_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive and attach to agent
|
||||
archive = client.archives.create(name=f"test_archive_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert some passages
|
||||
test_passages = [
|
||||
"Python is a popular programming language for data science and machine learning.",
|
||||
"JavaScript is widely used for web development and frontend applications.",
|
||||
"Turbopuffer is a vector database optimized for performance and scalability.",
|
||||
]
|
||||
|
||||
for passage_text in test_passages:
|
||||
client.archives.passages.create(archive_id=archive.id, text=passage_text)
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test search by agent_id
|
||||
results = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "python programming",
|
||||
"agent_id": agent.id,
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results) > 0, "Should find at least one passage"
|
||||
assert any("Python" in result.passage.text for result in results), "Should find Python-related passage"
|
||||
|
||||
# Verify result structure
|
||||
for result in results:
|
||||
assert hasattr(result, "passage"), "Result should have passage field"
|
||||
assert hasattr(result, "score"), "Result should have score field"
|
||||
assert hasattr(result, "metadata"), "Result should have metadata field"
|
||||
assert isinstance(result.score, float), "Score should be a float"
|
||||
|
||||
# Test search by archive_id
|
||||
archive_results = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "vector database",
|
||||
"archive_id": archive.id,
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(archive_results) > 0, "Should find passages in archive"
|
||||
assert any("Turbopuffer" in result.passage.text or "vector" in result.passage.text for result in archive_results), (
|
||||
"Should find vector-related passage"
|
||||
)
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_with_tags(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search with tag filtering"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_tags_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_tags_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert passages with tags (if supported)
|
||||
# Note: Tag support may depend on the SDK version
|
||||
test_passages = [
|
||||
"Python tutorial for beginners",
|
||||
"Advanced Python techniques",
|
||||
"JavaScript basics",
|
||||
]
|
||||
|
||||
for passage_text in test_passages:
|
||||
client.archives.passages.create(archive_id=archive.id, text=passage_text)
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test basic search without tags first
|
||||
results = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "programming tutorial",
|
||||
"agent_id": agent.id,
|
||||
"limit": 10,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results) > 0, "Should find passages"
|
||||
|
||||
# Test with tag filtering if supported
|
||||
# Note: The SDK may not expose tag parameters directly, so this test verifies basic functionality
|
||||
# The backend will handle tag filtering when available
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_with_date_filters(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search with date range filtering"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_dates_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_dates_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert passages at different times
|
||||
client.archives.passages.create(archive_id=archive.id, text="Recent passage about AI trends")
|
||||
|
||||
# Wait a bit before creating another
|
||||
time.sleep(1)
|
||||
|
||||
client.archives.passages.create(archive_id=archive.id, text="Another passage about machine learning")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test search with date range
|
||||
now = datetime.now(timezone.utc)
|
||||
start_date = now - timedelta(hours=1)
|
||||
|
||||
results = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "AI machine learning",
|
||||
"agent_id": agent.id,
|
||||
"limit": 10,
|
||||
"start_date": start_date.isoformat(),
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results) > 0, "Should find recent passages"
|
||||
|
||||
# Verify all results are within date range
|
||||
for result in results:
|
||||
passage_date = result.passage.created_at
|
||||
if passage_date:
|
||||
# Convert to datetime if it's a string
|
||||
if isinstance(passage_date, str):
|
||||
passage_date = datetime.fromisoformat(passage_date.replace("Z", "+00:00"))
|
||||
assert passage_date >= start_date, "Passage should be after start_date"
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_message_search_basic(client: Letta, enable_message_embedding):
|
||||
"""Test basic message search functionality through the SDK"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_message_search_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="helpful assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Send messages to the agent
|
||||
test_messages = [
|
||||
"What is the capital of Mozambique?",
|
||||
]
|
||||
|
||||
for msg_text in test_messages:
|
||||
client.agents.messages.create(agent_id=agent.id, messages=[MessageCreateParam(role="user", content=msg_text)])
|
||||
|
||||
# Wait for messages to be indexed and database transactions to complete
|
||||
# Extra time needed for async embedding and database commits
|
||||
time.sleep(6)
|
||||
|
||||
# Test FTS search for messages
|
||||
results = client.messages.search(query="capital of Mozambique", search_mode="fts", limit=10)
|
||||
|
||||
assert len(results) > 0, "Should find at least one message"
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_pagination(client: Letta, enable_turbopuffer):
|
||||
"""Test passage search pagination"""
|
||||
# Create an agent
|
||||
agent = client.agents.create(
|
||||
name=f"test_passage_pagination_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create an archive
|
||||
archive = client.archives.create(name=f"test_archive_pagination_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archive to agent
|
||||
client.agents.archives.attach(agent_id=agent.id, archive_id=archive.id)
|
||||
|
||||
# Insert many passages
|
||||
for i in range(10):
|
||||
client.archives.passages.create(archive_id=archive.id, text=f"Test passage number {i} about programming")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test with different limit values
|
||||
results_limit_3 = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "programming",
|
||||
"agent_id": agent.id,
|
||||
"limit": 3,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results_limit_3) == 3, "Should respect limit parameter"
|
||||
|
||||
results_limit_5 = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "programming",
|
||||
"agent_id": agent.id,
|
||||
"limit": 5,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results_limit_5) == 5, "Should return 5 results"
|
||||
|
||||
results_all = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "programming",
|
||||
"agent_id": agent.id,
|
||||
"limit": 20,
|
||||
},
|
||||
)
|
||||
|
||||
assert len(results_all) >= 10, "Should return all matching passages"
|
||||
|
||||
finally:
|
||||
# Clean up archive
|
||||
try:
|
||||
client.archives.delete(archive_id=archive.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agent
|
||||
cleanup_agent_with_messages(client, agent.id)
|
||||
|
||||
|
||||
@pytest.mark.skipif(not settings.tpuf_api_key, reason="Turbopuffer API key not configured")
|
||||
def test_passage_search_org_wide(client: Letta, enable_turbopuffer):
|
||||
"""Test organization-wide passage search (without agent_id or archive_id)"""
|
||||
# Create multiple agents with archives
|
||||
agent1 = client.agents.create(
|
||||
name=f"test_org_search_agent1_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant 1")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
agent2 = client.agents.create(
|
||||
name=f"test_org_search_agent2_{uuid.uuid4()}",
|
||||
memory_blocks=[CreateBlockParam(label="persona", value="test assistant 2")],
|
||||
model="openai/gpt-4o-mini",
|
||||
embedding="openai/text-embedding-3-small",
|
||||
)
|
||||
|
||||
try:
|
||||
# Create archives for both agents
|
||||
archive1 = client.archives.create(name=f"test_archive_org1_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
archive2 = client.archives.create(name=f"test_archive_org2_{uuid.uuid4()}", embedding="openai/text-embedding-3-small")
|
||||
|
||||
try:
|
||||
# Attach archives
|
||||
client.agents.archives.attach(agent_id=agent1.id, archive_id=archive1.id)
|
||||
client.agents.archives.attach(agent_id=agent2.id, archive_id=archive2.id)
|
||||
|
||||
# Insert passages in both archives
|
||||
client.archives.passages.create(archive_id=archive1.id, text="Unique passage in agent1 about quantum computing")
|
||||
|
||||
client.archives.passages.create(archive_id=archive2.id, text="Unique passage in agent2 about blockchain technology")
|
||||
|
||||
# Wait for indexing
|
||||
time.sleep(2)
|
||||
|
||||
# Test org-wide search (no agent_id or archive_id)
|
||||
results = client.post(
|
||||
"/v1/passages/search",
|
||||
cast_to=list[PassageSearchResult],
|
||||
body={
|
||||
"query": "unique passage",
|
||||
"limit": 20,
|
||||
},
|
||||
)
|
||||
|
||||
# Should find passages from both agents
|
||||
assert len(results) >= 2, "Should find passages from multiple agents"
|
||||
|
||||
found_texts = [result.passage.text for result in results]
|
||||
assert any("quantum computing" in text for text in found_texts), "Should find agent1 passage"
|
||||
assert any("blockchain" in text for text in found_texts), "Should find agent2 passage"
|
||||
|
||||
finally:
|
||||
# Clean up archives
|
||||
try:
|
||||
client.archives.delete(archive_id=archive1.id)
|
||||
except:
|
||||
pass
|
||||
try:
|
||||
client.archives.delete(archive_id=archive2.id)
|
||||
except:
|
||||
pass
|
||||
|
||||
finally:
|
||||
# Clean up agents
|
||||
cleanup_agent_with_messages(client, agent1.id)
|
||||
cleanup_agent_with_messages(client, agent2.id)
|
||||
@@ -44,14 +44,14 @@ TOOLS_UPSERT_PARAMS = [
|
||||
("unfriendly_func", {"source_code": UNFRIENDLY_FUNC_SOURCE_CODE_V2}, {}, None),
|
||||
]
|
||||
|
||||
TOOLS_MODIFY_PARAMS = [
|
||||
TOOLS_UPDATE_PARAMS = [
|
||||
("friendly_func", {"tags": ["sdk_test"]}, {}, None),
|
||||
("unfriendly_func", {"return_char_limit": 300}, {}, None),
|
||||
]
|
||||
|
||||
TOOLS_LIST_PARAMS = [
|
||||
({}, 2),
|
||||
({"name": ["friendly_func"]}, 1),
|
||||
({"name": "friendly_func"}, 1),
|
||||
]
|
||||
|
||||
# Create all test module components at once
|
||||
@@ -61,7 +61,7 @@ globals().update(
|
||||
id_param_name="tool_id",
|
||||
create_params=TOOLS_CREATE_PARAMS,
|
||||
upsert_params=TOOLS_UPSERT_PARAMS,
|
||||
modify_params=TOOLS_MODIFY_PARAMS,
|
||||
update_params=TOOLS_UPDATE_PARAMS,
|
||||
list_params=TOOLS_LIST_PARAMS,
|
||||
)
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user