diff --git a/app/auth/views/activate.py b/app/auth/views/activate.py index 0f91b254..4c688ebe 100644 --- a/app/auth/views/activate.py +++ b/app/auth/views/activate.py @@ -7,6 +7,7 @@ from app.db import Session from app.extensions import limiter from app.log import LOG from app.models import ActivationCode +from app.utils import sanitize_next_url @auth_bp.route("/activate", methods=["GET", "POST"]) @@ -58,7 +59,7 @@ def activate(): # The activation link contains the original page, for ex authorize page if "next" in request.args: - next_url = request.args.get("next") + next_url = sanitize_next_url(request.args.get("next")) LOG.d("redirect user to %s", next_url) return redirect(next_url) else: diff --git a/app/auth/views/facebook.py b/app/auth/views/facebook.py index 9ea5eb57..8068e2eb 100644 --- a/app/auth/views/facebook.py +++ b/app/auth/views/facebook.py @@ -13,7 +13,7 @@ from app.db import Session from app.log import LOG from app.models import User, SocialAuth from .login_utils import after_login -from ...utils import sanitize_email +from ...utils import sanitize_email, sanitize_next_url _authorization_base_url = "https://www.facebook.com/dialog/oauth" _token_url = "https://graph.facebook.com/oauth/access_token" @@ -30,7 +30,7 @@ def facebook_login(): # to avoid flask-login displaying the login error message session.pop("_flashes", None) - next_url = request.args.get("next") + next_url = sanitize_next_url(request.args.get("next")) # Facebook does not allow to append param to redirect_uri # we need to pass the next url by session diff --git a/app/auth/views/login.py b/app/auth/views/login.py index f9c87f7d..c68ff887 100644 --- a/app/auth/views/login.py +++ b/app/auth/views/login.py @@ -8,7 +8,7 @@ from app.auth.views.login_utils import after_login from app.extensions import limiter from app.log import LOG from app.models import User -from app.utils import sanitize_email +from app.utils import sanitize_email, sanitize_next_url class LoginForm(FlaskForm): @@ -21,7 +21,7 @@ class LoginForm(FlaskForm): "10/minute", deduct_when=lambda r: hasattr(g, "deduct_limit") and g.deduct_limit ) def login(): - next_url = request.args.get("next") + next_url = sanitize_next_url(request.args.get("next")) if current_user.is_authenticated: if next_url: diff --git a/app/utils.py b/app/utils.py index a55e2dbd..722a39d0 100644 --- a/app/utils.py +++ b/app/utils.py @@ -3,6 +3,7 @@ import string import time import urllib.parse from functools import wraps +from typing import Optional from unidecode import unidecode @@ -74,6 +75,14 @@ def sanitize_email(email_address: str, not_lower=False) -> str: return email_address +def sanitize_next_url(url: Optional[str]) -> Optional[str]: + if url is None or len(url) == 0: + return None + if url[0] != "/": + return None + return url + + def query2str(query): """Useful utility method to print out a SQLAlchemy query""" return query.statement.compile(compile_kwargs={"literal_binds": True}) diff --git a/tests/test_utils.py b/tests/test_utils.py index d11438ab..24a10663 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,4 @@ -from app.utils import random_string, random_words +from app.utils import random_string, random_words, sanitize_next_url def test_random_words(): @@ -9,3 +9,18 @@ def test_random_words(): def test_random_string(): s = random_string() assert len(s) > 0 + + +def test_sanitize_url(): + cases = [ + {"url": None, "expected": None}, + {"url": "", "expected": None}, + {"url": "https://badzone.org", "expected": None}, + {"url": "/", "expected": "/"}, + {"url": "/auth", "expected": "/auth"}, + {"url": "/some/path", "expected": "/some/path"}, + ] + + for case in cases: + res = sanitize_next_url(case["url"]) + assert res == case["expected"]