diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index 45584027..e8983797 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -15,7 +15,7 @@ from app.config import ( ) from app.proton.proton_client import HttpProtonClient, convert_access_token from app.proton.proton_callback_handler import ProtonCallbackHandler, Action -from app.utils import encode_url, sanitize_next_url +from app.utils import sanitize_next_url _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _token_url = PROTON_BASE_URL + "/oauth/token" @@ -51,11 +51,8 @@ def proton_login(): next_url = sanitize_next_url(request.args.get("next")) if next_url: - redirect_uri = _redirect_uri + "?next=" + encode_url(next_url) - else: - redirect_uri = _redirect_uri - - proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=redirect_uri) + session["oauth_next"] = next_url + proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=_redirect_uri) authorization_url, state = proton.authorization_url(_authorization_base_url) # State is used to prevent CSRF, keep this for later. @@ -120,5 +117,5 @@ def proton_callback(): if res.redirect: return redirect(res.redirect) - next_url = request.args.get("next") if request.args else None + next_url = session.get("oauth_next") return after_login(res.user, next_url) diff --git a/templates/auth/login.html b/templates/auth/login.html index 0d70e92d..b1f03811 100644 --- a/templates/auth/login.html +++ b/templates/auth/login.html @@ -45,7 +45,7 @@ {% if connect_with_proton %}
or
- Log in with Proton + Log in with Proton {% endif %}