From 7f16fd9a3e6e7145813c1cf67b6f5ad9c5e0ea8c Mon Sep 17 00:00:00 2001 From: Sarah Wooders Date: Tue, 27 Aug 2024 13:32:21 -0700 Subject: [PATCH] feat: add endpoint to list agent attached sources --- memgpt/client/client.py | 5 ++++- memgpt/server/rest_api/agents/index.py | 12 ++++++++++++ tests/test_client.py | 5 +++++ 3 files changed, 21 insertions(+), 1 deletion(-) diff --git a/memgpt/client/client.py b/memgpt/client/client.py index a3fa5077..3e247e91 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -586,7 +586,10 @@ class RESTClient(AbstractClient): return Source(**response_json) def list_attached_sources(self, agent_id: str) -> List[Source]: - raise NotImplementedError + response = requests.get(f"{self.base_url}/api/agents/{agent_id}/sources", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to list attached sources: {response.text}") + return [Source(**source) for source in response.json()] def update_source(self, source_id: str, name: Optional[str] = None) -> Source: request = SourceUpdate(id=source_id, name=name) diff --git a/memgpt/server/rest_api/agents/index.py b/memgpt/server/rest_api/agents/index.py index 9491f0a1..bfbcfb81 100644 --- a/memgpt/server/rest_api/agents/index.py +++ b/memgpt/server/rest_api/agents/index.py @@ -4,6 +4,7 @@ from typing import List from fastapi import APIRouter, Body, Depends, HTTPException from memgpt.schemas.agent import AgentState, CreateAgent, UpdateAgentState +from memgpt.schemas.source import Source from memgpt.server.rest_api.auth_token import get_current_user from memgpt.server.rest_api.interface import QueuingInterface from memgpt.server.server import SyncServer @@ -91,4 +92,15 @@ def setup_agents_index_router(server: SyncServer, interface: QueuingInterface, p except Exception as e: raise HTTPException(status_code=500, detail=f"{e}") + @router.get("/agents/{agent_id}/sources", tags=["agents"], response_model=List[Source]) + def get_agent_sources( + agent_id: str, + user_id: str = Depends(get_current_user_with_server), + ): + """ + Get the sources associated with an agent. + """ + interface.clear() + return server.list_attached_sources(agent_id) + return router diff --git a/tests/test_client.py b/tests/test_client.py index f703087e..cfd882b7 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -281,6 +281,11 @@ def test_sources(client, agent): # attach a source client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) + # list attached sources + attached_sources = client.list_attached_sources(agent_id=agent.id) + print("attached sources", attached_sources) + assert source.id in [s.id for s in attached_sources], f"Attached sources: {attached_sources}" + # list archival memory archival_memories = client.get_archival_memory(agent_id=agent.id) # print(archival_memories)