Support next with Proton Link (#1226)
* Support next with Proton Link * Add support for double next * Fix bug on account relink
This commit is contained in:
parent
3a75686898
commit
596dd0b1ee
|
@ -276,7 +276,7 @@ def process_link_case(
|
|||
return link_user(link_request, current_user, partner)
|
||||
|
||||
# There is a SL user registered with the partner. Check if is the current one
|
||||
if partner_user.id == current_user.id:
|
||||
if partner_user.user_id == current_user.id:
|
||||
# Update plan
|
||||
set_plan_for_partner_user(partner_user, link_request.plan)
|
||||
# It's the same user. No need to do anything
|
||||
|
@ -285,5 +285,4 @@ def process_link_case(
|
|||
strategy="Link",
|
||||
)
|
||||
else:
|
||||
|
||||
return switch_already_linked_user(link_request, partner_user, current_user)
|
||||
|
|
|
@ -23,7 +23,7 @@ from app.proton.proton_callback_handler import (
|
|||
Action,
|
||||
)
|
||||
from app.proton.utils import get_proton_partner
|
||||
from app.utils import sanitize_next_url
|
||||
from app.utils import sanitize_next_url, sanitize_scheme
|
||||
|
||||
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
|
||||
_token_url = PROTON_BASE_URL + "/oauth/token"
|
||||
|
@ -34,6 +34,7 @@ _redirect_uri = URL + "/auth/proton/callback"
|
|||
|
||||
SESSION_ACTION_KEY = "oauth_action"
|
||||
SESSION_STATE_KEY = "oauth_state"
|
||||
DEFAULT_SCHEME = "auth.simplelogin"
|
||||
|
||||
|
||||
def get_api_key_for_user(user: User) -> str:
|
||||
|
@ -75,6 +76,12 @@ def proton_login():
|
|||
elif "oauth_next" in session:
|
||||
del session["oauth_next"]
|
||||
|
||||
scheme = sanitize_scheme(request.args.get("scheme"))
|
||||
if scheme:
|
||||
session["oauth_scheme"] = scheme
|
||||
elif "oauth_scheme" in session:
|
||||
del session["oauth_scheme"]
|
||||
|
||||
mode = request.args.get("mode", "session")
|
||||
if mode == "apikey":
|
||||
session["oauth_mode"] = "apikey"
|
||||
|
@ -146,6 +153,7 @@ def proton_callback():
|
|||
handler = ProtonCallbackHandler(proton_client)
|
||||
proton_partner = get_proton_partner()
|
||||
|
||||
next_url = session.get("oauth_next")
|
||||
if action == Action.Login:
|
||||
res = handler.handle_login(proton_partner)
|
||||
elif action == Action.Link:
|
||||
|
@ -156,15 +164,17 @@ def proton_callback():
|
|||
if res.flash_message is not None:
|
||||
flash(res.flash_message, res.flash_category)
|
||||
|
||||
oauth_scheme = session.get("oauth_scheme")
|
||||
if session.get("oauth_mode", "session") == "apikey":
|
||||
apikey = get_api_key_for_user(res.user)
|
||||
return redirect(f"auth.simplelogin://callback?apikey={apikey}")
|
||||
scheme = oauth_scheme or DEFAULT_SCHEME
|
||||
return redirect(f"{scheme}:///login_callback?apikey={apikey}")
|
||||
|
||||
if res.redirect_to_login:
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
if res.redirect:
|
||||
return after_login(res.user, res.redirect, login_from_proton=True)
|
||||
if next_url and next_url[0] == "/" and oauth_scheme:
|
||||
next_url = f"{oauth_scheme}://{next_url}"
|
||||
|
||||
next_url = session.get("oauth_next")
|
||||
return after_login(res.user, next_url, login_from_proton=True)
|
||||
redirect_url = next_url or res.redirect
|
||||
return after_login(res.user, redirect_url, login_from_proton=True)
|
||||
|
|
|
@ -64,7 +64,9 @@ class ProtonCallbackHandler:
|
|||
)
|
||||
|
||||
def handle_link(
|
||||
self, current_user: Optional[User], partner: Partner
|
||||
self,
|
||||
current_user: Optional[User],
|
||||
partner: Partner,
|
||||
) -> ProtonCallbackResult:
|
||||
if current_user is None:
|
||||
raise Exception("Cannot link account with current_user being None")
|
||||
|
|
14
app/utils.py
14
app/utils.py
|
@ -1,3 +1,4 @@
|
|||
import re
|
||||
import secrets
|
||||
import string
|
||||
import time
|
||||
|
@ -88,6 +89,8 @@ class NextUrlSanitizer:
|
|||
else:
|
||||
return None
|
||||
if result.path and result.path[0] == "/" and not result.path.startswith("//"):
|
||||
if result.query:
|
||||
return f"{result.path}?{result.query}"
|
||||
return result.path
|
||||
|
||||
return None
|
||||
|
@ -97,6 +100,17 @@ def sanitize_next_url(url: Optional[str]) -> Optional[str]:
|
|||
return NextUrlSanitizer.sanitize(url, ALLOWED_REDIRECT_DOMAINS)
|
||||
|
||||
|
||||
def sanitize_scheme(scheme: Optional[str]) -> Optional[str]:
|
||||
if not scheme:
|
||||
return None
|
||||
if scheme in ["http", "https"]:
|
||||
return None
|
||||
scheme_regex = re.compile("^[a-z.]+$")
|
||||
if scheme_regex.match(scheme):
|
||||
return scheme
|
||||
return None
|
||||
|
||||
|
||||
def query2str(query):
|
||||
"""Useful utility method to print out a SQLAlchemy query"""
|
||||
return query.statement.compile(compile_kwargs={"literal_binds": True})
|
||||
|
|
Loading…
Reference in New Issue