mirror of
https://github.com/simple-login/app.git
synced 2024-11-13 07:31:12 +01:00
Merge pull request #894 from cquintana92/feature/add-login-with-proton
Add login with proton
This commit is contained in:
commit
a92981c52d
24 changed files with 1142 additions and 31 deletions
|
@ -9,6 +9,7 @@ from .views import (
|
|||
github,
|
||||
google,
|
||||
facebook,
|
||||
proton,
|
||||
change_email,
|
||||
mfa,
|
||||
fido,
|
||||
|
|
|
@ -5,6 +5,7 @@ from wtforms import StringField, validators
|
|||
|
||||
from app.auth.base import auth_bp
|
||||
from app.auth.views.login_utils import after_login
|
||||
from app.config import CONNECT_WITH_PROTON
|
||||
from app.events.auth_event import LoginEvent
|
||||
from app.extensions import limiter
|
||||
from app.log import LOG
|
||||
|
@ -67,4 +68,5 @@ def login():
|
|||
form=form,
|
||||
next_url=next_url,
|
||||
show_resend_activation=show_resend_activation,
|
||||
connect_with_proton=CONNECT_WITH_PROTON,
|
||||
)
|
||||
|
|
114
app/auth/views/proton.py
Normal file
114
app/auth/views/proton.py
Normal file
|
@ -0,0 +1,114 @@
|
|||
from flask import request, session, redirect, flash, url_for
|
||||
from flask_limiter.util import get_remote_address
|
||||
from flask_login import current_user
|
||||
from requests_oauthlib import OAuth2Session
|
||||
|
||||
from app.auth.base import auth_bp
|
||||
from app.auth.views.login_utils import after_login
|
||||
from app.config import (
|
||||
PROTON_BASE_URL,
|
||||
PROTON_CLIENT_ID,
|
||||
PROTON_CLIENT_SECRET,
|
||||
PROTON_VALIDATE_CERTS,
|
||||
URL,
|
||||
)
|
||||
from app.proton.proton_client import HttpProtonClient, convert_access_token
|
||||
from app.proton.proton_callback_handler import ProtonCallbackHandler, Action
|
||||
from app.utils import encode_url, sanitize_next_url
|
||||
|
||||
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
|
||||
_token_url = PROTON_BASE_URL + "/oauth/token"
|
||||
|
||||
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
|
||||
# when served behind nginx, the redirect_uri is localhost... and not the real url
|
||||
_redirect_uri = URL + "/auth/proton/callback"
|
||||
|
||||
|
||||
def extract_action() -> Action:
|
||||
action = request.args.get("action")
|
||||
if action is not None:
|
||||
if action == "link":
|
||||
return Action.Link
|
||||
else:
|
||||
raise Exception(f"Unknown action: {action}")
|
||||
return Action.Login
|
||||
|
||||
|
||||
def get_action_from_state() -> Action:
|
||||
oauth_action = session["oauth_action"]
|
||||
if oauth_action == Action.Login.value:
|
||||
return Action.Login
|
||||
elif oauth_action == Action.Link.value:
|
||||
return Action.Link
|
||||
raise Exception(f"Unknown action in state: {oauth_action}")
|
||||
|
||||
|
||||
@auth_bp.route("/proton/login")
|
||||
def proton_login():
|
||||
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
next_url = sanitize_next_url(request.args.get("next"))
|
||||
if next_url:
|
||||
redirect_uri = _redirect_uri + "?next=" + encode_url(next_url)
|
||||
else:
|
||||
redirect_uri = _redirect_uri
|
||||
|
||||
proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=redirect_uri)
|
||||
authorization_url, state = proton.authorization_url(_authorization_base_url)
|
||||
|
||||
# State is used to prevent CSRF, keep this for later.
|
||||
session["oauth_state"] = state
|
||||
session["oauth_action"] = extract_action().value
|
||||
return redirect(authorization_url)
|
||||
|
||||
|
||||
@auth_bp.route("/proton/callback")
|
||||
def proton_callback():
|
||||
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
# user clicks on cancel
|
||||
if "error" in request.args:
|
||||
flash("Please use another sign in method then", "warning")
|
||||
return redirect("/")
|
||||
|
||||
proton = OAuth2Session(
|
||||
PROTON_CLIENT_ID,
|
||||
state=session["oauth_state"],
|
||||
redirect_uri=_redirect_uri,
|
||||
)
|
||||
token = proton.fetch_token(
|
||||
_token_url,
|
||||
client_secret=PROTON_CLIENT_SECRET,
|
||||
authorization_response=request.url,
|
||||
verify=PROTON_VALIDATE_CERTS,
|
||||
method="GET",
|
||||
include_client_id=True,
|
||||
)
|
||||
credentials = convert_access_token(token["access_token"])
|
||||
action = get_action_from_state()
|
||||
|
||||
proton_client = HttpProtonClient(
|
||||
PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS
|
||||
)
|
||||
handler = ProtonCallbackHandler(proton_client)
|
||||
|
||||
if action == Action.Login:
|
||||
res = handler.handle_login()
|
||||
elif action == Action.Link:
|
||||
res = handler.handle_link(current_user)
|
||||
else:
|
||||
raise Exception(f"Unknown Action: {action.name}")
|
||||
|
||||
if res.flash_message is not None:
|
||||
flash(res.flash_message, res.flash_category)
|
||||
|
||||
if res.redirect_to_login:
|
||||
return redirect(url_for("auth.login"))
|
||||
|
||||
if res.redirect:
|
||||
return redirect(res.redirect)
|
||||
|
||||
next_url = request.args.get("next") if request.args else None
|
||||
return after_login(res.user, next_url)
|
|
@ -6,6 +6,7 @@ from wtforms import StringField, validators
|
|||
|
||||
from app import email_utils, config
|
||||
from app.auth.base import auth_bp
|
||||
from app.config import CONNECT_WITH_PROTON
|
||||
from app.auth.views.login_utils import get_referral
|
||||
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
|
||||
from app.db import Session
|
||||
|
@ -102,6 +103,7 @@ def register():
|
|||
form=form,
|
||||
next_url=next_url,
|
||||
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
|
||||
connect_with_proton=CONNECT_WITH_PROTON,
|
||||
)
|
||||
|
||||
|
||||
|
|
|
@ -84,7 +84,6 @@ BOUNCE_PREFIX_FOR_REPLY_PHASE = (
|
|||
os.environ.get("BOUNCE_PREFIX_FOR_REPLY_PHASE") or "bounce_reply"
|
||||
)
|
||||
|
||||
|
||||
# VERP for transactional email: mail_from set to BOUNCE_PREFIX + email_log.id + BOUNCE_SUFFIX
|
||||
TRANSACTIONAL_BOUNCE_PREFIX = (
|
||||
os.environ.get("TRANSACTIONAL_BOUNCE_PREFIX") or "transactional+"
|
||||
|
@ -159,7 +158,6 @@ if "DKIM_PRIVATE_KEY_PATH" in os.environ:
|
|||
with open(DKIM_PRIVATE_KEY_PATH) as f:
|
||||
DKIM_PRIVATE_KEY = f.read()
|
||||
|
||||
|
||||
# Database
|
||||
DB_URI = os.environ["DB_URI"]
|
||||
|
||||
|
@ -240,6 +238,14 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
|
|||
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
|
||||
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
|
||||
|
||||
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
|
||||
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
|
||||
PROTON_BASE_URL = os.environ.get(
|
||||
"PROTON_BASE_URL", "https://account.protonmail.com/api"
|
||||
)
|
||||
PROTON_VALIDATE_CERTS = "PROTON_VALIDATE_CERTS" in os.environ
|
||||
CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ
|
||||
|
||||
# in seconds
|
||||
AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week
|
||||
|
||||
|
@ -287,7 +293,6 @@ STATUS_PAGE_URL = os.environ.get("STATUS_PAGE_URL") or "https://status.simplelog
|
|||
# Loading PGP keys when mail_handler runs. To be used locally when init_app is not called.
|
||||
LOAD_PGP_EMAIL_HANDLER = "LOAD_PGP_EMAIL_HANDLER" in os.environ
|
||||
|
||||
|
||||
# Used when querying info on Apple API
|
||||
# for iOS App
|
||||
APPLE_API_SECRET = os.environ.get("APPLE_API_SECRET")
|
||||
|
|
|
@ -11,6 +11,7 @@ from flask import (
|
|||
from flask_login import login_required, current_user
|
||||
from flask_wtf import FlaskForm
|
||||
from flask_wtf.file import FileField
|
||||
from typing import Optional
|
||||
from wtforms import StringField, validators
|
||||
from wtforms.fields.html5 import EmailField
|
||||
|
||||
|
@ -19,6 +20,7 @@ from app.config import (
|
|||
URL,
|
||||
FIRST_ALIAS_DOMAIN,
|
||||
ALIAS_RANDOM_SUFFIX_LENGTH,
|
||||
CONNECT_WITH_PROTON,
|
||||
)
|
||||
from app.dashboard.base import dashboard_bp
|
||||
from app.db import Session
|
||||
|
@ -43,7 +45,9 @@ from app.models import (
|
|||
SLDomain,
|
||||
CoinbaseSubscription,
|
||||
AppleSubscription,
|
||||
PartnerUser,
|
||||
)
|
||||
from app.proton.proton_callback_handler import get_proton_partner_id
|
||||
from app.utils import random_string, sanitize_email
|
||||
|
||||
|
||||
|
@ -62,6 +66,21 @@ class PromoCodeForm(FlaskForm):
|
|||
code = StringField("Name", validators=[validators.DataRequired()])
|
||||
|
||||
|
||||
def get_proton_linked_account() -> Optional[str]:
|
||||
# Check if the current user has a partner_id
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
if current_user.partner_id != proton_partner_id:
|
||||
return None
|
||||
|
||||
# It has. Retrieve the information for the PartnerUser
|
||||
proton_linked_account = PartnerUser.get_by(
|
||||
user_id=current_user.id, partner_id=proton_partner_id
|
||||
)
|
||||
if proton_linked_account is None:
|
||||
return None
|
||||
return proton_linked_account.partner_email
|
||||
|
||||
|
||||
@dashboard_bp.route("/setting", methods=["GET", "POST"])
|
||||
@login_required
|
||||
def setting():
|
||||
|
@ -332,6 +351,7 @@ def setting():
|
|||
manual_sub = ManualSubscription.get_by(user_id=current_user.id)
|
||||
apple_sub = AppleSubscription.get_by(user_id=current_user.id)
|
||||
coinbase_sub = CoinbaseSubscription.get_by(user_id=current_user.id)
|
||||
proton_linked_account = get_proton_linked_account()
|
||||
|
||||
return render_template(
|
||||
"dashboard/setting.html",
|
||||
|
@ -348,6 +368,8 @@ def setting():
|
|||
coinbase_sub=coinbase_sub,
|
||||
FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN,
|
||||
ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH,
|
||||
connect_with_proton=CONNECT_WITH_PROTON,
|
||||
proton_linked_account=proton_linked_account,
|
||||
)
|
||||
|
||||
|
||||
|
@ -409,3 +431,18 @@ def cancel_email_change():
|
|||
"You have no pending email change. Redirect back to Setting page", "warning"
|
||||
)
|
||||
return redirect(url_for("dashboard.setting"))
|
||||
|
||||
|
||||
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
|
||||
@login_required
|
||||
def unlink_proton_account():
|
||||
current_user.partner_id = None
|
||||
current_user.partner_user_id = None
|
||||
partner_user = PartnerUser.get_by(
|
||||
user_id=current_user.id, partner_id=get_proton_partner_id()
|
||||
)
|
||||
if partner_user is not None:
|
||||
PartnerUser.delete(partner_user.id)
|
||||
Session.commit()
|
||||
flash("Your Proton account has been unlinked", "success")
|
||||
return redirect(url_for("dashboard.setting"))
|
||||
|
|
|
@ -479,6 +479,15 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
sa.Boolean, default=False, nullable=False, server_default="1"
|
||||
)
|
||||
|
||||
partner_id = sa.Column(sa.BigInteger, unique=False, nullable=True)
|
||||
partner_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint(
|
||||
"partner_id", "partner_user_id", name="uq_partner_id_partner_user_id"
|
||||
),
|
||||
)
|
||||
|
||||
@property
|
||||
def directory_quota(self):
|
||||
return min(
|
||||
|
@ -3037,3 +3046,42 @@ class ProviderComplaint(Base, ModelMixin):
|
|||
|
||||
user = orm.relationship(User, foreign_keys=[user_id])
|
||||
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
|
||||
|
||||
|
||||
class Partner(Base, ModelMixin):
|
||||
__tablename__ = "partner"
|
||||
|
||||
name = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
|
||||
|
||||
|
||||
class PartnerApiToken(Base, ModelMixin):
|
||||
__tablename__ = "partner_api_token"
|
||||
|
||||
token = sa.Column(sa.String(32), unique=True, nullable=False, index=True)
|
||||
partner_id = sa.Column(
|
||||
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
|
||||
)
|
||||
expiration_time = sa.Column(ArrowType, unique=False, nullable=True)
|
||||
|
||||
|
||||
class PartnerUser(Base, ModelMixin):
|
||||
__tablename__ = "partner_user"
|
||||
|
||||
user_id = sa.Column(
|
||||
sa.ForeignKey("users.id", ondelete="cascade"),
|
||||
unique=False,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
partner_id = sa.Column(
|
||||
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
|
||||
)
|
||||
partner_email = sa.Column(sa.String(255), unique=False, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint("user_id", "partner_id", name="uq_user_id_partner_id"),
|
||||
)
|
||||
|
||||
|
||||
# endregion
|
||||
|
|
0
app/proton/__init__.py
Normal file
0
app/proton/__init__.py
Normal file
237
app/proton/proton_callback_handler.py
Normal file
237
app/proton/proton_callback_handler.py
Normal file
|
@ -0,0 +1,237 @@
|
|||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from flask import url_for
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.models import User, PartnerUser, Partner
|
||||
from app.proton.proton_client import ProtonClient, ProtonUser
|
||||
from app.utils import random_string
|
||||
|
||||
PROTON_PARTNER_NAME = "Proton"
|
||||
_PROTON_PARTNER_ID: Optional[int] = None
|
||||
|
||||
|
||||
def get_proton_partner_id() -> int:
|
||||
global _PROTON_PARTNER_ID
|
||||
if _PROTON_PARTNER_ID is None:
|
||||
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
|
||||
if partner is None:
|
||||
raise Exception("Could not find Proton Partner instance")
|
||||
_PROTON_PARTNER_ID = partner.id
|
||||
|
||||
return _PROTON_PARTNER_ID
|
||||
|
||||
|
||||
class Action(enum.Enum):
|
||||
Login = 1
|
||||
Link = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class ProtonCallbackResult:
|
||||
redirect_to_login: bool
|
||||
flash_message: Optional[str]
|
||||
flash_category: Optional[str]
|
||||
redirect: Optional[str]
|
||||
user: Optional[User]
|
||||
|
||||
|
||||
def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User):
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id):
|
||||
PartnerUser.create(
|
||||
user_id=sl_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_email=proton_user.email,
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
|
||||
class ClientMergeStrategy(ABC):
|
||||
def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]):
|
||||
if self.__class__ == ClientMergeStrategy:
|
||||
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
|
||||
self.proton_user = proton_user
|
||||
self.sl_user = sl_user
|
||||
|
||||
@abstractmethod
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
pass
|
||||
|
||||
|
||||
class UnexistantSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
# Will create a new SL User with a random password
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
new_user = User.create(
|
||||
email=self.proton_user.email,
|
||||
name=self.proton_user.name,
|
||||
partner_user_id=self.proton_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
password=random_string(20),
|
||||
)
|
||||
PartnerUser.create(
|
||||
user_id=new_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_email=self.proton_user.email,
|
||||
)
|
||||
# TODO: Adjust plans
|
||||
Session.commit()
|
||||
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=new_user,
|
||||
)
|
||||
|
||||
|
||||
class ExistingSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
ensure_partner_user_exists(self.proton_user, self.sl_user)
|
||||
# TODO: Adjust plans
|
||||
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=self.sl_user,
|
||||
)
|
||||
|
||||
|
||||
class ExistingSlUserLinkedWithDifferentProtonAccountStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=True,
|
||||
flash_message="This Proton account is already linked to another account",
|
||||
flash_category="error",
|
||||
user=None,
|
||||
redirect=None,
|
||||
)
|
||||
|
||||
|
||||
class AlreadyLinkedUserStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=self.sl_user,
|
||||
)
|
||||
|
||||
|
||||
def get_login_strategy(
|
||||
proton_user: ProtonUser, sl_user: Optional[User]
|
||||
) -> ClientMergeStrategy:
|
||||
if sl_user is None:
|
||||
# We couldn't find any SimpleLogin user with the requested e-mail
|
||||
return UnexistantSlClientStrategy(proton_user, sl_user)
|
||||
# There is a SimpleLogin user with the proton_user's e-mail
|
||||
# Try to find if it has been registered via a partner
|
||||
if sl_user.partner_id is None:
|
||||
# It has not been registered via a Partner
|
||||
return ExistingSlClientStrategy(proton_user, sl_user)
|
||||
# It has been registered via a partner
|
||||
# Check if the partner_user_id matches
|
||||
if sl_user.partner_user_id != proton_user.id:
|
||||
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked
|
||||
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
|
||||
proton_user, sl_user
|
||||
)
|
||||
# This case means that the sl_user is already linked, so nothing to do
|
||||
return AlreadyLinkedUserStrategy(proton_user, sl_user)
|
||||
|
||||
|
||||
def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user registered with that proton user id
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
sl_user_with_external_id = User.get_by(
|
||||
partner_id=proton_partner_id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_with_external_id is None:
|
||||
# We didn't find any SimpleLogin user registered with that proton user id
|
||||
# Try to find it using the proton's e-mail address
|
||||
sl_user = User.get_by(email=proton_user.email)
|
||||
return get_login_strategy(proton_user, sl_user).process()
|
||||
else:
|
||||
# We found the SL user registered with that proton user id
|
||||
# We're done
|
||||
return AlreadyLinkedUserStrategy(
|
||||
proton_user, sl_user_with_external_id
|
||||
).process()
|
||||
|
||||
|
||||
def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult:
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
current_user.partner_user_id = proton_user.id
|
||||
current_user.partner_id = proton_partner_id
|
||||
|
||||
ensure_partner_user_exists(proton_user, current_user)
|
||||
|
||||
Session.commit()
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
redirect=url_for("dashboard.setting"),
|
||||
flash_category="success",
|
||||
flash_message="Account successfully linked",
|
||||
user=current_user,
|
||||
)
|
||||
|
||||
|
||||
def process_link_case(
|
||||
proton_user: ProtonUser, current_user: User
|
||||
) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user linked with this Proton account
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
sl_user_linked_to_proton_account = User.get_by(
|
||||
partner_id=proton_partner_id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_linked_to_proton_account is None:
|
||||
# There is no SL user linked with the proton email. Proceed with linking
|
||||
return link_user(proton_user, current_user)
|
||||
else:
|
||||
# There is a SL user registered with the proton email. Check if is the current one
|
||||
if sl_user_linked_to_proton_account.id == current_user.id:
|
||||
# It's the same user. No need to do anything
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
redirect=url_for("dashboard.setting"),
|
||||
flash_category="success",
|
||||
flash_message="Account successfully linked",
|
||||
user=current_user,
|
||||
)
|
||||
else:
|
||||
# It's a different user. Unlink the other account and link the current one
|
||||
sl_user_linked_to_proton_account.partner_id = None
|
||||
sl_user_linked_to_proton_account.partner_user_id = None
|
||||
other_partner_user = PartnerUser.get_by(
|
||||
user_id=sl_user_linked_to_proton_account.id,
|
||||
partner_id=proton_partner_id,
|
||||
)
|
||||
if other_partner_user is not None:
|
||||
PartnerUser.delete(other_partner_user.id)
|
||||
|
||||
return link_user(proton_user, current_user)
|
||||
|
||||
|
||||
class ProtonCallbackHandler:
|
||||
def __init__(self, proton_client: ProtonClient):
|
||||
self.proton_client = proton_client
|
||||
|
||||
def handle_login(self) -> ProtonCallbackResult:
|
||||
return process_login_case(self.__get_proton_user())
|
||||
|
||||
def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult:
|
||||
if current_user is None:
|
||||
raise Exception("Cannot link account with current_user being None")
|
||||
return process_link_case(self.__get_proton_user(), current_user)
|
||||
|
||||
def __get_proton_user(self) -> ProtonUser:
|
||||
user = self.proton_client.get_user()
|
||||
plan = self.proton_client.get_plan()
|
||||
return ProtonUser(email=user.email, plan=plan, name=user.name, id=user.id)
|
175
app/proton/proton_client.py
Normal file
175
app/proton/proton_client.py
Normal file
|
@ -0,0 +1,175 @@
|
|||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from http import HTTPStatus
|
||||
from requests import Response, Session
|
||||
from typing import Optional
|
||||
|
||||
_APP_VERSION = "OauthClient_1.0.0"
|
||||
|
||||
PROTON_ERROR_CODE_NOT_EXISTS = 2501
|
||||
|
||||
|
||||
class ProtonPlan(Enum):
|
||||
Free = 0
|
||||
Professional = 1
|
||||
Visionary = 2
|
||||
|
||||
def name(self):
|
||||
if self == self.Free:
|
||||
return "Free"
|
||||
elif self == self.Professional:
|
||||
return "Professional"
|
||||
elif self == self.Visionary:
|
||||
return "Visionary"
|
||||
else:
|
||||
raise Exception("Unknown plan")
|
||||
|
||||
|
||||
def plan_from_name(name: str) -> ProtonPlan:
|
||||
name_lower = name.lower()
|
||||
if name_lower == "free":
|
||||
return ProtonPlan.Free
|
||||
elif name_lower == "professional":
|
||||
return ProtonPlan.Professional
|
||||
elif name_lower == "visionary":
|
||||
return ProtonPlan.Visionary
|
||||
else:
|
||||
raise Exception(f"Unknown plan [{name}]")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class UserInformation:
|
||||
email: str
|
||||
name: str
|
||||
id: str
|
||||
|
||||
|
||||
class AuthorizeResponse:
|
||||
def __init__(self, code: str, has_accepted: bool):
|
||||
self.code = code
|
||||
self.has_accepted = has_accepted
|
||||
|
||||
def __str__(self):
|
||||
return f"[code={self.code}] [has_accepted={self.has_accepted}]"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SessionResponse:
|
||||
state: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
access_token: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class ProtonUser:
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
plan: ProtonPlan
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class AccessCredentials:
|
||||
access_token: str
|
||||
session_id: str
|
||||
|
||||
|
||||
def convert_access_token(access_token_response: str) -> AccessCredentials:
|
||||
"""
|
||||
The Access token response contains both the Proton Session ID and the Access Token.
|
||||
The Session ID is necessary in order to use the Proton API. However, the OAuth response does not allow us to return
|
||||
extra content.
|
||||
This method takes the Access token response and extracts the session ID and the access token.
|
||||
"""
|
||||
parts = access_token_response.split("-")
|
||||
if len(parts) != 3:
|
||||
raise Exception("Invalid access token response")
|
||||
if parts[0] != "pt":
|
||||
raise Exception("Invalid access token response format")
|
||||
return AccessCredentials(
|
||||
session_id=parts[1],
|
||||
access_token=parts[2],
|
||||
)
|
||||
|
||||
|
||||
class ProtonClient(ABC):
|
||||
@abstractmethod
|
||||
def get_user(self) -> UserInformation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_organization(self) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
pass
|
||||
|
||||
|
||||
class HttpProtonClient(ProtonClient):
|
||||
def __init__(
|
||||
self,
|
||||
base_url: str,
|
||||
credentials: AccessCredentials,
|
||||
original_ip: Optional[str],
|
||||
verify: bool = True,
|
||||
):
|
||||
self.base_url = base_url
|
||||
self.access_token = credentials.access_token
|
||||
client = Session()
|
||||
client.verify = verify
|
||||
headers = {
|
||||
"x-pm-appversion": _APP_VERSION,
|
||||
"x-pm-apiversion": "3",
|
||||
"x-pm-uid": credentials.session_id,
|
||||
"authorization": f"Bearer {credentials.access_token}",
|
||||
"accept": "application/vnd.protonmail.v1+json",
|
||||
"user-agent": "ProtonOauthClient",
|
||||
}
|
||||
if original_ip is not None:
|
||||
headers["x-forwarded-for"] = original_ip
|
||||
client.headers.update(headers)
|
||||
self.client = client
|
||||
|
||||
def get_user(self) -> UserInformation:
|
||||
info = self.__get("/users")["User"]
|
||||
return UserInformation(
|
||||
email=info.get("Email"), name=info.get("Name"), id=info.get("ID")
|
||||
)
|
||||
|
||||
def get_organization(self) -> dict:
|
||||
return self.__get("/code/v4/organizations")["Organization"]
|
||||
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
url = f"{self.base_url}/core/v4/organizations"
|
||||
res = self.client.get(url)
|
||||
|
||||
status = res.status_code
|
||||
if status == HTTPStatus.UNPROCESSABLE_ENTITY:
|
||||
as_json = res.json()
|
||||
error_code = as_json.get("Code")
|
||||
if error_code == PROTON_ERROR_CODE_NOT_EXISTS:
|
||||
return ProtonPlan.Free
|
||||
|
||||
org = self.__validate_response(res).get("Organization")
|
||||
if org is None:
|
||||
return ProtonPlan.Free
|
||||
return plan_from_name(org["PlanName"])
|
||||
|
||||
def __get(self, route: str) -> dict:
|
||||
url = f"{self.base_url}{route}"
|
||||
res = self.client.get(url)
|
||||
return self.__validate_response(res)
|
||||
|
||||
@staticmethod
|
||||
def __validate_response(res: Response) -> dict:
|
||||
status = res.status_code
|
||||
if status != HTTPStatus.OK:
|
||||
raise Exception(
|
||||
f"Unexpected status code. Wanted 200 and got {status}: " + res.text
|
||||
)
|
||||
return res.json()
|
|
@ -108,6 +108,13 @@ WORDS_FILE_PATH=local_data/test_words.txt
|
|||
# FACEBOOK_CLIENT_ID=to_fill
|
||||
# FACEBOOK_CLIENT_SECRET=to_fill
|
||||
|
||||
# Login with Proton
|
||||
# PROTON_CLIENT_ID=to_fill
|
||||
# PROTON_CLIENT_SECRET=to_fill
|
||||
# PROTON_BASE_URL=to_fill
|
||||
# PROTON_VALIDATE_CERTS=true
|
||||
# CONNECT_WITH_PROTON=true
|
||||
|
||||
# Flask profiler
|
||||
# FLASK_PROFILER_PATH=/tmp/flask-profiler.sql
|
||||
# FLASK_PROFILER_PASSWORD=password
|
||||
|
|
13
init_app.py
13
init_app.py
|
@ -4,8 +4,9 @@ from app.config import (
|
|||
)
|
||||
from app.db import Session
|
||||
from app.log import LOG
|
||||
from app.models import Mailbox, Contact, SLDomain
|
||||
from app.models import Mailbox, Contact, SLDomain, Partner
|
||||
from app.pgp_utils import load_public_key
|
||||
from app.proton.proton_callback_handler import PROTON_PARTNER_NAME
|
||||
from server import create_light_app
|
||||
|
||||
|
||||
|
@ -53,6 +54,16 @@ def add_sl_domains():
|
|||
Session.commit()
|
||||
|
||||
|
||||
def add_proton_partner():
|
||||
proton_partner = Partner.get_by(name=PROTON_PARTNER_NAME)
|
||||
if not proton_partner:
|
||||
Partner.create(
|
||||
name=PROTON_PARTNER_NAME,
|
||||
contact_email="simplelogin@protonmail.com",
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
|
||||
if __name__ == "__main__":
|
||||
# wrap in an app context to benefit from app setup like database cleanup, sentry integration, etc
|
||||
with create_light_app().app_context():
|
||||
|
|
|
@ -0,0 +1,76 @@
|
|||
"""Add partner tables
|
||||
|
||||
Revision ID: e866ad0e78e1
|
||||
Revises: 0aaad1740797
|
||||
Create Date: 2022-05-05 12:10:01.229457
|
||||
|
||||
"""
|
||||
import sqlalchemy_utils
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = 'e866ad0e78e1'
|
||||
down_revision = '0aaad1740797'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('partner',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
|
||||
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
|
||||
sa.Column('name', sa.String(length=128), nullable=False),
|
||||
sa.Column('contact_email', sa.String(length=128), nullable=False),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('contact_email'),
|
||||
sa.UniqueConstraint('name')
|
||||
)
|
||||
op.create_table('partner_api_token',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
|
||||
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
|
||||
sa.Column('token', sa.String(length=32), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=False),
|
||||
sa.Column('expiration_time', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
|
||||
sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'),
|
||||
sa.PrimaryKeyConstraint('id')
|
||||
)
|
||||
op.create_index(op.f('ix_partner_api_token_partner_id'), 'partner_api_token', ['partner_id'], unique=False)
|
||||
op.create_index(op.f('ix_partner_api_token_token'), 'partner_api_token', ['token'], unique=True)
|
||||
op.create_table('partner_user',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
|
||||
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
|
||||
sa.Column('user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('partner_id', sa.Integer(), nullable=False),
|
||||
sa.Column('partner_email', sa.String(length=255), nullable=True),
|
||||
sa.ForeignKeyConstraint(['partner_id'], ['partner.id'], ondelete='cascade'),
|
||||
sa.ForeignKeyConstraint(['user_id'], ['users.id'], ondelete='cascade'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('user_id', 'partner_id', name='uq_user_id_partner_id')
|
||||
)
|
||||
op.create_index(op.f('ix_partner_user_partner_id'), 'partner_user', ['partner_id'], unique=False)
|
||||
op.create_index(op.f('ix_partner_user_user_id'), 'partner_user', ['user_id'], unique=False)
|
||||
op.add_column('users', sa.Column('partner_id', sa.BigInteger(), nullable=True))
|
||||
op.add_column('users', sa.Column('partner_user_id', sa.String(length=128), nullable=True))
|
||||
op.create_unique_constraint('uq_partner_id_partner_user_id', 'users', ['partner_id', 'partner_user_id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint('uq_partner_id_partner_user_id', 'users', type_='unique')
|
||||
op.drop_column('users', 'partner_user_id')
|
||||
op.drop_column('users', 'partner_id')
|
||||
op.drop_index(op.f('ix_partner_user_user_id'), table_name='partner_user')
|
||||
op.drop_index(op.f('ix_partner_user_partner_id'), table_name='partner_user')
|
||||
op.drop_table('partner_user')
|
||||
op.drop_index(op.f('ix_partner_api_token_token'), table_name='partner_api_token')
|
||||
op.drop_index(op.f('ix_partner_api_token_partner_id'), table_name='partner_api_token')
|
||||
op.drop_table('partner_api_token')
|
||||
op.drop_table('partner')
|
||||
# ### end Alembic commands ###
|
8
static/style.css
vendored
8
static/style.css
vendored
|
@ -186,4 +186,10 @@ textarea.parsley-error {
|
|||
#help-menu-item {
|
||||
display: none;
|
||||
}
|
||||
}
|
||||
}
|
||||
|
||||
.proton-button {
|
||||
border-color:#6d4aff;
|
||||
background-color:white;
|
||||
color:#6d4aff;
|
||||
}
|
||||
|
|
|
@ -13,39 +13,46 @@
|
|||
</div>
|
||||
{% endif %}
|
||||
|
||||
<form class="card" style="border-radius: 2%" method="post">
|
||||
{{ form.csrf_token }}
|
||||
<div class="card" style="border-radius: 2%">
|
||||
<div class="card-body p-6">
|
||||
<h1 class="card-title">Welcome back!</h1>
|
||||
<div class="form-group">
|
||||
<label class="form-label">Email address</label>
|
||||
{{ form.email(class="form-control", type="email", autofocus="true") }}
|
||||
{{ render_field_errors(form.email) }}
|
||||
</div>
|
||||
<form method="post">
|
||||
{{ form.csrf_token }}
|
||||
|
||||
<div class="form-group">
|
||||
<label class="form-label">
|
||||
Password
|
||||
</label>
|
||||
{{ form.password(class="form-control", type="password") }}
|
||||
{{ render_field_errors(form.password) }}
|
||||
<div class="text-muted">
|
||||
<a href="{{ url_for('auth.forgot_password') }}" class="small">
|
||||
I forgot my password
|
||||
</a>
|
||||
<div class="form-group">
|
||||
<label class="form-label">Email address</label>
|
||||
{{ form.email(class="form-control", type="email", autofocus="true") }}
|
||||
{{ render_field_errors(form.email) }}
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-footer">
|
||||
<button type="submit" class="btn btn-primary btn-block">Log in</button>
|
||||
</div>
|
||||
<div class="form-group">
|
||||
<label class="form-label">
|
||||
Password
|
||||
</label>
|
||||
{{ form.password(class="form-control", type="password") }}
|
||||
{{ render_field_errors(form.password) }}
|
||||
<div class="text-muted">
|
||||
<a href="{{ url_for('auth.forgot_password') }}" class="small">
|
||||
I forgot my password
|
||||
</a>
|
||||
</div>
|
||||
</div>
|
||||
|
||||
<div class="form-footer">
|
||||
<button type="submit" class="btn btn-primary btn-block">Log in</button>
|
||||
</div>
|
||||
|
||||
</form>
|
||||
{% if connect_with_proton %}
|
||||
<div class="text-center my-2 text-gray"><span>or</span></div>
|
||||
<a class="btn btn-primary btn-block mt-2 proton-button" href="{{ url_for("auth.proton_login") }}">Log in with Proton</a>
|
||||
{% endif %}
|
||||
</div>
|
||||
</form>
|
||||
|
||||
</div>
|
||||
|
||||
<div class="text-center text-muted mt-2">
|
||||
Don't have an account yet? <a href="{{ url_for('auth.register') }}">Sign up</a>
|
||||
</div>
|
||||
|
||||
{% endblock %}
|
||||
{% endblock %}
|
||||
|
|
|
@ -48,6 +48,11 @@
|
|||
<div class="mt-2">
|
||||
<button type="submit" class="btn btn-primary btn-block">Create Account</button>
|
||||
</div>
|
||||
|
||||
{% if connect_with_proton %}
|
||||
<div class="text-center my-2 text-gray"><span>or</span></div>
|
||||
<a class="btn btn-primary btn-block mt-2 proton-button" href="{{ url_for("auth.proton_login") }}">Sign up with Proton</a>
|
||||
{% endif %}
|
||||
</div>
|
||||
</form>
|
||||
<div class="text-center text-muted mb-6">
|
||||
|
|
|
@ -208,6 +208,32 @@
|
|||
</div>
|
||||
<!-- END Change email -->
|
||||
|
||||
<!-- Connect with Proton -->
|
||||
{% if connect_with_proton %}
|
||||
<div class="card">
|
||||
<div class="card-body">
|
||||
<div class="card-title">
|
||||
Connect with Proton
|
||||
</div>
|
||||
{% if proton_linked_account != None %}
|
||||
<div class="mb-3">
|
||||
You have linked your Proton account: {{ proton_linked_account }} <br>
|
||||
</div>
|
||||
<a
|
||||
class="btn btn-primary mt-2 proton-button"
|
||||
href="{{ url_for('dashboard.unlink_proton_account') }}"
|
||||
>Unlink account</a>
|
||||
{% else %}
|
||||
<div class="mb-3">
|
||||
You can connect your Proton account with your SimpleLogin one. <br>
|
||||
</div>
|
||||
<a class="btn btn-primary mt-2 proton-button" href="{{ url_for("auth.proton_login", action="link") }}">Connect with Proton</a>
|
||||
{% endif %}
|
||||
</div>
|
||||
</div>
|
||||
{% endif %}
|
||||
<!-- END Connect with Proton -->
|
||||
|
||||
<!-- Change password -->
|
||||
<div class="card" id="change_password">
|
||||
<div class="card-body">
|
||||
|
@ -539,7 +565,7 @@
|
|||
|
||||
<div class="form-check">
|
||||
<input type="checkbox" id="include-sender-header" name="enable"
|
||||
{% if current_user.include_header_email_header %} checked {% endif %} class="form-check-input">
|
||||
{% if current_user.include_header_email_header %} checked {% endif %} class="form-check-input">
|
||||
<label for="include-sender-header">Include sender address in email headers</label>
|
||||
</div>
|
||||
|
||||
|
|
23
tests/auth/test_proton.py
Normal file
23
tests/auth/test_proton.py
Normal file
|
@ -0,0 +1,23 @@
|
|||
from flask import url_for
|
||||
from urllib.parse import parse_qs
|
||||
from urllib3.util import parse_url
|
||||
|
||||
from app.config import URL, PROTON_CLIENT_ID
|
||||
|
||||
|
||||
def test_login_with_proton(flask_client):
|
||||
r = flask_client.get(
|
||||
url_for("auth.proton_login"),
|
||||
follow_redirects=False,
|
||||
)
|
||||
location = r.headers.get("Location")
|
||||
assert location is not None
|
||||
|
||||
parsed = parse_url(location)
|
||||
query = parse_qs(parsed.query)
|
||||
|
||||
expected_redirect_url = f"{URL}/auth/proton/callback"
|
||||
|
||||
assert "code" == query["response_type"][0]
|
||||
assert PROTON_CLIENT_ID == query["client_id"][0]
|
||||
assert expected_redirect_url == query["redirect_uri"][0]
|
|
@ -16,7 +16,7 @@ from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST
|
|||
import pytest
|
||||
|
||||
from server import create_app
|
||||
from init_app import add_sl_domains
|
||||
from init_app import add_sl_domains, add_proton_partner
|
||||
|
||||
app = create_app()
|
||||
app.config["TESTING"] = True
|
||||
|
@ -34,6 +34,7 @@ with engine.connect() as conn:
|
|||
conn.execute("Rollback")
|
||||
|
||||
add_sl_domains()
|
||||
add_proton_partner()
|
||||
|
||||
|
||||
@pytest.fixture
|
||||
|
|
0
tests/proton/__init__.py
Normal file
0
tests/proton/__init__.py
Normal file
285
tests/proton/test_proton_callback_handler.py
Normal file
285
tests/proton/test_proton_callback_handler.py
Normal file
|
@ -0,0 +1,285 @@
|
|||
import pytest
|
||||
|
||||
from app.db import Session
|
||||
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
|
||||
from app.proton.proton_callback_handler import (
|
||||
ProtonCallbackHandler,
|
||||
get_proton_partner_id,
|
||||
get_login_strategy,
|
||||
process_link_case,
|
||||
ProtonUser,
|
||||
UnexistantSlClientStrategy,
|
||||
ExistingSlClientStrategy,
|
||||
AlreadyLinkedUserStrategy,
|
||||
ExistingSlUserLinkedWithDifferentProtonAccountStrategy,
|
||||
ClientMergeStrategy,
|
||||
)
|
||||
from app.models import User, PartnerUser
|
||||
from app.utils import random_string
|
||||
|
||||
|
||||
class MockProtonClient(ProtonClient):
|
||||
def __init__(self, user: UserInformation, plan: ProtonPlan, organization: dict):
|
||||
self.user = user
|
||||
self.plan = plan
|
||||
self.organization = organization
|
||||
|
||||
def get_organization(self) -> dict:
|
||||
return self.organization
|
||||
|
||||
def get_user(self) -> UserInformation:
|
||||
return self.user
|
||||
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
return self.plan
|
||||
|
||||
|
||||
def random_email() -> str:
|
||||
return "{rand}@{rand}.com".format(rand=random_string(20))
|
||||
|
||||
|
||||
def random_proton_user(
|
||||
user_id: str = None,
|
||||
name: str = None,
|
||||
email: str = None,
|
||||
plan: ProtonPlan = None,
|
||||
) -> ProtonUser:
|
||||
user_id = user_id if user_id is not None else random_string()
|
||||
name = name if name is not None else random_string()
|
||||
email = (
|
||||
email
|
||||
if email is not None
|
||||
else "{rand}@{rand}.com".format(rand=random_string(20))
|
||||
)
|
||||
plan = plan if plan is not None else ProtonPlan.Free
|
||||
return ProtonUser(id=user_id, name=name, email=email, plan=plan)
|
||||
|
||||
|
||||
def create_user(email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def create_user_for_partner(partner_user_id: str, email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
user.partner_id = get_proton_partner_id()
|
||||
user.partner_user_id = partner_user_id
|
||||
|
||||
PartnerUser.create(
|
||||
user_id=user.id, partner_id=get_proton_partner_id(), partner_email=email
|
||||
)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_proton_callback_handler_unexistant_sl_user():
|
||||
email = random_email()
|
||||
name = random_string()
|
||||
external_id = random_string()
|
||||
user = UserInformation(email=email, name=name, id=external_id)
|
||||
mock_client = MockProtonClient(
|
||||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
res = handler.handle_login()
|
||||
|
||||
assert res.user is not None
|
||||
assert res.user.email == email
|
||||
assert res.user.name == name
|
||||
assert res.user.partner_user_id == external_id
|
||||
|
||||
|
||||
def test_proton_callback_handler_existant_sl_user():
|
||||
email = random_email()
|
||||
sl_user = User.create(email, commit=True)
|
||||
|
||||
external_id = random_string()
|
||||
user = UserInformation(email=email, name=random_string(), id=external_id)
|
||||
mock_client = MockProtonClient(
|
||||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
res = handler.handle_login()
|
||||
|
||||
assert res.user is not None
|
||||
assert res.user.id == sl_user.id
|
||||
|
||||
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner_id())
|
||||
assert sa is not None
|
||||
assert sa.partner_email == user.email
|
||||
|
||||
|
||||
def test_get_strategy_unexistant_sl_user():
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(),
|
||||
sl_user=None,
|
||||
)
|
||||
assert isinstance(strategy, UnexistantSlClientStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user():
|
||||
email = random_email()
|
||||
sl_user = User.create(email, commit=True)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(email=email),
|
||||
sl_user=sl_user,
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlClientStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_already_linked_user():
|
||||
email = random_email()
|
||||
proton_user_id = random_string()
|
||||
sl_user = create_user_for_partner(proton_user_id, email=email)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(user_id=proton_user_id, email=email),
|
||||
sl_user=sl_user,
|
||||
)
|
||||
assert isinstance(strategy, AlreadyLinkedUserStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
|
||||
# In this scenario we have
|
||||
# - ProtonUser1 (ID1, email1@proton)
|
||||
# - ProtonUser2 (ID2, email2@proton)
|
||||
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
|
||||
# We will try to log in with email1@proton
|
||||
email1 = random_email()
|
||||
email2 = random_email()
|
||||
proton_user_id_1 = random_string()
|
||||
proton_user_id_2 = random_string()
|
||||
|
||||
proton_user_1 = random_proton_user(user_id=proton_user_id_1, email=email1)
|
||||
proton_user_2 = random_proton_user(user_id=proton_user_id_2, email=email2)
|
||||
|
||||
sl_user = create_user_for_partner(proton_user_2.id, email=proton_user_1.email)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=proton_user_1,
|
||||
sl_user=sl_user,
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
|
||||
|
||||
|
||||
##
|
||||
# LINK
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address(flask_client):
|
||||
# This is the most basic scenario
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (email1@proton)
|
||||
# - SimpleLoginUser registered with email1@proton
|
||||
# We will try to link both accounts
|
||||
|
||||
email = random_email()
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=email)
|
||||
sl_user = create_user(email)
|
||||
|
||||
res = process_link_case(proton_user, sl_user)
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner_id()
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address(flask_client):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser (bar@somethingelse)
|
||||
# We will try to link both accounts
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
|
||||
sl_user = create_user()
|
||||
|
||||
res = process_link_case(proton_user, sl_user)
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner_id()
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser1 (foo@proton)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
|
||||
proton_user_id = random_string()
|
||||
proton_email = random_email()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=proton_email)
|
||||
sl_user_1 = create_user(proton_email)
|
||||
sl_user_2 = create_user_for_partner(
|
||||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1)
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner_id()
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
|
||||
updated_user_2 = User.get(sl_user_2.id)
|
||||
assert updated_user_2.partner_id is None
|
||||
assert updated_user_2.partner_user_id is None
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser1 (bar@somethingelse)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
|
||||
sl_user_1 = create_user(random_email())
|
||||
sl_user_2 = create_user_for_partner(
|
||||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1)
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner_id()
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
partner_user_1 = PartnerUser.get_by(
|
||||
user_id=sl_user_1.id, partner_id=get_proton_partner_id()
|
||||
)
|
||||
assert partner_user_1 is not None
|
||||
assert partner_user_1.partner_email == proton_user.email
|
||||
|
||||
updated_user_2 = User.get(sl_user_2.id)
|
||||
assert updated_user_2.partner_id is None
|
||||
assert updated_user_2.partner_user_id is None
|
||||
partner_user_2 = PartnerUser.get_by(
|
||||
user_id=sl_user_2.id, partner_id=get_proton_partner_id()
|
||||
)
|
||||
assert partner_user_2 is None
|
||||
|
||||
|
||||
def test_cannot_create_instance_of_base_strategy():
|
||||
with pytest.raises(Exception):
|
||||
ClientMergeStrategy(random_proton_user(), None)
|
21
tests/proton/test_proton_client.py
Normal file
21
tests/proton/test_proton_client.py
Normal file
|
@ -0,0 +1,21 @@
|
|||
import pytest
|
||||
|
||||
from app.proton import proton_client
|
||||
|
||||
|
||||
def test_convert_access_token_valid():
|
||||
res = proton_client.convert_access_token("pt-abc-123")
|
||||
assert res.session_id == "abc"
|
||||
assert res.access_token == "123"
|
||||
|
||||
|
||||
def test_convert_access_token_not_containing_pt():
|
||||
with pytest.raises(Exception):
|
||||
proton_client.convert_access_token("pb-abc-123")
|
||||
|
||||
|
||||
def test_convert_access_token_not_containing_invalid_length():
|
||||
cases = ["pt-abc-too-long", "pt-short"]
|
||||
for case in cases:
|
||||
with pytest.raises(Exception):
|
||||
proton_client.convert_access_token(case)
|
|
@ -54,4 +54,9 @@ PGP_SENDER_PRIVATE_KEY_PATH=local_data/private-pgp.asc
|
|||
ALIAS_AUTOMATIC_DISABLE=true
|
||||
ALLOWED_REDIRECT_DOMAINS=["test.simplelogin.local"]
|
||||
|
||||
DMARC_CHECK_ENABLED=true
|
||||
|
||||
DMARC_CHECK_ENABLED=true
|
||||
|
||||
PROTON_CLIENT_ID=to_fill
|
||||
PROTON_CLIENT_SECRET=to_fill
|
||||
PROTON_BASE_URL=https://localhost/api
|
||||
|
|
|
@ -1,4 +1,5 @@
|
|||
from typing import List
|
||||
from urllib.parse import parse_qs
|
||||
|
||||
import pytest
|
||||
|
||||
|
@ -40,3 +41,19 @@ def generate_sanitize_url_cases() -> List:
|
|||
def test_sanitize_url(url, expected):
|
||||
sanitized = sanitize_next_url(url)
|
||||
assert expected == sanitized
|
||||
|
||||
|
||||
def test_parse_querystring():
|
||||
cases = [
|
||||
{"input": "", "expected": {}},
|
||||
{"input": "a=b", "expected": {"a": ["b"]}},
|
||||
{"input": "a=b&c=d", "expected": {"a": ["b"], "c": ["d"]}},
|
||||
{"input": "a=b&a=c", "expected": {"a": ["b", "c"]}},
|
||||
]
|
||||
|
||||
for case in cases:
|
||||
expected = case["expected"]
|
||||
res = parse_qs(case["input"])
|
||||
assert len(res) == len(expected)
|
||||
for k, v in expected.items():
|
||||
assert res[k] == v
|
||||
|
|
Loading…
Reference in a new issue