feat: add jobs routes to REST API (#1699)
This commit is contained in:
@@ -562,10 +562,20 @@ class RESTClient(AbstractClient):
|
||||
response = requests.delete(f"{self.base_url}/api/sources/{str(source_id)}", headers=self.headers)
|
||||
assert response.status_code == 200, f"Failed to delete source: {response.text}"
|
||||
|
||||
def get_job_status(self, job_id: str):
|
||||
response = requests.get(f"{self.base_url}/api/sources/status/{str(job_id)}", headers=self.headers)
|
||||
def get_job(self, job_id: str) -> Job:
|
||||
response = requests.get(f"{self.base_url}/api/jobs/{job_id}", headers=self.headers)
|
||||
if response.status_code != 200:
|
||||
raise ValueError(f"Failed to get job: {response.text}")
|
||||
return Job(**response.json())
|
||||
|
||||
def list_jobs(self):
|
||||
response = requests.get(f"{self.base_url}/api/jobs", headers=self.headers)
|
||||
return [Job(**job) for job in response.json()]
|
||||
|
||||
def list_active_jobs(self):
|
||||
response = requests.get(f"{self.base_url}/api/jobs/active", headers=self.headers)
|
||||
return [Job(**job) for job in response.json()]
|
||||
|
||||
def load_file_into_source(self, filename: str, source_id: str, blocking=True):
|
||||
"""Load {filename} and insert into source"""
|
||||
files = {"file": open(filename, "rb")}
|
||||
@@ -579,7 +589,7 @@ class RESTClient(AbstractClient):
|
||||
if blocking:
|
||||
# wait until job is completed
|
||||
while True:
|
||||
job = self.get_job_status(job.id)
|
||||
job = self.get_job(job.id)
|
||||
if job.status == JobStatus.completed:
|
||||
break
|
||||
elif job.status == JobStatus.failed:
|
||||
@@ -1176,6 +1186,15 @@ class LocalClient(AbstractClient):
|
||||
self.server.load_file_to_source(source_id=source_id, file_path=filename, job_id=job.id)
|
||||
return job
|
||||
|
||||
def get_job(self, job_id: str):
|
||||
return self.server.get_job(job_id=job_id)
|
||||
|
||||
def list_jobs(self):
|
||||
return self.server.list_jobs(user_id=self.user_id)
|
||||
|
||||
def list_active_jobs(self):
|
||||
return self.server.list_active_jobs(user_id=self.user_id)
|
||||
|
||||
def create_source(self, name: str) -> Source:
|
||||
request = SourceCreate(name=name)
|
||||
return self.server.create_source(request=request, user_id=self.user_id)
|
||||
|
||||
41
memgpt/server/rest_api/jobs/index.py
Normal file
41
memgpt/server/rest_api/jobs/index.py
Normal file
@@ -0,0 +1,41 @@
|
||||
from functools import partial
|
||||
from typing import List
|
||||
|
||||
from fastapi import APIRouter, Depends
|
||||
|
||||
from memgpt.schemas.job import Job
|
||||
from memgpt.server.rest_api.auth_token import get_current_user
|
||||
from memgpt.server.rest_api.interface import QueuingInterface
|
||||
from memgpt.server.server import SyncServer
|
||||
|
||||
router = APIRouter()
|
||||
|
||||
|
||||
def setup_jobs_index_router(server: SyncServer, interface: QueuingInterface, password: str):
|
||||
get_current_user_with_server = partial(partial(get_current_user, server), password)
|
||||
|
||||
@router.get("/jobs", tags=["jobs"], response_model=List[Job])
|
||||
async def list_jobs(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
|
||||
# TODO: add filtering by status
|
||||
return server.list_jobs(user_id=user_id)
|
||||
|
||||
@router.get("/jobs/active", tags=["jobs"], response_model=List[Job])
|
||||
async def list_active_jobs(
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.list_active_jobs(user_id=user_id)
|
||||
|
||||
@router.get("/jobs/{job_id}", tags=["jobs"], response_model=Job)
|
||||
async def get_job(
|
||||
job_id: str,
|
||||
user_id: str = Depends(get_current_user_with_server),
|
||||
):
|
||||
interface.clear()
|
||||
return server.get_job(job_id=job_id)
|
||||
|
||||
return router
|
||||
@@ -22,6 +22,7 @@ from memgpt.server.rest_api.auth.index import setup_auth_router
|
||||
from memgpt.server.rest_api.block.index import setup_block_index_router
|
||||
from memgpt.server.rest_api.config.index import setup_config_index_router
|
||||
from memgpt.server.rest_api.interface import StreamingServerInterface
|
||||
from memgpt.server.rest_api.jobs.index import setup_jobs_index_router
|
||||
from memgpt.server.rest_api.models.index import setup_models_index_router
|
||||
from memgpt.server.rest_api.openai_assistants.assistants import (
|
||||
setup_openai_assistant_router,
|
||||
@@ -95,6 +96,7 @@ app.include_router(setup_agents_index_router(server, interface, password), prefi
|
||||
app.include_router(setup_agents_memory_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_agents_message_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_block_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_jobs_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_models_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_user_tools_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
app.include_router(setup_sources_index_router(server, interface, password), prefix=API_PREFIX)
|
||||
|
||||
@@ -1401,6 +1401,11 @@ class SyncServer(Server):
|
||||
"""List all jobs for a user"""
|
||||
return self.ms.list_jobs(user_id=user_id)
|
||||
|
||||
def list_active_jobs(self, user_id: str) -> List[Job]:
|
||||
"""List all active jobs for a user"""
|
||||
jobs = self.ms.list_jobs(user_id=user_id)
|
||||
return [job for job in jobs if job.status in [JobStatus.created, JobStatus.running]]
|
||||
|
||||
def load_file_to_source(self, source_id: str, file_path: str, job_id: str) -> Job:
|
||||
|
||||
# update job
|
||||
|
||||
BIN
tests/data/memgpt_paper.pdf
Normal file
BIN
tests/data/memgpt_paper.pdf
Normal file
Binary file not shown.
@@ -8,6 +8,7 @@ from dotenv import load_dotenv
|
||||
|
||||
from memgpt import Admin, create_client
|
||||
from memgpt.constants import DEFAULT_PRESET
|
||||
from memgpt.schemas.enums import JobStatus
|
||||
from memgpt.schemas.message import Message
|
||||
from memgpt.schemas.usage import MemGPTUsageStatistics
|
||||
|
||||
@@ -238,8 +239,9 @@ def test_config(client, agent):
|
||||
def test_sources(client, agent):
|
||||
# _reset_config()
|
||||
|
||||
if not hasattr(client, "base_url"):
|
||||
pytest.skip("Skipping test_sources because base_url is None")
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
@@ -277,11 +279,34 @@ def test_sources(client, agent):
|
||||
print(archival_memories)
|
||||
assert len(archival_memories) == 0
|
||||
|
||||
# load a file into a source
|
||||
filename = "CONTRIBUTING.md"
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id)
|
||||
# load a file into a source (non-blocking job)
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
|
||||
# view active jobs
|
||||
active_jobs = client.list_active_jobs()
|
||||
jobs = client.list_jobs()
|
||||
print(jobs)
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 60
|
||||
start_time = time.time()
|
||||
while True:
|
||||
status = client.get_job(upload_job.id).status
|
||||
print(status)
|
||||
if status == JobStatus.completed:
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
raise ValueError("Job did not finish in time")
|
||||
job = client.get_job(upload_job.id)
|
||||
created_passages = job.metadata_["num_passages"]
|
||||
|
||||
# TODO: add test for blocking job
|
||||
|
||||
# TODO: make sure things run in the right order
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
assert len(archival_memories) == 0
|
||||
@@ -297,7 +322,7 @@ def test_sources(client, agent):
|
||||
# list archival memory
|
||||
archival_memories = client.get_archival_memory(agent_id=agent.id)
|
||||
# print(archival_memories)
|
||||
assert len(archival_memories) == 20 or len(archival_memories) == 21
|
||||
assert len(archival_memories) == created_passages
|
||||
|
||||
# check number of passages
|
||||
sources = client.list_sources()
|
||||
|
||||
Reference in New Issue
Block a user