Fixes for connect with proton on mobile (#1230)

* Fixes for connect with proton on mobile

* Added a test

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Carlos Quintana 2022-08-12 13:17:21 +02:00 committed by GitHub
parent 7476bdde4b
commit 7eb44a5947
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
7 changed files with 80 additions and 22 deletions

View File

@ -12,6 +12,7 @@ from app.models import (
SenderFormatEnum, SenderFormatEnum,
AliasSuffixEnum, AliasSuffixEnum,
) )
from app.proton.utils import perform_proton_account_unlink
def setting_to_dict(user: User): def setting_to_dict(user: User):
@ -137,3 +138,11 @@ def get_available_domains_for_random_alias_v2():
] ]
return jsonify(ret) return jsonify(ret)
@api_bp.route("/setting/unlink_proton_account", methods=["DELETE"])
@require_api_auth
def unlink_proton_account():
user = g.user
perform_proton_account_unlink(user)
return jsonify({"ok": True})

View File

@ -3,13 +3,23 @@ from io import BytesIO
from flask import jsonify, g, request, make_response from flask import jsonify, g, request, make_response
from flask_login import logout_user from flask_login import logout_user
from typing import Optional
from app import s3 from app import s3
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.db import Session from app.db import Session
from app.models import ApiKey, File, User from app.models import ApiKey, File, PartnerUser, User
from app.utils import random_string from app.utils import random_string
from app.proton.utils import get_proton_partner
def get_connected_proton_address(user: User) -> Optional[str]:
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(user_id=user.id, partner_id=proton_partner.id)
if partner_user is None:
return None
return partner_user.partner_email
def user_to_dict(user: User) -> dict: def user_to_dict(user: User) -> dict:
@ -19,6 +29,7 @@ def user_to_dict(user: User) -> dict:
"email": user.email, "email": user.email,
"in_trial": user.in_trial(), "in_trial": user.in_trial(),
"max_alias_free_plan": user.max_alias_for_free_account(), "max_alias_free_plan": user.max_alias_for_free_account(),
"connected_proton_address": get_connected_proton_address(user),
} }
if user.profile_picture_id: if user.profile_picture_id:
@ -41,6 +52,7 @@ def user_info():
- email - email
- in_trial - in_trial
- max_alias_free - max_alias_free
- is_connected_with_proton
""" """
user = g.user user = g.user

View File

@ -1,18 +1,14 @@
import arrow import arrow
from flask import redirect, url_for, request, flash from flask import redirect, url_for, request, flash
from flask_login import current_user, login_user from flask_login import login_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.log import LOG
from app.models import ApiToCookieToken from app.models import ApiToCookieToken
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
@auth_bp.route("/api_to_cookie", methods=["GET"]) @auth_bp.route("/api_to_cookie", methods=["GET"])
def api_to_cookie(): def api_to_cookie():
if current_user.is_authenticated:
LOG.d("user is already authenticated, redirect to dashboard")
return redirect(url_for("dashboard.index"))
code = request.args.get("token") code = request.args.get("token")
if not code: if not code:
flash("Missing token", "error") flash("Missing token", "error")
@ -23,8 +19,9 @@ def api_to_cookie():
flash("Missing token", "error") flash("Missing token", "error")
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
login_user(token.user) user = token.user
ApiToCookieToken.delete(token.id, commit=True) ApiToCookieToken.delete(token.id, commit=True)
login_user(user)
next_url = sanitize_next_url(request.args.get("next")) next_url = sanitize_next_url(request.args.get("next"))
if next_url: if next_url:

View File

