diff --git a/memgpt/client/client.py b/memgpt/client/client.py index c7472622..f788374f 100644 --- a/memgpt/client/client.py +++ b/memgpt/client/client.py @@ -562,10 +562,20 @@ class RESTClient(AbstractClient): response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers) assert response.status_code == 200, f"Failed to delete source: {response.text}" - def get_job_status(self, job_id: str): - response = requests.get(f"{self.base_url}/api/sources/status/{str(job_id)}", headers=self.headers) + def get_job(self, job_id: str) -> Job: + response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers) + if response.status_code != 200: + raise ValueError(f"Failed to get job: {response.text}") return Job(**response.json()) + def list_jobs(self): + response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers) + return [Job(**job) for job in response.json()] + + def list_active_jobs(self): + response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers) + return [Job(**job) for job in response.json()] + def load_file_into_source(self, filename: str, source_id: str, blocking=True): """Load {filename} and insert into source""" files = {"file": open(filename, "rb")} @@ -579,7 +589,7 @@ class RESTClient(AbstractClient): if blocking: # wait until job is completed while True: - job = self.get_job_status(job.id) + job = self.get_job(job.id) if job.status == JobStatus.completed: break elif job.status == JobStatus.failed: @@ -1176,6 +1186,15 @@ class LocalClient(AbstractClient): self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id) return job + def get_job(self, job_id: str): + return self.server.get_job(job_id=job_id) + + def list_jobs(self): + return self.server.list_jobs(user_id=self.user_id) + + def list_active_jobs(self): + return self.server.list_active_jobs(user_id=self.user_id) + def create_source(self, name: str) -> Source: request = SourceCreate(name=name) return self.server.create_source(request=request, user_id=self.user_id) diff --git a/memgpt/server/rest_api/jobs/index.py b/memgpt/server/rest_api/jobs/index.py new file mode 100644 index 00000000..0a2cb083 --- /dev/null +++ b/memgpt/server/rest_api/jobs/index.py @@ -0,0 +1,41 @@ +from functools import partial +from typing import List + +from fastapi import APIRouter, Depends + +from memgpt.schemas.job import Job +from memgpt.server.rest_api.auth_token import get_current_user +from memgpt.server.rest_api.interface import QueuingInterface +from memgpt.server.server import SyncServer + +router = APIRouter() + + +def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str): + get_current_user_with_server = partial(partial(get_current_user, server), password) + + @router.get("/jobs", tags=["jobs"], response_model=List[Job]) + async def list_jobs( + user_id: str = Depends(get_current_user_with_server), + ): + interface.clear() + + # TODO: add filtering by status + return server.list_jobs(user_id=user_id) + + @router.get("/jobs/active", tags=["jobs"], response_model=List[Job]) + async def list_active_jobs( + user_id: str = Depends(get_current_user_with_server), + ): + interface.clear() + return server.list_active_jobs(user_id=user_id) + + @router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job) + async def get_job( + job_id: str, + user_id: str = Depends(get_current_user_with_server), + ): + interface.clear() + return server.get_job(job_id=job_id) + + return router diff --git a/memgpt/server/rest_api/server.py b/memgpt/server/rest_api/server.py index 2fc4bf5a..7e224d77 100644 --- a/memgpt/server/rest_api/server.py +++ b/memgpt/server/rest_api/server.py @@ -22,6 +22,7 @@ from memgpt.server.rest_api.auth.index import setup_auth_router from memgpt.server.rest_api.block.index import setup_block_index_router from memgpt.server.rest_api.config.index import setup_config_index_router from memgpt.server.rest_api.interface import StreamingServerInterface +from memgpt.server.rest_api.jobs.index import setup_jobs_index_router from memgpt.server.rest_api.models.index import setup_models_index_router from memgpt.server.rest_api.openai_assistants.assistants import ( setup_openai_assistant_router, @@ -95,6 +96,7 @@ app.include_router(setup_agents_index_router(server, interface, password), prefi app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX) +app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX) app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX) diff --git a/memgpt/server/server.py b/memgpt/server/server.py index 067c2974..4610128f 100644 --- a/memgpt/server/server.py +++ b/memgpt/server/server.py @@ -1401,6 +1401,11 @@ class SyncServer(Server): """List all jobs for a user""" return self.ms.list_jobs(user_id=user_id) + def list_active_jobs(self, user_id: str) -> List[Job]: + """List all active jobs for a user""" + jobs = self.ms.list_jobs(user_id=user_id) + return [job for job in jobs if job.status in [JobStatus.created, JobStatus.running]] + def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Job: # update job diff --git a/tests/data/memgpt_paper.pdf b/tests/data/memgpt_paper.pdf new file mode 100644 index 00000000..d2c8bd78 Binary files /dev/null and b/tests/data/memgpt_paper.pdf differ diff --git a/tests/test_client.py b/tests/test_client.py index d5178209..c3f5ee7d 100644 --- a/tests/test_client.py +++ b/tests/test_client.py @@ -8,6 +8,7 @@ from dotenv import load_dotenv from memgpt import Admin, create_client from memgpt.constants import DEFAULT_PRESET +from memgpt.schemas.enums import JobStatus from memgpt.schemas.message import Message from memgpt.schemas.usage import MemGPTUsageStatistics @@ -238,8 +239,9 @@ def test_config(client, agent): def test_sources(client, agent): # _reset_config() - if not hasattr(client, "base_url"): - pytest.skip("Skipping test_sources because base_url is None") + # clear sources + for source in client.list_sources(): + client.delete_source(source.id) # list sources sources = client.list_sources() @@ -277,11 +279,34 @@ def test_sources(client, agent): print(archival_memories) assert len(archival_memories) == 0 - # load a file into a source - filename = "CONTRIBUTING.md" - upload_job = client.load_file_into_source(filename=filename, source_id=source.id) + # load a file into a source (non-blocking job) + filename = "tests/data/memgpt_paper.pdf" + upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False) print("Upload job", upload_job, upload_job.status, upload_job.metadata_) + # view active jobs + active_jobs = client.list_active_jobs() + jobs = client.list_jobs() + print(jobs) + assert upload_job.id in [j.id for j in jobs] + assert len(active_jobs) == 1 + + # wait for job to finish (with timeout) + timeout = 60 + start_time = time.time() + while True: + status = client.get_job(upload_job.id).status + print(status) + if status == JobStatus.completed: + break + time.sleep(1) + if time.time() - start_time > timeout: + raise ValueError("Job did not finish in time") + job = client.get_job(upload_job.id) + created_passages = job.metadata_["num_passages"] + + # TODO: add test for blocking job + # TODO: make sure things run in the right order archival_memories = client.get_archival_memory(agent_id=agent.id) assert len(archival_memories) == 0 @@ -297,7 +322,7 @@ def test_sources(client, agent): # list archival memory archival_memories = client.get_archival_memory(agent_id=agent.id) # print(archival_memories) - assert len(archival_memories) == 20 or len(archival_memories) == 21 + assert len(archival_memories) == created_passages # check number of passages sources = client.list_sources()