fix: Fix security vuln with file upload (#2067)
This commit is contained in:
3
.github/workflows/tests.yml
vendored
3
.github/workflows/tests.yml
vendored
@@ -29,6 +29,7 @@ jobs:
|
||||
- "test_o1_agent.py"
|
||||
- "test_tool_rule_solver.py"
|
||||
- "test_agent_tool_graph.py"
|
||||
- "test_utils.py"
|
||||
services:
|
||||
qdrant:
|
||||
image: qdrant/qdrant
|
||||
@@ -131,4 +132,4 @@ jobs:
|
||||
LETTA_SERVER_PASS: test_server_token
|
||||
PYTHONPATH: ${{ github.workspace }}:${{ env.PYTHONPATH }}
|
||||
run: |
|
||||
poetry run pytest -s -vv -k "not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
|
||||
poetry run pytest -s -vv -k "not test_utils.py and not integration_test_summarizer.py and not test_agent_tool_graph.py and not test_tool_rule_solver.py and not test_local_client.py and not test_o1_agent.py and not test_cli.py and not test_tools.py and not test_concurrent_connections.py and not test_quickstart and not test_model_letta_perfomance and not test_storage and not test_server and not test_openai_client and not test_providers and not test_client.py" tests
|
||||
|
||||
@@ -380,9 +380,10 @@ class PostgresStorageConnector(SQLStorageConnector):
|
||||
else:
|
||||
raise ValueError(f"Table type {table_type} not implemented")
|
||||
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
if settings.pg_uri:
|
||||
for c in self.db_model.__table__.columns:
|
||||
if c.name == "embedding":
|
||||
assert isinstance(c.type, Vector), f"Embedding column must be of type Vector, got {c.type}"
|
||||
|
||||
from letta.server.server import db_context
|
||||
|
||||
|
||||
@@ -158,3 +158,6 @@ RETRIEVAL_QUERY_DEFAULT_PAGE_SIZE = 5
|
||||
# TODO Is this config or constant?
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT: int = 2000
|
||||
CORE_MEMORY_HUMAN_CHAR_LIMIT: int = 2000
|
||||
|
||||
MAX_FILENAME_LENGTH = 255
|
||||
RESERVED_FILENAMES = {"CON", "PRN", "AUX", "NUL", "COM1", "COM2", "LPT1", "LPT2"}
|
||||
|
||||
@@ -77,6 +77,6 @@ class LettaBase(BaseModel):
|
||||
"""
|
||||
_ = values # for SCA
|
||||
if isinstance(v, UUID):
|
||||
logger.warning("Bare UUIDs are deprecated, please use the full prefixed id!")
|
||||
logger.warning(f"Bare UUIDs are deprecated, please use the full prefixed id ({cls.__id_prefix__})!")
|
||||
return f"{cls.__id_prefix__}-{v}"
|
||||
return v
|
||||
|
||||
@@ -18,6 +18,7 @@ from letta.schemas.passage import Passage
|
||||
from letta.schemas.source import Source, SourceCreate, SourceUpdate
|
||||
from letta.server.rest_api.utils import get_letta_server
|
||||
from letta.server.server import SyncServer
|
||||
from letta.utils import sanitize_filename
|
||||
|
||||
# These can be forward refs, but because Fastapi needs them at runtime the must be imported normally
|
||||
|
||||
@@ -170,7 +171,7 @@ def upload_file_to_source(
|
||||
server.ms.create_job(job)
|
||||
|
||||
# create background task
|
||||
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, job_id=job.id, file=file, bytes=bytes)
|
||||
background_tasks.add_task(load_file_to_source_async, server, source_id=source.id, file=file, job_id=job.id, bytes=bytes)
|
||||
|
||||
# return job information
|
||||
job = server.ms.get_job(job_id=job_id)
|
||||
@@ -227,10 +228,15 @@ def delete_file_from_source(
|
||||
|
||||
|
||||
def load_file_to_source_async(server: SyncServer, source_id: str, job_id: str, file: UploadFile, bytes: bytes):
|
||||
# write the file to a temporary directory (deleted after the context manager exits)
|
||||
# Create a temporary directory (deleted after the context manager exits)
|
||||
with tempfile.TemporaryDirectory() as tmpdirname:
|
||||
file_path = os.path.join(str(tmpdirname), str(file.filename))
|
||||
# Sanitize the filename
|
||||
sanitized_filename = sanitize_filename(file.filename)
|
||||
file_path = os.path.join(tmpdirname, sanitized_filename)
|
||||
|
||||
# Write the file to the sanitized path
|
||||
with open(file_path, "wb") as buffer:
|
||||
buffer.write(bytes)
|
||||
|
||||
# Pass the file to load_file_to_source
|
||||
server.load_file_to_source(source_id, file_path, job_id)
|
||||
|
||||
@@ -21,6 +21,7 @@ from urllib.parse import urljoin, urlparse
|
||||
import demjson3 as demjson
|
||||
import pytz
|
||||
import tiktoken
|
||||
from pathvalidate import sanitize_filename as pathvalidate_sanitize_filename
|
||||
|
||||
import letta
|
||||
from letta.constants import (
|
||||
@@ -29,6 +30,7 @@ from letta.constants import (
|
||||
CORE_MEMORY_PERSONA_CHAR_LIMIT,
|
||||
FUNCTION_RETURN_CHAR_LIMIT,
|
||||
LETTA_DIR,
|
||||
MAX_FILENAME_LENGTH,
|
||||
TOOL_CALL_ID_MAX_LEN,
|
||||
)
|
||||
from letta.schemas.openai.chat_completion_response import ChatCompletionResponse
|
||||
@@ -1071,3 +1073,40 @@ def json_dumps(data, indent=2):
|
||||
|
||||
def json_loads(data):
|
||||
return json.loads(data, strict=False)
|
||||
|
||||
|
||||
def sanitize_filename(filename: str) -> str:
|
||||
"""
|
||||
Sanitize the given filename to prevent directory traversal, invalid characters,
|
||||
and reserved names while ensuring it fits within the maximum length allowed by the filesystem.
|
||||
|
||||
Parameters:
|
||||
filename (str): The user-provided filename.
|
||||
|
||||
Returns:
|
||||
str: A sanitized filename that is unique and safe for use.
|
||||
"""
|
||||
# Extract the base filename to avoid directory components
|
||||
filename = os.path.basename(filename)
|
||||
|
||||
# Split the base and extension
|
||||
base, ext = os.path.splitext(filename)
|
||||
|
||||
# External sanitization library
|
||||
base = pathvalidate_sanitize_filename(base)
|
||||
|
||||
# Cannot start with a period
|
||||
if base.startswith("."):
|
||||
raise ValueError(f"Invalid filename - derived file name {base} cannot start with '.'")
|
||||
|
||||
# Truncate the base name to fit within the maximum allowed length
|
||||
max_base_length = MAX_FILENAME_LENGTH - len(ext) - 33 # 32 for UUID + 1 for `_`
|
||||
if len(base) > max_base_length:
|
||||
base = base[:max_base_length]
|
||||
|
||||
# Append a unique UUID suffix for uniqueness
|
||||
unique_suffix = uuid.uuid4().hex
|
||||
sanitized_filename = f"{base}_{unique_suffix}{ext}"
|
||||
|
||||
# Return the sanitized filename
|
||||
return sanitized_filename
|
||||
|
||||
18
poetry.lock
generated
18
poetry.lock
generated
@@ -4905,6 +4905,22 @@ files = [
|
||||
{file = "pathspec-0.12.1.tar.gz", hash = "sha256:a482d51503a1ab33b1c67a6c3813a26953dbdc71c31dacaef9a838c4e29f5712"},
|
||||
]
|
||||
|
||||
[[package]]
|
||||
name = "pathvalidate"
|
||||
version = "3.2.1"
|
||||
description = "pathvalidate is a Python library to sanitize/validate a string such as filenames/file-paths/etc."
|
||||
optional = false
|
||||
python-versions = ">=3.7"
|
||||
files = [
|
||||
{file = "pathvalidate-3.2.1-py3-none-any.whl", hash = "sha256:9a6255eb8f63c9e2135b9be97a5ce08f10230128c4ae7b3e935378b82b22c4c9"},
|
||||
{file = "pathvalidate-3.2.1.tar.gz", hash = "sha256:f5d07b1e2374187040612a1fcd2bcb2919f8db180df254c9581bb90bf903377d"},
|
||||
]
|
||||
|
||||
[package.extras]
|
||||
docs = ["Sphinx (>=2.4)", "sphinx-rtd-theme (>=1.2.2)", "urllib3 (<2)"]
|
||||
readme = ["path (>=13,<17)", "readmemaker (>=1.1.0)"]
|
||||
test = ["Faker (>=1.0.8)", "allpairspy (>=2)", "click (>=6.2)", "pytest (>=6.0.1)", "pytest-md-report (>=0.6.2)"]
|
||||
|
||||
[[package]]
|
||||
name = "pexpect"
|
||||
version = "4.9.0"
|
||||
@@ -8494,4 +8510,4 @@ tests = ["wikipedia"]
|
||||
[metadata]
|
||||
lock-version = "2.0"
|
||||
python-versions = "<3.13,>=3.10"
|
||||
content-hash = "570c482aed9ff66761ac47b8b7e1ca06525d4e5084791723380101217e163500"
|
||||
content-hash = "5aef7fe9900da5d0fefbb0ce4f4f65b565f1967826f840138cfdd59444fd7330"
|
||||
|
||||
@@ -80,6 +80,7 @@ alembic = "^1.13.3"
|
||||
pyhumps = "^3.8.0"
|
||||
psycopg2 = "^2.9.10"
|
||||
psycopg2-binary = "^2.9.10"
|
||||
pathvalidate = "^3.2.1"
|
||||
|
||||
[tool.poetry.extras]
|
||||
#local = ["llama-index-embeddings-huggingface"]
|
||||
|
||||
@@ -20,7 +20,7 @@ def upload_file_using_client(client: Union[LocalClient, RESTClient], source: Sou
|
||||
assert active_jobs[0].metadata_["source_id"] == source.id
|
||||
|
||||
# wait for job to finish (with timeout)
|
||||
timeout = 120
|
||||
timeout = 240
|
||||
start_time = time.time()
|
||||
while True:
|
||||
status = client.get_job(upload_job.id).status
|
||||
|
||||
@@ -1,4 +1,5 @@
|
||||
import os
|
||||
import re
|
||||
import threading
|
||||
import time
|
||||
import uuid
|
||||
@@ -435,7 +436,10 @@ def test_load_file(client: Union[LocalClient, RESTClient], agent: AgentState):
|
||||
|
||||
# Get the memgpt paper
|
||||
file = files[0]
|
||||
assert file.file_name == "memgpt_paper.pdf"
|
||||
# Assert the filename matches the pattern
|
||||
pattern = re.compile(r"^memgpt_paper_[a-f0-9]{32}\.pdf$")
|
||||
assert pattern.match(file.file_name), f"Filename '{file.file_name}' does not match expected pattern."
|
||||
|
||||
assert file.source_id == source.id
|
||||
|
||||
|
||||
|
||||
66
tests/test_utils.py
Normal file
66
tests/test_utils.py
Normal file
@@ -0,0 +1,66 @@
|
||||
import pytest
|
||||
|
||||
from letta.constants import MAX_FILENAME_LENGTH
|
||||
from letta.utils import sanitize_filename
|
||||
|
||||
|
||||
def test_valid_filename():
|
||||
filename = "valid_filename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert sanitized.startswith("valid_filename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_filename_with_special_characters():
|
||||
filename = "invalid:/<>?*ƒfilename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert sanitized.startswith("ƒfilename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_null_byte_in_filename():
|
||||
filename = "valid\0filename.txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert "\0" not in sanitized
|
||||
assert sanitized.startswith("validfilename_")
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_path_traversal_characters():
|
||||
filename = "../../etc/passwd"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert sanitized.startswith("passwd_")
|
||||
assert len(sanitized) <= MAX_FILENAME_LENGTH
|
||||
|
||||
|
||||
def test_empty_filename():
|
||||
sanitized = sanitize_filename("")
|
||||
assert sanitized.startswith("_")
|
||||
|
||||
|
||||
def test_dot_as_filename():
|
||||
with pytest.raises(ValueError, match="Invalid filename"):
|
||||
sanitize_filename(".")
|
||||
|
||||
|
||||
def test_dotdot_as_filename():
|
||||
with pytest.raises(ValueError, match="Invalid filename"):
|
||||
sanitize_filename("..")
|
||||
|
||||
|
||||
def test_long_filename():
|
||||
filename = "a" * (MAX_FILENAME_LENGTH + 10) + ".txt"
|
||||
sanitized = sanitize_filename(filename)
|
||||
assert len(sanitized) <= MAX_FILENAME_LENGTH
|
||||
assert sanitized.endswith(".txt")
|
||||
|
||||
|
||||
def test_unique_filenames():
|
||||
filename = "duplicate.txt"
|
||||
sanitized1 = sanitize_filename(filename)
|
||||
sanitized2 = sanitize_filename(filename)
|
||||
assert sanitized1 != sanitized2
|
||||
assert sanitized1.startswith("duplicate_")
|
||||
assert sanitized2.startswith("duplicate_")
|
||||
assert sanitized1.endswith(".txt")
|
||||
assert sanitized2.endswith(".txt")
|
||||
Reference in New Issue
Block a user