diff --git a/memgpt/client/client.py b/memgpt/client/client.py index c238a021..ea48b0ff 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -5,7 +5,7 @@ import uuid from typing import Dict, List, Union, Optional, Tuple from memgpt.data_types import AgentState, User, Preset, LLMConfig, EmbeddingConfig, Source -from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel +from memgpt.models.pydantic_models import HumanModel, PersonaModel, PresetModel, SourceModel from memgpt.cli.cli import QuickstartChoice from memgpt.cli.cli import set_config_with_dict, quickstart as quickstart_func, str_to_quickstart_choice from memgpt.config import MemGPTConfig @@ -31,6 +31,7 @@ from memgpt.server.rest_api.personas.index import ListPersonasResponse from memgpt.server.rest_api.tools.index import ListToolsResponse, CreateToolResponse from memgpt.server.rest_api.models.index import ListModelsResponse from memgpt.server.rest_api.presets.index import CreatePresetResponse, CreatePresetsRequest, ListPresetsResponse +from memgpt.server.rest_api.sources.index import ListSourcesResponse, UploadFileToSourceResponse def create_client(base_url: Optional[str] = None, token: Optional[str] = None): @@ -437,7 +438,7 @@ class RESTClient(AbstractClient): """List loaded sources""" response = requests.get(f"{self.base_url}/api/sources", headers=self.headers) response_json = response.json() - return response_json + return ListSourcesResponse(**response_json) def delete_source(self, source_id: uuid.UUID): """Delete a source and associated data (including attached to agents)""" @@ -448,7 +449,7 @@ class RESTClient(AbstractClient): """Load {filename} and insert into source""" files = {"file": open(filename, "rb")} response = requests.post(f"{self.base_url}/api/sources/{source_id}/upload", files=files, headers=self.headers) - return response.json() + return UploadFileToSourceResponse(**response.json()) def create_source(self, name: str) -> Source: """Create a new source""" @@ -456,13 +457,14 @@ class RESTClient(AbstractClient): response = requests.post(f"{self.base_url}/api/sources", json=payload, headers=self.headers) response_json = response.json() print("CREATE SOURCE", response_json, response.text) + response_obj = SourceModel(**response_json) return Source( - id=uuid.UUID(response_json["id"]), - name=response_json["name"], - user_id=uuid.UUID(response_json["user_id"]), - created_at=datetime.datetime.fromtimestamp(response_json["created_at"]), - embedding_dim=response_json["embedding_config"]["embedding_dim"], - embedding_model=response_json["embedding_config"]["embedding_model"], + id=uuid.UUID(response_obj.id), + name=response_obj.name, + user_id=uuid.UUID(response_obj.user_id), + created_at=response_obj.created_at, + embedding_dim=response_obj.embedding_config["embedding_dim"], + embedding_model=response_obj.embedding_config["embedding_model"], ) def attach_source_to_agent(self, source_id: uuid.UUID, agent_id: uuid.UUID): @@ -470,14 +472,12 @@ class RESTClient(AbstractClient): params = {"agent_id": agent_id} response = requests.post(f"{self.base_url}/api/sources/{source_id}/attach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to attach source to agent: {response.text}" - return response.json() def detach_source(self, source_id: uuid.UUID, agent_id: uuid.UUID): """Detach a source from an agent""" params = {"agent_id": str(agent_id)} response = requests.post(f"{self.base_url}/api/sources/{source_id}/detach", params=params, headers=self.headers) assert response.status_code == 200, f"Failed to detach source from agent: {response.text}" - return response.json() # server configuration commands diff --git a/memgpt/server/rest_api/sources/index.py b/memgpt/server/rest_api/sources/index.py index 5c7ae9de..c5ce8126 100644 --- a/memgpt/server/rest_api/sources/index.py +++ b/memgpt/server/rest_api/sources/index.py @@ -39,10 +39,6 @@ class CreateSourceRequest(BaseModel): description: Optional[str] = Field(None, description="The description of the source.") -class CreateSourceResponse(BaseModel): - source: SourceModel = Field(..., description="The newly created source.") - - class UploadFileToSourceRequest(BaseModel): file: UploadFile = Field(..., description="The file to upload.") @@ -128,7 +124,8 @@ def setup_sources_index_router(server: SyncServer, interface: QueuingInterface, interface.clear() assert isinstance(agent_id, uuid.UUID), f"Expected agent_id to be a UUID, got {agent_id}" assert isinstance(user_id, uuid.UUID), f"Expected user_id to be a UUID, got {user_id}" - source = server.attach_source_to_agent(source_id=source_id, agent_id=agent_id, user_id=user_id) + source = server.ms.get_source(source_id=source_id, user_id=user_id) + source = server.attach_source_to_agent(source_name=source.name, agent_id=agent_id, user_id=user_id) return SourceModel( name=source.name, description=None, # TODO: actually store descriptions diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 5e17dd66..023a9bd4 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1360,8 +1360,21 @@ class SyncServer(LockingServer): sources_with_metadata = [] for source in sources: - passages = self.list_data_source_passages(user_id=user_id, source_id=source.id) - documents = self.list_data_source_documents(user_id=user_id, source_id=source.id) + # count number of passages + passage_conn = StorageConnector.get_storage_connector(TableType.PASSAGES, self.config, user_id=user_id) + num_passages = passage_conn.size({"data_source": source.name}) + print(passage_conn.get_all()) + print( + "NUMBER PASSAGES", + num_passages, + ) + + # TODO: add when documents table implemented + ## count number of documents + # document_conn = StorageConnector.get_storage_connector(TableType.DOCUMENTS, self.config, user_id=user_id) + # num_documents = document_conn.size({"data_source": source.name}) + num_documents = 0 + agent_ids = self.ms.list_attached_agents(source_id=source.id) # add the agent name information attached_agents = [ @@ -1374,8 +1387,8 @@ class SyncServer(LockingServer): # Overwrite metadata field, should be empty anyways source.metadata_ = dict( - num_documents=len(passages), - num_passages=len(documents), + num_documents=num_documents, + num_passages=num_passages, attached_agents=attached_agents, ) diff --git a/tests/test_client.py b/tests/test_client.py index 25f7d109..40c3d772 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -253,6 +253,7 @@ def test_sources(client, agent): # list sources sources = client.list_sources() print("listed sources", sources) + assert len(sources.sources) == 0 # create a source source = client.create_source(name="test_source") @@ -260,7 +261,9 @@ def test_sources(client, agent): # list sources sources = client.list_sources() print("listed sources", sources) - assert len(sources) == 1 + assert len(sources.sources) == 1 + assert sources.sources[0].metadata_["num_passages"] == 0 + assert sources.sources[0].metadata_["num_documents"] == 0 # check agent archival memory size archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory @@ -269,18 +272,25 @@ def test_sources(client, agent): # load a file into a source filename = "CONTRIBUTING.md" - num_passages = 20 response = client.load_file_into_source(filename=filename, source_id=source.id) - print(response) + + # TODO: make sure things run in the right order + archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory + assert len(archival_memories) == 0 # attach a source - # TODO: make sure things run in the right order client.attach_source_to_agent(source_id=source.id, agent_id=agent.id) # list archival memory archival_memories = client.get_agent_archival_memory(agent_id=agent.id).archival_memory - print(archival_memories) - assert len(archival_memories) == num_passages + # print(archival_memories) + assert len(archival_memories) == 20 or len(archival_memories) == 21 + + # check number of passages + sources = client.list_sources() + assert sources.sources[0].metadata_["num_passages"] > 0 + assert sources.sources[0].metadata_["num_documents"] == 0 # TODO: fix this once document store added + print(sources) # detach the source # TODO: add when implemented