diff --git a/memgpt/client/client.py b/memgpt/client/client.py index a3fa5077..3e247e91 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -586,7 +586,10 @@ class RESTClient(AbstractClient): return Source(**response_json) def list_attached_sources(self, agent_id: str) -> List[Source]: - raise NotImplementedError + response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to list attached sources: {response.text}") + return [Source(**source) for source in response.json()] def update_source(self, source_id: str, name: Optional[str] = None) -> Source: request = SourceUpdate(id=source_id, name=name) diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 9491f0a1..bfbcfb81 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -4,6 +4,7 @@ from typing import List from fastapi import APIRouter, Body, Depends, HTTPException from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from memgpt.schemas.source import Source from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer @@ -91,4 +92,15 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") + @router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source]) + def get_agent_sources( + agent_id: str, + user_id: str = Depends(get_current_user_with_server), + ): + """ + Get the sources associated with an agent. + """ + interface.clear() + return server.list_attached_sources(agent_id) + return router diff --git a/tests/test_client.py b/tests/test_client.py index fb168046..5da524f1 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -281,6 +281,11 @@ def test_sources(client, agent): # attach a source client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) + # list attached sources + attached_sources = client.list_attached_sources(agent_id=agent.id) + print("attached sources", attached_sources) + assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}" + # list archival memory archival_memories = client.get_archival_memory(agent_id=agent.id) # print(archival_memories)