From 2cb12b821bdc50faef2437e08ed33e0bbd9e3fbb Mon Sep 17 00:00:00 2001 From: Matthew Zhou Date: Thu, 23 Jan 2025 15:28:08 -1000 Subject: [PATCH] feat: Add list_agents_for_block endpoint (#759) --- letta/server/rest_api/routers/v1/blocks.py | 19 +++ letta/services/block_manager.py | 13 +++ tests/test_managers.py | 20 ++++ tests/test_v1_routes.py | 130 +++++++++++++++++++++ 4 files changed, 182 insertions(+) diff --git a/letta/server/rest_api/routers/v1/blocks.py b/letta/server/rest_api/routers/v1/blocks.py index 2d261f39..8c5297d0 100644 --- a/letta/server/rest_api/routers/v1/blocks.py +++ b/letta/server/rest_api/routers/v1/blocks.py @@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query from letta.orm.errors import NoResultFound +from letta.schemas.agent import AgentState from letta.schemas.block import Block, BlockUpdate, CreateBlock from letta.server.rest_api.utils import get_letta_server from letta.server.server import SyncServer @@ -73,3 +74,21 @@ def retrieve_block( return block except NoResultFound: raise HTTPException(status_code=404, detail="Block not found") + + +@router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block") +def list_agents_for_block( + block_id: str, + server: SyncServer = Depends(get_letta_server), + user_id: Optional[str] = Header(None, alias="user_id"), +): + """ + Retrieves all agents associated with the specified block. + Raises a 404 if the block does not exist. + """ + actor = server.user_manager.get_user_or_default(user_id=user_id) + try: + agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor) + return agents + except NoResultFound: + raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found") diff --git a/letta/services/block_manager.py b/letta/services/block_manager.py index 2d25093f..41275e1e 100644 --- a/letta/services/block_manager.py +++ b/letta/services/block_manager.py @@ -3,6 +3,7 @@ from typing import List, Optional from letta.orm.block import Block as BlockModel from letta.orm.errors import NoResultFound +from letta.schemas.agent import AgentState as PydanticAgentState from letta.schemas.block import Block from letta.schemas.block import Block as PydanticBlock from letta.schemas.block import BlockUpdate, Human, Persona @@ -114,3 +115,15 @@ class BlockManager: text = open(human_file, "r", encoding="utf-8").read() name = os.path.basename(human_file).replace(".txt", "") self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor) + + @enforce_types + def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]: + """ + Retrieve all agents associated with a given block. + """ + with self.session_maker() as session: + block = BlockModel.read(db_session=session, identifier=block_id, actor=actor) + agents_orm = block.agents + agents_pydantic = [agent.to_pydantic() for agent in agents_orm] + + return agents_pydantic diff --git a/tests/test_managers.py b/tests/test_managers.py index fe6280d5..3b43f0f7 100644 --- a/tests/test_managers.py +++ b/tests/test_managers.py @@ -2104,6 +2104,26 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau assert not (block.id in [b.id for b in agent_state.memory.blocks]) +def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user): + # Create and delete a block + block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user) + sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user) + charles_agent = server.agent_manager.attach_block(agent_id=charles_agent.id, block_id=block.id, actor=default_user) + + # Check that block has been attached to both + assert block.id in [b.id for b in sarah_agent.memory.blocks] + assert block.id in [b.id for b in charles_agent.memory.blocks] + + # Get the agents for that block + agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user) + assert len(agent_states) == 2 + + # Check both agents are in the list + agent_state_ids = [a.id for a in agent_states] + assert sarah_agent.id in agent_state_ids + assert charles_agent.id in agent_state_ids + + # ====================================================================================================================== # SourceManager Tests - Sources # ====================================================================================================================== diff --git a/tests/test_v1_routes.py b/tests/test_v1_routes.py index f8ddde71..647f4517 100644 --- a/tests/test_v1_routes.py +++ b/tests/test_v1_routes.py @@ -6,6 +6,7 @@ from composio.client.collections import ActionModel, ActionParametersModel, Acti from fastapi.testclient import TestClient from letta.orm.errors import NoResultFound +from letta.schemas.block import Block, BlockUpdate, CreateBlock from letta.schemas.message import UserMessage from letta.schemas.tool import ToolCreate, ToolUpdate from letta.server.rest_api.app import app @@ -461,3 +462,132 @@ def test_get_tags_with_search(client, mock_sync_server): mock_sync_server.agent_manager.list_tags.assert_called_once_with( actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user" ) + + +# ====================================================================================================================== +# Blocks Routes Tests +# ====================================================================================================================== + + +def test_list_blocks(client, mock_sync_server): + """ + Test the GET /v1/blocks endpoint to list blocks. + """ + # Arrange: mock return from block_manager + mock_block = Block(label="human", value="Hi", is_template=True) + mock_sync_server.block_manager.get_blocks.return_value = [mock_block] + + # Act + response = client.get("/v1/blocks", headers={"user_id": "test_user"}) + + # Assert + assert response.status_code == 200 + data = response.json() + assert len(data) == 1 + assert data[0]["id"] == mock_block.id + mock_sync_server.block_manager.get_blocks.assert_called_once_with( + actor=mock_sync_server.user_manager.get_user_or_default.return_value, + label=None, + is_template=True, + template_name=None, + ) + + +def test_create_block(client, mock_sync_server): + """ + Test the POST /v1/blocks endpoint to create a block. + """ + new_block = CreateBlock(label="system", value="Some system text") + returned_block = Block(**new_block.model_dump()) + + mock_sync_server.block_manager.create_or_update_block.return_value = returned_block + + response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"}) + assert response.status_code == 200 + data = response.json() + assert data["id"] == returned_block.id + + mock_sync_server.block_manager.create_or_update_block.assert_called_once() + + +def test_modify_block(client, mock_sync_server): + """ + Test the PATCH /v1/blocks/{block_id} endpoint to update a block. + """ + block_update = BlockUpdate(value="Updated text", description="New description") + updated_block = Block(label="human", value="Updated text", description="New description") + mock_sync_server.block_manager.update_block.return_value = updated_block + + response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"}) + assert response.status_code == 200 + data = response.json() + assert data["value"] == "Updated text" + assert data["description"] == "New description" + + mock_sync_server.block_manager.update_block.assert_called_once_with( + block_id=updated_block.id, + block_update=block_update, + actor=mock_sync_server.user_manager.get_user_or_default.return_value, + ) + + +def test_delete_block(client, mock_sync_server): + """ + Test the DELETE /v1/blocks/{block_id} endpoint. + """ + deleted_block = Block(label="persona", value="Deleted text") + mock_sync_server.block_manager.delete_block.return_value = deleted_block + + response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"}) + assert response.status_code == 200 + data = response.json() + assert data["id"] == deleted_block.id + + mock_sync_server.block_manager.delete_block.assert_called_once_with( + block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value + ) + + +def test_retrieve_block(client, mock_sync_server): + """ + Test the GET /v1/blocks/{block_id} endpoint. + """ + existing_block = Block(label="human", value="Hello") + mock_sync_server.block_manager.get_block_by_id.return_value = existing_block + + response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"}) + assert response.status_code == 200 + data = response.json() + assert data["id"] == existing_block.id + + mock_sync_server.block_manager.get_block_by_id.assert_called_once_with( + block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value + ) + + +def test_retrieve_block_404(client, mock_sync_server): + """ + Test that retrieving a non-existent block returns 404. + """ + mock_sync_server.block_manager.get_block_by_id.return_value = None + + response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"}) + assert response.status_code == 404 + assert "Block not found" in response.json()["detail"] + + +def test_list_agents_for_block(client, mock_sync_server): + """ + Test the GET /v1/blocks/{block_id}/agents endpoint. + """ + mock_sync_server.block_manager.get_agents_for_block.return_value = [] + + response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"}) + assert response.status_code == 200 + data = response.json() + assert len(data) == 0 + + mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with( + block_id="block-abc", + actor=mock_sync_server.user_manager.get_user_or_default.return_value, + )