diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index f8b28b17..cb9fc03d 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -31,6 +31,9 @@ _token_url = PROTON_BASE_URL + "/oauth/token" # when served behind nginx, the redirect_uri is localhost... and not the real url _redirect_uri = URL + "/auth/proton/callback" +SESSION_ACTION_KEY = "oauth_action" +SESSION_STATE_KEY = "oauth_state" + def extract_action() -> Action: action = request.args.get("action") @@ -43,7 +46,7 @@ def extract_action() -> Action: def get_action_from_state() -> Action: - oauth_action = session["oauth_action"] + oauth_action = session[SESSION_ACTION_KEY] if oauth_action == Action.Login.value: return Action.Login elif oauth_action == Action.Link.value: @@ -65,13 +68,16 @@ def proton_login(): authorization_url, state = proton.authorization_url(_authorization_base_url) # State is used to prevent CSRF, keep this for later. - session["oauth_state"] = state - session["oauth_action"] = extract_action().value + session[SESSION_STATE_KEY] = state + session[SESSION_ACTION_KEY] = extract_action().value return redirect(authorization_url) @auth_bp.route("/proton/callback") def proton_callback(): + if SESSION_STATE_KEY not in session or SESSION_STATE_KEY not in session: + flash("Invalid state, please retry", "error") + return redirect(url_for("auth.login")) if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: return redirect(url_for("auth.login")) @@ -82,7 +88,7 @@ def proton_callback(): proton = OAuth2Session( PROTON_CLIENT_ID, - state=session["oauth_state"], + state=session[SESSION_STATE_KEY], redirect_uri=_redirect_uri, )