@ -51,6 +51,8 @@ def extract_action() -> Action:
if action is not None: if action is not None:
if action == "link": if action == "link":
return Action.Link return Action.Link
elif action == "login":
return Action.Login
else: else:
raise Exception(f"Unknown action: {action}") raise Exception(f"Unknown action: {action}")
return Action.Login return Action.Login
@ -70,6 +72,10 @@ def proton_login():
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
action = extract_action()
if action == Action.Link and not current_user.is_authenticated:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next")) next_url = sanitize_next_url(request.args.get("next"))
if next_url: if next_url:
session["oauth_next"] = next_url session["oauth_next"] = next_url
@ -93,7 +99,7 @@ def proton_login():
# State is used to prevent CSRF, keep this for later. # State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state session[SESSION_STATE_KEY] = state
session[SESSION_ACTION_KEY] = extract_action().value session[SESSION_ACTION_KEY] = action.value
return redirect(authorization_url) return redirect(authorization_url)
@ -168,7 +174,7 @@ def proton_callback():
if session.get("oauth_mode", "session") == "apikey": if session.get("oauth_mode", "session") == "apikey":
apikey = get_api_key_for_user(res.user) apikey = get_api_key_for_user(res.user)
scheme = oauth_scheme or DEFAULT_SCHEME scheme = oauth_scheme or DEFAULT_SCHEME
return redirect(f"{scheme}:///login_callback?apikey={apikey}") return redirect(f"{scheme}:///login?apikey={apikey}")
if res.redirect_to_login: if res.redirect_to_login:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))

View File

@ -12,7 +12,6 @@ from flask import (
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from flask_wtf.file import FileField from flask_wtf.file import FileField
from newrelic import agent
from wtforms import StringField, validators from wtforms import StringField, validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
@ -53,7 +52,7 @@ from app.models import (
PartnerSubscription, PartnerSubscription,
UnsubscribeBehaviourEnum, UnsubscribeBehaviourEnum,
) )
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner, perform_proton_account_unlink
from app.utils import random_string, sanitize_email from app.utils import random_string, sanitize_email
@ -481,13 +480,6 @@ def cancel_email_change():
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"]) @dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
@login_required @login_required
def unlink_proton_account(): def unlink_proton_account():
proton_partner = get_proton_partner() perform_proton_account_unlink(current_user)
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner.id
)
if partner_user is not None:
PartnerUser.delete(partner_user.id)
Session.commit()
flash("Your Proton account has been unlinked", "success") flash("Your Proton account has been unlinked", "success")
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))

View File

@ -1,8 +1,9 @@
from newrelic import agent
from typing import Optional from typing import Optional
from app.db import Session from app.db import Session
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.models import Partner from app.models import Partner, PartnerUser, User
PROTON_PARTNER_NAME = "Proton" PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None _PROTON_PARTNER: Optional[Partner] = None
@ -21,3 +22,14 @@ def get_proton_partner() -> Partner:
def is_proton_partner(partner: Partner) -> bool: def is_proton_partner(partner: Partner) -> bool:
return partner.name == PROTON_PARTNER_NAME return partner.name == PROTON_PARTNER_NAME
def perform_proton_account_unlink(current_user: User):
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner.id
)
if partner_user is not None:
PartnerUser.delete(partner_user.id)
Session.commit()
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})

View File

@ -1,9 +1,10 @@
from flask import url_for from flask import url_for
from app import config from app import config
from app.models import User from app.models import User, PartnerUser
from app.proton.utils import get_proton_partner
from tests.api.utils import get_new_user_and_api_key from tests.api.utils import get_new_user_and_api_key
from tests.utils import login from tests.utils import login, random_token, random_email
def test_user_in_trial(flask_client): def test_user_in_trial(flask_client):
@ -21,6 +22,35 @@ def test_user_in_trial(flask_client):
"in_trial": True, "in_trial": True,
"profile_picture_url": None, "profile_picture_url": None,
"max_alias_free_plan": config.MAX_NB_EMAIL_FREE_PLAN, "max_alias_free_plan": config.MAX_NB_EMAIL_FREE_PLAN,
"connected_proton_address": None,
}
def test_user_linked_to_proton(flask_client):
user, api_key = get_new_user_and_api_key()
partner = get_proton_partner()
partner_email = random_email()
PartnerUser.create(
user_id=user.id,
partner_id=partner.id,
external_user_id=random_token(),
partner_email=partner_email,
commit=True,
)
r = flask_client.get(
url_for("api.user_info"), headers={"Authentication": api_key.code}
)
assert r.status_code == 200
assert r.json == {
"is_premium": True,
"name": "Test User",
"email": user.email,
"in_trial": True,
"profile_picture_url": None,
"max_alias_free_plan": config.MAX_NB_EMAIL_FREE_PLAN,
"connected_proton_address": partner_email,
} }