From a608503df66e1278ca483f6a06e6a886bae201a1 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Daniel=20M=C3=BChlbachler-Pietrzykowski?= Date: Wed, 13 Mar 2024 14:30:00 +0100 Subject: [PATCH] feat: add generic OIDC connect (#2046) --- app/auth/__init__.py | 2 + app/auth/views/login.py | 4 +- app/auth/views/oidc.py | 131 ++++++++++++ app/auth/views/register.py | 4 +- app/config.py | 11 +- app/dashboard/views/enter_sudo.py | 12 +- example.env | 10 + templates/auth/login.html | 12 +- templates/auth/register.html | 12 +- templates/dashboard/enter_sudo.html | 15 +- tests/auth/test_oidc.py | 304 ++++++++++++++++++++++++++++ tests/test.env | 10 + 12 files changed, 517 insertions(+), 10 deletions(-) create mode 100644 app/auth/views/oidc.py create mode 100644 tests/auth/test_oidc.py diff --git a/app/auth/__init__.py b/app/auth/__init__.py index e8adcc6c..7a6cdeeb 100644 --- a/app/auth/__init__.py +++ b/app/auth/__init__.py @@ -16,6 +16,7 @@ from .views import ( social, recovery, api_to_cookie, + oidc, ) __all__ = [ @@ -36,4 +37,5 @@ __all__ = [ "social", "recovery", "api_to_cookie", + "oidc", ] diff --git a/app/auth/views/login.py b/app/auth/views/login.py index 56b2ac36..261c1a98 100644 --- a/app/auth/views/login.py +++ b/app/auth/views/login.py @@ -5,7 +5,7 @@ from wtforms import StringField, validators from app.auth.base import auth_bp from app.auth.views.login_utils import after_login -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON, OIDC_CLIENT_ID from app.events.auth_event import LoginEvent from app.extensions import limiter from app.log import LOG @@ -77,4 +77,6 @@ def login(): next_url=next_url, show_resend_activation=show_resend_activation, connect_with_proton=CONNECT_WITH_PROTON, + connect_with_oidc=OIDC_CLIENT_ID is not None, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/app/auth/views/oidc.py b/app/auth/views/oidc.py new file mode 100644 index 00000000..12c4e491 --- /dev/null +++ b/app/auth/views/oidc.py @@ -0,0 +1,131 @@ +from flask import request, session, redirect, flash, url_for +from requests_oauthlib import OAuth2Session + +from app import config +from app.auth.base import auth_bp +from app.auth.views.login_utils import after_login +from app.config import ( + URL, + OIDC_AUTHORIZATION_URL, + OIDC_USER_INFO_URL, + OIDC_TOKEN_URL, + OIDC_SCOPES, + OIDC_NAME_FIELD, +) +from app.db import Session +from app.email_utils import send_welcome_email +from app.log import LOG +from app.models import User, SocialAuth +from app.utils import encode_url, sanitize_email, sanitize_next_url + + +# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri +# when served behind nginx, the redirect_uri is localhost... and not the real url +_redirect_uri = URL + "/auth/oidc/callback" + +SESSION_STATE_KEY = "oauth_state" + + +@auth_bp.route("/oidc/login") +def oidc_login(): + if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + next_url = sanitize_next_url(request.args.get("next")) + if next_url: + redirect_uri = _redirect_uri + "?next=" + encode_url(next_url) + else: + redirect_uri = _redirect_uri + + oidc = OAuth2Session( + config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri + ) + authorization_url, state = oidc.authorization_url(OIDC_AUTHORIZATION_URL) + + # State is used to prevent CSRF, keep this for later. + session[SESSION_STATE_KEY] = state + return redirect(authorization_url) + + +@auth_bp.route("/oidc/callback") +def oidc_callback(): + if SESSION_STATE_KEY not in session: + flash("Invalid state, please retry", "error") + return redirect(url_for("auth.login")) + if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + # user clicks on cancel + if "error" in request.args: + flash("Please use another sign in method then", "warning") + return redirect("/") + + oidc = OAuth2Session( + config.OIDC_CLIENT_ID, + state=session[SESSION_STATE_KEY], + scope=[OIDC_SCOPES], + redirect_uri=_redirect_uri, + ) + oidc.fetch_token( + OIDC_TOKEN_URL, + client_secret=config.OIDC_CLIENT_SECRET, + authorization_response=request.url, + ) + + oidc_user_data = oidc.get(OIDC_USER_INFO_URL) + if oidc_user_data.status_code != 200: + LOG.e( + f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}" + ) + flash( + "Cannot get user data from OIDC, please use another way to login/sign up", + "error", + ) + return redirect(url_for("auth.login")) + oidc_user_data = oidc_user_data.json() + + email = oidc_user_data.get("email") + + if not email: + LOG.e(f"cannot get email for OIDC user {oidc_user_data} {email}") + flash( + "Cannot get a valid email from OIDC, please another way to login/sign up", + "error", + ) + return redirect(url_for("auth.login")) + + email = sanitize_email(email) + user = User.get_by(email=email) + + if not user and config.DISABLE_REGISTRATION: + flash( + "Sorry you cannot sign up via the OIDC provider. Please sign-up first with your email.", + "error", + ) + return redirect(url_for("auth.register")) + elif not user: + user = create_user(email, oidc_user_data) + + if not SocialAuth.get_by(user_id=user.id, social="oidc"): + SocialAuth.create(user_id=user.id, social="oidc") + Session.commit() + + # The activation link contains the original page, for ex authorize page + next_url = sanitize_next_url(request.args.get("next")) if request.args else None + + return after_login(user, next_url) + + +def create_user(email, oidc_user_data): + new_user = User.create( + email=email, + name=oidc_user_data.get(OIDC_NAME_FIELD), + password="", + activated=True, + ) + LOG.i(f"Created new user for login request from OIDC. New user {new_user.id}") + Session.commit() + + send_welcome_email(new_user) + + return new_user diff --git a/app/auth/views/register.py b/app/auth/views/register.py index f40a98a2..75053039 100644 --- a/app/auth/views/register.py +++ b/app/auth/views/register.py @@ -6,7 +6,7 @@ from wtforms import StringField, validators from app import email_utils, config from app.auth.base import auth_bp -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON from app.auth.views.login_utils import get_referral from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.db import Session @@ -109,6 +109,8 @@ def register(): next_url=next_url, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, connect_with_proton=CONNECT_WITH_PROTON, + connect_with_oidc=config.OIDC_CLIENT_ID is not None, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/app/config.py b/app/config.py index e8824f5a..5f33d0c5 100644 --- a/app/config.py +++ b/app/config.py @@ -234,7 +234,7 @@ else: print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME) -# Github, Google, Facebook client id and secrets +# Github, Google, Facebook, OIDC client id and secrets GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID") GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET") @@ -244,6 +244,15 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") +CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON") +OIDC_AUTHORIZATION_URL = os.environ.get("OIDC_AUTHORIZATION_URL") +OIDC_USER_INFO_URL = os.environ.get("OIDC_USER_INFO_URL") +OIDC_TOKEN_URL = os.environ.get("OIDC_TOKEN_URL") +OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID") +OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET") +OIDC_SCOPES = os.environ.get("OIDC_SCOPES") +OIDC_NAME_FIELD = os.environ.get("OIDC_NAME_FIELD", "name") + PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID") PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET") PROTON_BASE_URL = os.environ.get( diff --git a/app/dashboard/views/enter_sudo.py b/app/dashboard/views/enter_sudo.py index d32deb87..3910873e 100644 --- a/app/dashboard/views/enter_sudo.py +++ b/app/dashboard/views/enter_sudo.py @@ -6,11 +6,11 @@ from flask_login import login_required, current_user from flask_wtf import FlaskForm from wtforms import PasswordField, validators -from app.config import CONNECT_WITH_PROTON +from app.config import CONNECT_WITH_PROTON, OIDC_CLIENT_ID, CONNECT_WITH_OIDC_ICON from app.dashboard.base import dashboard_bp from app.extensions import limiter from app.log import LOG -from app.models import PartnerUser +from app.models import PartnerUser, SocialAuth from app.proton.utils import get_proton_partner from app.utils import sanitize_next_url @@ -51,11 +51,19 @@ def enter_sudo(): if not partner_user or partner_user.partner_id != get_proton_partner().id: proton_enabled = False + oidc_enabled = OIDC_CLIENT_ID is not None + if oidc_enabled: + oidc_enabled = ( + SocialAuth.get_by(user_id=current_user.id, social="oidc") is not None + ) + return render_template( "dashboard/enter_sudo.html", password_check_form=password_check_form, next=request.args.get("next"), connect_with_proton=proton_enabled, + connect_with_oidc=oidc_enabled, + connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON, ) diff --git a/example.env b/example.env index d2e22b63..4ee09519 100644 --- a/example.env +++ b/example.env @@ -116,6 +116,16 @@ WORDS_FILE_PATH=local_data/test_words.txt # CONNECT_WITH_PROTON=true # CONNECT_WITH_PROTON_COOKIE_NAME=to_fill +# Login with OIDC +# CONNECT_WITH_OIDC_ICON=fa-github +# OIDC_AUTHORIZATION_URL=to_fill +# OIDC_USER_INFO_URL=to_fill +# OIDC_TOKEN_URL=to_fill +# OIDC_SCOPES=openid email profile +# OIDC_NAME_FIELD=name +# OIDC_CLIENT_ID=to_fill +# OIDC_CLIENT_SECRET=to_fill + # Flask profiler # FLASK_PROFILER_PATH=/tmp/flask-profiler.sql # FLASK_PROFILER_PASSWORD=password diff --git a/templates/auth/login.html b/templates/auth/login.html index 0451fd39..898e6d05 100644 --- a/templates/auth/login.html +++ b/templates/auth/login.html @@ -38,11 +38,21 @@ or + href="{{ url_for('auth.proton_login', next=next_url) }}"> Log in with Proton {% endif %} + {% if connect_with_oidc %} + +
+ or +
+ + Log in with SSO + + {% endif %}
diff --git a/templates/auth/register.html b/templates/auth/register.html index 6be4f4cf..c8400e1c 100644 --- a/templates/auth/register.html +++ b/templates/auth/register.html @@ -50,11 +50,21 @@ or
+ href="{{ url_for('auth.proton_login', next=next_url) }}"> Sign up with Proton {% endif %} + {% if connect_with_oidc %} + +
+ or +
+ + Sign up with SSO + + {% endif %}
diff --git a/templates/dashboard/enter_sudo.html b/templates/dashboard/enter_sudo.html index ed1f895d..3175f375 100644 --- a/templates/dashboard/enter_sudo.html +++ b/templates/dashboard/enter_sudo.html @@ -22,11 +22,20 @@

Alternatively you can use your Proton credentials to ensure it's you.

+ href="{{ url_for('auth.proton_login', next=next) }}"> Authenticate with Proton {% endif %} + {% if connect_with_oidc %} + +
+

Alternatively you can use your SSO credentials to ensure it's you.

+ + Authenticate with SSO + + {% endif %} +
- -{% endblock %} + {% endblock %} diff --git a/tests/auth/test_oidc.py b/tests/auth/test_oidc.py new file mode 100644 index 00000000..e35bb5e4 --- /dev/null +++ b/tests/auth/test_oidc.py @@ -0,0 +1,304 @@ +from app import config +from flask import url_for +from urllib.parse import parse_qs +from urllib3.util import parse_url +from app.auth.views.oidc import create_user +from app.utils import random_string +from unittest.mock import patch +from app.models import User + +from app.config import URL, OIDC_CLIENT_ID + + +def test_oidc_login(flask_client): + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + query = parse_qs(parsed.query) + + expected_redirect_url = f"{URL}/auth/oidc/callback" + + assert "code" == query["response_type"][0] + assert OIDC_CLIENT_ID == query["client_id"][0] + assert expected_redirect_url == query["redirect_uri"][0] + + +def test_oidc_login_no_client_id(flask_client): + config.OIDC_CLIENT_ID = None + + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_ID = "to_fill" + + +def test_oidc_login_no_client_secret(flask_client): + config.OIDC_CLIENT_SECRET = None + + r = flask_client.get( + url_for("auth.oidc_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_SECRET = "to_fill" + + +def test_oidc_callback_no_oauth_state(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is None + + +def test_oidc_callback_no_client_id(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + config.OIDC_CLIENT_ID = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_ID = "to_fill" + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +def test_oidc_callback_no_client_secret(flask_client): + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + config.OIDC_CLIENT_SECRET = None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + + config.OIDC_CLIENT_SECRET = "to_fill" + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client): + mock_get.return_value = MockResponse(400, {}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client): + mock_get.return_value = MockResponse(200, {}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/login" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_client): + config.DISABLE_REGISTRATION = True + email = random_string() + mock_get.return_value = MockResponse(200, {"email": email}) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/auth/register" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + config.DISABLE_REGISTRATION = False + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client): + email = random_string() + mock_get.return_value = MockResponse( + 200, + { + "email": email, + config.OIDC_NAME_FIELD: "name", + }, + ) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + user = User.get_by(email=email) + assert user is None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/dashboard/" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + user = User.get_by(email=email) + assert user is not None + assert user.email == email + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +@patch("requests_oauthlib.OAuth2Session.fetch_token") +@patch("requests_oauthlib.OAuth2Session.get") +def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client): + email = random_string() + mock_get.return_value = MockResponse( + 200, + { + "email": email, + }, + ) + with flask_client.session_transaction() as session: + session["oauth_state"] = "state" + + user = User.create( + email=email, + name="name", + password="", + activated=True, + ) + user = User.get_by(email=email) + assert user is not None + + r = flask_client.get( + url_for("auth.oidc_callback"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + + expected_redirect_url = "/dashboard/" + + assert expected_redirect_url == parsed.path + assert mock_get.called + + with flask_client.session_transaction() as session: + session["oauth_state"] = None + + +def test_create_user(): + email = random_string() + user = create_user( + email, + { + config.OIDC_NAME_FIELD: "name", + }, + ) + assert user.email == email + assert user.name == "name" + assert user.activated + + +class MockResponse: + def __init__(self, status_code, json_data): + self.status_code = status_code + self.json_data = json_data + self.text = "error" + + def json(self): + return self.json_data diff --git a/tests/test.env b/tests/test.env index 86d383ae..49941bee 100644 --- a/tests/test.env +++ b/tests/test.env @@ -49,6 +49,16 @@ GOOGLE_CLIENT_SECRET=to_fill FACEBOOK_CLIENT_ID=to_fill FACEBOOK_CLIENT_SECRET=to_fill +# Login with OIDC +CONNECT_WITH_OIDC_ICON=fa-github +OIDC_AUTHORIZATION_URL=to_fill +OIDC_USER_INFO_URL=to_fill +OIDC_TOKEN_URL=to_fill +OIDC_SCOPES=openid email profile +OIDC_NAME_FIELD=name +OIDC_CLIENT_ID=to_fill +OIDC_CLIENT_SECRET=to_fill + PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc ALIAS_AUTOMATIC_DISABLE=true