feat: add generic OIDC connect (#2046)
This commit is contained in:
parent
0c3c6db2ab
commit
a608503df6
|
@ -16,6 +16,7 @@ from .views import (
|
||||||
social,
|
social,
|
||||||
recovery,
|
recovery,
|
||||||
api_to_cookie,
|
api_to_cookie,
|
||||||
|
oidc,
|
||||||
)
|
)
|
||||||
|
|
||||||
__all__ = [
|
__all__ = [
|
||||||
|
@ -36,4 +37,5 @@ __all__ = [
|
||||||
"social",
|
"social",
|
||||||
"recovery",
|
"recovery",
|
||||||
"api_to_cookie",
|
"api_to_cookie",
|
||||||
|
"oidc",
|
||||||
]
|
]
|
||||||
|
|
|
@ -5,7 +5,7 @@ from wtforms import StringField, validators
|
||||||
|
|
||||||
from app.auth.base import auth_bp
|
from app.auth.base import auth_bp
|
||||||
from app.auth.views.login_utils import after_login
|
from app.auth.views.login_utils import after_login
|
||||||
from app.config import CONNECT_WITH_PROTON
|
from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON, OIDC_CLIENT_ID
|
||||||
from app.events.auth_event import LoginEvent
|
from app.events.auth_event import LoginEvent
|
||||||
from app.extensions import limiter
|
from app.extensions import limiter
|
||||||
from app.log import LOG
|
from app.log import LOG
|
||||||
|
@ -77,4 +77,6 @@ def login():
|
||||||
next_url=next_url,
|
next_url=next_url,
|
||||||
show_resend_activation=show_resend_activation,
|
show_resend_activation=show_resend_activation,
|
||||||
connect_with_proton=CONNECT_WITH_PROTON,
|
connect_with_proton=CONNECT_WITH_PROTON,
|
||||||
|
connect_with_oidc=OIDC_CLIENT_ID is not None,
|
||||||
|
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
|
||||||
)
|
)
|
||||||
|
|
|
@ -0,0 +1,131 @@
|
||||||
|
from flask import request, session, redirect, flash, url_for
|
||||||
|
from requests_oauthlib import OAuth2Session
|
||||||
|
|
||||||
|
from app import config
|
||||||
|
from app.auth.base import auth_bp
|
||||||
|
from app.auth.views.login_utils import after_login
|
||||||
|
from app.config import (
|
||||||
|
URL,
|
||||||
|
OIDC_AUTHORIZATION_URL,
|
||||||
|
OIDC_USER_INFO_URL,
|
||||||
|
OIDC_TOKEN_URL,
|
||||||
|
OIDC_SCOPES,
|
||||||
|
OIDC_NAME_FIELD,
|
||||||
|
)
|
||||||
|
from app.db import Session
|
||||||
|
from app.email_utils import send_welcome_email
|
||||||
|
from app.log import LOG
|
||||||
|
from app.models import User, SocialAuth
|
||||||
|
from app.utils import encode_url, sanitize_email, sanitize_next_url
|
||||||
|
|
||||||
|
|
||||||
|
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
|
||||||
|
# when served behind nginx, the redirect_uri is localhost... and not the real url
|
||||||
|
_redirect_uri = URL + "/auth/oidc/callback"
|
||||||
|
|
||||||
|
SESSION_STATE_KEY = "oauth_state"
|
||||||
|
|
||||||
|
|
||||||
|
@auth_bp.route("/oidc/login")
|
||||||
|
def oidc_login():
|
||||||
|
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
|
||||||
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
|
next_url = sanitize_next_url(request.args.get("next"))
|
||||||
|
if next_url:
|
||||||
|
redirect_uri = _redirect_uri + "?next=" + encode_url(next_url)
|
||||||
|
else:
|
||||||
|
redirect_uri = _redirect_uri
|
||||||
|
|
||||||
|
oidc = OAuth2Session(
|
||||||
|
config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri
|
||||||
|
)
|
||||||
|
authorization_url, state = oidc.authorization_url(OIDC_AUTHORIZATION_URL)
|
||||||
|
|
||||||
|
# State is used to prevent CSRF, keep this for later.
|
||||||
|
session[SESSION_STATE_KEY] = state
|
||||||
|
return redirect(authorization_url)
|
||||||
|
|
||||||
|
|
||||||
|
@auth_bp.route("/oidc/callback")
|
||||||
|
def oidc_callback():
|
||||||
|
if SESSION_STATE_KEY not in session:
|
||||||
|
flash("Invalid state, please retry", "error")
|
||||||
|
return redirect(url_for("auth.login"))
|
||||||
|
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
|
||||||
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
|
# user clicks on cancel
|
||||||
|
if "error" in request.args:
|
||||||
|
flash("Please use another sign in method then", "warning")
|
||||||
|
return redirect("/")
|
||||||
|
|
||||||
|
oidc = OAuth2Session(
|
||||||
|
config.OIDC_CLIENT_ID,
|
||||||
|
state=session[SESSION_STATE_KEY],
|
||||||
|
scope=[OIDC_SCOPES],
|
||||||
|
redirect_uri=_redirect_uri,
|
||||||
|
)
|
||||||
|
oidc.fetch_token(
|
||||||
|
OIDC_TOKEN_URL,
|
||||||
|
client_secret=config.OIDC_CLIENT_SECRET,
|
||||||
|
authorization_response=request.url,
|
||||||
|
)
|
||||||
|
|
||||||
|
oidc_user_data = oidc.get(OIDC_USER_INFO_URL)
|
||||||
|
if oidc_user_data.status_code != 200:
|
||||||
|
LOG.e(
|
||||||
|
f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}"
|
||||||
|
)
|
||||||
|
flash(
|
||||||
|
"Cannot get user data from OIDC, please use another way to login/sign up",
|
||||||
|
"error",
|
||||||
|
)
|
||||||
|
return redirect(url_for("auth.login"))
|
||||||
|
oidc_user_data = oidc_user_data.json()
|
||||||
|
|
||||||
|
email = oidc_user_data.get("email")
|
||||||
|
|
||||||
|
if not email:
|
||||||
|
LOG.e(f"cannot get email for OIDC user {oidc_user_data} {email}")
|
||||||
|
flash(
|
||||||
|
"Cannot get a valid email from OIDC, please another way to login/sign up",
|
||||||
|
"error",
|
||||||
|
)
|
||||||
|
return redirect(url_for("auth.login"))
|
||||||
|
|
||||||
|
email = sanitize_email(email)
|
||||||
|
user = User.get_by(email=email)
|
||||||
|
|
||||||
|
if not user and config.DISABLE_REGISTRATION:
|
||||||
|
flash(
|
||||||
|
"Sorry you cannot sign up via the OIDC provider. Please sign-up first with your email.",
|
||||||
|
"error",
|
||||||
|
)
|
||||||
|
return redirect(url_for("auth.register"))
|
||||||
|
elif not user:
|
||||||
|
user = create_user(email, oidc_user_data)
|
||||||
|
|
||||||
|
if not SocialAuth.get_by(user_id=user.id, social="oidc"):
|
||||||
|
SocialAuth.create(user_id=user.id, social="oidc")
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
# The activation link contains the original page, for ex authorize page
|
||||||
|
next_url = sanitize_next_url(request.args.get("next")) if request.args else None
|
||||||
|
|
||||||
|
return after_login(user, next_url)
|
||||||
|
|
||||||
|
|
||||||
|
def create_user(email, oidc_user_data):
|
||||||
|
new_user = User.create(
|
||||||
|
email=email,
|
||||||
|
name=oidc_user_data.get(OIDC_NAME_FIELD),
|
||||||
|
password="",
|
||||||
|
activated=True,
|
||||||
|
)
|
||||||
|
LOG.i(f"Created new user for login request from OIDC. New user {new_user.id}")
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
send_welcome_email(new_user)
|
||||||
|
|
||||||
|
return new_user
|
|
@ -6,7 +6,7 @@ from wtforms import StringField, validators
|
||||||
|
|
||||||
from app import email_utils, config
|
from app import email_utils, config
|
||||||
from app.auth.base import auth_bp
|
from app.auth.base import auth_bp
|
||||||
from app.config import CONNECT_WITH_PROTON
|
from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON
|
||||||
from app.auth.views.login_utils import get_referral
|
from app.auth.views.login_utils import get_referral
|
||||||
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
|
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
|
@ -109,6 +109,8 @@ def register():
|
||||||
next_url=next_url,
|
next_url=next_url,
|
||||||
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
|
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
|
||||||
connect_with_proton=CONNECT_WITH_PROTON,
|
connect_with_proton=CONNECT_WITH_PROTON,
|
||||||
|
connect_with_oidc=config.OIDC_CLIENT_ID is not None,
|
||||||
|
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
|
@ -234,7 +234,7 @@ else:
|
||||||
|
|
||||||
print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME)
|
print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME)
|
||||||
|
|
||||||
# Github, Google, Facebook client id and secrets
|
# Github, Google, Facebook, OIDC client id and secrets
|
||||||
GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID")
|
GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID")
|
||||||
GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET")
|
GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET")
|
||||||
|
|
||||||
|
@ -244,6 +244,15 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
|
||||||
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
|
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
|
||||||
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
|
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
|
||||||
|
|
||||||
|
CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON")
|
||||||
|
OIDC_AUTHORIZATION_URL = os.environ.get("OIDC_AUTHORIZATION_URL")
|
||||||
|
OIDC_USER_INFO_URL = os.environ.get("OIDC_USER_INFO_URL")
|
||||||
|
OIDC_TOKEN_URL = os.environ.get("OIDC_TOKEN_URL")
|
||||||
|
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID")
|
||||||
|
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET")
|
||||||
|
OIDC_SCOPES = os.environ.get("OIDC_SCOPES")
|
||||||
|
OIDC_NAME_FIELD = os.environ.get("OIDC_NAME_FIELD", "name")
|
||||||
|
|
||||||
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
|
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
|
||||||
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
|
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
|
||||||
PROTON_BASE_URL = os.environ.get(
|
PROTON_BASE_URL = os.environ.get(
|
||||||
|
|
|
@ -6,11 +6,11 @@ from flask_login import login_required, current_user
|
||||||
from flask_wtf import FlaskForm
|
from flask_wtf import FlaskForm
|
||||||
from wtforms import PasswordField, validators
|
from wtforms import PasswordField, validators
|
||||||
|
|
||||||
from app.config import CONNECT_WITH_PROTON
|
from app.config import CONNECT_WITH_PROTON, OIDC_CLIENT_ID, CONNECT_WITH_OIDC_ICON
|
||||||
from app.dashboard.base import dashboard_bp
|
from app.dashboard.base import dashboard_bp
|
||||||
from app.extensions import limiter
|
from app.extensions import limiter
|
||||||
from app.log import LOG
|
from app.log import LOG
|
||||||
from app.models import PartnerUser
|
from app.models import PartnerUser, SocialAuth
|
||||||
from app.proton.utils import get_proton_partner
|
from app.proton.utils import get_proton_partner
|
||||||
from app.utils import sanitize_next_url
|
from app.utils import sanitize_next_url
|
||||||
|
|
||||||
|
@ -51,11 +51,19 @@ def enter_sudo():
|
||||||
if not partner_user or partner_user.partner_id != get_proton_partner().id:
|
if not partner_user or partner_user.partner_id != get_proton_partner().id:
|
||||||
proton_enabled = False
|
proton_enabled = False
|
||||||
|
|
||||||
|
oidc_enabled = OIDC_CLIENT_ID is not None
|
||||||
|
if oidc_enabled:
|
||||||
|
oidc_enabled = (
|
||||||
|
SocialAuth.get_by(user_id=current_user.id, social="oidc") is not None
|
||||||
|
)
|
||||||
|
|
||||||
return render_template(
|
return render_template(
|
||||||
"dashboard/enter_sudo.html",
|
"dashboard/enter_sudo.html",
|
||||||
password_check_form=password_check_form,
|
password_check_form=password_check_form,
|
||||||
next=request.args.get("next"),
|
next=request.args.get("next"),
|
||||||
connect_with_proton=proton_enabled,
|
connect_with_proton=proton_enabled,
|
||||||
|
connect_with_oidc=oidc_enabled,
|
||||||
|
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
|
10
example.env
10
example.env
|
@ -116,6 +116,16 @@ WORDS_FILE_PATH=local_data/test_words.txt
|
||||||
# CONNECT_WITH_PROTON=true
|
# CONNECT_WITH_PROTON=true
|
||||||
# CONNECT_WITH_PROTON_COOKIE_NAME=to_fill
|
# CONNECT_WITH_PROTON_COOKIE_NAME=to_fill
|
||||||
|
|
||||||
|
# Login with OIDC
|
||||||
|
# CONNECT_WITH_OIDC_ICON=fa-github
|
||||||
|
# OIDC_AUTHORIZATION_URL=to_fill
|
||||||
|
# OIDC_USER_INFO_URL=to_fill
|
||||||
|
# OIDC_TOKEN_URL=to_fill
|
||||||
|
# OIDC_SCOPES=openid email profile
|
||||||
|
# OIDC_NAME_FIELD=name
|
||||||
|
# OIDC_CLIENT_ID=to_fill
|
||||||
|
# OIDC_CLIENT_SECRET=to_fill
|
||||||
|
|
||||||
# Flask profiler
|
# Flask profiler
|
||||||
# FLASK_PROFILER_PATH=/tmp/flask-profiler.sql
|
# FLASK_PROFILER_PATH=/tmp/flask-profiler.sql
|
||||||
# FLASK_PROFILER_PASSWORD=password
|
# FLASK_PROFILER_PASSWORD=password
|
||||||
|
|
|
@ -38,11 +38,21 @@
|
||||||
<span>or</span>
|
<span>or</span>
|
||||||
</div>
|
</div>
|
||||||
<a class="btn btn-primary btn-block mt-2 proton-button"
|
<a class="btn btn-primary btn-block mt-2 proton-button"
|
||||||
href="{{ url_for("auth.proton_login", next=next_url) }}">
|
href="{{ url_for('auth.proton_login', next=next_url) }}">
|
||||||
<img class="mr-2" src="/static/images/proton.svg" />
|
<img class="mr-2" src="/static/images/proton.svg" />
|
||||||
Log in with Proton
|
Log in with Proton
|
||||||
</a>
|
</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if connect_with_oidc %}
|
||||||
|
|
||||||
|
<div class="text-center my-2 text-gray">
|
||||||
|
<span>or</span>
|
||||||
|
</div>
|
||||||
|
<a class="btn btn-primary btn-block mt-2 btn-social"
|
||||||
|
href="{{ url_for('auth.oidc_login', next=next_url) }}">
|
||||||
|
<i class="fa {{ connect_with_oidc_icon }}"></i> Log in with SSO
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
</div>
|
</div>
|
||||||
<div class="text-center text-muted mt-2">
|
<div class="text-center text-muted mt-2">
|
||||||
|
|
|
@ -50,11 +50,21 @@
|
||||||
<span>or</span>
|
<span>or</span>
|
||||||
</div>
|
</div>
|
||||||
<a class="btn btn-primary btn-block mt-2 proton-button"
|
<a class="btn btn-primary btn-block mt-2 proton-button"
|
||||||
href="{{ url_for("auth.proton_login", next=next_url) }}">
|
href="{{ url_for('auth.proton_login', next=next_url) }}">
|
||||||
<img class="mr-2" src="/static/images/proton.svg" />
|
<img class="mr-2" src="/static/images/proton.svg" />
|
||||||
Sign up with Proton
|
Sign up with Proton
|
||||||
</a>
|
</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if connect_with_oidc %}
|
||||||
|
|
||||||
|
<div class="text-center my-2 text-gray">
|
||||||
|
<span>or</span>
|
||||||
|
</div>
|
||||||
|
<a class="btn btn-primary btn-block mt-2 btn-social"
|
||||||
|
href="{{ url_for('auth.oidc_login', next=next_url) }}">
|
||||||
|
<i class="fa {{ connect_with_oidc_icon }}"></i> Sign up with SSO
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
</div>
|
</div>
|
||||||
</form>
|
</form>
|
||||||
<div class="text-center text-muted mb-6">
|
<div class="text-center text-muted mb-6">
|
||||||
|
|
|
@ -22,11 +22,20 @@
|
||||||
<p>Alternatively you can use your Proton credentials to ensure it's you.</p>
|
<p>Alternatively you can use your Proton credentials to ensure it's you.</p>
|
||||||
</div>
|
</div>
|
||||||
<a class="btn btn-primary btn-block mt-2 proton-button w-25"
|
<a class="btn btn-primary btn-block mt-2 proton-button w-25"
|
||||||
href="{{ url_for("auth.proton_login", next=next) }}">
|
href="{{ url_for('auth.proton_login', next=next) }}">
|
||||||
<img class="mr-2" src="/static/images/proton.svg" />
|
<img class="mr-2" src="/static/images/proton.svg" />
|
||||||
Authenticate with Proton
|
Authenticate with Proton
|
||||||
</a>
|
</a>
|
||||||
{% endif %}
|
{% endif %}
|
||||||
|
{% if connect_with_oidc %}
|
||||||
|
|
||||||
|
<div class="my-3">
|
||||||
|
<p>Alternatively you can use your SSO credentials to ensure it's you.</p>
|
||||||
|
<a class="btn btn-primary btn-block mt-2 btn-social w-25"
|
||||||
|
href="{{ url_for('auth.oidc_login', next=next) }}">
|
||||||
|
<i class="fa {{ connect_with_oidc_icon }}"></i> Authenticate with SSO
|
||||||
|
</a>
|
||||||
|
{% endif %}
|
||||||
|
</div>
|
||||||
</div>
|
</div>
|
||||||
</div>
|
{% endblock %}
|
||||||
{% endblock %}
|
|
||||||
|
|
|
@ -0,0 +1,304 @@
|
||||||
|
from app import config
|
||||||
|
from flask import url_for
|
||||||
|
from urllib.parse import parse_qs
|
||||||
|
from urllib3.util import parse_url
|
||||||
|
from app.auth.views.oidc import create_user
|
||||||
|
from app.utils import random_string
|
||||||
|
from unittest.mock import patch
|
||||||
|
from app.models import User
|
||||||
|
|
||||||
|
from app.config import URL, OIDC_CLIENT_ID
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_login(flask_client):
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_login"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
query = parse_qs(parsed.query)
|
||||||
|
|
||||||
|
expected_redirect_url = f"{URL}/auth/oidc/callback"
|
||||||
|
|
||||||
|
assert "code" == query["response_type"][0]
|
||||||
|
assert OIDC_CLIENT_ID == query["client_id"][0]
|
||||||
|
assert expected_redirect_url == query["redirect_uri"][0]
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_login_no_client_id(flask_client):
|
||||||
|
config.OIDC_CLIENT_ID = None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_login"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
|
||||||
|
config.OIDC_CLIENT_ID = "to_fill"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_login_no_client_secret(flask_client):
|
||||||
|
config.OIDC_CLIENT_SECRET = None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_login"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
|
||||||
|
config.OIDC_CLIENT_SECRET = "to_fill"
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_callback_no_oauth_state(flask_client):
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_callback_no_client_id(flask_client):
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
config.OIDC_CLIENT_ID = None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
|
||||||
|
config.OIDC_CLIENT_ID = "to_fill"
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
def test_oidc_callback_no_client_secret(flask_client):
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
config.OIDC_CLIENT_SECRET = None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
|
||||||
|
config.OIDC_CLIENT_SECRET = "to_fill"
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.fetch_token")
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.get")
|
||||||
|
def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client):
|
||||||
|
mock_get.return_value = MockResponse(400, {})
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.fetch_token")
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.get")
|
||||||
|
def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client):
|
||||||
|
mock_get.return_value = MockResponse(200, {})
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/login"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.fetch_token")
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.get")
|
||||||
|
def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_client):
|
||||||
|
config.DISABLE_REGISTRATION = True
|
||||||
|
email = random_string()
|
||||||
|
mock_get.return_value = MockResponse(200, {"email": email})
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/auth/register"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
config.DISABLE_REGISTRATION = False
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.fetch_token")
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.get")
|
||||||
|
def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client):
|
||||||
|
email = random_string()
|
||||||
|
mock_get.return_value = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"email": email,
|
||||||
|
config.OIDC_NAME_FIELD: "name",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
|
||||||
|
user = User.get_by(email=email)
|
||||||
|
assert user is None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/dashboard/"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
user = User.get_by(email=email)
|
||||||
|
assert user is not None
|
||||||
|
assert user.email == email
|
||||||
|
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.fetch_token")
|
||||||
|
@patch("requests_oauthlib.OAuth2Session.get")
|
||||||
|
def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client):
|
||||||
|
email = random_string()
|
||||||
|
mock_get.return_value = MockResponse(
|
||||||
|
200,
|
||||||
|
{
|
||||||
|
"email": email,
|
||||||
|
},
|
||||||
|
)
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = "state"
|
||||||
|
|
||||||
|
user = User.create(
|
||||||
|
email=email,
|
||||||
|
name="name",
|
||||||
|
password="",
|
||||||
|
activated=True,
|
||||||
|
)
|
||||||
|
user = User.get_by(email=email)
|
||||||
|
assert user is not None
|
||||||
|
|
||||||
|
r = flask_client.get(
|
||||||
|
url_for("auth.oidc_callback"),
|
||||||
|
follow_redirects=False,
|
||||||
|
)
|
||||||
|
location = r.headers.get("Location")
|
||||||
|
assert location is not None
|
||||||
|
|
||||||
|
parsed = parse_url(location)
|
||||||
|
|
||||||
|
expected_redirect_url = "/dashboard/"
|
||||||
|
|
||||||
|
assert expected_redirect_url == parsed.path
|
||||||
|
assert mock_get.called
|
||||||
|
|
||||||
|
with flask_client.session_transaction() as session:
|
||||||
|
session["oauth_state"] = None
|
||||||
|
|
||||||
|
|
||||||
|
def test_create_user():
|
||||||
|
email = random_string()
|
||||||
|
user = create_user(
|
||||||
|
email,
|
||||||
|
{
|
||||||
|
config.OIDC_NAME_FIELD: "name",
|
||||||
|
},
|
||||||
|
)
|
||||||
|
assert user.email == email
|
||||||
|
assert user.name == "name"
|
||||||
|
assert user.activated
|
||||||
|
|
||||||
|
|
||||||
|
class MockResponse:
|
||||||
|
def __init__(self, status_code, json_data):
|
||||||
|
self.status_code = status_code
|
||||||
|
self.json_data = json_data
|
||||||
|
self.text = "error"
|
||||||
|
|
||||||
|
def json(self):
|
||||||
|
return self.json_data
|
|
@ -49,6 +49,16 @@ GOOGLE_CLIENT_SECRET=to_fill
|
||||||
FACEBOOK_CLIENT_ID=to_fill
|
FACEBOOK_CLIENT_ID=to_fill
|
||||||
FACEBOOK_CLIENT_SECRET=to_fill
|
FACEBOOK_CLIENT_SECRET=to_fill
|
||||||
|
|
||||||
|
# Login with OIDC
|
||||||
|
CONNECT_WITH_OIDC_ICON=fa-github
|
||||||
|
OIDC_AUTHORIZATION_URL=to_fill
|
||||||
|
OIDC_USER_INFO_URL=to_fill
|
||||||
|
OIDC_TOKEN_URL=to_fill
|
||||||
|
OIDC_SCOPES=openid email profile
|
||||||
|
OIDC_NAME_FIELD=name
|
||||||
|
OIDC_CLIENT_ID=to_fill
|
||||||
|
OIDC_CLIENT_SECRET=to_fill
|
||||||
|
|
||||||
PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc
|
PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc
|
||||||
|
|
||||||
ALIAS_AUTOMATIC_DISABLE=true
|
ALIAS_AUTOMATIC_DISABLE=true
|
||||||
|
|
Loading…
Reference in New Issue