Separate code for proton callback handler (#1040)
* Separate code for proton callback handler * Upgrade migration * Use simple_login endpoint from Proton API * Remove unused classes * Rename Dto class to Data * Push rename * Moved link to PartnerUser to allow subscriptions to depend only on it * Fix test * PR comments * Add unique user_id constraint to PartnerUser * Added more logs * Added more logs Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
parent
faf67ff338
commit
c0a4c44e94
|
@ -0,0 +1,257 @@
|
|||
from abc import ABC, abstractmethod
|
||||
from arrow import Arrow
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.errors import AccountAlreadyLinkedToAnotherPartnerException
|
||||
from app.log import LOG
|
||||
from app.models import PartnerSubscription, Partner, PartnerUser, User
|
||||
from app.utils import random_string
|
||||
|
||||
|
||||
class SLPlanType(Enum):
|
||||
Free = 1
|
||||
Premium = 2
|
||||
|
||||
|
||||
@dataclass
|
||||
class SLPlan:
|
||||
type: SLPlanType
|
||||
expiration: Optional[Arrow]
|
||||
|
||||
|
||||
@dataclass
|
||||
class PartnerLinkRequest:
|
||||
name: str
|
||||
email: str
|
||||
external_user_id: str
|
||||
plan: SLPlan
|
||||
|
||||
|
||||
@dataclass
|
||||
class LinkResult:
|
||||
user: User
|
||||
strategy: str
|
||||
|
||||
|
||||
def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
|
||||
sub = PartnerSubscription.get_by(partner_user_id=partner_user.id)
|
||||
if plan.type == SLPlanType.Free:
|
||||
if sub is not None:
|
||||
LOG.i(
|
||||
f"Deleting partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
|
||||
)
|
||||
PartnerSubscription.delete(sub.id)
|
||||
else:
|
||||
if sub is None:
|
||||
LOG.i(
|
||||
f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
|
||||
)
|
||||
PartnerSubscription.create(
|
||||
partner_user_id=partner_user.id,
|
||||
end_at=plan.expiration,
|
||||
)
|
||||
else:
|
||||
LOG.i(
|
||||
f"Updating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
|
||||
)
|
||||
sub.end_at = plan.expiration
|
||||
Session.commit()
|
||||
|
||||
|
||||
def set_plan_for_user(user: User, plan: SLPlan, partner: Partner):
|
||||
partner_user = PartnerUser.get_by(partner_id=partner.id, user_id=user.id)
|
||||
if partner_user is None:
|
||||
return
|
||||
return set_plan_for_partner_user(partner_user, plan)
|
||||
|
||||
|
||||
def ensure_partner_user_exists_for_user(
|
||||
link_request: PartnerLinkRequest, sl_user: User, partner: Partner
|
||||
) -> PartnerUser:
|
||||
# Find partner_user by user_id
|
||||
res = PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id)
|
||||
if not res:
|
||||
res = PartnerUser.create(
|
||||
user_id=sl_user.id,
|
||||
partner_id=partner.id,
|
||||
partner_email=link_request.email,
|
||||
external_user_id=link_request.external_user_id,
|
||||
)
|
||||
Session.commit()
|
||||
LOG.i(
|
||||
f"Created new partner_user for partner:{partner.id} user:{sl_user.id} external_user_id:{link_request.external_user_id}. PartnerUser.id is {res.id}"
|
||||
)
|
||||
return res
|
||||
|
||||
|
||||
class ClientMergeStrategy(ABC):
|
||||
def __init__(
|
||||
self,
|
||||
link_request: PartnerLinkRequest,
|
||||
user: Optional[User],
|
||||
partner: Partner,
|
||||
):
|
||||
if self.__class__ == ClientMergeStrategy:
|
||||
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
|
||||
self.link_request = link_request
|
||||
self.user = user
|
||||
self.partner = partner
|
||||
|
||||
@abstractmethod
|
||||
def process(self) -> LinkResult:
|
||||
pass
|
||||
|
||||
|
||||
class NewUserStrategy(ClientMergeStrategy):
|
||||
def process(self) -> LinkResult:
|
||||
# Will create a new SL User with a random password
|
||||
new_user = User.create(
|
||||
email=self.link_request.email,
|
||||
name=self.link_request.name,
|
||||
password=random_string(20),
|
||||
)
|
||||
partner_user = PartnerUser.create(
|
||||
user_id=new_user.id,
|
||||
partner_id=self.partner.id,
|
||||
external_user_id=self.link_request.external_user_id,
|
||||
partner_email=self.link_request.email,
|
||||
)
|
||||
LOG.i(
|
||||
f"Created new user for login request for partner:{self.partner.id} external_user_id:{self.link_request.external_user_id}. New user {new_user.id} partner_user:{partner_user.id}"
|
||||
)
|
||||
set_plan_for_partner_user(
|
||||
partner_user,
|
||||
self.link_request.plan,
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
return LinkResult(
|
||||
user=new_user,
|
||||
strategy=self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class ExistingUnlinedUserStrategy(ClientMergeStrategy):
|
||||
def process(self) -> LinkResult:
|
||||
|
||||
partner_user = ensure_partner_user_exists_for_user(
|
||||
self.link_request, self.user, self.partner
|
||||
)
|
||||
set_plan_for_partner_user(partner_user, self.link_request.plan)
|
||||
|
||||
return LinkResult(
|
||||
user=self.user,
|
||||
strategy=self.__class__.__name__,
|
||||
)
|
||||
|
||||
|
||||
class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy):
|
||||
def process(self) -> LinkResult:
|
||||
raise AccountAlreadyLinkedToAnotherPartnerException()
|
||||
|
||||
|
||||
def get_login_strategy(
|
||||
link_request: PartnerLinkRequest, user: Optional[User], partner: Partner
|
||||
) -> ClientMergeStrategy:
|
||||
if user is None:
|
||||
# We couldn't find any SimpleLogin user with the requested e-mail
|
||||
return NewUserStrategy(link_request, user, partner)
|
||||
# Check if user is already linked with another partner_user
|
||||
other_partner_user = PartnerUser.get_by(partner_id=partner.id, user_id=user.id)
|
||||
if other_partner_user is not None:
|
||||
return LinkedWithAnotherPartnerUserStrategy(link_request, user, partner)
|
||||
# There is a SimpleLogin user with the partner_user's e-mail
|
||||
return ExistingUnlinedUserStrategy(link_request, user, partner)
|
||||
|
||||
|
||||
def process_login_case(
|
||||
link_request: PartnerLinkRequest, partner: Partner
|
||||
) -> LinkResult:
|
||||
# Try to find a SimpleLogin user registered with that partner user id
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=partner.id, external_user_id=link_request.external_user_id
|
||||
)
|
||||
if partner_user is None:
|
||||
# We didn't find any SimpleLogin user registered with that partner user id
|
||||
# Try to find it using the partner's e-mail address
|
||||
user = User.get_by(email=link_request.email)
|
||||
return get_login_strategy(link_request, user, partner).process()
|
||||
else:
|
||||
# We found the SL user registered with that partner user id
|
||||
# We're done
|
||||
set_plan_for_partner_user(partner_user, link_request.plan)
|
||||
# It's the same user. No need to do anything
|
||||
return LinkResult(
|
||||
user=partner_user.user,
|
||||
strategy="Link",
|
||||
)
|
||||
|
||||
|
||||
def link_user(
|
||||
link_request: PartnerLinkRequest, current_user: User, partner: Partner
|
||||
) -> LinkResult:
|
||||
partner_user = ensure_partner_user_exists_for_user(
|
||||
link_request, current_user, partner
|
||||
)
|
||||
set_plan_for_partner_user(partner_user, link_request.plan)
|
||||
|
||||
Session.commit()
|
||||
return LinkResult(
|
||||
user=current_user,
|
||||
strategy="Link",
|
||||
)
|
||||
|
||||
|
||||
def switch_already_linked_user(
|
||||
link_request: PartnerLinkRequest, partner_user: PartnerUser, current_user: User
|
||||
):
|
||||
# Find if the user has another link and unlink it
|
||||
other_partner_user = PartnerUser.get_by(
|
||||
user_id=current_user.id,
|
||||
partner_id=partner_user.partner_id,
|
||||
)
|
||||
if other_partner_user is not None:
|
||||
LOG.i(
|
||||
f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}"
|
||||
)
|
||||
PartnerUser.delete(other_partner_user.id)
|
||||
LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}")
|
||||
# Link this partner_user to the current user
|
||||
partner_user.user_id = current_user.id
|
||||
# Set plan
|
||||
set_plan_for_partner_user(partner_user, link_request.plan)
|
||||
Session.commit()
|
||||
return LinkResult(
|
||||
user=current_user,
|
||||
strategy="Link",
|
||||
)
|
||||
|
||||
|
||||
def process_link_case(
|
||||
link_request: PartnerLinkRequest,
|
||||
current_user: User,
|
||||
partner: Partner,
|
||||
) -> LinkResult:
|
||||
# Try to find a SimpleLogin user linked with this Partner account
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=partner.id, external_user_id=link_request.external_user_id
|
||||
)
|
||||
if partner_user is None:
|
||||
# There is no SL user linked with the partner. Proceed with linking
|
||||
return link_user(link_request, current_user, partner)
|
||||
|
||||
# There is a SL user registered with the partner. Check if is the current one
|
||||
if partner_user.id == current_user.id:
|
||||
# Update plan
|
||||
set_plan_for_partner_user(partner_user, link_request.plan)
|
||||
# It's the same user. No need to do anything
|
||||
return LinkResult(
|
||||
user=current_user,
|
||||
strategy="Link",
|
||||
)
|
||||
else:
|
||||
|
||||
return switch_already_linked_user(link_request, partner_user, current_user)
|
|
@ -120,7 +120,7 @@ def proton_callback():
|
|||
return redirect(url_for("auth.login"))
|
||||
|
||||
if res.redirect:
|
||||
return redirect(res.redirect)
|
||||
return after_login(res.user, res.redirect)
|
||||
|
||||
next_url = session.get("oauth_next")
|
||||
return after_login(res.user, next_url)
|
||||
|
|
|
@ -446,8 +446,6 @@ def cancel_email_change():
|
|||
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
|
||||
@login_required
|
||||
def unlink_proton_account():
|
||||
current_user.partner_id = None
|
||||
current_user.partner_user_id = None
|
||||
partner_user = PartnerUser.get_by(
|
||||
user_id=current_user.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
|
|
|
@ -93,3 +93,18 @@ class ErrContactAlreadyExists(SLException):
|
|||
|
||||
def error_for_user(self) -> str:
|
||||
return f"{self.contact.website_email} is already added"
|
||||
|
||||
|
||||
class LinkException(SLException):
|
||||
def __init__(self, message: str):
|
||||
self.message = message
|
||||
|
||||
|
||||
class AccountAlreadyLinkedToAnotherPartnerException(LinkException):
|
||||
def __init__(self):
|
||||
super().__init__("This account is already linked to another partner")
|
||||
|
||||
|
||||
class AccountAlreadyLinkedToAnotherUserException(LinkException):
|
||||
def __init__(self):
|
||||
super().__init__("This account is linked to another user")
|
||||
|
|
|
@ -66,6 +66,7 @@ from app.utils import (
|
|||
Base = declarative_base()
|
||||
|
||||
PADDLE_SUBSCRIPTION_GRACE_DAYS = 14
|
||||
_PARTNER_SUBSCRIPTION_GRACE_DAYS = 14
|
||||
|
||||
|
||||
class TSVector(sa.types.TypeDecorator):
|
||||
|
@ -492,15 +493,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
sa.Boolean, default=False, nullable=False, server_default="1"
|
||||
)
|
||||
|
||||
partner_id = sa.Column(sa.BigInteger, unique=False, nullable=True)
|
||||
partner_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
|
||||
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint(
|
||||
"partner_id", "partner_user_id", name="uq_partner_id_partner_user_id"
|
||||
),
|
||||
)
|
||||
|
||||
# bitwise flags. Allow for future expansion
|
||||
flags = sa.Column(
|
||||
sa.BigInteger,
|
||||
|
@ -613,6 +605,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
if coinbase_subscription and coinbase_subscription.is_active():
|
||||
return True
|
||||
|
||||
partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id(self.id)
|
||||
if partner_sub and partner_sub.is_active():
|
||||
return True
|
||||
|
||||
return False
|
||||
|
||||
def is_paid(self) -> bool:
|
||||
|
@ -3138,18 +3134,60 @@ class PartnerUser(Base, ModelMixin):
|
|||
|
||||
user_id = sa.Column(
|
||||
sa.ForeignKey("users.id", ondelete="cascade"),
|
||||
unique=False,
|
||||
unique=True,
|
||||
nullable=False,
|
||||
index=True,
|
||||
)
|
||||
partner_id = sa.Column(
|
||||
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
|
||||
)
|
||||
external_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
|
||||
partner_email = sa.Column(sa.String(255), unique=False, nullable=True)
|
||||
|
||||
user = orm.relationship(User, foreign_keys=[user_id])
|
||||
|
||||
__table_args__ = (
|
||||
sa.UniqueConstraint("user_id", "partner_id", name="uq_user_id_partner_id"),
|
||||
sa.UniqueConstraint(
|
||||
"partner_id", "external_user_id", name="uq_partner_id_external_user_id"
|
||||
),
|
||||
)
|
||||
|
||||
|
||||
class PartnerSubscription(Base, ModelMixin):
|
||||
"""
|
||||
For users who have a subscription via a partner
|
||||
"""
|
||||
|
||||
__tablename__ = "partner_subscription"
|
||||
|
||||
partner_user_id = sa.Column(
|
||||
sa.ForeignKey(PartnerUser.id, ondelete="cascade"), nullable=False, unique=True
|
||||
)
|
||||
|
||||
# when the partner subscription ends
|
||||
end_at = sa.Column(ArrowType, nullable=False)
|
||||
|
||||
partner_user = orm.relationship(PartnerUser)
|
||||
|
||||
@classmethod
|
||||
def find_by_user_id(cls, user_id: int) -> Optional[PartnerSubscription]:
|
||||
res = (
|
||||
Session.query(PartnerSubscription, PartnerUser)
|
||||
.filter(
|
||||
and_(
|
||||
PartnerUser.user_id == user_id,
|
||||
PartnerSubscription.partner_user_id == PartnerUser.id,
|
||||
)
|
||||
)
|
||||
.first()
|
||||
)
|
||||
if res:
|
||||
subscription, partner_user = res
|
||||
return subscription
|
||||
return None
|
||||
|
||||
def is_active(self):
|
||||
return self.end_at > arrow.now().shift(days=-_PARTNER_SUBSCRIPTION_GRACE_DAYS)
|
||||
|
||||
|
||||
# endregion
|
||||
|
|
|
@ -1,14 +1,17 @@
|
|||
import enum
|
||||
from abc import ABC, abstractmethod
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from flask import url_for
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.errors import ProtonPartnerNotSetUp
|
||||
from app.models import User, PartnerUser, Partner
|
||||
from app.errors import LinkException, ProtonPartnerNotSetUp
|
||||
from app.models import User, Partner
|
||||
from app.proton.proton_client import ProtonClient, ProtonUser
|
||||
from app.utils import random_string
|
||||
from app.account_linking import (
|
||||
process_login_case,
|
||||
process_link_case,
|
||||
PartnerLinkRequest,
|
||||
)
|
||||
|
||||
PROTON_PARTNER_NAME = "Proton"
|
||||
_PROTON_PARTNER: Optional[Partner] = None
|
||||
|
@ -25,7 +28,7 @@ def get_proton_partner() -> Partner:
|
|||
return _PROTON_PARTNER
|
||||
|
||||
|
||||
class Action(enum.Enum):
|
||||
class Action(Enum):
|
||||
Login = 1
|
||||
Link = 2
|
||||
|
||||
|
@ -39,208 +42,81 @@ class ProtonCallbackResult:
|
|||
user: Optional[User]
|
||||
|
||||
|
||||
def ensure_partner_user_exists(
|
||||
proton_user: ProtonUser, sl_user: User, partner: Partner
|
||||
):
|
||||
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id):
|
||||
PartnerUser.create(
|
||||
user_id=sl_user.id,
|
||||
partner_id=partner.id,
|
||||
partner_email=proton_user.email,
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
|
||||
class ClientMergeStrategy(ABC):
|
||||
def __init__(
|
||||
self, proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
|
||||
):
|
||||
if self.__class__ == ClientMergeStrategy:
|
||||
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
|
||||
self.proton_user = proton_user
|
||||
self.sl_user = sl_user
|
||||
self.partner = partner
|
||||
|
||||
@abstractmethod
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
pass
|
||||
|
||||
|
||||
class UnexistantSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
# Will create a new SL User with a random password
|
||||
new_user = User.create(
|
||||
email=self.proton_user.email,
|
||||
name=self.proton_user.name,
|
||||
partner_user_id=self.proton_user.id,
|
||||
partner_id=self.partner.id,
|
||||
password=random_string(20),
|
||||
)
|
||||
PartnerUser.create(
|
||||
user_id=new_user.id,
|
||||
partner_id=self.partner.id,
|
||||
partner_email=self.proton_user.email,
|
||||
)
|
||||
# TODO: Adjust plans
|
||||
Session.commit()
|
||||
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=new_user,
|
||||
)
|
||||
|
||||
|
||||
class ExistingSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
ensure_partner_user_exists(self.proton_user, self.sl_user, self.partner)
|
||||
# TODO: Adjust plans
|
||||
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=self.sl_user,
|
||||
)
|
||||
|
||||
|
||||
class ExistingSlUserLinkedWithDifferentProtonAccountStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=True,
|
||||
flash_message="This Proton account is already linked to another account",
|
||||
flash_category="error",
|
||||
user=None,
|
||||
redirect=None,
|
||||
)
|
||||
|
||||
|
||||
class AlreadyLinkedUserStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=self.sl_user,
|
||||
)
|
||||
|
||||
|
||||
def get_login_strategy(
|
||||
proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
|
||||
) -> ClientMergeStrategy:
|
||||
if sl_user is None:
|
||||
# We couldn't find any SimpleLogin user with the requested e-mail
|
||||
return UnexistantSlClientStrategy(proton_user, sl_user, partner)
|
||||
# There is a SimpleLogin user with the proton_user's e-mail
|
||||
# Try to find if it has been registered via a partner
|
||||
if sl_user.partner_id is None:
|
||||
# It has not been registered via a Partner
|
||||
return ExistingSlClientStrategy(proton_user, sl_user, partner)
|
||||
# It has been registered via a partner
|
||||
# Check if the partner_user_id matches
|
||||
if sl_user.partner_user_id != proton_user.id:
|
||||
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked
|
||||
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
|
||||
proton_user, sl_user, partner
|
||||
)
|
||||
# This case means that the sl_user is already linked, so nothing to do
|
||||
return AlreadyLinkedUserStrategy(proton_user, sl_user, partner)
|
||||
|
||||
|
||||
def process_login_case(
|
||||
proton_user: ProtonUser, partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user registered with that proton user id
|
||||
sl_user_with_external_id = User.get_by(
|
||||
partner_id=partner.id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_with_external_id is None:
|
||||
# We didn't find any SimpleLogin user registered with that proton user id
|
||||
# Try to find it using the proton's e-mail address
|
||||
sl_user = User.get_by(email=proton_user.email)
|
||||
return get_login_strategy(proton_user, sl_user, partner).process()
|
||||
else:
|
||||
# We found the SL user registered with that proton user id
|
||||
# We're done
|
||||
return AlreadyLinkedUserStrategy(
|
||||
proton_user, sl_user_with_external_id, partner
|
||||
).process()
|
||||
|
||||
|
||||
def link_user(
|
||||
proton_user: ProtonUser, current_user: User, partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
current_user.partner_user_id = proton_user.id
|
||||
current_user.partner_id = partner.id
|
||||
|
||||
ensure_partner_user_exists(proton_user, current_user, partner)
|
||||
|
||||
Session.commit()
|
||||
def generate_account_not_allowed_to_log_in() -> ProtonCallbackResult:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
redirect=url_for("dashboard.setting"),
|
||||
flash_category="success",
|
||||
flash_message="Account successfully linked",
|
||||
user=current_user,
|
||||
redirect_to_login=True,
|
||||
flash_message="This account is not allowed to log in with Proton. Please convert your account to a full Proton account",
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
|
||||
def process_link_case(
|
||||
proton_user: ProtonUser,
|
||||
current_user: User,
|
||||
partner: Partner,
|
||||
) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user linked with this Proton account
|
||||
sl_user_linked_to_proton_account = User.get_by(
|
||||
partner_id=partner.id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_linked_to_proton_account is None:
|
||||
# There is no SL user linked with the proton email. Proceed with linking
|
||||
return link_user(proton_user, current_user, partner)
|
||||
else:
|
||||
# There is a SL user registered with the proton email. Check if is the current one
|
||||
if sl_user_linked_to_proton_account.id == current_user.id:
|
||||
# It's the same user. No need to do anything
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
redirect=url_for("dashboard.setting"),
|
||||
flash_category="success",
|
||||
flash_message="Account successfully linked",
|
||||
user=current_user,
|
||||
)
|
||||
else:
|
||||
# It's a different user. Unlink the other account and link the current one
|
||||
sl_user_linked_to_proton_account.partner_id = None
|
||||
sl_user_linked_to_proton_account.partner_user_id = None
|
||||
other_partner_user = PartnerUser.get_by(
|
||||
user_id=sl_user_linked_to_proton_account.id,
|
||||
partner_id=partner.id,
|
||||
)
|
||||
if other_partner_user is not None:
|
||||
PartnerUser.delete(other_partner_user.id)
|
||||
|
||||
return link_user(proton_user, current_user, partner)
|
||||
|
||||
|
||||
class ProtonCallbackHandler:
|
||||
def __init__(self, proton_client: ProtonClient):
|
||||
self.proton_client = proton_client
|
||||
|
||||
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
|
||||
return process_login_case(self.__get_proton_user(), partner)
|
||||
try:
|
||||
user = self.__get_partner_user()
|
||||
if user is None:
|
||||
return generate_account_not_allowed_to_log_in()
|
||||
res = process_login_case(user, partner)
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=None,
|
||||
flash_category=None,
|
||||
redirect=None,
|
||||
user=res.user,
|
||||
)
|
||||
except LinkException as e:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=True,
|
||||
flash_message=e.message,
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
def handle_link(
|
||||
self, current_user: Optional[User], partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
if current_user is None:
|
||||
raise Exception("Cannot link account with current_user being None")
|
||||
return process_link_case(self.__get_proton_user(), current_user, partner)
|
||||
try:
|
||||
user = self.__get_partner_user()
|
||||
if user is None:
|
||||
return generate_account_not_allowed_to_log_in()
|
||||
res = process_link_case(user, current_user, partner)
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message="Account successfully linked",
|
||||
flash_category="success",
|
||||
redirect=url_for("dashboard.setting"),
|
||||
user=res.user,
|
||||
)
|
||||
except LinkException as e:
|
||||
return ProtonCallbackResult(
|
||||
redirect_to_login=False,
|
||||
flash_message=e.message,
|
||||
flash_category="error",
|
||||
redirect=None,
|
||||
user=None,
|
||||
)
|
||||
|
||||
def __get_proton_user(self) -> ProtonUser:
|
||||
def __get_partner_user(self) -> Optional[PartnerLinkRequest]:
|
||||
proton_user = self.__get_proton_user()
|
||||
if proton_user is None:
|
||||
return None
|
||||
return PartnerLinkRequest(
|
||||
email=proton_user.email,
|
||||
external_user_id=proton_user.id,
|
||||
name=proton_user.name,
|
||||
plan=proton_user.plan,
|
||||
)
|
||||
|
||||
def __get_proton_user(self) -> Optional[ProtonUser]:
|
||||
user = self.proton_client.get_user()
|
||||
plan = self.proton_client.get_plan()
|
||||
return ProtonUser(email=user.email, plan=plan, name=user.name, id=user.id)
|
||||
if user is None:
|
||||
return None
|
||||
return ProtonUser(email=user.email, plan=user.plan, name=user.name, id=user.id)
|
||||
|
|
|
@ -1,78 +1,37 @@
|
|||
import dataclasses
|
||||
from abc import ABC, abstractmethod
|
||||
from enum import Enum
|
||||
from dataclasses import dataclass
|
||||
from http import HTTPStatus
|
||||
from requests import Response, Session
|
||||
from typing import Optional
|
||||
|
||||
from app.account_linking import SLPlan, SLPlanType
|
||||
from app.log import LOG
|
||||
|
||||
_APP_VERSION = "OauthClient_1.0.0"
|
||||
|
||||
PROTON_ERROR_CODE_NOT_EXISTS = 2501
|
||||
|
||||
|
||||
class ProtonPlan(Enum):
|
||||
Free = 0
|
||||
Professional = 1
|
||||
Visionary = 2
|
||||
|
||||
def name(self):
|
||||
if self == self.Free:
|
||||
return "Free"
|
||||
elif self == self.Professional:
|
||||
return "Professional"
|
||||
elif self == self.Visionary:
|
||||
return "Visionary"
|
||||
else:
|
||||
raise Exception("Unknown plan")
|
||||
PLAN_FREE = 1
|
||||
PLAN_PREMIUM = 2
|
||||
|
||||
|
||||
def plan_from_name(name: str) -> ProtonPlan:
|
||||
name_lower = name.lower()
|
||||
if name_lower == "free":
|
||||
return ProtonPlan.Free
|
||||
elif name_lower == "professional":
|
||||
return ProtonPlan.Professional
|
||||
elif name_lower == "visionary":
|
||||
return ProtonPlan.Visionary
|
||||
else:
|
||||
raise Exception(f"Unknown plan [{name}]")
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class UserInformation:
|
||||
email: str
|
||||
name: str
|
||||
id: str
|
||||
plan: SLPlan
|
||||
|
||||
|
||||
class AuthorizeResponse:
|
||||
def __init__(self, code: str, has_accepted: bool):
|
||||
self.code = code
|
||||
self.has_accepted = has_accepted
|
||||
|
||||
def __str__(self):
|
||||
return f"[code={self.code}] [has_accepted={self.has_accepted}]"
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
class SessionResponse:
|
||||
state: str
|
||||
expires_in: int
|
||||
token_type: str
|
||||
refresh_token: str
|
||||
access_token: str
|
||||
session_id: str
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class ProtonUser:
|
||||
id: str
|
||||
name: str
|
||||
email: str
|
||||
plan: ProtonPlan
|
||||
plan: SLPlan
|
||||
|
||||
|
||||
@dataclasses.dataclass
|
||||
@dataclass
|
||||
class AccessCredentials:
|
||||
access_token: str
|
||||
session_id: str
|
||||
|
@ -98,15 +57,7 @@ def convert_access_token(access_token_response: str) -> AccessCredentials:
|
|||
|
||||
class ProtonClient(ABC):
|
||||
@abstractmethod
|
||||
def get_user(self) -> UserInformation:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_organization(self) -> dict:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
def get_user(self) -> Optional[UserInformation]:
|
||||
pass
|
||||
|
||||
|
||||
|
@ -135,31 +86,27 @@ class HttpProtonClient(ProtonClient):
|
|||
client.headers.update(headers)
|
||||
self.client = client
|
||||
|
||||
def get_user(self) -> UserInformation:
|
||||
info = self.__get("/users")["User"]
|
||||
def get_user(self) -> Optional[UserInformation]:
|
||||
info = self.__get("/simple_login/v1/subscription")["Subscription"]
|
||||
if not info["IsAllowed"]:
|
||||
LOG.debug("Account is not allowed to log into SL")
|
||||
return None
|
||||
|
||||
plan_value = info["Plan"]
|
||||
if plan_value == PLAN_FREE:
|
||||
plan = SLPlan(type=SLPlanType.Free, expiration=None)
|
||||
elif plan_value == PLAN_PREMIUM:
|
||||
plan = SLPlan(type=SLPlanType.Premium, expiration=info["PlanExpiration"])
|
||||
else:
|
||||
raise Exception(f"Invalid value for plan: {plan_value}")
|
||||
|
||||
return UserInformation(
|
||||
email=info.get("Email"), name=info.get("Name"), id=info.get("ID")
|
||||
email=info.get("Email"),
|
||||
name=info.get("DisplayName"),
|
||||
id=info.get("UserID"),
|
||||
plan=plan,
|
||||
)
|
||||
|
||||
def get_organization(self) -> dict:
|
||||
return self.__get("/code/v4/organizations")["Organization"]
|
||||
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
url = f"{self.base_url}/core/v4/organizations"
|
||||
res = self.client.get(url)
|
||||
|
||||
status = res.status_code
|
||||
if status == HTTPStatus.UNPROCESSABLE_ENTITY:
|
||||
as_json = res.json()
|
||||
error_code = as_json.get("Code")
|
||||
if error_code == PROTON_ERROR_CODE_NOT_EXISTS:
|
||||
return ProtonPlan.Free
|
||||
|
||||
org = self.__validate_response(res).get("Organization")
|
||||
if org is None:
|
||||
return ProtonPlan.Free
|
||||
return plan_from_name(org["PlanName"])
|
||||
|
||||
def __get(self, route: str) -> dict:
|
||||
url = f"{self.base_url}{route}"
|
||||
res = self.client.get(url)
|
||||
|
@ -172,4 +119,10 @@ class HttpProtonClient(ProtonClient):
|
|||
raise Exception(
|
||||
f"Unexpected status code. Wanted 200 and got {status}: " + res.text
|
||||
)
|
||||
return res.json()
|
||||
as_json = res.json()
|
||||
res_code = as_json.get("Code")
|
||||
if not res_code or res_code != 1000:
|
||||
raise Exception(
|
||||
f"Unexpected response code. Wanted 1000 and got {res_code}: " + res.text
|
||||
)
|
||||
return as_json
|
||||
|
|
|
@ -0,0 +1,54 @@
|
|||
"""partner_user and partner_subscription
|
||||
|
||||
Revision ID: 82d3c7109ffb
|
||||
Revises: 2b1d3cd93e4b
|
||||
Create Date: 2022-06-09 08:25:09.078840
|
||||
|
||||
"""
|
||||
import sqlalchemy_utils
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '82d3c7109ffb'
|
||||
down_revision = '2b1d3cd93e4b'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.create_table('partner_subscription',
|
||||
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
|
||||
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
|
||||
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
|
||||
sa.Column('partner_user_id', sa.Integer(), nullable=False),
|
||||
sa.Column('end_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
|
||||
sa.ForeignKeyConstraint(['partner_user_id'], ['partner_user.id'], ondelete='cascade'),
|
||||
sa.PrimaryKeyConstraint('id'),
|
||||
sa.UniqueConstraint('partner_user_id')
|
||||
)
|
||||
op.add_column('partner_user', sa.Column('external_user_id', sa.String(length=128), nullable=True))
|
||||
op.create_unique_constraint('uq_partner_id_external_user_id', 'partner_user', ['partner_id', 'external_user_id'])
|
||||
op.drop_index('ix_partner_user_user_id', table_name='partner_user')
|
||||
op.create_index(op.f('ix_partner_user_user_id'), 'partner_user', ['user_id'], unique=True)
|
||||
op.drop_constraint('uq_user_id_partner_id', 'partner_user', type_='unique')
|
||||
op.drop_constraint('uq_partner_id_partner_user_id', 'users', type_='unique')
|
||||
op.drop_column('users', 'partner_id')
|
||||
op.drop_column('users', 'partner_user_id')
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('users', sa.Column('partner_user_id', sa.VARCHAR(length=128), autoincrement=False, nullable=True))
|
||||
op.add_column('users', sa.Column('partner_id', sa.BIGINT(), autoincrement=False, nullable=True))
|
||||
op.create_unique_constraint('uq_partner_id_partner_user_id', 'users', ['partner_id', 'partner_user_id'])
|
||||
op.create_unique_constraint('uq_user_id_partner_id', 'partner_user', ['user_id', 'partner_id'])
|
||||
op.drop_index(op.f('ix_partner_user_user_id'), table_name='partner_user')
|
||||
op.create_index('ix_partner_user_user_id', 'partner_user', ['user_id'], unique=False)
|
||||
op.drop_constraint('uq_partner_id_external_user_id', 'partner_user', type_='unique')
|
||||
op.drop_column('partner_user', 'external_user_id')
|
||||
op.drop_table('partner_subscription')
|
||||
# ### end Alembic commands ###
|
|
@ -0,0 +1,37 @@
|
|||
from arrow import Arrow
|
||||
from app.models import Partner, PartnerUser, PartnerSubscription
|
||||
from app.utils import random_string
|
||||
from tests.utils import create_new_user, random_email
|
||||
|
||||
|
||||
def test_generate_partner_subscription(flask_client):
|
||||
partner = Partner.create(
|
||||
name=random_string(10),
|
||||
contact_email=random_email(),
|
||||
commit=True,
|
||||
)
|
||||
user = create_new_user()
|
||||
partner_user = PartnerUser.create(
|
||||
user_id=user.id,
|
||||
partner_id=partner.id,
|
||||
partner_email=random_email(),
|
||||
commit=True,
|
||||
)
|
||||
|
||||
subs = PartnerSubscription.create(
|
||||
partner_user_id=partner_user.id,
|
||||
end_at=Arrow.utcnow().shift(hours=1),
|
||||
commit=True,
|
||||
)
|
||||
|
||||
retrieved_subscription = PartnerSubscription.find_by_user_id(user.id)
|
||||
|
||||
assert retrieved_subscription is not None
|
||||
assert retrieved_subscription.id == subs.id
|
||||
|
||||
assert user.lifetime_or_active_subscription() is True
|
||||
|
||||
|
||||
def test_partner_subscription_for_not_partner_subscription_user(flask_client):
|
||||
unexistant_subscription = PartnerSubscription.find_by_user_id(999999)
|
||||
assert unexistant_subscription is None
|
|
@ -1,95 +1,55 @@
|
|||
import pytest
|
||||
|
||||
from app.db import Session
|
||||
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
|
||||
from arrow import Arrow
|
||||
from app.account_linking import (
|
||||
SLPlan,
|
||||
SLPlanType,
|
||||
)
|
||||
from app.proton.proton_client import ProtonClient, UserInformation
|
||||
from app.proton.proton_callback_handler import (
|
||||
ProtonCallbackHandler,
|
||||
get_proton_partner,
|
||||
get_login_strategy,
|
||||
process_link_case,
|
||||
ProtonUser,
|
||||
UnexistantSlClientStrategy,
|
||||
ExistingSlClientStrategy,
|
||||
AlreadyLinkedUserStrategy,
|
||||
ExistingSlUserLinkedWithDifferentProtonAccountStrategy,
|
||||
ClientMergeStrategy,
|
||||
generate_account_not_allowed_to_log_in,
|
||||
)
|
||||
from app.models import User, PartnerUser
|
||||
from app.utils import random_string
|
||||
from typing import Optional
|
||||
from tests.utils import random_email
|
||||
|
||||
|
||||
class MockProtonClient(ProtonClient):
|
||||
def __init__(self, user: UserInformation, plan: ProtonPlan, organization: dict):
|
||||
def __init__(self, user: Optional[UserInformation]):
|
||||
self.user = user
|
||||
self.plan = plan
|
||||
self.organization = organization
|
||||
|
||||
def get_organization(self) -> dict:
|
||||
return self.organization
|
||||
|
||||
def get_user(self) -> UserInformation:
|
||||
def get_user(self) -> Optional[UserInformation]:
|
||||
return self.user
|
||||
|
||||
def get_plan(self) -> ProtonPlan:
|
||||
return self.plan
|
||||
|
||||
|
||||
def random_email() -> str:
|
||||
return "{rand}@{rand}.com".format(rand=random_string(20))
|
||||
|
||||
|
||||
def random_proton_user(
|
||||
user_id: str = None,
|
||||
name: str = None,
|
||||
email: str = None,
|
||||
plan: ProtonPlan = None,
|
||||
) -> ProtonUser:
|
||||
user_id = user_id if user_id is not None else random_string()
|
||||
name = name if name is not None else random_string()
|
||||
email = (
|
||||
email
|
||||
if email is not None
|
||||
else "{rand}@{rand}.com".format(rand=random_string(20))
|
||||
)
|
||||
plan = plan if plan is not None else ProtonPlan.Free
|
||||
return ProtonUser(id=user_id, name=name, email=email, plan=plan)
|
||||
|
||||
|
||||
def create_user(email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def create_user_for_partner(partner_user_id: str, email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
user.partner_id = get_proton_partner().id
|
||||
user.partner_user_id = partner_user_id
|
||||
|
||||
PartnerUser.create(
|
||||
user_id=user.id, partner_id=get_proton_partner().id, partner_email=email
|
||||
)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_proton_callback_handler_unexistant_sl_user():
|
||||
email = random_email()
|
||||
name = random_string()
|
||||
external_id = random_string()
|
||||
user = UserInformation(email=email, name=name, id=external_id)
|
||||
mock_client = MockProtonClient(
|
||||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
handler = ProtonCallbackHandler(
|
||||
MockProtonClient(
|
||||
user=UserInformation(
|
||||
email=email,
|
||||
name=name,
|
||||
id=external_id,
|
||||
plan=SLPlan(
|
||||
type=SLPlanType.Premium, expiration=Arrow.utcnow().shift(hours=2)
|
||||
),
|
||||
)
|
||||
)
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
res = handler.handle_login(get_proton_partner())
|
||||
|
||||
assert res.user is not None
|
||||
assert res.user.email == email
|
||||
assert res.user.name == name
|
||||
assert res.user.partner_user_id == external_id
|
||||
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=get_proton_partner().id, user_id=res.user.id
|
||||
)
|
||||
assert partner_user is not None
|
||||
assert partner_user.external_user_id == external_id
|
||||
|
||||
|
||||
def test_proton_callback_handler_existant_sl_user():
|
||||
|
@ -97,11 +57,13 @@ def test_proton_callback_handler_existant_sl_user():
|
|||
sl_user = User.create(email, commit=True)
|
||||
|
||||
external_id = random_string()
|
||||
user = UserInformation(email=email, name=random_string(), id=external_id)
|
||||
mock_client = MockProtonClient(
|
||||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
user = UserInformation(
|
||||
email=email,
|
||||
name=random_string(),
|
||||
id=external_id,
|
||||
plan=SLPlan(type=SLPlanType.Premium, expiration=Arrow.utcnow().shift(hours=2)),
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
handler = ProtonCallbackHandler(MockProtonClient(user=user))
|
||||
res = handler.handle_login(get_proton_partner())
|
||||
|
||||
assert res.user is not None
|
||||
|
@ -112,178 +74,18 @@ def test_proton_callback_handler_existant_sl_user():
|
|||
assert sa.partner_email == user.email
|
||||
|
||||
|
||||
def test_get_strategy_unexistant_sl_user():
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(),
|
||||
sl_user=None,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, UnexistantSlClientStrategy)
|
||||
def test_proton_callback_handler_none_user_login():
|
||||
handler = ProtonCallbackHandler(MockProtonClient(user=None))
|
||||
res = handler.handle_login(get_proton_partner())
|
||||
|
||||
expected = generate_account_not_allowed_to_log_in()
|
||||
assert res == expected
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user():
|
||||
email = random_email()
|
||||
sl_user = User.create(email, commit=True)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(email=email),
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlClientStrategy)
|
||||
def test_proton_callback_handler_none_user_link():
|
||||
sl_user = User.create(random_email(), commit=True)
|
||||
handler = ProtonCallbackHandler(MockProtonClient(user=None))
|
||||
res = handler.handle_link(sl_user, get_proton_partner())
|
||||
|
||||
|
||||
def test_get_strategy_already_linked_user():
|
||||
email = random_email()
|
||||
proton_user_id = random_string()
|
||||
sl_user = create_user_for_partner(proton_user_id, email=email)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(user_id=proton_user_id, email=email),
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, AlreadyLinkedUserStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
|
||||
# In this scenario we have
|
||||
# - ProtonUser1 (ID1, email1@proton)
|
||||
# - ProtonUser2 (ID2, email2@proton)
|
||||
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
|
||||
# We will try to log in with email1@proton
|
||||
email1 = random_email()
|
||||
email2 = random_email()
|
||||
proton_user_id_1 = random_string()
|
||||
proton_user_id_2 = random_string()
|
||||
|
||||
proton_user_1 = random_proton_user(user_id=proton_user_id_1, email=email1)
|
||||
proton_user_2 = random_proton_user(user_id=proton_user_id_2, email=email2)
|
||||
|
||||
sl_user = create_user_for_partner(proton_user_2.id, email=proton_user_1.email)
|
||||
strategy = get_login_strategy(
|
||||
proton_user=proton_user_1,
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
|
||||
|
||||
|
||||
##
|
||||
# LINK
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address(flask_client):
|
||||
# This is the most basic scenario
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (email1@proton)
|
||||
# - SimpleLoginUser registered with email1@proton
|
||||
# We will try to link both accounts
|
||||
|
||||
email = random_email()
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=email)
|
||||
sl_user = create_user(email)
|
||||
|
||||
res = process_link_case(proton_user, sl_user, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner().id
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address(flask_client):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser (bar@somethingelse)
|
||||
# We will try to link both accounts
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
|
||||
sl_user = create_user()
|
||||
|
||||
res = process_link_case(proton_user, sl_user, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner().id
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser1 (foo@proton)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
|
||||
proton_user_id = random_string()
|
||||
proton_email = random_email()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=proton_email)
|
||||
sl_user_1 = create_user(proton_email)
|
||||
sl_user_2 = create_user_for_partner(
|
||||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner().id
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
|
||||
updated_user_2 = User.get(sl_user_2.id)
|
||||
assert updated_user_2.partner_id is None
|
||||
assert updated_user_2.partner_user_id is None
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser1 (bar@somethingelse)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
|
||||
proton_user_id = random_string()
|
||||
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
|
||||
sl_user_1 = create_user(random_email())
|
||||
sl_user_2 = create_user_for_partner(
|
||||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner().id
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
partner_user_1 = PartnerUser.get_by(
|
||||
user_id=sl_user_1.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_1 is not None
|
||||
assert partner_user_1.partner_email == proton_user.email
|
||||
|
||||
updated_user_2 = User.get(sl_user_2.id)
|
||||
assert updated_user_2.partner_id is None
|
||||
assert updated_user_2.partner_user_id is None
|
||||
partner_user_2 = PartnerUser.get_by(
|
||||
user_id=sl_user_2.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_2 is None
|
||||
|
||||
|
||||
def test_cannot_create_instance_of_base_strategy():
|
||||
with pytest.raises(Exception):
|
||||
ClientMergeStrategy(random_proton_user(), None, get_proton_partner())
|
||||
expected = generate_account_not_allowed_to_log_in()
|
||||
assert res == expected
|
||||
|
|
|
@ -0,0 +1,239 @@
|
|||
import pytest
|
||||
from arrow import Arrow
|
||||
|
||||
from app.account_linking import (
|
||||
process_link_case,
|
||||
get_login_strategy,
|
||||
NewUserStrategy,
|
||||
ExistingUnlinedUserStrategy,
|
||||
LinkedWithAnotherPartnerUserStrategy,
|
||||
SLPlan,
|
||||
SLPlanType,
|
||||
PartnerLinkRequest,
|
||||
ClientMergeStrategy,
|
||||
)
|
||||
from app.proton.proton_callback_handler import get_proton_partner
|
||||
from app.db import Session
|
||||
from app.models import PartnerUser, User
|
||||
from app.utils import random_string
|
||||
|
||||
from tests.utils import random_email
|
||||
|
||||
|
||||
def random_link_request(
|
||||
external_user_id: str = None,
|
||||
name: str = None,
|
||||
email: str = None,
|
||||
plan: SLPlan = None,
|
||||
) -> PartnerLinkRequest:
|
||||
external_user_id = (
|
||||
external_user_id if external_user_id is not None else random_string()
|
||||
)
|
||||
name = name if name is not None else random_string()
|
||||
email = email if email is not None else random_email()
|
||||
plan = plan if plan is not None else SLPlanType.Free
|
||||
return PartnerLinkRequest(
|
||||
name=name,
|
||||
email=email,
|
||||
external_user_id=external_user_id,
|
||||
plan=SLPlan(type=plan, expiration=Arrow.utcnow().shift(hours=2)),
|
||||
)
|
||||
|
||||
|
||||
def create_user(email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def create_user_for_partner(external_user_id: str, email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
|
||||
PartnerUser.create(
|
||||
user_id=user.id,
|
||||
partner_id=get_proton_partner().id,
|
||||
partner_email=email,
|
||||
external_user_id=external_user_id,
|
||||
)
|
||||
Session.commit()
|
||||
return user
|
||||
|
||||
|
||||
def test_get_strategy_unexistant_sl_user():
|
||||
strategy = get_login_strategy(
|
||||
link_request=random_link_request(),
|
||||
user=None,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, NewUserStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user():
|
||||
email = random_email()
|
||||
user = User.create(email, commit=True)
|
||||
strategy = get_login_strategy(
|
||||
link_request=random_link_request(email=email),
|
||||
user=user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, ExistingUnlinedUserStrategy)
|
||||
|
||||
|
||||
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
|
||||
# In this scenario we have
|
||||
# - PartnerUser1 (ID1, email1@proton)
|
||||
# - PartnerUser2 (ID2, email2@proton)
|
||||
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
|
||||
# We will try to log in with email1@proton
|
||||
email1 = random_email()
|
||||
email2 = random_email()
|
||||
partner_user_id_1 = random_string()
|
||||
partner_user_id_2 = random_string()
|
||||
|
||||
link_request_1 = random_link_request(
|
||||
external_user_id=partner_user_id_1, email=email1
|
||||
)
|
||||
link_request_2 = random_link_request(
|
||||
external_user_id=partner_user_id_2, email=email2
|
||||
)
|
||||
|
||||
user = create_user_for_partner(
|
||||
link_request_2.external_user_id, email=link_request_1.email
|
||||
)
|
||||
strategy = get_login_strategy(
|
||||
link_request=link_request_1,
|
||||
user=user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, LinkedWithAnotherPartnerUserStrategy)
|
||||
|
||||
|
||||
##
|
||||
# LINK
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address(flask_client):
|
||||
# This is the most basic scenario
|
||||
# In this scenario we have:
|
||||
# - PartnerUser (email1@partner)
|
||||
# - SimpleLoginUser registered with email1@proton
|
||||
# We will try to link both accounts
|
||||
|
||||
email = random_email()
|
||||
partner_user_id = random_string()
|
||||
link_request = random_link_request(external_user_id=partner_user_id, email=email)
|
||||
user = create_user(email)
|
||||
|
||||
res = process_link_case(link_request, user, get_proton_partner())
|
||||
assert res is not None
|
||||
assert res.user is not None
|
||||
assert res.user.id == user.id
|
||||
assert res.user.email == email
|
||||
assert res.strategy == "Link"
|
||||
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=get_proton_partner().id, user_id=user.id
|
||||
)
|
||||
assert partner_user.partner_id == get_proton_partner().id
|
||||
assert partner_user.external_user_id == partner_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address(flask_client):
|
||||
# In this scenario we have:
|
||||
# - ProtonUser (foo@proton)
|
||||
# - SimpleLoginUser (bar@somethingelse)
|
||||
# We will try to link both accounts
|
||||
partner_user_id = random_string()
|
||||
link_request = random_link_request(
|
||||
external_user_id=partner_user_id, email=random_email()
|
||||
)
|
||||
user = create_user()
|
||||
|
||||
res = process_link_case(link_request, user, get_proton_partner())
|
||||
assert res.user.id == user.id
|
||||
assert res.user.email == user.email
|
||||
assert res.strategy == "Link"
|
||||
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=get_proton_partner().id, user_id=user.id
|
||||
)
|
||||
assert partner_user.partner_id == get_proton_partner().id
|
||||
assert partner_user.external_user_id == partner_user_id
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - PartnerUser (foo@partner)
|
||||
# - SimpleLoginUser1 (foo@partner)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@partner
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@partner
|
||||
partner_user_id = random_string()
|
||||
partner_email = random_email()
|
||||
link_request = random_link_request(
|
||||
external_user_id=partner_user_id, email=partner_email
|
||||
)
|
||||
sl_user_1 = create_user(partner_email)
|
||||
sl_user_2 = create_user_for_partner(
|
||||
partner_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(link_request, sl_user_1, get_proton_partner())
|
||||
assert res.user.id == sl_user_1.id
|
||||
assert res.user.email == partner_email
|
||||
assert res.strategy == "Link"
|
||||
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=get_proton_partner().id, user_id=sl_user_1.id
|
||||
)
|
||||
assert partner_user.partner_id == get_proton_partner().id
|
||||
assert partner_user.external_user_id == partner_user_id
|
||||
|
||||
partner_user = PartnerUser.get_by(
|
||||
partner_id=get_proton_partner().id, user_id=sl_user_2.id
|
||||
)
|
||||
assert partner_user is None
|
||||
|
||||
|
||||
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
|
||||
flask_client,
|
||||
):
|
||||
# In this scenario we have:
|
||||
# - PartnerUser (foo@partner)
|
||||
# - SimpleLoginUser1 (bar@somethingelse)
|
||||
# - SimpleLoginUser2 (other@somethingelse) linked with foo@partner
|
||||
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@partner
|
||||
partner_user_id = random_string()
|
||||
link_request = random_link_request(
|
||||
external_user_id=partner_user_id, email=random_email()
|
||||
)
|
||||
sl_user_1 = create_user(random_email())
|
||||
sl_user_2 = create_user_for_partner(
|
||||
partner_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(link_request, sl_user_1, get_proton_partner())
|
||||
assert res.user.id == sl_user_1.id
|
||||
assert res.user.email == sl_user_1.email
|
||||
assert res.strategy == "Link"
|
||||
|
||||
partner_user_1 = PartnerUser.get_by(
|
||||
user_id=sl_user_1.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_1 is not None
|
||||
assert partner_user_1.partner_email == sl_user_2.email
|
||||
assert partner_user_1.partner_id == get_proton_partner().id
|
||||
assert partner_user_1.external_user_id == partner_user_id
|
||||
|
||||
partner_user_2 = PartnerUser.get_by(
|
||||
user_id=sl_user_2.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_2 is None
|
||||
|
||||
|
||||
def test_cannot_create_instance_of_base_strategy():
|
||||
with pytest.raises(Exception):
|
||||
ClientMergeStrategy(random_link_request(), None, get_proton_partner())
|
|
@ -10,6 +10,7 @@ import jinja2
|
|||
from flask import url_for
|
||||
|
||||
from app.models import User
|
||||
from app.utils import random_string
|
||||
|
||||
|
||||
def create_new_user() -> User:
|
||||
|
@ -66,3 +67,7 @@ def load_eml_file(
|
|||
template_values = {}
|
||||
rendered = template.render(**template_values)
|
||||
return email.message_from_string(rendered)
|
||||
|
||||
|
||||
def random_email() -> str:
|
||||
return "{rand}@{rand}.com".format(rand=random_string(20))
|
||||
|
|
Loading…
Reference in New Issue