feat: add endpoint to list agent attached sources

This commit is contained in:
Sarah Wooders
2024-08-27 13:32:21 -07:00
parent ce89defec7
commit 7f16fd9a3e
3 changed files with 21 additions and 1 deletions

View File

@@ -586,7 +586,10 @@ class RESTClient(AbstractClient):
return Source(**response_json)
def list_attached_sources(self, agent_id: str) -> List[Source]:
raise NotImplementedError
response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to list attached sources: {response.text}")
return [Source(**source) for source in response.json()]
def update_source(self, source_id: str, name: Optional[str] = None) -> Source:
request = SourceUpdate(id=source_id, name=name)

View File

@@ -4,6 +4,7 @@ from typing import List
from fastapi import APIRouter, Body, Depends, HTTPException
from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState
from memgpt.schemas.source import Source
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
@@ -91,4 +92,15 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p
except Exception as e:
raise HTTPException(status_code=500, detail=f"{e}")
@router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source])
def get_agent_sources(
agent_id: str,
user_id: str = Depends(get_current_user_with_server),
):
"""
Get the sources associated with an agent.
"""
interface.clear()
return server.list_attached_sources(agent_id)
return router

View File

@@ -281,6 +281,11 @@ def test_sources(client, agent):
# attach a source
client.attach_source_to_agent(source_id=source.id, agent_id=agent.id)
# list attached sources
attached_sources = client.list_attached_sources(agent_id=agent.id)
print("attached sources", attached_sources)
assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}"
# list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id)
# print(archival_memories)