From 8d4683e59ec5955783d77705ba624256e805cd16 Mon Sep 17 00:00:00 2001 From: Carlos Quintana Date: Mon, 14 Mar 2022 09:33:31 +0100 Subject: [PATCH] Add login with proton --- app/auth/__init__.py | 1 + app/auth/views/login.py | 2 + app/auth/views/proton.py | 114 +++++++ app/auth/views/register.py | 2 + app/config.py | 11 +- app/dashboard/views/setting.py | 37 +++ app/models.py | 48 +++ app/proton/__init__.py | 0 app/proton/proton_callback_handler.py | 237 +++++++++++++++ app/proton/proton_client.py | 175 +++++++++++ example.env | 7 + init_app.py | 13 +- ..._050512_e866ad0e78e1_add_partner_tables.py | 76 +++++ static/style.css | 8 +- templates/auth/login.html | 53 ++-- templates/auth/register.html | 5 + templates/dashboard/setting.html | 28 +- tests/auth/test_proton.py | 23 ++ tests/conftest.py | 3 +- tests/proton/__init__.py | 0 tests/proton/test_proton_callback_handler.py | 285 ++++++++++++++++++ tests/proton/test_proton_client.py | 21 ++ tests/test.env | 7 +- tests/test_utils.py | 17 ++ 24 files changed, 1142 insertions(+), 31 deletions(-) create mode 100644 app/auth/views/proton.py create mode 100644 app/proton/__init__.py create mode 100644 app/proton/proton_callback_handler.py create mode 100644 app/proton/proton_client.py create mode 100644 migrations/versions/2022_050512_e866ad0e78e1_add_partner_tables.py create mode 100644 tests/auth/test_proton.py create mode 100644 tests/proton/__init__.py create mode 100644 tests/proton/test_proton_callback_handler.py create mode 100644 tests/proton/test_proton_client.py diff --git a/app/auth/__init__.py b/app/auth/__init__.py index c41bb56a..e91fd8cb 100644 --- a/app/auth/__init__.py +++ b/app/auth/__init__.py @@ -9,6 +9,7 @@ from .views import ( github, google, facebook, + proton, change_email, mfa, fido, diff --git a/app/auth/views/login.py b/app/auth/views/login.py index a08ca774..c74bf76a 100644 --- a/app/auth/views/login.py +++ b/app/auth/views/login.py @@ -5,6 +5,7 @@ from wtforms import StringField, validators from app.auth.base import auth_bp from app.auth.views.login_utils import after_login +from app.config import CONNECT_WITH_PROTON from app.events.auth_event import LoginEvent from app.extensions import limiter from app.log import LOG @@ -67,4 +68,5 @@ def login(): form=form, next_url=next_url, show_resend_activation=show_resend_activation, + connect_with_proton=CONNECT_WITH_PROTON, ) diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py new file mode 100644 index 00000000..4e40f354 --- /dev/null +++ b/app/auth/views/proton.py @@ -0,0 +1,114 @@ +from flask import request, session, redirect, flash, url_for +from flask_limiter.util import get_remote_address +from flask_login import current_user +from requests_oauthlib import OAuth2Session + +from app.auth.base import auth_bp +from app.auth.views.login_utils import after_login +from app.config import ( + PROTON_BASE_URL, + PROTON_CLIENT_ID, + PROTON_CLIENT_SECRET, + PROTON_VALIDATE_CERTS, + URL, +) +from app.proton.proton_client import HttpProtonClient, convert_access_token +from app.proton.proton_callback_handler import ProtonCallbackHandler, Action +from app.utils import encode_url, sanitize_next_url + +_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" +_token_url = PROTON_BASE_URL + "/oauth/token" + +# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri +# when served behind nginx, the redirect_uri is localhost... and not the real url +_redirect_uri = URL + "/auth/proton/callback" + + +def extract_action() -> Action: + action = request.args.get("action") + if action is not None: + if action == "link": + return Action.Link + else: + raise Exception(f"Unknown action: {action}") + return Action.Login + + +def get_action_from_state() -> Action: + oauth_action = session["oauth_action"] + if oauth_action == Action.Login.value: + return Action.Login + elif oauth_action == Action.Link.value: + return Action.Link + raise Exception(f"Unknown action in state: {oauth_action}") + + +@auth_bp.route("/proton/login") +def proton_login(): + if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + next_url = sanitize_next_url(request.args.get("next")) + if next_url: + redirect_uri = _redirect_uri + "?next=" + encode_url(next_url) + else: + redirect_uri = _redirect_uri + + proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=redirect_uri) + authorization_url, state = proton.authorization_url(_authorization_base_url) + + # State is used to prevent CSRF, keep this for later. + session["oauth_state"] = state + session["oauth_action"] = extract_action().value + return redirect(authorization_url) + + +@auth_bp.route("/proton/callback") +def proton_callback(): + if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: + return redirect(url_for("auth.login")) + + # user clicks on cancel + if "error" in request.args: + flash("Please use another sign in method then", "warning") + return redirect("/") + + proton = OAuth2Session( + PROTON_CLIENT_ID, + state=session["oauth_state"], + redirect_uri=_redirect_uri, + ) + token = proton.fetch_token( + _token_url, + client_secret=PROTON_CLIENT_SECRET, + authorization_response=request.url, + verify=PROTON_VALIDATE_CERTS, + method="GET", + include_client_id=True, + ) + credentials = convert_access_token(token["access_token"]) + action = get_action_from_state() + + proton_client = HttpProtonClient( + PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS + ) + handler = ProtonCallbackHandler(proton_client) + + if action == Action.Login: + res = handler.handle_login() + elif action == Action.Link: + res = handler.handle_link(current_user) + else: + raise Exception(f"Unknown Action: {action.name}") + + if res.flash_message is not None: + flash(res.flash_message, res.flash_category) + + if res.redirect_to_login: + return redirect(url_for("auth.login")) + + if res.redirect: + return redirect(res.redirect) + + next_url = request.args.get("next") if request.args else None + return after_login(res.user, next_url) diff --git a/app/auth/views/register.py b/app/auth/views/register.py index 60e6b01d..379e1576 100644 --- a/app/auth/views/register.py +++ b/app/auth/views/register.py @@ -6,6 +6,7 @@ from wtforms import StringField, validators from app import email_utils, config from app.auth.base import auth_bp +from app.config import CONNECT_WITH_PROTON from app.auth.views.login_utils import get_referral from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.db import Session @@ -102,6 +103,7 @@ def register(): form=form, next_url=next_url, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, + connect_with_proton=CONNECT_WITH_PROTON, ) diff --git a/app/config.py b/app/config.py index fa041a1d..72eec144 100644 --- a/app/config.py +++ b/app/config.py @@ -84,7 +84,6 @@ BOUNCE_PREFIX_FOR_REPLY_PHASE = ( os.environ.get("BOUNCE_PREFIX_FOR_REPLY_PHASE") or "bounce_reply" ) - # VERP for transactional email: mail_from set to BOUNCE_PREFIX + email_log.id + BOUNCE_SUFFIX TRANSACTIONAL_BOUNCE_PREFIX = ( os.environ.get("TRANSACTIONAL_BOUNCE_PREFIX") or "transactional+" @@ -159,7 +158,6 @@ if "DKIM_PRIVATE_KEY_PATH" in os.environ: with open(DKIM_PRIVATE_KEY_PATH) as f: DKIM_PRIVATE_KEY = f.read() - # Database DB_URI = os.environ["DB_URI"] @@ -240,6 +238,14 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET") FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") +PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID") +PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET") +PROTON_BASE_URL = os.environ.get( + "PROTON_BASE_URL", "https://account.protonmail.com/api" +) +PROTON_VALIDATE_CERTS = "PROTON_VALIDATE_CERTS" in os.environ +CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ + # in seconds AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week @@ -287,7 +293,6 @@ STATUS_PAGE_URL = os.environ.get("STATUS_PAGE_URL") or "https://status.simplelog # Loading PGP keys when mail_handler runs. To be used locally when init_app is not called. LOAD_PGP_EMAIL_HANDLER = "LOAD_PGP_EMAIL_HANDLER" in os.environ - # Used when querying info on Apple API # for iOS App APPLE_API_SECRET = os.environ.get("APPLE_API_SECRET") diff --git a/app/dashboard/views/setting.py b/app/dashboard/views/setting.py index b9af1e0b..f5b382d1 100644 --- a/app/dashboard/views/setting.py +++ b/app/dashboard/views/setting.py @@ -11,6 +11,7 @@ from flask import ( from flask_login import login_required, current_user from flask_wtf import FlaskForm from flask_wtf.file import FileField +from typing import Optional from wtforms import StringField, validators from wtforms.fields.html5 import EmailField @@ -19,6 +20,7 @@ from app.config import ( URL, FIRST_ALIAS_DOMAIN, ALIAS_RANDOM_SUFFIX_LENGTH, + CONNECT_WITH_PROTON, ) from app.dashboard.base import dashboard_bp from app.db import Session @@ -43,7 +45,9 @@ from app.models import ( SLDomain, CoinbaseSubscription, AppleSubscription, + PartnerUser, ) +from app.proton.proton_callback_handler import get_proton_partner_id from app.utils import random_string, sanitize_email @@ -62,6 +66,21 @@ class PromoCodeForm(FlaskForm): code = StringField("Name", validators=[validators.DataRequired()]) +def get_proton_linked_account() -> Optional[str]: + # Check if the current user has a partner_id + proton_partner_id = get_proton_partner_id() + if current_user.partner_id != proton_partner_id: + return None + + # It has. Retrieve the information for the PartnerUser + proton_linked_account = PartnerUser.get_by( + user_id=current_user.id, partner_id=proton_partner_id + ) + if proton_linked_account is None: + return None + return proton_linked_account.partner_email + + @dashboard_bp.route("/setting", methods=["GET", "POST"]) @login_required def setting(): @@ -332,6 +351,7 @@ def setting(): manual_sub = ManualSubscription.get_by(user_id=current_user.id) apple_sub = AppleSubscription.get_by(user_id=current_user.id) coinbase_sub = CoinbaseSubscription.get_by(user_id=current_user.id) + proton_linked_account = get_proton_linked_account() return render_template( "dashboard/setting.html", @@ -348,6 +368,8 @@ def setting(): coinbase_sub=coinbase_sub, FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN, ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH, + connect_with_proton=CONNECT_WITH_PROTON, + proton_linked_account=proton_linked_account, ) @@ -409,3 +431,18 @@ def cancel_email_change(): "You have no pending email change. Redirect back to Setting page", "warning" ) return redirect(url_for("dashboard.setting")) + + +@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"]) +@login_required +def unlink_proton_account(): + current_user.partner_id = None + current_user.partner_user_id = None + partner_user = PartnerUser.get_by( + user_id=current_user.id, partner_id=get_proton_partner_id() + ) + if partner_user is not None: + PartnerUser.delete(partner_user.id) + Session.commit() + flash("Your Proton account has been unlinked", "success") + return redirect(url_for("dashboard.setting")) diff --git a/app/models.py b/app/models.py index 575dc373..6146bbc3 100644 --- a/app/models.py +++ b/app/models.py @@ -479,6 +479,15 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): sa.Boolean, default=False, nullable=False, server_default="1" ) + partner_id = sa.Column(sa.BigInteger, unique=False, nullable=True) + partner_user_id = sa.Column(sa.String(128), unique=False, nullable=True) + + __table_args__ = ( + sa.UniqueConstraint( + "partner_id", "partner_user_id", name="uq_partner_id_partner_user_id" + ), + ) + @property def directory_quota(self): return min( @@ -3037,3 +3046,42 @@ class ProviderComplaint(Base, ModelMixin): user = orm.relationship(User, foreign_keys=[user_id]) refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id]) + + +class Partner(Base, ModelMixin): + __tablename__ = "partner" + + name = sa.Column(sa.String(128), unique=True, nullable=False) + contact_email = sa.Column(sa.String(128), unique=True, nullable=False) + + +class PartnerApiToken(Base, ModelMixin): + __tablename__ = "partner_api_token" + + token = sa.Column(sa.String(32), unique=True, nullable=False, index=True) + partner_id = sa.Column( + sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True + ) + expiration_time = sa.Column(ArrowType, unique=False, nullable=True) + + +class PartnerUser(Base, ModelMixin): + __tablename__ = "partner_user" + + user_id = sa.Column( + sa.ForeignKey("users.id", ondelete="cascade"), + unique=False, + nullable=False, + index=True, + ) + partner_id = sa.Column( + sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True + ) + partner_email = sa.Column(sa.String(255), unique=False, nullable=True) + + __table_args__ = ( + sa.UniqueConstraint("user_id", "partner_id", name="uq_user_id_partner_id"), + ) + + +# endregion diff --git a/app/proton/__init__.py b/app/proton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/app/proton/proton_callback_handler.py b/app/proton/proton_callback_handler.py new file mode 100644 index 00000000..abdecadf --- /dev/null +++ b/app/proton/proton_callback_handler.py @@ -0,0 +1,237 @@ +import enum +from abc import ABC, abstractmethod +from dataclasses import dataclass +from flask import url_for +from typing import Optional + +from app.db import Session +from app.models import User, PartnerUser, Partner +from app.proton.proton_client import ProtonClient, ProtonUser +from app.utils import random_string + +PROTON_PARTNER_NAME = "Proton" +_PROTON_PARTNER_ID: Optional[int] = None + + +def get_proton_partner_id() -> int: + global _PROTON_PARTNER_ID + if _PROTON_PARTNER_ID is None: + partner = Partner.get_by(name=PROTON_PARTNER_NAME) + if partner is None: + raise Exception("Could not find Proton Partner instance") + _PROTON_PARTNER_ID = partner.id + + return _PROTON_PARTNER_ID + + +class Action(enum.Enum): + Login = 1 + Link = 2 + + +@dataclass +class ProtonCallbackResult: + redirect_to_login: bool + flash_message: Optional[str] + flash_category: Optional[str] + redirect: Optional[str] + user: Optional[User] + + +def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User): + proton_partner_id = get_proton_partner_id() + if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id): + PartnerUser.create( + user_id=sl_user.id, + partner_id=proton_partner_id, + partner_email=proton_user.email, + ) + Session.commit() + + +class ClientMergeStrategy(ABC): + def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]): + if self.__class__ == ClientMergeStrategy: + raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy") + self.proton_user = proton_user + self.sl_user = sl_user + + @abstractmethod + def process(self) -> ProtonCallbackResult: + pass + + +class UnexistantSlClientStrategy(ClientMergeStrategy): + def process(self) -> ProtonCallbackResult: + # Will create a new SL User with a random password + proton_partner_id = get_proton_partner_id() + new_user = User.create( + email=self.proton_user.email, + name=self.proton_user.name, + partner_user_id=self.proton_user.id, + partner_id=proton_partner_id, + password=random_string(20), + ) + PartnerUser.create( + user_id=new_user.id, + partner_id=proton_partner_id, + partner_email=self.proton_user.email, + ) + # TODO: Adjust plans + Session.commit() + + return ProtonCallbackResult( + redirect_to_login=False, + flash_message=None, + flash_category=None, + redirect=None, + user=new_user, + ) + + +class ExistingSlClientStrategy(ClientMergeStrategy): + def process(self) -> ProtonCallbackResult: + ensure_partner_user_exists(self.proton_user, self.sl_user) + # TODO: Adjust plans + + return ProtonCallbackResult( + redirect_to_login=False, + flash_message=None, + flash_category=None, + redirect=None, + user=self.sl_user, + ) + + +class ExistingSlUserLinkedWithDifferentProtonAccountStrategy(ClientMergeStrategy): + def process(self) -> ProtonCallbackResult: + return ProtonCallbackResult( + redirect_to_login=True, + flash_message="This Proton account is already linked to another account", + flash_category="error", + user=None, + redirect=None, + ) + + +class AlreadyLinkedUserStrategy(ClientMergeStrategy): + def process(self) -> ProtonCallbackResult: + return ProtonCallbackResult( + redirect_to_login=False, + flash_message=None, + flash_category=None, + redirect=None, + user=self.sl_user, + ) + + +def get_login_strategy( + proton_user: ProtonUser, sl_user: Optional[User] +) -> ClientMergeStrategy: + if sl_user is None: + # We couldn't find any SimpleLogin user with the requested e-mail + return UnexistantSlClientStrategy(proton_user, sl_user) + # There is a SimpleLogin user with the proton_user's e-mail + # Try to find if it has been registered via a partner + if sl_user.partner_id is None: + # It has not been registered via a Partner + return ExistingSlClientStrategy(proton_user, sl_user) + # It has been registered via a partner + # Check if the partner_user_id matches + if sl_user.partner_user_id != proton_user.id: + # It doesn't match. That means that the SimpleLogin user has a different Proton account linked + return ExistingSlUserLinkedWithDifferentProtonAccountStrategy( + proton_user, sl_user + ) + # This case means that the sl_user is already linked, so nothing to do + return AlreadyLinkedUserStrategy(proton_user, sl_user) + + +def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult: + # Try to find a SimpleLogin user registered with that proton user id + proton_partner_id = get_proton_partner_id() + sl_user_with_external_id = User.get_by( + partner_id=proton_partner_id, partner_user_id=proton_user.id + ) + if sl_user_with_external_id is None: + # We didn't find any SimpleLogin user registered with that proton user id + # Try to find it using the proton's e-mail address + sl_user = User.get_by(email=proton_user.email) + return get_login_strategy(proton_user, sl_user).process() + else: + # We found the SL user registered with that proton user id + # We're done + return AlreadyLinkedUserStrategy( + proton_user, sl_user_with_external_id + ).process() + + +def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult: + proton_partner_id = get_proton_partner_id() + current_user.partner_user_id = proton_user.id + current_user.partner_id = proton_partner_id + + ensure_partner_user_exists(proton_user, current_user) + + Session.commit() + return ProtonCallbackResult( + redirect_to_login=False, + redirect=url_for("dashboard.setting"), + flash_category="success", + flash_message="Account successfully linked", + user=current_user, + ) + + +def process_link_case( + proton_user: ProtonUser, current_user: User +) -> ProtonCallbackResult: + # Try to find a SimpleLogin user linked with this Proton account + proton_partner_id = get_proton_partner_id() + sl_user_linked_to_proton_account = User.get_by( + partner_id=proton_partner_id, partner_user_id=proton_user.id + ) + if sl_user_linked_to_proton_account is None: + # There is no SL user linked with the proton email. Proceed with linking + return link_user(proton_user, current_user) + else: + # There is a SL user registered with the proton email. Check if is the current one + if sl_user_linked_to_proton_account.id == current_user.id: + # It's the same user. No need to do anything + return ProtonCallbackResult( + redirect_to_login=False, + redirect=url_for("dashboard.setting"), + flash_category="success", + flash_message="Account successfully linked", + user=current_user, + ) + else: + # It's a different user. Unlink the other account and link the current one + sl_user_linked_to_proton_account.partner_id = None + sl_user_linked_to_proton_account.partner_user_id = None + other_partner_user = PartnerUser.get_by( + user_id=sl_user_linked_to_proton_account.id, + partner_id=proton_partner_id, + ) + if other_partner_user is not None: + PartnerUser.delete(other_partner_user.id) + + return link_user(proton_user, current_user) + + +class ProtonCallbackHandler: + def __init__(self, proton_client: ProtonClient): + self.proton_client = proton_client + + def handle_login(self) -> ProtonCallbackResult: + return process_login_case(self.__get_proton_user()) + + def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult: + if current_user is None: + raise Exception("Cannot link account with current_user being None") + return process_link_case(self.__get_proton_user(), current_user) + + def __get_proton_user(self) -> ProtonUser: + user = self.proton_client.get_user() + plan = self.proton_client.get_plan() + return ProtonUser(email=user.email, plan=plan, name=user.name, id=user.id) diff --git a/app/proton/proton_client.py b/app/proton/proton_client.py new file mode 100644 index 00000000..3cd1d86a --- /dev/null +++ b/app/proton/proton_client.py @@ -0,0 +1,175 @@ +import dataclasses +from abc import ABC, abstractmethod +from enum import Enum +from http import HTTPStatus +from requests import Response, Session +from typing import Optional + +_APP_VERSION = "OauthClient_1.0.0" + +PROTON_ERROR_CODE_NOT_EXISTS = 2501 + + +class ProtonPlan(Enum): + Free = 0 + Professional = 1 + Visionary = 2 + + def name(self): + if self == self.Free: + return "Free" + elif self == self.Professional: + return "Professional" + elif self == self.Visionary: + return "Visionary" + else: + raise Exception("Unknown plan") + + +def plan_from_name(name: str) -> ProtonPlan: + name_lower = name.lower() + if name_lower == "free": + return ProtonPlan.Free + elif name_lower == "professional": + return ProtonPlan.Professional + elif name_lower == "visionary": + return ProtonPlan.Visionary + else: + raise Exception(f"Unknown plan [{name}]") + + +@dataclasses.dataclass +class UserInformation: + email: str + name: str + id: str + + +class AuthorizeResponse: + def __init__(self, code: str, has_accepted: bool): + self.code = code + self.has_accepted = has_accepted + + def __str__(self): + return f"[code={self.code}] [has_accepted={self.has_accepted}]" + + +@dataclasses.dataclass +class SessionResponse: + state: str + expires_in: int + token_type: str + refresh_token: str + access_token: str + session_id: str + + +@dataclasses.dataclass +class ProtonUser: + id: str + name: str + email: str + plan: ProtonPlan + + +@dataclasses.dataclass +class AccessCredentials: + access_token: str + session_id: str + + +def convert_access_token(access_token_response: str) -> AccessCredentials: + """ + The Access token response contains both the Proton Session ID and the Access Token. + The Session ID is necessary in order to use the Proton API. However, the OAuth response does not allow us to return + extra content. + This method takes the Access token response and extracts the session ID and the access token. + """ + parts = access_token_response.split("-") + if len(parts) != 3: + raise Exception("Invalid access token response") + if parts[0] != "pt": + raise Exception("Invalid access token response format") + return AccessCredentials( + session_id=parts[1], + access_token=parts[2], + ) + + +class ProtonClient(ABC): + @abstractmethod + def get_user(self) -> UserInformation: + pass + + @abstractmethod + def get_organization(self) -> dict: + pass + + @abstractmethod + def get_plan(self) -> ProtonPlan: + pass + + +class HttpProtonClient(ProtonClient): + def __init__( + self, + base_url: str, + credentials: AccessCredentials, + original_ip: Optional[str], + verify: bool = True, + ): + self.base_url = base_url + self.access_token = credentials.access_token + client = Session() + client.verify = verify + headers = { + "x-pm-appversion": _APP_VERSION, + "x-pm-apiversion": "3", + "x-pm-uid": credentials.session_id, + "authorization": f"Bearer {credentials.access_token}", + "accept": "application/vnd.protonmail.v1+json", + "user-agent": "ProtonOauthClient", + } + if original_ip is not None: + headers["x-forwarded-for"] = original_ip + client.headers.update(headers) + self.client = client + + def get_user(self) -> UserInformation: + info = self.__get("/users")["User"] + return UserInformation( + email=info.get("Email"), name=info.get("Name"), id=info.get("ID") + ) + + def get_organization(self) -> dict: + return self.__get("/code/v4/organizations")["Organization"] + + def get_plan(self) -> ProtonPlan: + url = f"{self.base_url}/core/v4/organizations" + res = self.client.get(url) + + status = res.status_code + if status == HTTPStatus.UNPROCESSABLE_ENTITY: + as_json = res.json() + error_code = as_json.get("Code") + if error_code == PROTON_ERROR_CODE_NOT_EXISTS: + return ProtonPlan.Free + + org = self.__validate_response(res).get("Organization") + if org is None: + return ProtonPlan.Free + return plan_from_name(org["PlanName"]) + + def __get(self, route: str) -> dict: + url = f"{self.base_url}{route}" + res = self.client.get(url) + return self.__validate_response(res) + + @staticmethod + def __validate_response(res: Response) -> dict: + status = res.status_code + if status != HTTPStatus.OK: + raise Exception( + f"Unexpected status code. Wanted 200 and got {status}: " + res.text + ) + return res.json() diff --git a/example.env b/example.env index 31429f75..93fb5796 100644 --- a/example.env +++ b/example.env @@ -108,6 +108,13 @@ WORDS_FILE_PATH=local_data/test_words.txt # FACEBOOK_CLIENT_ID=to_fill # FACEBOOK_CLIENT_SECRET=to_fill +# Login with Proton +# PROTON_CLIENT_ID=to_fill +# PROTON_CLIENT_SECRET=to_fill +# PROTON_BASE_URL=to_fill +# PROTON_VALIDATE_CERTS=true +# CONNECT_WITH_PROTON=true + # Flask profiler # FLASK_PROFILER_PATH=/tmp/flask-profiler.sql # FLASK_PROFILER_PASSWORD=password diff --git a/init_app.py b/init_app.py index ef2914fa..331b8f38 100644 --- a/init_app.py +++ b/init_app.py @@ -4,8 +4,9 @@ from app.config import ( ) from app.db import Session from app.log import LOG -from app.models import Mailbox, Contact, SLDomain +from app.models import Mailbox, Contact, SLDomain, Partner from app.pgp_utils import load_public_key +from app.proton.proton_callback_handler import PROTON_PARTNER_NAME from server import create_light_app @@ -53,6 +54,16 @@ def add_sl_domains(): Session.commit() +def add_proton_partner(): + proton_partner = Partner.get_by(name=PROTON_PARTNER_NAME) + if not proton_partner: + Partner.create( + name=PROTON_PARTNER_NAME, + contact_email="simplelogin@protonmail.com", + ) + Session.commit() + + if __name__ == "__main__": # wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc with create_light_app().app_context(): diff --git a/migrations/versions/2022_050512_e866ad0e78e1_add_partner_tables.py b/migrations/versions/2022_050512_e866ad0e78e1_add_partner_tables.py new file mode 100644 index 00000000..cd29323b --- /dev/null +++ b/migrations/versions/2022_050512_e866ad0e78e1_add_partner_tables.py @@ -0,0 +1,76 @@ +"""Add partner tables + +Revision ID: e866ad0e78e1 +Revises: 0aaad1740797 +Create Date: 2022-05-05 12:10:01.229457 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'e866ad0e78e1' +down_revision = '0aaad1740797' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.create_table('partner', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('name', sa.String(length=128), nullable=False), + sa.Column('contact_email', sa.String(length=128), nullable=False), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('contact_email'), + sa.UniqueConstraint('name') + ) + op.create_table('partner_api_token', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('token', sa.String(length=32), nullable=False), + sa.Column('partner_id', sa.Integer(), nullable=False), + sa.Column('expiration_time', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'), + sa.PrimaryKeyConstraint('id') + ) + op.create_index(op.f('ix_partner_api_token_partner_id'), 'partner_api_token', ['partner_id'], unique=False) + op.create_index(op.f('ix_partner_api_token_token'), 'partner_api_token', ['token'], unique=True) + op.create_table('partner_user', + sa.Column('id', sa.Integer(), autoincrement=True, nullable=False), + sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False), + sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True), + sa.Column('user_id', sa.Integer(), nullable=False), + sa.Column('partner_id', sa.Integer(), nullable=False), + sa.Column('partner_email', sa.String(length=255), nullable=True), + sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'), + sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='cascade'), + sa.PrimaryKeyConstraint('id'), + sa.UniqueConstraint('user_id', 'partner_id', name='uq_user_id_partner_id') + ) + op.create_index(op.f('ix_partner_user_partner_id'), 'partner_user', ['partner_id'], unique=False) + op.create_index(op.f('ix_partner_user_user_id'), 'partner_user', ['user_id'], unique=False) + op.add_column('users', sa.Column('partner_id', sa.BigInteger(), nullable=True)) + op.add_column('users', sa.Column('partner_user_id', sa.String(length=128), nullable=True)) + op.create_unique_constraint('uq_partner_id_partner_user_id', 'users', ['partner_id', 'partner_user_id']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint('uq_partner_id_partner_user_id', 'users', type_='unique') + op.drop_column('users', 'partner_user_id') + op.drop_column('users', 'partner_id') + op.drop_index(op.f('ix_partner_user_user_id'), table_name='partner_user') + op.drop_index(op.f('ix_partner_user_partner_id'), table_name='partner_user') + op.drop_table('partner_user') + op.drop_index(op.f('ix_partner_api_token_token'), table_name='partner_api_token') + op.drop_index(op.f('ix_partner_api_token_partner_id'), table_name='partner_api_token') + op.drop_table('partner_api_token') + op.drop_table('partner') + # ### end Alembic commands ### diff --git a/static/style.css b/static/style.css index b125cd29..889801ca 100644 --- a/static/style.css +++ b/static/style.css @@ -186,4 +186,10 @@ textarea.parsley-error { #help-menu-item { display: none; } -} \ No newline at end of file +} + +.proton-button { + border-color:#6d4aff; + background-color:white; + color:#6d4aff; +} diff --git a/templates/auth/login.html b/templates/auth/login.html index b1474186..0d70e92d 100644 --- a/templates/auth/login.html +++ b/templates/auth/login.html @@ -13,39 +13,46 @@ {% endif %} -
- {{ form.csrf_token }} +

