feat: Add list_agents_for_block endpoint (#759)
This commit is contained in:
@@ -3,6 +3,7 @@ from typing import TYPE_CHECKING, List, Optional
|
||||
from fastapi import APIRouter, Body, Depends, Header, HTTPException, Query
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
@@ -73,3 +74,21 @@ def retrieve_block(
|
||||
return block
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail="Block not found")
|
||||
|
||||
|
||||
@router.get("/{block_id}/agents", response_model=List[AgentState], operation_id="list_agents_for_block")
|
||||
def list_agents_for_block(
|
||||
block_id: str,
|
||||
server: SyncServer = Depends(get_letta_server),
|
||||
user_id: Optional[str] = Header(None, alias="user_id"),
|
||||
):
|
||||
"""
|
||||
Retrieves all agents associated with the specified block.
|
||||
Raises a 404 if the block does not exist.
|
||||
"""
|
||||
actor = server.user_manager.get_user_or_default(user_id=user_id)
|
||||
try:
|
||||
agents = server.block_manager.get_agents_for_block(block_id=block_id, actor=actor)
|
||||
return agents
|
||||
except NoResultFound:
|
||||
raise HTTPException(status_code=404, detail=f"Block with id={block_id} not found")
|
||||
|
||||
@@ -3,6 +3,7 @@ from typing import List, Optional
|
||||
|
||||
from letta.orm.block import Block as BlockModel
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.agent import AgentState as PydanticAgentState
|
||||
from letta.schemas.block import Block
|
||||
from letta.schemas.block import Block as PydanticBlock
|
||||
from letta.schemas.block import BlockUpdate, Human, Persona
|
||||
@@ -114,3 +115,15 @@ class BlockManager:
|
||||
text = open(human_file, "r", encoding="utf-8").read()
|
||||
name = os.path.basename(human_file).replace(".txt", "")
|
||||
self.create_or_update_block(Human(template_name=name, value=text, is_template=True), actor=actor)
|
||||
|
||||
@enforce_types
|
||||
def get_agents_for_block(self, block_id: str, actor: PydanticUser) -> List[PydanticAgentState]:
|
||||
"""
|
||||
Retrieve all agents associated with a given block.
|
||||
"""
|
||||
with self.session_maker() as session:
|
||||
block = BlockModel.read(db_session=session, identifier=block_id, actor=actor)
|
||||
agents_orm = block.agents
|
||||
agents_pydantic = [agent.to_pydantic() for agent in agents_orm]
|
||||
|
||||
return agents_pydantic
|
||||
|
||||
@@ -2104,6 +2104,26 @@ def test_delete_block_detaches_from_agent(server: SyncServer, sarah_agent, defau
|
||||
assert not (block.id in [b.id for b in agent_state.memory.blocks])
|
||||
|
||||
|
||||
def test_get_agents_for_block(server: SyncServer, sarah_agent, charles_agent, default_user):
|
||||
# Create and delete a block
|
||||
block = server.block_manager.create_or_update_block(PydanticBlock(label="alien", value="Sample content"), actor=default_user)
|
||||
sarah_agent = server.agent_manager.attach_block(agent_id=sarah_agent.id, block_id=block.id, actor=default_user)
|
||||
charles_agent = server.agent_manager.attach_block(agent_id=charles_agent.id, block_id=block.id, actor=default_user)
|
||||
|
||||
# Check that block has been attached to both
|
||||
assert block.id in [b.id for b in sarah_agent.memory.blocks]
|
||||
assert block.id in [b.id for b in charles_agent.memory.blocks]
|
||||
|
||||
# Get the agents for that block
|
||||
agent_states = server.block_manager.get_agents_for_block(block_id=block.id, actor=default_user)
|
||||
assert len(agent_states) == 2
|
||||
|
||||
# Check both agents are in the list
|
||||
agent_state_ids = [a.id for a in agent_states]
|
||||
assert sarah_agent.id in agent_state_ids
|
||||
assert charles_agent.id in agent_state_ids
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# SourceManager Tests - Sources
|
||||
# ======================================================================================================================
|
||||
|
||||
@@ -6,6 +6,7 @@ from composio.client.collections import ActionModel, ActionParametersModel, Acti
|
||||
from fastapi.testclient import TestClient
|
||||
|
||||
from letta.orm.errors import NoResultFound
|
||||
from letta.schemas.block import Block, BlockUpdate, CreateBlock
|
||||
from letta.schemas.message import UserMessage
|
||||
from letta.schemas.tool import ToolCreate, ToolUpdate
|
||||
from letta.server.rest_api.app import app
|
||||
@@ -461,3 +462,132 @@ def test_get_tags_with_search(client, mock_sync_server):
|
||||
mock_sync_server.agent_manager.list_tags.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value, after=None, limit=50, query_text="user"
|
||||
)
|
||||
|
||||
|
||||
# ======================================================================================================================
|
||||
# Blocks Routes Tests
|
||||
# ======================================================================================================================
|
||||
|
||||
|
||||
def test_list_blocks(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks endpoint to list blocks.
|
||||
"""
|
||||
# Arrange: mock return from block_manager
|
||||
mock_block = Block(label="human", value="Hi", is_template=True)
|
||||
mock_sync_server.block_manager.get_blocks.return_value = [mock_block]
|
||||
|
||||
# Act
|
||||
response = client.get("/v1/blocks", headers={"user_id": "test_user"})
|
||||
|
||||
# Assert
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 1
|
||||
assert data[0]["id"] == mock_block.id
|
||||
mock_sync_server.block_manager.get_blocks.assert_called_once_with(
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
label=None,
|
||||
is_template=True,
|
||||
template_name=None,
|
||||
)
|
||||
|
||||
|
||||
def test_create_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the POST /v1/blocks endpoint to create a block.
|
||||
"""
|
||||
new_block = CreateBlock(label="system", value="Some system text")
|
||||
returned_block = Block(**new_block.model_dump())
|
||||
|
||||
mock_sync_server.block_manager.create_or_update_block.return_value = returned_block
|
||||
|
||||
response = client.post("/v1/blocks", json=new_block.model_dump(), headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == returned_block.id
|
||||
|
||||
mock_sync_server.block_manager.create_or_update_block.assert_called_once()
|
||||
|
||||
|
||||
def test_modify_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the PATCH /v1/blocks/{block_id} endpoint to update a block.
|
||||
"""
|
||||
block_update = BlockUpdate(value="Updated text", description="New description")
|
||||
updated_block = Block(label="human", value="Updated text", description="New description")
|
||||
mock_sync_server.block_manager.update_block.return_value = updated_block
|
||||
|
||||
response = client.patch(f"/v1/blocks/{updated_block.id}", json=block_update.model_dump(), headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["value"] == "Updated text"
|
||||
assert data["description"] == "New description"
|
||||
|
||||
mock_sync_server.block_manager.update_block.assert_called_once_with(
|
||||
block_id=updated_block.id,
|
||||
block_update=block_update,
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
)
|
||||
|
||||
|
||||
def test_delete_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the DELETE /v1/blocks/{block_id} endpoint.
|
||||
"""
|
||||
deleted_block = Block(label="persona", value="Deleted text")
|
||||
mock_sync_server.block_manager.delete_block.return_value = deleted_block
|
||||
|
||||
response = client.delete(f"/v1/blocks/{deleted_block.id}", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == deleted_block.id
|
||||
|
||||
mock_sync_server.block_manager.delete_block.assert_called_once_with(
|
||||
block_id=deleted_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks/{block_id} endpoint.
|
||||
"""
|
||||
existing_block = Block(label="human", value="Hello")
|
||||
mock_sync_server.block_manager.get_block_by_id.return_value = existing_block
|
||||
|
||||
response = client.get(f"/v1/blocks/{existing_block.id}", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert data["id"] == existing_block.id
|
||||
|
||||
mock_sync_server.block_manager.get_block_by_id.assert_called_once_with(
|
||||
block_id=existing_block.id, actor=mock_sync_server.user_manager.get_user_or_default.return_value
|
||||
)
|
||||
|
||||
|
||||
def test_retrieve_block_404(client, mock_sync_server):
|
||||
"""
|
||||
Test that retrieving a non-existent block returns 404.
|
||||
"""
|
||||
mock_sync_server.block_manager.get_block_by_id.return_value = None
|
||||
|
||||
response = client.get("/v1/blocks/block-999", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 404
|
||||
assert "Block not found" in response.json()["detail"]
|
||||
|
||||
|
||||
def test_list_agents_for_block(client, mock_sync_server):
|
||||
"""
|
||||
Test the GET /v1/blocks/{block_id}/agents endpoint.
|
||||
"""
|
||||
mock_sync_server.block_manager.get_agents_for_block.return_value = []
|
||||
|
||||
response = client.get("/v1/blocks/block-abc/agents", headers={"user_id": "test_user"})
|
||||
assert response.status_code == 200
|
||||
data = response.json()
|
||||
assert len(data) == 0
|
||||
|
||||
mock_sync_server.block_manager.get_agents_for_block.assert_called_once_with(
|
||||
block_id="block-abc",
|
||||
actor=mock_sync_server.user_manager.get_user_or_default.return_value,
|
||||
)
|
||||
|
||||
Reference in New Issue
Block a user