diff --git a/app/api/views/setting.py b/app/api/views/setting.py index 58e314d6..7f04a7b5 100644 --- a/app/api/views/setting.py +++ b/app/api/views/setting.py @@ -12,6 +12,7 @@ from app.models import ( SenderFormatEnum, AliasSuffixEnum, ) +from app.proton.utils import perform_proton_account_unlink def setting_to_dict(user: User): @@ -137,3 +138,11 @@ def get_available_domains_for_random_alias_v2(): ] return jsonify(ret) + + +@api_bp.route("/setting/unlink_proton_account", methods=["DELETE"]) +@require_api_auth +def unlink_proton_account(): + user = g.user + perform_proton_account_unlink(user) + return jsonify({"ok": True}) diff --git a/app/api/views/user_info.py b/app/api/views/user_info.py index 3a8ce34c..e97600bc 100644 --- a/app/api/views/user_info.py +++ b/app/api/views/user_info.py @@ -3,13 +3,23 @@ from io import BytesIO from flask import jsonify, g, request, make_response from flask_login import logout_user +from typing import Optional from app import s3 from app.api.base import api_bp, require_api_auth from app.config import SESSION_COOKIE_NAME from app.db import Session -from app.models import ApiKey, File, User +from app.models import ApiKey, File, PartnerUser, User from app.utils import random_string +from app.proton.utils import get_proton_partner + + +def get_connected_proton_address(user: User) -> Optional[str]: + proton_partner = get_proton_partner() + partner_user = PartnerUser.get_by(user_id=user.id, partner_id=proton_partner.id) + if partner_user is None: + return None + return partner_user.partner_email def user_to_dict(user: User) -> dict: @@ -19,6 +29,7 @@ def user_to_dict(user: User) -> dict: "email": user.email, "in_trial": user.in_trial(), "max_alias_free_plan": user.max_alias_for_free_account(), + "connected_proton_address": get_connected_proton_address(user), } if user.profile_picture_id: @@ -41,6 +52,7 @@ def user_info(): - email - in_trial - max_alias_free + - is_connected_with_proton """ user = g.user diff --git a/app/auth/views/api_to_cookie.py b/app/auth/views/api_to_cookie.py index 70164132..abd26170 100644 --- a/app/auth/views/api_to_cookie.py +++ b/app/auth/views/api_to_cookie.py @@ -1,18 +1,14 @@ import arrow from flask import redirect, url_for, request, flash -from flask_login import current_user, login_user +from flask_login import login_user from app.auth.base import auth_bp -from app.log import LOG from app.models import ApiToCookieToken from app.utils import sanitize_next_url @auth_bp.route("/api_to_cookie", methods=["GET"]) def api_to_cookie(): - if current_user.is_authenticated: - LOG.d("user is already authenticated, redirect to dashboard") - return redirect(url_for("dashboard.index")) code = request.args.get("token") if not code: flash("Missing token", "error") @@ -23,8 +19,9 @@ def api_to_cookie(): flash("Missing token", "error") return redirect(url_for("auth.login")) - login_user(token.user) + user = token.user ApiToCookieToken.delete(token.id, commit=True) + login_user(user) next_url = sanitize_next_url(request.args.get("next")) if next_url: diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index 52fa148a..cedcd5eb 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -51,6 +51,8 @@ def extract_action() -> Action: if action is not None: if action == "link": return Action.Link + elif action == "login": + return Action.Login else: raise Exception(f"Unknown action: {action}") return Action.Login @@ -70,6 +72,10 @@ def proton_login(): if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: return redirect(url_for("auth.login")) + action = extract_action() + if action == Action.Link and not current_user.is_authenticated: + return redirect(url_for("auth.login")) + next_url = sanitize_next_url(request.args.get("next")) if next_url: session["oauth_next"] = next_url @@ -93,7 +99,7 @@ def proton_login(): # State is used to prevent CSRF, keep this for later. session[SESSION_STATE_KEY] = state - session[SESSION_ACTION_KEY] = extract_action().value + session[SESSION_ACTION_KEY] = action.value return redirect(authorization_url) @@ -168,7 +174,7 @@ def proton_callback(): if session.get("oauth_mode", "session") == "apikey": apikey = get_api_key_for_user(res.user) scheme = oauth_scheme or DEFAULT_SCHEME - return redirect(f"{scheme}:///login_callback?apikey={apikey}") + return redirect(f"{scheme}:///login?apikey={apikey}") if res.redirect_to_login: return redirect(url_for("auth.login")) diff --git a/app/dashboard/views/setting.py b/app/dashboard/views/setting.py index b26c15d5..d08375b0 100644 --- a/app/dashboard/views/setting.py +++ b/app/dashboard/views/setting.py @@ -12,7 +12,6 @@ from flask import ( from flask_login import login_required, current_user from flask_wtf import FlaskForm from flask_wtf.file import FileField -from newrelic import agent from wtforms import StringField, validators from wtforms.fields.html5 import EmailField @@ -53,7 +52,7 @@ from app.models import ( PartnerSubscription, UnsubscribeBehaviourEnum, ) -from app.proton.utils import get_proton_partner +from app.proton.utils import get_proton_partner, perform_proton_account_unlink from app.utils import random_string, sanitize_email @@ -481,13 +480,6 @@ def cancel_email_change(): @dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"]) @login_required def unlink_proton_account(): - proton_partner = get_proton_partner() - partner_user = PartnerUser.get_by( - user_id=current_user.id, partner_id=proton_partner.id - ) - if partner_user is not None: - PartnerUser.delete(partner_user.id) - Session.commit() + perform_proton_account_unlink(current_user) flash("Your Proton account has been unlinked", "success") - agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name}) return redirect(url_for("dashboard.setting")) diff --git a/app/proton/utils.py b/app/proton/utils.py index 7b02e832..ed18ba4e 100644 --- a/app/proton/utils.py +++ b/app/proton/utils.py @@ -1,8 +1,9 @@ +from newrelic import agent from typing import Optional from app.db import Session from app.errors import ProtonPartnerNotSetUp -from app.models import Partner +from app.models import Partner, PartnerUser, User PROTON_PARTNER_NAME = "Proton" _PROTON_PARTNER: Optional[Partner] = None @@ -21,3 +22,14 @@ def get_proton_partner() -> Partner: def is_proton_partner(partner: Partner) -> bool: return partner.name == PROTON_PARTNER_NAME + + +def perform_proton_account_unlink(current_user: User): + proton_partner = get_proton_partner() + partner_user = PartnerUser.get_by( + user_id=current_user.id, partner_id=proton_partner.id + ) + if partner_user is not None: + PartnerUser.delete(partner_user.id) + Session.commit() + agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name}) diff --git a/tests/api/test_user_info.py b/tests/api/test_user_info.py index f7728af1..8a06a169 100644 --- a/tests/api/test_user_info.py +++ b/tests/api/test_user_info.py @@ -1,9 +1,10 @@ from flask import url_for from app import config -from app.models import User +from app.models import User, PartnerUser +from app.proton.utils import get_proton_partner from tests.api.utils import get_new_user_and_api_key -from tests.utils import login +from tests.utils import login, random_token, random_email def test_user_in_trial(flask_client): @@ -21,6 +22,35 @@ def test_user_in_trial(flask_client): "in_trial": True, "profile_picture_url": None, "max_alias_free_plan": config.MAX_NB_EMAIL_FREE_PLAN, + "connected_proton_address": None, + } + + +def test_user_linked_to_proton(flask_client): + user, api_key = get_new_user_and_api_key() + partner = get_proton_partner() + partner_email = random_email() + PartnerUser.create( + user_id=user.id, + partner_id=partner.id, + external_user_id=random_token(), + partner_email=partner_email, + commit=True, + ) + + r = flask_client.get( + url_for("api.user_info"), headers={"Authentication": api_key.code} + ) + + assert r.status_code == 200 + assert r.json == { + "is_premium": True, + "name": "Test User", + "email": user.email, + "in_trial": True, + "profile_picture_url": None, + "max_alias_free_plan": config.MAX_NB_EMAIL_FREE_PLAN, + "connected_proton_address": partner_email, }