feat: Enable adding files (#1864)
Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
@@ -12,12 +12,13 @@ from letta.client.client import LocalClient, RESTClient
|
||||
from letta.constants import DEFAULT_PRESET
|
||||
from letta.schemas.agent import AgentState
|
||||
from letta.schemas.embedding_config import EmbeddingConfig
|
||||
from letta.schemas.enums import JobStatus, MessageStreamStatus
|
||||
from letta.schemas.enums import MessageStreamStatus
|
||||
from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue
|
||||
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
|
||||
from letta.schemas.llm_config import LLMConfig
|
||||
from letta.schemas.message import Message
|
||||
from letta.schemas.usage import LettaUsageStatistics
|
||||
from tests.helpers.client_helper import upload_file_using_client
|
||||
|
||||
# from tests.utils import create_config
|
||||
|
||||
@@ -298,6 +299,70 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# print("CONFIG", config_response)
|
||||
|
||||
|
||||
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
|
||||
# load files into sources
|
||||
file_a = "tests/data/memgpt_paper.pdf"
|
||||
file_b = "tests/data/test.txt"
|
||||
upload_file_using_client(client, source, file_a)
|
||||
upload_file_using_client(client, source, file_b)
|
||||
|
||||
# Get the first file
|
||||
files_a = client.list_files_from_source(source.id, limit=1)
|
||||
assert len(files_a) == 1
|
||||
assert files_a[0].source_id == source.id
|
||||
|
||||
# Use the cursor from response_a to get the remaining file
|
||||
files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id)
|
||||
assert len(files_b) == 1
|
||||
assert files_b[0].source_id == source.id
|
||||
|
||||
# Check files are different to ensure the cursor works
|
||||
assert files_a[0].file_name != files_b[0].file_name
|
||||
|
||||
# Use the cursor from response_b to list files, should be empty
|
||||
files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id)
|
||||
assert len(files) == 0 # Should be empty
|
||||
|
||||
|
||||
def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
# clear sources
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# create a source
|
||||
source = client.create_source(name="test_source")
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_file_using_client(client, source, filename)
|
||||
|
||||
# Get the files
|
||||
files = client.list_files_from_source(source.id)
|
||||
assert len(files) == 1 # Should be condensed to one document
|
||||
|
||||
# Get the memgpt paper
|
||||
file = files[0]
|
||||
assert file.file_name == "memgpt_paper.pdf"
|
||||
assert file.source_id == source.id
|
||||
|
||||
|
||||
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
# _reset_config()
|
||||
|
||||
@@ -305,6 +370,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
for source in client.list_sources():
|
||||
client.delete_source(source.id)
|
||||
|
||||
# clear jobs
|
||||
for job in client.list_jobs():
|
||||
client.delete_job(job.id)
|
||||
|
||||
# list sources
|
||||
sources = client.list_sources()
|
||||
print("listed sources", sources)
|
||||
@@ -343,28 +412,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# load a file into a source (non-blocking job)
|
||||
filename = "tests/data/memgpt_paper.pdf"
|
||||
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
|
||||
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
|
||||
|
||||
# view active jobs
|
||||
active_jobs = client.list_active_jobs()
|
||||
jobs = client.list_jobs()
|
||||
print(jobs)
|
||||
assert upload_job.id in [j.id for j in jobs]
|
||||
assert len(active_jobs) == 1
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 120
|
||||
start_time = time.time()
|
||||
while True:
|
||||
status = client.get_job(upload_job.id).status
|
||||
print(status)
|
||||
if status == JobStatus.completed:
|
||||
break
|
||||
time.sleep(1)
|
||||
if time.time() - start_time > timeout:
|
||||
raise ValueError("Job did not finish in time")
|
||||
upload_job = upload_file_using_client(client, source, filename)
|
||||
job = client.get_job(upload_job.id)
|
||||
created_passages = job.metadata_["num_passages"]
|
||||
|
||||
|
||||
Reference in New Issue
Block a user