feat: Add list_agents_for_block endpoint (#759)

This commit is contained in:
Matthew Zhou
2025-01-23 15:28:08 -10:00
committed by GitHub
parent 2233535335
commit 2cb12b821b
4 changed files with 182 additions and 0 deletions

View File

@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState
from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.server.rest_api.utils import get_letta_server
from letta.server.server import SyncServer
@@ -73,3 +74,21 @@ def retrieve_block(
return block
except NoResultFound:
raise HTTPException(status_code=404, detail="Block not found")
@router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block")
def list_agents_for_block(
block_id: str,
server: SyncServer = Depends(get_letta_server),
user_id: Optional[str] = Header(None, alias="user_id"),
):
"""
Retrieves all agents associated with the specified block.
Raises a 404 if the block does not exist.
"""
actor = server.user_manager.get_user_or_default(user_id=user_id)
try:
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
return agents
except NoResultFound:
raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found")

View File

@@ -3,6 +3,7 @@ from typing import List, Optional
from letta.orm.block import Block as BlockModel
from letta.orm.errors import NoResultFound
from letta.schemas.agent import AgentState as PydanticAgentState
from letta.schemas.block import Block
from letta.schemas.block import Block as PydanticBlock
from letta.schemas.block import BlockUpdate, Human, Persona
@@ -114,3 +115,15 @@ class BlockManager:
text = open(human_file, "r", encoding="utf-8").read()
name = os.path.basename(human_file).replace(".txt", "")
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)
@enforce_types
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
"""
Retrieve all agents associated with a given block.
"""
with self.session_maker() as session:
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
agents_orm = block.agents
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
return agents_pydantic

View File

@@ -2104,6 +2104,26 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau
assert not (block.id in [b.id for b in agent_state.memory.blocks])
def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user):
# Create and delete a block
block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user)
sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
charles_agent = server.agent_manager.attach_block(agent_id=charles_agent.id, block_id=block.id, actor=default_user)
# Check that block has been attached to both
assert block.id in [b.id for b in sarah_agent.memory.blocks]
assert block.id in [b.id for b in charles_agent.memory.blocks]
# Get the agents for that block
agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user)
assert len(agent_states) == 2
# Check both agents are in the list
agent_state_ids = [a.id for a in agent_states]
assert sarah_agent.id in agent_state_ids
assert charles_agent.id in agent_state_ids
# ======================================================================================================================
# SourceManager Tests - Sources
# ======================================================================================================================

View File

@@ -6,6 +6,7 @@ from composio.client.collections import ActionModel, ActionParametersModel, Acti
from fastapi.testclient import TestClient
from letta.orm.errors import NoResultFound
from letta.schemas.block import Block, BlockUpdate, CreateBlock
from letta.schemas.message import UserMessage
from letta.schemas.tool import ToolCreate, ToolUpdate
from letta.server.rest_api.app import app
@@ -461,3 +462,132 @@ def test_get_tags_with_search(client, mock_sync_server):
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user"
)
# ======================================================================================================================
# Blocks Routes Tests
# ======================================================================================================================
def test_list_blocks(client, mock_sync_server):
"""
Test the GET /v1/blocks endpoint to list blocks.
"""
# Arrange: mock return from block_manager
mock_block = Block(label="human", value="Hi", is_template=True)
mock_sync_server.block_manager.get_blocks.return_value = [mock_block]
# Act
response = client.get("/v1/blocks", headers={"user_id": "test_user"})
# Assert
assert response.status_code == 200
data = response.json()
assert len(data) == 1
assert data[0]["id"] == mock_block.id
mock_sync_server.block_manager.get_blocks.assert_called_once_with(
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
label=None,
is_template=True,
template_name=None,
)
def test_create_block(client, mock_sync_server):
"""
Test the POST /v1/blocks endpoint to create a block.
"""
new_block = CreateBlock(label="system", value="Some system text")
returned_block = Block(**new_block.model_dump())
mock_sync_server.block_manager.create_or_update_block.return_value = returned_block
response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == returned_block.id
mock_sync_server.block_manager.create_or_update_block.assert_called_once()
def test_modify_block(client, mock_sync_server):
"""
Test the PATCH /v1/blocks/{block_id} endpoint to update a block.
"""
block_update = BlockUpdate(value="Updated text", description="New description")
updated_block = Block(label="human", value="Updated text", description="New description")
mock_sync_server.block_manager.update_block.return_value = updated_block
response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["value"] == "Updated text"
assert data["description"] == "New description"
mock_sync_server.block_manager.update_block.assert_called_once_with(
block_id=updated_block.id,
block_update=block_update,
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
)
def test_delete_block(client, mock_sync_server):
"""
Test the DELETE /v1/blocks/{block_id} endpoint.
"""
deleted_block = Block(label="persona", value="Deleted text")
mock_sync_server.block_manager.delete_block.return_value = deleted_block
response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == deleted_block.id
mock_sync_server.block_manager.delete_block.assert_called_once_with(
block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_retrieve_block(client, mock_sync_server):
"""
Test the GET /v1/blocks/{block_id} endpoint.
"""
existing_block = Block(label="human", value="Hello")
mock_sync_server.block_manager.get_block_by_id.return_value = existing_block
response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert data["id"] == existing_block.id
mock_sync_server.block_manager.get_block_by_id.assert_called_once_with(
block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
)
def test_retrieve_block_404(client, mock_sync_server):
"""
Test that retrieving a non-existent block returns 404.
"""
mock_sync_server.block_manager.get_block_by_id.return_value = None
response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"})
assert response.status_code == 404
assert "Block not found" in response.json()["detail"]
def test_list_agents_for_block(client, mock_sync_server):
"""
Test the GET /v1/blocks/{block_id}/agents endpoint.
"""
mock_sync_server.block_manager.get_agents_for_block.return_value = []
response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"})
assert response.status_code == 200
data = response.json()
assert len(data) == 0
mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with(
block_id="block-abc",
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
)