feat: add text search for agent names (#662)

Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
mlong93
2025-01-15 10:41:40 -08:00
committed by GitHub
parent a9c5866ae9
commit f33518ca66
6 changed files with 152 additions and 13 deletions

View File

@@ -436,13 +436,22 @@ class RESTClient(AbstractClient):
self._default_llm_config = default_llm_config
self._default_embedding_config = default_embedding_config
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
params = {}
def list_agents(
self, tags: Optional[List[str]] = None, query_text: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None
) -> List[AgentState]:
params = {"limit": limit}
if tags:
params["tags"] = tags
params["match_all_tags"] = False
if query_text:
params["query_text"] = query_text
if cursor:
params["cursor"] = cursor
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
print(f"\nLIST RESPONSE\n{response.json()}\n")
return [AgentState(**agent) for agent in response.json()]
def agent_exists(self, agent_id: str) -> bool:
@@ -2210,10 +2219,12 @@ class LocalClient(AbstractClient):
self.organization = self.server.get_organization_or_default(self.org_id)
# agents
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
def list_agents(
self, query_text: Optional[str] = None, tags: Optional[List[str]] = None, limit: int = 100, cursor: Optional[str] = None
) -> List[AgentState]:
self.interface.clear()
return self.server.agent_manager.list_agents(actor=self.user, tags=tags)
return self.server.agent_manager.list_agents(actor=self.user, tags=tags, query_text=query_text, limit=limit, cursor=cursor)
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
"""

View File

@@ -163,7 +163,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
# Text search
if query_text:
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
if hasattr(cls, "text"):
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
elif hasattr(cls, "name"):
# Special case for Agent model - search across name
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
# Embedding search (for Passages)
is_ordered = False

View File

@@ -47,9 +47,9 @@ def list_agents(
),
server: "SyncServer" = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
cursor: Optional[int] = Query(None, description="Cursor for pagination"),
cursor: Optional[str] = Query(None, description="Cursor for pagination"),
limit: Optional[int] = Query(None, description="Limit for pagination"),
# Extract user_id from header, default to None if not present
query_text: Optional[str] = Query(None, description="Search agents by name"),
):
"""
List all agents associated with a given user.
@@ -64,6 +64,7 @@ def list_agents(
"tags": tags,
"match_all_tags": match_all_tags,
"name": name,
"query_text": query_text,
}.items()
if value is not None
}

View File

@@ -268,6 +268,7 @@ class AgentManager:
match_all_tags: bool = False,
cursor: Optional[str] = None,
limit: Optional[int] = 50,
query_text: Optional[str] = None,
**kwargs,
) -> List[PydanticAgentState]:
"""
@@ -281,6 +282,7 @@ class AgentManager:
cursor=cursor,
limit=limit,
organization_id=actor.organization_id if actor else None,
query_text=query_text,
**kwargs,
)

View File

@@ -80,6 +80,28 @@ def agent(client: Union[LocalClient, RESTClient]):
client.delete_agent(agent_state.id)
# Fixture for test agent
@pytest.fixture
def search_agent_one(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name="Search Agent One")
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
# Fixture for test agent
@pytest.fixture
def search_agent_two(client: Union[LocalClient, RESTClient]):
agent_state = client.create_agent(name="Search Agent Two")
yield agent_state
# delete agent
client.delete_agent(agent_state.id)
@pytest.fixture(autouse=True)
def clear_tables():
"""Clear the sandbox tables before each test."""
@@ -560,17 +582,50 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
# ==========================================
# TESTS FOR AGENT LISTING
# ==========================================
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
"""Test listing agents with pagination and query text filtering."""
# Test query text filtering
search_results = client.list_agents(query_text="search agent")
assert len(search_results) == 2
search_agent_ids = {agent.id for agent in search_results}
assert search_agent_one.id in search_agent_ids
assert search_agent_two.id in search_agent_ids
assert agent.id not in search_agent_ids
different_results = client.list_agents(query_text="client")
assert len(different_results) == 1
assert different_results[0].id == agent.id
# Test pagination
first_page = client.list_agents(query_text="search agent", limit=1)
assert len(first_page) == 1
first_agent = first_page[0]
second_page = client.list_agents(query_text="search agent", cursor=first_agent.id, limit=1) # Use agent ID as cursor
assert len(second_page) == 1
assert second_page[0].id != first_agent.id
# Verify we got both search agents with no duplicates
all_ids = {first_page[0].id, second_page[0].id}
assert len(all_ids) == 2
assert all_ids == {search_agent_one.id, search_agent_two.id}
# Test listing without any filters
all_agents = client.list_agents()
assert len(all_agents) == 3
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
def test_agent_creation(client: Union[LocalClient, RESTClient]):
"""Test that block IDs are properly attached when creating an agent."""
if not isinstance(client, RESTClient):
pytest.skip("This test only runs when the server is enabled")
offline_memory_agent_system = """
You are a helpful agent. You will be provided with a list of memory blocks and a user preferences block.
You should use the memory blocks to remember information about the user and their preferences.
You should also use the user preferences block to remember information about the user's preferences.
"""
from letta import BasicBlockMemory
# Create a test block that will represent user preferences
@@ -623,3 +678,5 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
assert len(agent_tools) == 2
tool_ids = {tool1.id, tool2.id}
assert all(tool.id in tool_ids for tool in agent_tools)
client.delete_agent(agent_id=agent.id)

View File

@@ -914,6 +914,70 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
assert agent2.id in all_ids
def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization):
"""Test listing agents with query text filtering and pagination."""
# Create test agents with specific names and descriptions
agent1 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="Search Agent One",
memory_blocks=[],
description="This is a search agent for testing",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
agent2 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="Search Agent Two",
memory_blocks=[],
description="Another search agent for testing",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
agent3 = server.agent_manager.create_agent(
agent_create=CreateAgent(
name="Different Agent",
memory_blocks=[],
description="This is a different agent",
llm_config=LLMConfig.default_config("gpt-4"),
embedding_config=EmbeddingConfig.default_config(provider="openai"),
),
actor=default_user,
)
# Test query text filtering
search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent")
assert len(search_results) == 2
search_agent_ids = {agent.id for agent in search_results}
assert agent1.id in search_agent_ids
assert agent2.id in search_agent_ids
assert agent3.id not in search_agent_ids
different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent")
assert len(different_results) == 1
assert different_results[0].id == agent3.id
# Test pagination with query text
first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1)
assert len(first_page) == 1
first_agent_id = first_page[0].id
# Get second page using cursor
second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", cursor=first_agent_id, limit=1)
assert len(second_page) == 1
assert second_page[0].id != first_agent_id
# Verify we got both search agents with no duplicates
all_ids = {first_page[0].id, second_page[0].id}
assert len(all_ids) == 2
assert all_ids == {agent1.id, agent2.id}
# ======================================================================================================================
# AgentManager Tests - Messages Relationship
# ======================================================================================================================