diff --git a/app/utils.py b/app/utils.py index b98d6a87..83ee0403 100644 --- a/app/utils.py +++ b/app/utils.py @@ -80,10 +80,11 @@ class NextUrlSanitizer: def sanitize(url: Optional[str], allowed_domains: List[str]) -> Optional[str]: if not url: return None - result = urllib.parse.urlparse(url) + replaced = url.replace("\\", "/") + result = urllib.parse.urlparse(replaced) if result.hostname: if result.hostname in allowed_domains: - return url + return replaced else: return None if result.path and result.path[0] == "/": diff --git a/tests/test_utils.py b/tests/test_utils.py index 52984373..013ea681 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from typing import List +from typing import List, Optional from urllib.parse import parse_qs import pytest @@ -34,11 +34,12 @@ def generate_sanitize_url_cases() -> List: cases.append([f"https://{domain}/sub", f"https://{domain}/sub"]) cases.append([domain, None]) cases.append([f"//{domain}", f"//{domain}"]) + cases.append([f"https://google.com\\@{domain}/haha", None]) return cases @pytest.mark.parametrize("url,expected", generate_sanitize_url_cases()) -def test_sanitize_url(url, expected): +def test_sanitize_url(url: str, expected: Optional[str]): sanitized = sanitize_next_url(url) assert expected == sanitized