feat: add text search for agent names (#662)
Co-authored-by: Mindy Long <mindy@letta.com>
This commit is contained in:
@@ -436,13 +436,22 @@ class RESTClient(AbstractClient):
|
||||
self._default_llm_config = default_llm_config
|
||||
self._default_embedding_config = default_embedding_config
|
||||
|
||||
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
params = {}
|
||||
def list_agents(
|
||||
self, tags: Optional[List[str]] = None, query_text: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None
|
||||
) -> List[AgentState]:
|
||||
params = {"limit": limit}
|
||||
if tags:
|
||||
params["tags"] = tags
|
||||
params["match_all_tags"] = False
|
||||
|
||||
if query_text:
|
||||
params["query_text"] = query_text
|
||||
|
||||
if cursor:
|
||||
params["cursor"] = cursor
|
||||
|
||||
response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params)
|
||||
print(f"\nLIST RESPONSE\n{response.json()}\n")
|
||||
return [AgentState(**agent) for agent in response.json()]
|
||||
|
||||
def agent_exists(self, agent_id: str) -> bool:
|
||||
@@ -2210,10 +2219,12 @@ class LocalClient(AbstractClient):
|
||||
self.organization = self.server.get_organization_or_default(self.org_id)
|
||||
|
||||
# agents
|
||||
def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]:
|
||||
def list_agents(
|
||||
self, query_text: Optional[str] = None, tags: Optional[List[str]] = None, limit: int = 100, cursor: Optional[str] = None
|
||||
) -> List[AgentState]:
|
||||
self.interface.clear()
|
||||
|
||||
return self.server.agent_manager.list_agents(actor=self.user, tags=tags)
|
||||
return self.server.agent_manager.list_agents(actor=self.user, tags=tags, query_text=query_text, limit=limit, cursor=cursor)
|
||||
|
||||
def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool:
|
||||
"""
|
||||
|
||||
@@ -163,7 +163,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base):
|
||||
|
||||
# Text search
|
||||
if query_text:
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
if hasattr(cls, "text"):
|
||||
query = query.filter(func.lower(cls.text).contains(func.lower(query_text)))
|
||||
elif hasattr(cls, "name"):
|
||||
# Special case for Agent model - search across name
|
||||
query = query.filter(func.lower(cls.name).contains(func.lower(query_text)))
|
||||
|
||||
# Embedding search (for Passages)
|
||||
is_ordered = False
|
||||
|
||||
@@ -47,9 +47,9 @@ def list_agents(
|
||||
),
|
||||
server: "SyncServer" = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
cursor: Optional[int] = Query(None, description="Cursor for pagination"),
|
||||
cursor: Optional[str] = Query(None, description="Cursor for pagination"),
|
||||
limit: Optional[int] = Query(None, description="Limit for pagination"),
|
||||
# Extract user_id from header, default to None if not present
|
||||
query_text: Optional[str] = Query(None, description="Search agents by name"),
|
||||
):
|
||||
"""
|
||||
List all agents associated with a given user.
|
||||
@@ -64,6 +64,7 @@ def list_agents(
|
||||
"tags": tags,
|
||||
"match_all_tags": match_all_tags,
|
||||
"name": name,
|
||||
"query_text": query_text,
|
||||
}.items()
|
||||
if value is not None
|
||||
}
|
||||
|
||||
@@ -268,6 +268,7 @@ class AgentManager:
|
||||
match_all_tags: bool = False,
|
||||
cursor: Optional[str] = None,
|
||||
limit: Optional[int] = 50,
|
||||
query_text: Optional[str] = None,
|
||||
**kwargs,
|
||||
) -> List[PydanticAgentState]:
|
||||
"""
|
||||
@@ -281,6 +282,7 @@ class AgentManager:
|
||||
cursor=cursor,
|
||||
limit=limit,
|
||||
organization_id=actor.organization_id if actor else None,
|
||||
query_text=query_text,
|
||||
**kwargs,
|
||||
)
|
||||
|
||||
|
||||
@@ -80,6 +80,28 @@ def agent(client: Union[LocalClient, RESTClient]):
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture
|
||||
def search_agent_one(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name="Search Agent One")
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
# Fixture for test agent
|
||||
@pytest.fixture
|
||||
def search_agent_two(client: Union[LocalClient, RESTClient]):
|
||||
agent_state = client.create_agent(name="Search Agent Two")
|
||||
|
||||
yield agent_state
|
||||
|
||||
# delete agent
|
||||
client.delete_agent(agent_state.id)
|
||||
|
||||
|
||||
@pytest.fixture(autouse=True)
|
||||
def clear_tables():
|
||||
"""Clear the sandbox tables before each test."""
|
||||
@@ -560,17 +582,50 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent
|
||||
assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens
|
||||
|
||||
|
||||
# ==========================================
|
||||
# TESTS FOR AGENT LISTING
|
||||
# ==========================================
|
||||
|
||||
|
||||
def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two):
|
||||
"""Test listing agents with pagination and query text filtering."""
|
||||
# Test query text filtering
|
||||
search_results = client.list_agents(query_text="search agent")
|
||||
assert len(search_results) == 2
|
||||
search_agent_ids = {agent.id for agent in search_results}
|
||||
assert search_agent_one.id in search_agent_ids
|
||||
assert search_agent_two.id in search_agent_ids
|
||||
assert agent.id not in search_agent_ids
|
||||
|
||||
different_results = client.list_agents(query_text="client")
|
||||
assert len(different_results) == 1
|
||||
assert different_results[0].id == agent.id
|
||||
|
||||
# Test pagination
|
||||
first_page = client.list_agents(query_text="search agent", limit=1)
|
||||
assert len(first_page) == 1
|
||||
first_agent = first_page[0]
|
||||
|
||||
second_page = client.list_agents(query_text="search agent", cursor=first_agent.id, limit=1) # Use agent ID as cursor
|
||||
assert len(second_page) == 1
|
||||
assert second_page[0].id != first_agent.id
|
||||
|
||||
# Verify we got both search agents with no duplicates
|
||||
all_ids = {first_page[0].id, second_page[0].id}
|
||||
assert len(all_ids) == 2
|
||||
assert all_ids == {search_agent_one.id, search_agent_two.id}
|
||||
|
||||
# Test listing without any filters
|
||||
all_agents = client.list_agents()
|
||||
assert len(all_agents) == 3
|
||||
assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent])
|
||||
|
||||
|
||||
def test_agent_creation(client: Union[LocalClient, RESTClient]):
|
||||
"""Test that block IDs are properly attached when creating an agent."""
|
||||
if not isinstance(client, RESTClient):
|
||||
pytest.skip("This test only runs when the server is enabled")
|
||||
|
||||
offline_memory_agent_system = """
|
||||
You are a helpful agent. You will be provided with a list of memory blocks and a user preferences block.
|
||||
You should use the memory blocks to remember information about the user and their preferences.
|
||||
You should also use the user preferences block to remember information about the user's preferences.
|
||||
"""
|
||||
|
||||
from letta import BasicBlockMemory
|
||||
|
||||
# Create a test block that will represent user preferences
|
||||
@@ -623,3 +678,5 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]):
|
||||
assert len(agent_tools) == 2
|
||||
tool_ids = {tool1.id, tool2.id}
|
||||
assert all(tool.id in tool_ids for tool in agent_tools)
|
||||
|
||||
client.delete_agent(agent_id=agent.id)
|
||||
|
||||
@@ -914,6 +914,70 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul
|
||||
assert agent2.id in all_ids
|
||||
|
||||
|
||||
def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization):
|
||||
"""Test listing agents with query text filtering and pagination."""
|
||||
# Create test agents with specific names and descriptions
|
||||
agent1 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="Search Agent One",
|
||||
memory_blocks=[],
|
||||
description="This is a search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent2 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="Search Agent Two",
|
||||
memory_blocks=[],
|
||||
description="Another search agent for testing",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
agent3 = server.agent_manager.create_agent(
|
||||
agent_create=CreateAgent(
|
||||
name="Different Agent",
|
||||
memory_blocks=[],
|
||||
description="This is a different agent",
|
||||
llm_config=LLMConfig.default_config("gpt-4"),
|
||||
embedding_config=EmbeddingConfig.default_config(provider="openai"),
|
||||
),
|
||||
actor=default_user,
|
||||
)
|
||||
|
||||
# Test query text filtering
|
||||
search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent")
|
||||
assert len(search_results) == 2
|
||||
search_agent_ids = {agent.id for agent in search_results}
|
||||
assert agent1.id in search_agent_ids
|
||||
assert agent2.id in search_agent_ids
|
||||
assert agent3.id not in search_agent_ids
|
||||
|
||||
different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent")
|
||||
assert len(different_results) == 1
|
||||
assert different_results[0].id == agent3.id
|
||||
|
||||
# Test pagination with query text
|
||||
first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1)
|
||||
assert len(first_page) == 1
|
||||
first_agent_id = first_page[0].id
|
||||
|
||||
# Get second page using cursor
|
||||
second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", cursor=first_agent_id, limit=1)
|
||||
assert len(second_page) == 1
|
||||
assert second_page[0].id != first_agent_id
|
||||
|
||||
# Verify we got both search agents with no duplicates
|
||||
all_ids = {first_page[0].id, second_page[0].id}
|
||||
assert len(all_ids) == 2
|
||||
assert all_ids == {agent1.id, agent2.id}
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# AgentManager Tests - Messages Relationship
|
||||
# ======================================================================================================================
|
||||
|
||||
Reference in New Issue
Block a user