diff --git a/letta/client/client.py b/letta/client/client.py index 8e9d24bb..ec62bcf9 100644 --- a/letta/client/client.py +++ b/letta/client/client.py @@ -436,13 +436,22 @@ class RESTClient(AbstractClient): self._default_llm_config = default_llm_config self._default_embedding_config = default_embedding_config - def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]: - params = {} + def list_agents( + self, tags: Optional[List[str]] = None, query_text: Optional[str] = None, limit: int = 50, cursor: Optional[str] = None + ) -> List[AgentState]: + params = {"limit": limit} if tags: params["tags"] = tags params["match_all_tags"] = False + if query_text: + params["query_text"] = query_text + + if cursor: + params["cursor"] = cursor + response = requests.get(f"{self.base_url}/{self.api_prefix}/agents", headers=self.headers, params=params) + print(f"\nLIST RESPONSE\n{response.json()}\n") return [AgentState(**agent) for agent in response.json()] def agent_exists(self, agent_id: str) -> bool: @@ -2210,10 +2219,12 @@ class LocalClient(AbstractClient): self.organization = self.server.get_organization_or_default(self.org_id) # agents - def list_agents(self, tags: Optional[List[str]] = None) -> List[AgentState]: + def list_agents( + self, query_text: Optional[str] = None, tags: Optional[List[str]] = None, limit: int = 100, cursor: Optional[str] = None + ) -> List[AgentState]: self.interface.clear() - return self.server.agent_manager.list_agents(actor=self.user, tags=tags) + return self.server.agent_manager.list_agents(actor=self.user, tags=tags, query_text=query_text, limit=limit, cursor=cursor) def agent_exists(self, agent_id: Optional[str] = None, agent_name: Optional[str] = None) -> bool: """ diff --git a/letta/orm/sqlalchemy_base.py b/letta/orm/sqlalchemy_base.py index 9dfee6d3..05ada679 100644 --- a/letta/orm/sqlalchemy_base.py +++ b/letta/orm/sqlalchemy_base.py @@ -163,7 +163,11 @@ class SqlalchemyBase(CommonSqlalchemyMetaMixins, Base): # Text search if query_text: - query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) + if hasattr(cls, "text"): + query = query.filter(func.lower(cls.text).contains(func.lower(query_text))) + elif hasattr(cls, "name"): + # Special case for Agent model - search across name + query = query.filter(func.lower(cls.name).contains(func.lower(query_text))) # Embedding search (for Passages) is_ordered = False diff --git a/letta/server/rest_api/routers/v1/agents.py b/letta/server/rest_api/routers/v1/agents.py index 9a7d6dae..558f59fc 100644 --- a/letta/server/rest_api/routers/v1/agents.py +++ b/letta/server/rest_api/routers/v1/agents.py @@ -47,9 +47,9 @@ def list_agents( ), server: "SyncServer" = Depends(get_letta_server), user_id: Optional[str] = Header(None, alias="user_id"), - cursor: Optional[int] = Query(None, description="Cursor for pagination"), + cursor: Optional[str] = Query(None, description="Cursor for pagination"), limit: Optional[int] = Query(None, description="Limit for pagination"), - # Extract user_id from header, default to None if not present + query_text: Optional[str] = Query(None, description="Search agents by name"), ): """ List all agents associated with a given user. @@ -64,6 +64,7 @@ def list_agents( "tags": tags, "match_all_tags": match_all_tags, "name": name, + "query_text": query_text, }.items() if value is not None } diff --git a/letta/services/agent_manager.py b/letta/services/agent_manager.py index a9c67d43..3ba29ba9 100644 --- a/letta/services/agent_manager.py +++ b/letta/services/agent_manager.py @@ -268,6 +268,7 @@ class AgentManager: match_all_tags: bool = False, cursor: Optional[str] = None, limit: Optional[int] = 50, + query_text: Optional[str] = None, **kwargs, ) -> List[PydanticAgentState]: """ @@ -281,6 +282,7 @@ class AgentManager: cursor=cursor, limit=limit, organization_id=actor.organization_id if actor else None, + query_text=query_text, **kwargs, ) diff --git a/tests/test_client.py b/tests/test_client.py index 00c7d495..8aeef23a 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -80,6 +80,28 @@ def agent(client: Union[LocalClient, RESTClient]): client.delete_agent(agent_state.id) +# Fixture for test agent +@pytest.fixture +def search_agent_one(client: Union[LocalClient, RESTClient]): + agent_state = client.create_agent(name="Search Agent One") + + yield agent_state + + # delete agent + client.delete_agent(agent_state.id) + + +# Fixture for test agent +@pytest.fixture +def search_agent_two(client: Union[LocalClient, RESTClient]): + agent_state = client.create_agent(name="Search Agent Two") + + yield agent_state + + # delete agent + client.delete_agent(agent_state.id) + + @pytest.fixture(autouse=True) def clear_tables(): """Clear the sandbox tables before each test.""" @@ -560,17 +582,50 @@ def test_send_message_async(client: Union[LocalClient, RESTClient], agent: Agent assert usage.total_tokens == usage.completion_tokens + usage.prompt_tokens +# ========================================== +# TESTS FOR AGENT LISTING +# ========================================== + + +def test_agent_listing(client: Union[LocalClient, RESTClient], agent, search_agent_one, search_agent_two): + """Test listing agents with pagination and query text filtering.""" + # Test query text filtering + search_results = client.list_agents(query_text="search agent") + assert len(search_results) == 2 + search_agent_ids = {agent.id for agent in search_results} + assert search_agent_one.id in search_agent_ids + assert search_agent_two.id in search_agent_ids + assert agent.id not in search_agent_ids + + different_results = client.list_agents(query_text="client") + assert len(different_results) == 1 + assert different_results[0].id == agent.id + + # Test pagination + first_page = client.list_agents(query_text="search agent", limit=1) + assert len(first_page) == 1 + first_agent = first_page[0] + + second_page = client.list_agents(query_text="search agent", cursor=first_agent.id, limit=1) # Use agent ID as cursor + assert len(second_page) == 1 + assert second_page[0].id != first_agent.id + + # Verify we got both search agents with no duplicates + all_ids = {first_page[0].id, second_page[0].id} + assert len(all_ids) == 2 + assert all_ids == {search_agent_one.id, search_agent_two.id} + + # Test listing without any filters + all_agents = client.list_agents() + assert len(all_agents) == 3 + assert all(agent.id in {a.id for a in all_agents} for agent in [search_agent_one, search_agent_two, agent]) + + def test_agent_creation(client: Union[LocalClient, RESTClient]): """Test that block IDs are properly attached when creating an agent.""" if not isinstance(client, RESTClient): pytest.skip("This test only runs when the server is enabled") - offline_memory_agent_system = """ - You are a helpful agent. You will be provided with a list of memory blocks and a user preferences block. - You should use the memory blocks to remember information about the user and their preferences. - You should also use the user preferences block to remember information about the user's preferences. - """ - from letta import BasicBlockMemory # Create a test block that will represent user preferences @@ -623,3 +678,5 @@ def test_agent_creation(client: Union[LocalClient, RESTClient]): assert len(agent_tools) == 2 tool_ids = {tool1.id, tool2.id} assert all(tool.id in tool_ids for tool in agent_tools) + + client.delete_agent(agent_id=agent.id) diff --git a/tests/test_managers.py b/tests/test_managers.py index 734db4cb..08e2b1e3 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -914,6 +914,70 @@ def test_list_agents_by_tags_pagination(server: SyncServer, default_user, defaul assert agent2.id in all_ids +def test_list_agents_query_text_pagination(server: SyncServer, default_user, default_organization): + """Test listing agents with query text filtering and pagination.""" + # Create test agents with specific names and descriptions + agent1 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="Search Agent One", + memory_blocks=[], + description="This is a search agent for testing", + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + + agent2 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="Search Agent Two", + memory_blocks=[], + description="Another search agent for testing", + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + + agent3 = server.agent_manager.create_agent( + agent_create=CreateAgent( + name="Different Agent", + memory_blocks=[], + description="This is a different agent", + llm_config=LLMConfig.default_config("gpt-4"), + embedding_config=EmbeddingConfig.default_config(provider="openai"), + ), + actor=default_user, + ) + + # Test query text filtering + search_results = server.agent_manager.list_agents(actor=default_user, query_text="search agent") + assert len(search_results) == 2 + search_agent_ids = {agent.id for agent in search_results} + assert agent1.id in search_agent_ids + assert agent2.id in search_agent_ids + assert agent3.id not in search_agent_ids + + different_results = server.agent_manager.list_agents(actor=default_user, query_text="different agent") + assert len(different_results) == 1 + assert different_results[0].id == agent3.id + + # Test pagination with query text + first_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", limit=1) + assert len(first_page) == 1 + first_agent_id = first_page[0].id + + # Get second page using cursor + second_page = server.agent_manager.list_agents(actor=default_user, query_text="search agent", cursor=first_agent_id, limit=1) + assert len(second_page) == 1 + assert second_page[0].id != first_agent_id + + # Verify we got both search agents with no duplicates + all_ids = {first_page[0].id, second_page[0].id} + assert len(all_ids) == 2 + assert all_ids == {agent1.id, agent2.id} + + # ====================================================================================================================== # AgentManager Tests - Messages Relationship # ======================================================================================================================