feat: Enable adding files (#1864)

Co-authored-by: Matt Zhou <mattzhou@Matts-MacBook-Pro.local>
This commit is contained in:
Matthew Zhou
2024-10-14 10:22:45 -07:00
committed by GitHub
parent 9a44cc3df7
commit 93aacc087e
26 changed files with 565 additions and 223 deletions

View File

@@ -12,12 +12,13 @@ from letta.client.client import LocalClient, RESTClient
from letta.constants import DEFAULT_PRESET
from letta.schemas.agent import AgentState
from letta.schemas.embedding_config import EmbeddingConfig
from letta.schemas.enums import JobStatus, MessageStreamStatus
from letta.schemas.enums import MessageStreamStatus
from letta.schemas.letta_message import FunctionCallMessage, InternalMonologue
from letta.schemas.letta_response import LettaResponse, LettaStreamingResponse
from letta.schemas.llm_config import LLMConfig
from letta.schemas.message import Message
from letta.schemas.usage import LettaUsageStatistics
from tests.helpers.client_helper import upload_file_using_client
# from tests.utils import create_config
@@ -298,6 +299,70 @@ def test_config(client: Union[LocalClient, RESTClient], agent: AgentState):
# print("CONFIG", config_response)
def test_list_files_pagination(client: Union[LocalClient, RESTClient], agent: AgentState):
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# create a source
source = client.create_source(name="test_source")
# load files into sources
file_a = "tests/data/memgpt_paper.pdf"
file_b = "tests/data/test.txt"
upload_file_using_client(client, source, file_a)
upload_file_using_client(client, source, file_b)
# Get the first file
files_a = client.list_files_from_source(source.id, limit=1)
assert len(files_a) == 1
assert files_a[0].source_id == source.id
# Use the cursor from response_a to get the remaining file
files_b = client.list_files_from_source(source.id, limit=1, cursor=files_a[-1].id)
assert len(files_b) == 1
assert files_b[0].source_id == source.id
# Check files are different to ensure the cursor works
assert files_a[0].file_name != files_b[0].file_name
# Use the cursor from response_b to list files, should be empty
files = client.list_files_from_source(source.id, limit=1, cursor=files_b[-1].id)
assert len(files) == 0 # Should be empty
def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
# clear sources
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# create a source
source = client.create_source(name="test_source")
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_file_using_client(client, source, filename)
# Get the files
files = client.list_files_from_source(source.id)
assert len(files) == 1 # Should be condensed to one document
# Get the memgpt paper
file = files[0]
assert file.file_name == "memgpt_paper.pdf"
assert file.source_id == source.id
def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# _reset_config()
@@ -305,6 +370,10 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
for source in client.list_sources():
client.delete_source(source.id)
# clear jobs
for job in client.list_jobs():
client.delete_job(job.id)
# list sources
sources = client.list_sources()
print("listed sources", sources)
@@ -343,28 +412,7 @@ def test_sources(client: Union[LocalClient, RESTClient], agent: AgentState):
# load a file into a source (non-blocking job)
filename = "tests/data/memgpt_paper.pdf"
upload_job = client.load_file_into_source(filename=filename, source_id=source.id, blocking=False)
print("Upload job", upload_job, upload_job.status, upload_job.metadata_)
# view active jobs
active_jobs = client.list_active_jobs()
jobs = client.list_jobs()
print(jobs)
assert upload_job.id in [j.id for j in jobs]
assert len(active_jobs) == 1
assert active_jobs[0].metadata_["source_id"] == source.id
# wait for job to finish (with timeout)
timeout = 120
start_time = time.time()
while True:
status = client.get_job(upload_job.id).status
print(status)
if status == JobStatus.completed:
break
time.sleep(1)
if time.time() - start_time > timeout:
raise ValueError("Job did not finish in time")
upload_job = upload_file_using_client(client, source, filename)
job = client.get_job(upload_job.id)
created_passages = job.metadata_["num_passages"]