feat: add jobs routes to REST API (#1699)

This commit is contained in:
Sarah Wooders
2024-08-30 17:07:30 -07:00
committed by GitHub
parent 512315800c
commit 07d74937cb
6 changed files with 101 additions and 9 deletions

View File

@@ -562,10 +562,20 @@ class RESTClient(AbstractClient):
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
assert response.status_code == 200, f"Failed to delete source: {response.text}"
def get_job_status(self, job_id: str):
response = requests.get(f"{self.base_url}/api/sources/status/{str(job_id)}", headers=self.headers)
def get_job(self, job_id: str) -> Job:
response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers)
if response.status_code != 200:
raise ValueError(f"Failed to get job: {response.text}")
return Job(**response.json())
def list_jobs(self):
response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers)
return [Job(**job) for job in response.json()]
def list_active_jobs(self):
response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers)
return [Job(**job) for job in response.json()]
def load_file_into_source(self, filename: str, source_id: str, blocking=True):
"""Load {filename} and insert into source"""
files = {"file": open(filename, "rb")}
@@ -579,7 +589,7 @@ class RESTClient(AbstractClient):
if blocking:
# wait until job is completed
while True:
job = self.get_job_status(job.id)
job = self.get_job(job.id)
if job.status == JobStatus.completed:
break
elif job.status == JobStatus.failed:
@@ -1176,6 +1186,15 @@ class LocalClient(AbstractClient):
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id)
return job
def get_job(self, job_id: str):
return self.server.get_job(job_id=job_id)
def list_jobs(self):
return self.server.list_jobs(user_id=self.user_id)
def list_active_jobs(self):
return self.server.list_active_jobs(user_id=self.user_id)
def create_source(self, name: str) -> Source:
request = SourceCreate(name=name)
return self.server.create_source(request=request, user_id=self.user_id)

View File

@@ -0,0 +1,41 @@
from functools import partial
from typing import List
from fastapi import APIRouter, Depends
from memgpt.schemas.job import Job
from memgpt.server.rest_api.auth_token import get_current_user
from memgpt.server.rest_api.interface import QueuingInterface
from memgpt.server.server import SyncServer
router = APIRouter()
def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str):
get_current_user_with_server = partial(partial(get_current_user, server), password)
@router.get("/jobs", tags=["jobs"], response_model=List[Job])
async def list_jobs(
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
# TODO: add filtering by status
return server.list_jobs(user_id=user_id)
@router.get("/jobs/active", tags=["jobs"], response_model=List[Job])
async def list_active_jobs(
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.list_active_jobs(user_id=user_id)
@router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job)
async def get_job(
job_id: str,
user_id: str = Depends(get_current_user_with_server),
):
interface.clear()
return server.get_job(job_id=job_id)
return router

View File

@@ -22,6 +22,7 @@ from memgpt.server.rest_api.auth.index import setup_auth_router
from memgpt.server.rest_api.block.index import setup_block_index_router
from memgpt.server.rest_api.config.index import setup_config_index_router
from memgpt.server.rest_api.interface import StreamingServerInterface
from memgpt.server.rest_api.jobs.index import setup_jobs_index_router
from memgpt.server.rest_api.models.index import setup_models_index_router
from memgpt.server.rest_api.openai_assistants.assistants import (
setup_openai_assistant_router,
@@ -95,6 +96,7 @@ app.include_router(setup_agents_index_router(server, interface, password), prefi
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)

View File

@@ -1401,6 +1401,11 @@ class SyncServer(Server):
"""List all jobs for a user"""
return self.ms.list_jobs(user_id=user_id)
def list_active_jobs(self, user_id: str) -> List[Job]:
"""List all active jobs for a user"""
jobs = self.ms.list_jobs(user_id=user_id)
return [job for job in jobs if job.status in [JobStatus.created, JobStatus.running]]
def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Job:
# update job

BIN
tests/data/memgpt_paper.pdf Normal file

Binary file not shown.

View File

@@ -8,6 +8,7 @@ from dotenv import load_dotenv
from memgpt import Admin, create_client
from memgpt.constants import DEFAULT_PRESET
from memgpt.schemas.enums import JobStatus
from memgpt.schemas.message import Message
from memgpt.schemas.usage import MemGPTUsageStatistics
@@ -238,8 +239,9 @@ def test_config(client, agent):
def test_sources(client, agent):
# _reset_config()
if not hasattr(client, "base_url"):
pytest.skip("Skipping test_sources because base_url is None")
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# list sources
sources = client.list_sources()
@@ -277,11 +279,34 @@ def test_sources(client, agent):
print(archival_memories)
assert len(archival_memories) == 0
# load a file into a source
filename = "CONTRIBUTING.md"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id)
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
print(jobs)
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
# wait for job to finish (with timeout)
timeout = 60
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
print(status)
if status == JobStatus.completed:
break
time.sleep(1)
if time.time() - start_time > timeout:
raise ValueError("Job did not finish in time")
job = client.get_job(upload_job.id)
created_passages = job.metadata_["num_passages"]
# TODO: add test for blocking job
# TODO: make sure things run in the right order
archival_memories = client.get_archival_memory(agent_id=agent.id)
assert len(archival_memories) == 0
@@ -297,7 +322,7 @@ def test_sources(client, agent):
# list archival memory
archival_memories = client.get_archival_memory(agent_id=agent.id)
# print(archival_memories)
assert len(archival_memories) == 20 or len(archival_memories) == 21
assert len(archival_memories) == created_passages
# check number of passages
sources = client.list_sources()