diff --git a/app/dashboard/views/setting.py b/app/dashboard/views/setting.py index c08ab145..28ea408d 100644 --- a/app/dashboard/views/setting.py +++ b/app/dashboard/views/setting.py @@ -29,6 +29,7 @@ from app.email_utils import ( personal_email_already_used, ) from app.errors import ProtonPartnerNotSetUp +from app.image_validation import detect_image_format, ImageFormat from app.jobs.export_user_data_job import ExportUserDataJob from app.log import LOG from app.models import ( @@ -181,12 +182,18 @@ def setting(): profile_updated = True if form.profile_picture.data: + image_contents = form.profile_picture.data.read() + if detect_image_format(image_contents) == ImageFormat.Unknown: + flash( + "This image format is not supported", + "error", + ) + return redirect(url_for("dashboard.setting")) + file_path = random_string(30) file = File.create(user_id=current_user.id, path=file_path) - s3.upload_from_bytesio( - file_path, BytesIO(form.profile_picture.data.read()) - ) + s3.upload_from_bytesio(file_path, BytesIO(image_contents)) Session.flush() LOG.d("upload file %s to s3", file) diff --git a/app/image_validation.py b/app/image_validation.py new file mode 100644 index 00000000..6c5bbb2a --- /dev/null +++ b/app/image_validation.py @@ -0,0 +1,28 @@ +from enum import Enum + + +class ImageFormat(Enum): + Png = 1 + Jpg = 2 + Webp = 3 + Svg = 4 + Unknown = 9 + + +magic_numbers = { + ImageFormat.Png: bytes([0x89, 0x50, 0x4E, 0x47, 0x0D, 0x0A, 0x1A, 0x0A]), + ImageFormat.Jpg: bytes([0xFF, 0xD8, 0xFF, 0xE0]), + ImageFormat.Webp: bytes([0x52, 0x49, 0x46, 0x46]), + ImageFormat.Svg: bytes([0x3C, 0x3F, 0x78, 0x6D, 0x6C]), # ImageFormat: + # Detect image based on magic number + for fmt, header in magic_numbers.items(): + if image.startswith(header): + return fmt + # Detect if is svg + + # We don't know the type + return ImageFormat.Unknown diff --git a/tests/data/1px.jpg b/tests/data/1px.jpg new file mode 100644 index 00000000..1cda9a53 Binary files /dev/null and b/tests/data/1px.jpg differ diff --git a/tests/data/1px.webp b/tests/data/1px.webp new file mode 100644 index 00000000..c78718cd Binary files /dev/null and b/tests/data/1px.webp differ diff --git a/tests/test_image_validation.py b/tests/test_image_validation.py new file mode 100644 index 00000000..51e5c78e --- /dev/null +++ b/tests/test_image_validation.py @@ -0,0 +1,47 @@ +from app.image_validation import ImageFormat, detect_image_format +from pathlib import Path + + +def get_path_to_static_dir() -> Path: + this_path = Path(__file__) + repo_root_path = this_path.parent.parent + return repo_root_path.joinpath("static") + + +def read_static_file_contents(filename: str) -> bytes: + image_path = get_path_to_static_dir().joinpath(filename) + with open(image_path.as_posix(), "rb") as f: + return f.read() + + +def read_test_data_file_contents(filename: str) -> bytes: + this_path = Path(__file__) + test_data_path = this_path.parent.joinpath("data") + file_path = test_data_path.joinpath(filename) + with open(file_path.as_posix(), "rb") as f: + return f.read() + + +def test_non_image_file_returns_unknown(): + contents = read_static_file_contents("local-storage-polyfill.js") + assert detect_image_format(contents) is ImageFormat.Unknown + + +def test_png_file_is_detected(): + contents = read_static_file_contents("logo.png") + assert detect_image_format(contents) is ImageFormat.Png + + +def test_jpg_file_is_detected(): + contents = read_test_data_file_contents("1px.jpg") + assert detect_image_format(contents) is ImageFormat.Jpg + + +def test_webp_file_is_detected(): + contents = read_test_data_file_contents("1px.webp") + assert detect_image_format(contents) is ImageFormat.Webp + + +def test_svg_file_is_detected(): + contents = read_static_file_contents("icon.svg") + assert detect_image_format(contents) is ImageFormat.Svg