Welcome back!

-
- - {{ form.email(class="form-control", type="email", autofocus="true") }} - {{ render_field_errors(form.email) }} -
+ + {{ form.csrf_token }} -
- - {{ form.password(class="form-control", type="password") }} - {{ render_field_errors(form.password) }} -
- - I forgot my password - +
+ + {{ form.email(class="form-control", type="email", autofocus="true") }} + {{ render_field_errors(form.email) }}
-
- +
+ + {{ form.password(class="form-control", type="password") }} + {{ render_field_errors(form.password) }} + +
+ + + {% if connect_with_proton %} +
or
+ Log in with Proton + {% endif %}
- + +
Don't have an account yet? Sign up
-{% endblock %} \ No newline at end of file +{% endblock %} diff --git a/templates/auth/register.html b/templates/auth/register.html index 0a6f369e..80d94172 100644 --- a/templates/auth/register.html +++ b/templates/auth/register.html @@ -48,6 +48,11 @@
+ + {% if connect_with_proton %} +
or
+ Sign up with Proton + {% endif %}
diff --git a/templates/dashboard/setting.html b/templates/dashboard/setting.html index 16ac1050..5eda0caa 100644 --- a/templates/dashboard/setting.html +++ b/templates/dashboard/setting.html @@ -208,6 +208,32 @@
+ + {% if connect_with_proton %} +
+
+
+ Connect with Proton +
+ {% if proton_linked_account != None %} +
+ You have linked your Proton account: {{ proton_linked_account }}
+
+ Unlink account + {% else %} +
+ You can connect your Proton account with your SimpleLogin one.
+
+ Connect with Proton + {% endif %} +
+
+ {% endif %} + +
@@ -539,7 +565,7 @@
+ {% if current_user.include_header_email_header %} checked {% endif %} class="form-check-input">
diff --git a/tests/auth/test_proton.py b/tests/auth/test_proton.py new file mode 100644 index 00000000..7ae63b45 --- /dev/null +++ b/tests/auth/test_proton.py @@ -0,0 +1,23 @@ +from flask import url_for +from urllib.parse import parse_qs +from urllib3.util import parse_url + +from app.config import URL, PROTON_CLIENT_ID + + +def test_login_with_proton(flask_client): + r = flask_client.get( + url_for("auth.proton_login"), + follow_redirects=False, + ) + location = r.headers.get("Location") + assert location is not None + + parsed = parse_url(location) + query = parse_qs(parsed.query) + + expected_redirect_url = f"{URL}/auth/proton/callback" + + assert "code" == query["response_type"][0] + assert PROTON_CLIENT_ID == query["client_id"][0] + assert expected_redirect_url == query["redirect_uri"][0] diff --git a/tests/conftest.py b/tests/conftest.py index a8c5ce69..e6b92fc3 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -16,7 +16,7 @@ from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST import pytest from server import create_app -from init_app import add_sl_domains +from init_app import add_sl_domains, add_proton_partner app = create_app() app.config["TESTING"] = True @@ -34,6 +34,7 @@ with engine.connect() as conn: conn.execute("Rollback") add_sl_domains() +add_proton_partner() @pytest.fixture diff --git a/tests/proton/__init__.py b/tests/proton/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/proton/test_proton_callback_handler.py b/tests/proton/test_proton_callback_handler.py new file mode 100644 index 00000000..a5278b83 --- /dev/null +++ b/tests/proton/test_proton_callback_handler.py @@ -0,0 +1,285 @@ +import pytest + +from app.db import Session +from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan +from app.proton.proton_callback_handler import ( + ProtonCallbackHandler, + get_proton_partner_id, + get_login_strategy, + process_link_case, + ProtonUser, + UnexistantSlClientStrategy, + ExistingSlClientStrategy, + AlreadyLinkedUserStrategy, + ExistingSlUserLinkedWithDifferentProtonAccountStrategy, + ClientMergeStrategy, +) +from app.models import User, PartnerUser +from app.utils import random_string + + +class MockProtonClient(ProtonClient): + def __init__(self, user: UserInformation, plan: ProtonPlan, organization: dict): + self.user = user + self.plan = plan + self.organization = organization + + def get_organization(self) -> dict: + return self.organization + + def get_user(self) -> UserInformation: + return self.user + + def get_plan(self) -> ProtonPlan: + return self.plan + + +def random_email() -> str: + return "{rand}@{rand}.com".format(rand=random_string(20)) + + +def random_proton_user( + user_id: str = None, + name: str = None, + email: str = None, + plan: ProtonPlan = None, +) -> ProtonUser: + user_id = user_id if user_id is not None else random_string() + name = name if name is not None else random_string() + email = ( + email + if email is not None + else "{rand}@{rand}.com".format(rand=random_string(20)) + ) + plan = plan if plan is not None else ProtonPlan.Free + return ProtonUser(id=user_id, name=name, email=email, plan=plan) + + +def create_user(email: str = None) -> User: + email = email if email is not None else random_email() + user = User.create(email=email) + Session.commit() + return user + + +def create_user_for_partner(partner_user_id: str, email: str = None) -> User: + email = email if email is not None else random_email() + user = User.create(email=email) + user.partner_id = get_proton_partner_id() + user.partner_user_id = partner_user_id + + PartnerUser.create( + user_id=user.id, partner_id=get_proton_partner_id(), partner_email=email + ) + Session.commit() + return user + + +def test_proton_callback_handler_unexistant_sl_user(): + email = random_email() + name = random_string() + external_id = random_string() + user = UserInformation(email=email, name=name, id=external_id) + mock_client = MockProtonClient( + user=user, plan=ProtonPlan.Professional, organization={} + ) + handler = ProtonCallbackHandler(mock_client) + res = handler.handle_login() + + assert res.user is not None + assert res.user.email == email + assert res.user.name == name + assert res.user.partner_user_id == external_id + + +def test_proton_callback_handler_existant_sl_user(): + email = random_email() + sl_user = User.create(email, commit=True) + + external_id = random_string() + user = UserInformation(email=email, name=random_string(), id=external_id) + mock_client = MockProtonClient( + user=user, plan=ProtonPlan.Professional, organization={} + ) + handler = ProtonCallbackHandler(mock_client) + res = handler.handle_login() + + assert res.user is not None + assert res.user.id == sl_user.id + + sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner_id()) + assert sa is not None + assert sa.partner_email == user.email + + +def test_get_strategy_unexistant_sl_user(): + strategy = get_login_strategy( + proton_user=random_proton_user(), + sl_user=None, + ) + assert isinstance(strategy, UnexistantSlClientStrategy) + + +def test_get_strategy_existing_sl_user(): + email = random_email() + sl_user = User.create(email, commit=True) + strategy = get_login_strategy( + proton_user=random_proton_user(email=email), + sl_user=sl_user, + ) + assert isinstance(strategy, ExistingSlClientStrategy) + + +def test_get_strategy_already_linked_user(): + email = random_email() + proton_user_id = random_string() + sl_user = create_user_for_partner(proton_user_id, email=email) + strategy = get_login_strategy( + proton_user=random_proton_user(user_id=proton_user_id, email=email), + sl_user=sl_user, + ) + assert isinstance(strategy, AlreadyLinkedUserStrategy) + + +def test_get_strategy_existing_sl_user_linked_with_different_proton_account(): + # In this scenario we have + # - ProtonUser1 (ID1, email1@proton) + # - ProtonUser2 (ID2, email2@proton) + # - SimpleLoginUser1 registered with email1@proton, but linked to account ID2 + # We will try to log in with email1@proton + email1 = random_email() + email2 = random_email() + proton_user_id_1 = random_string() + proton_user_id_2 = random_string() + + proton_user_1 = random_proton_user(user_id=proton_user_id_1, email=email1) + proton_user_2 = random_proton_user(user_id=proton_user_id_2, email=email2) + + sl_user = create_user_for_partner(proton_user_2.id, email=proton_user_1.email) + strategy = get_login_strategy( + proton_user=proton_user_1, + sl_user=sl_user, + ) + assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy) + + +## +# LINK + + +def test_link_account_with_proton_account_same_address(flask_client): + # This is the most basic scenario + # In this scenario we have: + # - ProtonUser (email1@proton) + # - SimpleLoginUser registered with email1@proton + # We will try to link both accounts + + email = random_email() + proton_user_id = random_string() + proton_user = random_proton_user(user_id=proton_user_id, email=email) + sl_user = create_user(email) + + res = process_link_case(proton_user, sl_user) + assert res.redirect_to_login is False + assert res.redirect is not None + assert res.flash_category == "success" + assert res.flash_message is not None + + updated_user = User.get(sl_user.id) + assert updated_user.partner_id == get_proton_partner_id() + assert updated_user.partner_user_id == proton_user_id + + +def test_link_account_with_proton_account_different_address(flask_client): + # In this scenario we have: + # - ProtonUser (foo@proton) + # - SimpleLoginUser (bar@somethingelse) + # We will try to link both accounts + proton_user_id = random_string() + proton_user = random_proton_user(user_id=proton_user_id, email=random_email()) + sl_user = create_user() + + res = process_link_case(proton_user, sl_user) + assert res.redirect_to_login is False + assert res.redirect is not None + assert res.flash_category == "success" + assert res.flash_message is not None + + updated_user = User.get(sl_user.id) + assert updated_user.partner_id == get_proton_partner_id() + assert updated_user.partner_user_id == proton_user_id + + +def test_link_account_with_proton_account_same_address_but_linked_to_other_user( + flask_client, +): + # In this scenario we have: + # - ProtonUser (foo@proton) + # - SimpleLoginUser1 (foo@proton) + # - SimpleLoginUser2 (other@somethingelse) linked with foo@proton + # We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton + proton_user_id = random_string() + proton_email = random_email() + proton_user = random_proton_user(user_id=proton_user_id, email=proton_email) + sl_user_1 = create_user(proton_email) + sl_user_2 = create_user_for_partner( + proton_user_id, email=random_email() + ) # User already linked with the proton account + + res = process_link_case(proton_user, sl_user_1) + assert res.redirect_to_login is False + assert res.redirect is not None + assert res.flash_category == "success" + assert res.flash_message is not None + + updated_user_1 = User.get(sl_user_1.id) + assert updated_user_1.partner_id == get_proton_partner_id() + assert updated_user_1.partner_user_id == proton_user_id + + updated_user_2 = User.get(sl_user_2.id) + assert updated_user_2.partner_id is None + assert updated_user_2.partner_user_id is None + + +def test_link_account_with_proton_account_different_address_and_linked_to_other_user( + flask_client, +): + # In this scenario we have: + # - ProtonUser (foo@proton) + # - SimpleLoginUser1 (bar@somethingelse) + # - SimpleLoginUser2 (other@somethingelse) linked with foo@proton + # We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton + proton_user_id = random_string() + proton_user = random_proton_user(user_id=proton_user_id, email=random_email()) + sl_user_1 = create_user(random_email()) + sl_user_2 = create_user_for_partner( + proton_user_id, email=random_email() + ) # User already linked with the proton account + + res = process_link_case(proton_user, sl_user_1) + assert res.redirect_to_login is False + assert res.redirect is not None + assert res.flash_category == "success" + assert res.flash_message is not None + + updated_user_1 = User.get(sl_user_1.id) + assert updated_user_1.partner_id == get_proton_partner_id() + assert updated_user_1.partner_user_id == proton_user_id + partner_user_1 = PartnerUser.get_by( + user_id=sl_user_1.id, partner_id=get_proton_partner_id() + ) + assert partner_user_1 is not None + assert partner_user_1.partner_email == proton_user.email + + updated_user_2 = User.get(sl_user_2.id) + assert updated_user_2.partner_id is None + assert updated_user_2.partner_user_id is None + partner_user_2 = PartnerUser.get_by( + user_id=sl_user_2.id, partner_id=get_proton_partner_id() + ) + assert partner_user_2 is None + + +def test_cannot_create_instance_of_base_strategy(): + with pytest.raises(Exception): + ClientMergeStrategy(random_proton_user(), None) diff --git a/tests/proton/test_proton_client.py b/tests/proton/test_proton_client.py new file mode 100644 index 00000000..3879e760 --- /dev/null +++ b/tests/proton/test_proton_client.py @@ -0,0 +1,21 @@ +import pytest + +from app.proton import proton_client + + +def test_convert_access_token_valid(): + res = proton_client.convert_access_token("pt-abc-123") + assert res.session_id == "abc" + assert res.access_token == "123" + + +def test_convert_access_token_not_containing_pt(): + with pytest.raises(Exception): + proton_client.convert_access_token("pb-abc-123") + + +def test_convert_access_token_not_containing_invalid_length(): + cases = ["pt-abc-too-long", "pt-short"] + for case in cases: + with pytest.raises(Exception): + proton_client.convert_access_token(case) diff --git a/tests/test.env b/tests/test.env index 5f6fc935..8189945f 100644 --- a/tests/test.env +++ b/tests/test.env @@ -54,4 +54,9 @@ PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc ALIAS_AUTOMATIC_DISABLE=true ALLOWED_REDIRECT_DOMAINS=["test.simplelogin.local"] -DMARC_CHECK_ENABLED=true \ No newline at end of file + +DMARC_CHECK_ENABLED=true + +PROTON_CLIENT_ID=to_fill +PROTON_CLIENT_SECRET=to_fill +PROTON_BASE_URL=https://localhost/api diff --git a/tests/test_utils.py b/tests/test_utils.py index ef7605d4..52984373 100644 --- a/tests/test_utils.py +++ b/tests/test_utils.py @@ -1,4 +1,5 @@ from typing import List +from urllib.parse import parse_qs import pytest @@ -40,3 +41,19 @@ def generate_sanitize_url_cases() -> List: def test_sanitize_url(url, expected): sanitized = sanitize_next_url(url) assert expected == sanitized + + +def test_parse_querystring(): + cases = [ + {"input": "", "expected": {}}, + {"input": "a=b", "expected": {"a": ["b"]}}, + {"input": "a=b&c=d", "expected": {"a": ["b"], "c": ["d"]}}, + {"input": "a=b&a=c", "expected": {"a": ["b", "c"]}}, + ] + + for case in cases: + expected = case["expected"] + res = parse_qs(case["input"]) + assert len(res) == len(expected) + for k, v in expected.items(): + assert res[k] == v