Separate code for proton callback handler (#1040)

* Separate code for proton callback handler

* Upgrade migration

* Use simple_login endpoint from Proton API

* Remove unused classes

* Rename Dto class to Data

* Push rename

* Moved link to PartnerUser to allow subscriptions to depend only on it

* Fix test

* PR comments

* Add unique user_id constraint to PartnerUser

* Added more logs

* Added more logs

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Carlos Quintana 2022-06-09 10:19:49 +02:00 committed by GitHub
parent faf67ff338
commit c0a4c44e94
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
12 changed files with 810 additions and 536 deletions

257
app/account_linking.py Normal file
View File

@ -0,0 +1,257 @@
from abc import ABC, abstractmethod
from arrow import Arrow
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from app.db import Session
from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.log import LOG
from app.models import PartnerSubscription, Partner, PartnerUser, User
from app.utils import random_string
class SLPlanType(Enum):
Free = 1
Premium = 2
@dataclass
class SLPlan:
type: SLPlanType
expiration: Optional[Arrow]
@dataclass
class PartnerLinkRequest:
name: str
email: str
external_user_id: str
plan: SLPlan
@dataclass
class LinkResult:
user: User
strategy: str
def set_plan_for_partner_user(partner_user: PartnerUser, plan: SLPlan):
sub = PartnerSubscription.get_by(partner_user_id=partner_user.id)
if plan.type == SLPlanType.Free:
if sub is not None:
LOG.i(
f"Deleting partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
)
PartnerSubscription.delete(sub.id)
else:
if sub is None:
LOG.i(
f"Creating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
)
PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=plan.expiration,
)
else:
LOG.i(
f"Updating partner_subscription [user_id={partner_user.user_id}] [partner_id={partner_user.partner_id}]"
)
sub.end_at = plan.expiration
Session.commit()
def set_plan_for_user(user: User, plan: SLPlan, partner: Partner):
partner_user = PartnerUser.get_by(partner_id=partner.id, user_id=user.id)
if partner_user is None:
return
return set_plan_for_partner_user(partner_user, plan)
def ensure_partner_user_exists_for_user(
link_request: PartnerLinkRequest, sl_user: User, partner: Partner
) -> PartnerUser:
# Find partner_user by user_id
res = PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id)
if not res:
res = PartnerUser.create(
user_id=sl_user.id,
partner_id=partner.id,
partner_email=link_request.email,
external_user_id=link_request.external_user_id,
)
Session.commit()
LOG.i(
f"Created new partner_user for partner:{partner.id} user:{sl_user.id} external_user_id:{link_request.external_user_id}. PartnerUser.id is {res.id}"
)
return res
class ClientMergeStrategy(ABC):
def __init__(
self,
link_request: PartnerLinkRequest,
user: Optional[User],
partner: Partner,
):
if self.__class__ == ClientMergeStrategy:
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
self.link_request = link_request
self.user = user
self.partner = partner
@abstractmethod
def process(self) -> LinkResult:
pass
class NewUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult:
# Will create a new SL User with a random password
new_user = User.create(
email=self.link_request.email,
name=self.link_request.name,
password=random_string(20),
)
partner_user = PartnerUser.create(
user_id=new_user.id,
partner_id=self.partner.id,
external_user_id=self.link_request.external_user_id,
partner_email=self.link_request.email,
)
LOG.i(
f"Created new user for login request for partner:{self.partner.id} external_user_id:{self.link_request.external_user_id}. New user {new_user.id} partner_user:{partner_user.id}"
)
set_plan_for_partner_user(
partner_user,
self.link_request.plan,
)
Session.commit()
return LinkResult(
user=new_user,
strategy=self.__class__.__name__,
)
class ExistingUnlinedUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult:
partner_user = ensure_partner_user_exists_for_user(
self.link_request, self.user, self.partner
)
set_plan_for_partner_user(partner_user, self.link_request.plan)
return LinkResult(
user=self.user,
strategy=self.__class__.__name__,
)
class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult:
raise AccountAlreadyLinkedToAnotherPartnerException()
def get_login_strategy(
link_request: PartnerLinkRequest, user: Optional[User], partner: Partner
) -> ClientMergeStrategy:
if user is None:
# We couldn't find any SimpleLogin user with the requested e-mail
return NewUserStrategy(link_request, user, partner)
# Check if user is already linked with another partner_user
other_partner_user = PartnerUser.get_by(partner_id=partner.id, user_id=user.id)
if other_partner_user is not None:
return LinkedWithAnotherPartnerUserStrategy(link_request, user, partner)
# There is a SimpleLogin user with the partner_user's e-mail
return ExistingUnlinedUserStrategy(link_request, user, partner)
def process_login_case(
link_request: PartnerLinkRequest, partner: Partner
) -> LinkResult:
# Try to find a SimpleLogin user registered with that partner user id
partner_user = PartnerUser.get_by(
partner_id=partner.id, external_user_id=link_request.external_user_id
)
if partner_user is None:
# We didn't find any SimpleLogin user registered with that partner user id
# Try to find it using the partner's e-mail address
user = User.get_by(email=link_request.email)
return get_login_strategy(link_request, user, partner).process()
else:
# We found the SL user registered with that partner user id
# We're done
set_plan_for_partner_user(partner_user, link_request.plan)
# It's the same user. No need to do anything
return LinkResult(
user=partner_user.user,
strategy="Link",
)
def link_user(
link_request: PartnerLinkRequest, current_user: User, partner: Partner
) -> LinkResult:
partner_user = ensure_partner_user_exists_for_user(
link_request, current_user, partner
)
set_plan_for_partner_user(partner_user, link_request.plan)
Session.commit()
return LinkResult(
user=current_user,
strategy="Link",
)
def switch_already_linked_user(
link_request: PartnerLinkRequest, partner_user: PartnerUser, current_user: User
):
# Find if the user has another link and unlink it
other_partner_user = PartnerUser.get_by(
user_id=current_user.id,
partner_id=partner_user.partner_id,
)
if other_partner_user is not None:
LOG.i(
f"Deleting previous partner_user:{other_partner_user.id} from user:{current_user.id}"
)
PartnerUser.delete(other_partner_user.id)
LOG.i(f"Linking partner_user:{partner_user.id} to user:{current_user.id}")
# Link this partner_user to the current user
partner_user.user_id = current_user.id
# Set plan
set_plan_for_partner_user(partner_user, link_request.plan)
Session.commit()
return LinkResult(
user=current_user,
strategy="Link",
)
def process_link_case(
link_request: PartnerLinkRequest,
current_user: User,
partner: Partner,
) -> LinkResult:
# Try to find a SimpleLogin user linked with this Partner account
partner_user = PartnerUser.get_by(
partner_id=partner.id, external_user_id=link_request.external_user_id
)
if partner_user is None:
# There is no SL user linked with the partner. Proceed with linking
return link_user(link_request, current_user, partner)
# There is a SL user registered with the partner. Check if is the current one
if partner_user.id == current_user.id:
# Update plan
set_plan_for_partner_user(partner_user, link_request.plan)
# It's the same user. No need to do anything
return LinkResult(
user=current_user,
strategy="Link",
)
else:
return switch_already_linked_user(link_request, partner_user, current_user)

View File

@ -120,7 +120,7 @@ def proton_callback():
return redirect(url_for("auth.login"))
if res.redirect:
return redirect(res.redirect)
return after_login(res.user, res.redirect)
next_url = session.get("oauth_next")
return after_login(res.user, next_url)

View File

@ -446,8 +446,6 @@ def cancel_email_change():
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
@login_required
def unlink_proton_account():
current_user.partner_id = None
current_user.partner_user_id = None
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=get_proton_partner().id
)

View File

@ -93,3 +93,18 @@ class ErrContactAlreadyExists(SLException):
def error_for_user(self) -> str:
return f"{self.contact.website_email} is already added"
class LinkException(SLException):
def __init__(self, message: str):
self.message = message
class AccountAlreadyLinkedToAnotherPartnerException(LinkException):
def __init__(self):
super().__init__("This account is already linked to another partner")
class AccountAlreadyLinkedToAnotherUserException(LinkException):
def __init__(self):
super().__init__("This account is linked to another user")

View File

@ -66,6 +66,7 @@ from app.utils import (
Base = declarative_base()
PADDLE_SUBSCRIPTION_GRACE_DAYS = 14
_PARTNER_SUBSCRIPTION_GRACE_DAYS = 14
class TSVector(sa.types.TypeDecorator):
@ -492,15 +493,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=False, nullable=False, server_default="1"
)
partner_id = sa.Column(sa.BigInteger, unique=False, nullable=True)
partner_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
__table_args__ = (
sa.UniqueConstraint(
"partner_id", "partner_user_id", name="uq_partner_id_partner_user_id"
),
)
# bitwise flags. Allow for future expansion
flags = sa.Column(
sa.BigInteger,
@ -613,6 +605,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
if coinbase_subscription and coinbase_subscription.is_active():
return True
partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id(self.id)
if partner_sub and partner_sub.is_active():
return True
return False
def is_paid(self) -> bool:
@ -3138,18 +3134,60 @@ class PartnerUser(Base, ModelMixin):
user_id = sa.Column(
sa.ForeignKey("users.id", ondelete="cascade"),
unique=False,
unique=True,
nullable=False,
index=True,
)
partner_id = sa.Column(
sa.ForeignKey("partner.id", ondelete="cascade"), nullable=False, index=True
)
external_user_id = sa.Column(sa.String(128), unique=False, nullable=True)
partner_email = sa.Column(sa.String(255), unique=False, nullable=True)
user = orm.relationship(User, foreign_keys=[user_id])
__table_args__ = (
sa.UniqueConstraint("user_id", "partner_id", name="uq_user_id_partner_id"),
sa.UniqueConstraint(
"partner_id", "external_user_id", name="uq_partner_id_external_user_id"
),
)
class PartnerSubscription(Base, ModelMixin):
"""
For users who have a subscription via a partner
"""
__tablename__ = "partner_subscription"
partner_user_id = sa.Column(
sa.ForeignKey(PartnerUser.id, ondelete="cascade"), nullable=False, unique=True
)
# when the partner subscription ends
end_at = sa.Column(ArrowType, nullable=False)
partner_user = orm.relationship(PartnerUser)
@classmethod
def find_by_user_id(cls, user_id: int) -> Optional[PartnerSubscription]:
res = (
Session.query(PartnerSubscription, PartnerUser)
.filter(
and_(
PartnerUser.user_id == user_id,
PartnerSubscription.partner_user_id == PartnerUser.id,
)
)
.first()
)
if res:
subscription, partner_user = res
return subscription
return None
def is_active(self):
return self.end_at > arrow.now().shift(days=-_PARTNER_SUBSCRIPTION_GRACE_DAYS)
# endregion

View File

@ -1,14 +1,17 @@
import enum
from abc import ABC, abstractmethod
from dataclasses import dataclass
from enum import Enum
from flask import url_for
from typing import Optional
from app.db import Session
from app.errors import ProtonPartnerNotSetUp
from app.models import User, PartnerUser, Partner
from app.errors import LinkException, ProtonPartnerNotSetUp
from app.models import User, Partner
from app.proton.proton_client import ProtonClient, ProtonUser
from app.utils import random_string
from app.account_linking import (
process_login_case,
process_link_case,
PartnerLinkRequest,
)
PROTON_PARTNER_NAME = "Proton"
_PROTON_PARTNER: Optional[Partner] = None
@ -25,7 +28,7 @@ def get_proton_partner() -> Partner:
return _PROTON_PARTNER
class Action(enum.Enum):
class Action(Enum):
Login = 1
Link = 2
@ -39,208 +42,81 @@ class ProtonCallbackResult:
user: Optional[User]
def ensure_partner_user_exists(
proton_user: ProtonUser, sl_user: User, partner: Partner
):
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id):
PartnerUser.create(
user_id=sl_user.id,
partner_id=partner.id,
partner_email=proton_user.email,
)
Session.commit()
class ClientMergeStrategy(ABC):
def __init__(
self, proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
):
if self.__class__ == ClientMergeStrategy:
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
self.proton_user = proton_user
self.sl_user = sl_user
self.partner = partner
@abstractmethod
def process(self) -> ProtonCallbackResult:
pass
class UnexistantSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
# Will create a new SL User with a random password
new_user = User.create(
email=self.proton_user.email,
name=self.proton_user.name,
partner_user_id=self.proton_user.id,
partner_id=self.partner.id,
password=random_string(20),
)
PartnerUser.create(
user_id=new_user.id,
partner_id=self.partner.id,
partner_email=self.proton_user.email,
)
# TODO: Adjust plans
Session.commit()
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=new_user,
)
class ExistingSlClientStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
ensure_partner_user_exists(self.proton_user, self.sl_user, self.partner)
# TODO: Adjust plans
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=self.sl_user,
)
class ExistingSlUserLinkedWithDifferentProtonAccountStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
return ProtonCallbackResult(
redirect_to_login=True,
flash_message="This Proton account is already linked to another account",
flash_category="error",
user=None,
redirect=None,
)
class AlreadyLinkedUserStrategy(ClientMergeStrategy):
def process(self) -> ProtonCallbackResult:
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=self.sl_user,
)
def get_login_strategy(
proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
) -> ClientMergeStrategy:
if sl_user is None:
# We couldn't find any SimpleLogin user with the requested e-mail
return UnexistantSlClientStrategy(proton_user, sl_user, partner)
# There is a SimpleLogin user with the proton_user's e-mail
# Try to find if it has been registered via a partner
if sl_user.partner_id is None:
# It has not been registered via a Partner
return ExistingSlClientStrategy(proton_user, sl_user, partner)
# It has been registered via a partner
# Check if the partner_user_id matches
if sl_user.partner_user_id != proton_user.id:
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
proton_user, sl_user, partner
)
# This case means that the sl_user is already linked, so nothing to do
return AlreadyLinkedUserStrategy(proton_user, sl_user, partner)
def process_login_case(
proton_user: ProtonUser, partner: Partner
) -> ProtonCallbackResult:
# Try to find a SimpleLogin user registered with that proton user id
sl_user_with_external_id = User.get_by(
partner_id=partner.id, partner_user_id=proton_user.id
)
if sl_user_with_external_id is None:
# We didn't find any SimpleLogin user registered with that proton user id
# Try to find it using the proton's e-mail address
sl_user = User.get_by(email=proton_user.email)
return get_login_strategy(proton_user, sl_user, partner).process()
else:
# We found the SL user registered with that proton user id
# We're done
return AlreadyLinkedUserStrategy(
proton_user, sl_user_with_external_id, partner
).process()
def link_user(
proton_user: ProtonUser, current_user: User, partner: Partner
) -> ProtonCallbackResult:
current_user.partner_user_id = proton_user.id
current_user.partner_id = partner.id
ensure_partner_user_exists(proton_user, current_user, partner)
Session.commit()
def generate_account_not_allowed_to_log_in() -> ProtonCallbackResult:
return ProtonCallbackResult(
redirect_to_login=False,
redirect=url_for("dashboard.setting"),
flash_category="success",
flash_message="Account successfully linked",
user=current_user,
redirect_to_login=True,
flash_message="This account is not allowed to log in with Proton. Please convert your account to a full Proton account",
flash_category="error",
redirect=None,
user=None,
)
def process_link_case(
proton_user: ProtonUser,
current_user: User,
partner: Partner,
) -> ProtonCallbackResult:
# Try to find a SimpleLogin user linked with this Proton account
sl_user_linked_to_proton_account = User.get_by(
partner_id=partner.id, partner_user_id=proton_user.id
)
if sl_user_linked_to_proton_account is None:
# There is no SL user linked with the proton email. Proceed with linking
return link_user(proton_user, current_user, partner)
else:
# There is a SL user registered with the proton email. Check if is the current one
if sl_user_linked_to_proton_account.id == current_user.id:
# It's the same user. No need to do anything
return ProtonCallbackResult(
redirect_to_login=False,
redirect=url_for("dashboard.setting"),
flash_category="success",
flash_message="Account successfully linked",
user=current_user,
)
else:
# It's a different user. Unlink the other account and link the current one
sl_user_linked_to_proton_account.partner_id = None
sl_user_linked_to_proton_account.partner_user_id = None
other_partner_user = PartnerUser.get_by(
user_id=sl_user_linked_to_proton_account.id,
partner_id=partner.id,
)
if other_partner_user is not None:
PartnerUser.delete(other_partner_user.id)
return link_user(proton_user, current_user, partner)
class ProtonCallbackHandler:
def __init__(self, proton_client: ProtonClient):
self.proton_client = proton_client
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
return process_login_case(self.__get_proton_user(), partner)
try:
user = self.__get_partner_user()
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_login_case(user, partner)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=None,
flash_category=None,
redirect=None,
user=res.user,
)
except LinkException as e:
return ProtonCallbackResult(
redirect_to_login=True,
flash_message=e.message,
flash_category="error",
redirect=None,
user=None,
)
def handle_link(
self, current_user: Optional[User], partner: Partner
) -> ProtonCallbackResult:
if current_user is None:
raise Exception("Cannot link account with current_user being None")
return process_link_case(self.__get_proton_user(), current_user, partner)
try:
user = self.__get_partner_user()
if user is None:
return generate_account_not_allowed_to_log_in()
res = process_link_case(user, current_user, partner)
return ProtonCallbackResult(
redirect_to_login=False,
flash_message="Account successfully linked",
flash_category="success",
redirect=url_for("dashboard.setting"),
user=res.user,
)
except LinkException as e:
return ProtonCallbackResult(
redirect_to_login=False,
flash_message=e.message,
flash_category="error",
redirect=None,
user=None,
)
def __get_proton_user(self) -> ProtonUser:
def __get_partner_user(self) -> Optional[PartnerLinkRequest]:
proton_user = self.__get_proton_user()
if proton_user is None:
return None
return PartnerLinkRequest(
email=proton_user.email,
external_user_id=proton_user.id,
name=proton_user.name,
plan=proton_user.plan,
)
def __get_proton_user(self) -> Optional[ProtonUser]:
user = self.proton_client.get_user()
plan = self.proton_client.get_plan()
return ProtonUser(email=user.email, plan=plan, name=user.name, id=user.id)
if user is None:
return None
return ProtonUser(email=user.email, plan=user.plan, name=user.name, id=user.id)

View File

@ -1,78 +1,37 @@
import dataclasses
from abc import ABC, abstractmethod
from enum import Enum
from dataclasses import dataclass
from http import HTTPStatus
from requests import Response, Session
from typing import Optional
from app.account_linking import SLPlan, SLPlanType
from app.log import LOG
_APP_VERSION = "OauthClient_1.0.0"
PROTON_ERROR_CODE_NOT_EXISTS = 2501
class ProtonPlan(Enum):
Free = 0
Professional = 1
Visionary = 2
def name(self):
if self == self.Free:
return "Free"
elif self == self.Professional:
return "Professional"
elif self == self.Visionary:
return "Visionary"
else:
raise Exception("Unknown plan")
PLAN_FREE = 1
PLAN_PREMIUM = 2
def plan_from_name(name: str) -> ProtonPlan:
name_lower = name.lower()
if name_lower == "free":
return ProtonPlan.Free
elif name_lower == "professional":
return ProtonPlan.Professional
elif name_lower == "visionary":
return ProtonPlan.Visionary
else:
raise Exception(f"Unknown plan [{name}]")
@dataclasses.dataclass
@dataclass
class UserInformation:
email: str
name: str
id: str
plan: SLPlan
class AuthorizeResponse:
def __init__(self, code: str, has_accepted: bool):
self.code = code
self.has_accepted = has_accepted
def __str__(self):
return f"[code={self.code}] [has_accepted={self.has_accepted}]"
@dataclasses.dataclass
class SessionResponse:
state: str
expires_in: int
token_type: str
refresh_token: str
access_token: str
session_id: str
@dataclasses.dataclass
@dataclass
class ProtonUser:
id: str
name: str
email: str
plan: ProtonPlan
plan: SLPlan
@dataclasses.dataclass
@dataclass
class AccessCredentials:
access_token: str
session_id: str
@ -98,15 +57,7 @@ def convert_access_token(access_token_response: str) -> AccessCredentials:
class ProtonClient(ABC):
@abstractmethod
def get_user(self) -> UserInformation:
pass
@abstractmethod
def get_organization(self) -> dict:
pass
@abstractmethod
def get_plan(self) -> ProtonPlan:
def get_user(self) -> Optional[UserInformation]:
pass
@ -135,31 +86,27 @@ class HttpProtonClient(ProtonClient):
client.headers.update(headers)
self.client = client
def get_user(self) -> UserInformation:
info = self.__get("/users")["User"]
def get_user(self) -> Optional[UserInformation]:
info = self.__get("/simple_login/v1/subscription")["Subscription"]
if not info["IsAllowed"]:
LOG.debug("Account is not allowed to log into SL")
return None
plan_value = info["Plan"]
if plan_value == PLAN_FREE:
plan = SLPlan(type=SLPlanType.Free, expiration=None)
elif plan_value == PLAN_PREMIUM:
plan = SLPlan(type=SLPlanType.Premium, expiration=info["PlanExpiration"])
else:
raise Exception(f"Invalid value for plan: {plan_value}")
return UserInformation(
email=info.get("Email"), name=info.get("Name"), id=info.get("ID")
email=info.get("Email"),
name=info.get("DisplayName"),
id=info.get("UserID"),
plan=plan,
)
def get_organization(self) -> dict:
return self.__get("/code/v4/organizations")["Organization"]
def get_plan(self) -> ProtonPlan:
url = f"{self.base_url}/core/v4/organizations"
res = self.client.get(url)
status = res.status_code
if status == HTTPStatus.UNPROCESSABLE_ENTITY:
as_json = res.json()
error_code = as_json.get("Code")
if error_code == PROTON_ERROR_CODE_NOT_EXISTS:
return ProtonPlan.Free
org = self.__validate_response(res).get("Organization")
if org is None:
return ProtonPlan.Free
return plan_from_name(org["PlanName"])
def __get(self, route: str) -> dict:
url = f"{self.base_url}{route}"
res = self.client.get(url)
@ -172,4 +119,10 @@ class HttpProtonClient(ProtonClient):
raise Exception(
f"Unexpected status code. Wanted 200 and got {status}: " + res.text
)
return res.json()
as_json = res.json()
res_code = as_json.get("Code")
if not res_code or res_code != 1000:
raise Exception(
f"Unexpected response code. Wanted 1000 and got {res_code}: " + res.text
)
return as_json

View File

@ -0,0 +1,54 @@
"""partner_user and partner_subscription
Revision ID: 82d3c7109ffb
Revises: 2b1d3cd93e4b
Create Date: 2022-06-09 08:25:09.078840
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '82d3c7109ffb'
down_revision = '2b1d3cd93e4b'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('partner_subscription',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('partner_user_id', sa.Integer(), nullable=False),
sa.Column('end_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.ForeignKeyConstraint(['partner_user_id'], ['partner_user.id'], ondelete='cascade'),
sa.PrimaryKeyConstraint('id'),
sa.UniqueConstraint('partner_user_id')
)
op.add_column('partner_user', sa.Column('external_user_id', sa.String(length=128), nullable=True))
op.create_unique_constraint('uq_partner_id_external_user_id', 'partner_user', ['partner_id', 'external_user_id'])
op.drop_index('ix_partner_user_user_id', table_name='partner_user')
op.create_index(op.f('ix_partner_user_user_id'), 'partner_user', ['user_id'], unique=True)
op.drop_constraint('uq_user_id_partner_id', 'partner_user', type_='unique')
op.drop_constraint('uq_partner_id_partner_user_id', 'users', type_='unique')
op.drop_column('users', 'partner_id')
op.drop_column('users', 'partner_user_id')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('partner_user_id', sa.VARCHAR(length=128), autoincrement=False, nullable=True))
op.add_column('users', sa.Column('partner_id', sa.BIGINT(), autoincrement=False, nullable=True))
op.create_unique_constraint('uq_partner_id_partner_user_id', 'users', ['partner_id', 'partner_user_id'])
op.create_unique_constraint('uq_user_id_partner_id', 'partner_user', ['user_id', 'partner_id'])
op.drop_index(op.f('ix_partner_user_user_id'), table_name='partner_user')
op.create_index('ix_partner_user_user_id', 'partner_user', ['user_id'], unique=False)
op.drop_constraint('uq_partner_id_external_user_id', 'partner_user', type_='unique')
op.drop_column('partner_user', 'external_user_id')
op.drop_table('partner_subscription')
# ### end Alembic commands ###

View File

@ -0,0 +1,37 @@
from arrow import Arrow
from app.models import Partner, PartnerUser, PartnerSubscription
from app.utils import random_string
from tests.utils import create_new_user, random_email
def test_generate_partner_subscription(flask_client):
partner = Partner.create(
name=random_string(10),
contact_email=random_email(),
commit=True,
)
user = create_new_user()
partner_user = PartnerUser.create(
user_id=user.id,
partner_id=partner.id,
partner_email=random_email(),
commit=True,
)
subs = PartnerSubscription.create(
partner_user_id=partner_user.id,
end_at=Arrow.utcnow().shift(hours=1),
commit=True,
)
retrieved_subscription = PartnerSubscription.find_by_user_id(user.id)
assert retrieved_subscription is not None
assert retrieved_subscription.id == subs.id
assert user.lifetime_or_active_subscription() is True
def test_partner_subscription_for_not_partner_subscription_user(flask_client):
unexistant_subscription = PartnerSubscription.find_by_user_id(999999)
assert unexistant_subscription is None

View File

@ -1,95 +1,55 @@
import pytest
from app.db import Session
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
from arrow import Arrow
from app.account_linking import (
SLPlan,
SLPlanType,
)
from app.proton.proton_client import ProtonClient, UserInformation
from app.proton.proton_callback_handler import (
ProtonCallbackHandler,
get_proton_partner,
get_login_strategy,
process_link_case,
ProtonUser,
UnexistantSlClientStrategy,
ExistingSlClientStrategy,
AlreadyLinkedUserStrategy,
ExistingSlUserLinkedWithDifferentProtonAccountStrategy,
ClientMergeStrategy,
generate_account_not_allowed_to_log_in,
)
from app.models import User, PartnerUser
from app.utils import random_string
from typing import Optional
from tests.utils import random_email
class MockProtonClient(ProtonClient):
def __init__(self, user: UserInformation, plan: ProtonPlan, organization: dict):
def __init__(self, user: Optional[UserInformation]):
self.user = user
self.plan = plan
self.organization = organization
def get_organization(self) -> dict:
return self.organization
def get_user(self) -> UserInformation:
def get_user(self) -> Optional[UserInformation]:
return self.user
def get_plan(self) -> ProtonPlan:
return self.plan
def random_email() -> str:
return "{rand}@{rand}.com".format(rand=random_string(20))
def random_proton_user(
user_id: str = None,
name: str = None,
email: str = None,
plan: ProtonPlan = None,
) -> ProtonUser:
user_id = user_id if user_id is not None else random_string()
name = name if name is not None else random_string()
email = (
email
if email is not None
else "{rand}@{rand}.com".format(rand=random_string(20))
)
plan = plan if plan is not None else ProtonPlan.Free
return ProtonUser(id=user_id, name=name, email=email, plan=plan)
def create_user(email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
Session.commit()
return user
def create_user_for_partner(partner_user_id: str, email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
user.partner_id = get_proton_partner().id
user.partner_user_id = partner_user_id
PartnerUser.create(
user_id=user.id, partner_id=get_proton_partner().id, partner_email=email
)
Session.commit()
return user
def test_proton_callback_handler_unexistant_sl_user():
email = random_email()
name = random_string()
external_id = random_string()
user = UserInformation(email=email, name=name, id=external_id)
mock_client = MockProtonClient(
user=user, plan=ProtonPlan.Professional, organization={}
handler = ProtonCallbackHandler(
MockProtonClient(
user=UserInformation(
email=email,
name=name,
id=external_id,
plan=SLPlan(
type=SLPlanType.Premium, expiration=Arrow.utcnow().shift(hours=2)
),
)
)
)
handler = ProtonCallbackHandler(mock_client)
res = handler.handle_login(get_proton_partner())
assert res.user is not None
assert res.user.email == email
assert res.user.name == name
assert res.user.partner_user_id == external_id
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=res.user.id
)
assert partner_user is not None
assert partner_user.external_user_id == external_id
def test_proton_callback_handler_existant_sl_user():
@ -97,11 +57,13 @@ def test_proton_callback_handler_existant_sl_user():
sl_user = User.create(email, commit=True)
external_id = random_string()
user = UserInformation(email=email, name=random_string(), id=external_id)
mock_client = MockProtonClient(
user=user, plan=ProtonPlan.Professional, organization={}
user = UserInformation(
email=email,
name=random_string(),
id=external_id,
plan=SLPlan(type=SLPlanType.Premium, expiration=Arrow.utcnow().shift(hours=2)),
)
handler = ProtonCallbackHandler(mock_client)
handler = ProtonCallbackHandler(MockProtonClient(user=user))
res = handler.handle_login(get_proton_partner())
assert res.user is not None
@ -112,178 +74,18 @@ def test_proton_callback_handler_existant_sl_user():
assert sa.partner_email == user.email
def test_get_strategy_unexistant_sl_user():
strategy = get_login_strategy(
proton_user=random_proton_user(),
sl_user=None,
partner=get_proton_partner(),
)
assert isinstance(strategy, UnexistantSlClientStrategy)
def test_proton_callback_handler_none_user_login():
handler = ProtonCallbackHandler(MockProtonClient(user=None))
res = handler.handle_login(get_proton_partner())
expected = generate_account_not_allowed_to_log_in()
assert res == expected
def test_get_strategy_existing_sl_user():
email = random_email()
sl_user = User.create(email, commit=True)
strategy = get_login_strategy(
proton_user=random_proton_user(email=email),
sl_user=sl_user,
partner=get_proton_partner(),
)
assert isinstance(strategy, ExistingSlClientStrategy)
def test_proton_callback_handler_none_user_link():
sl_user = User.create(random_email(), commit=True)
handler = ProtonCallbackHandler(MockProtonClient(user=None))
res = handler.handle_link(sl_user, get_proton_partner())
def test_get_strategy_already_linked_user():
email = random_email()
proton_user_id = random_string()
sl_user = create_user_for_partner(proton_user_id, email=email)
strategy = get_login_strategy(
proton_user=random_proton_user(user_id=proton_user_id, email=email),
sl_user=sl_user,
partner=get_proton_partner(),
)
assert isinstance(strategy, AlreadyLinkedUserStrategy)
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
# In this scenario we have
# - ProtonUser1 (ID1, email1@proton)
# - ProtonUser2 (ID2, email2@proton)
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
# We will try to log in with email1@proton
email1 = random_email()
email2 = random_email()
proton_user_id_1 = random_string()
proton_user_id_2 = random_string()
proton_user_1 = random_proton_user(user_id=proton_user_id_1, email=email1)
proton_user_2 = random_proton_user(user_id=proton_user_id_2, email=email2)
sl_user = create_user_for_partner(proton_user_2.id, email=proton_user_1.email)
strategy = get_login_strategy(
proton_user=proton_user_1,
sl_user=sl_user,
partner=get_proton_partner(),
)
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
##
# LINK
def test_link_account_with_proton_account_same_address(flask_client):
# This is the most basic scenario
# In this scenario we have:
# - ProtonUser (email1@proton)
# - SimpleLoginUser registered with email1@proton
# We will try to link both accounts
email = random_email()
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=email)
sl_user = create_user(email)
res = process_link_case(proton_user, sl_user, get_proton_partner())
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user = User.get(sl_user.id)
assert updated_user.partner_id == get_proton_partner().id
assert updated_user.partner_user_id == proton_user_id
def test_link_account_with_proton_account_different_address(flask_client):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser (bar@somethingelse)
# We will try to link both accounts
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
sl_user = create_user()
res = process_link_case(proton_user, sl_user, get_proton_partner())
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user = User.get(sl_user.id)
assert updated_user.partner_id == get_proton_partner().id
assert updated_user.partner_user_id == proton_user_id
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser1 (foo@proton)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
proton_user_id = random_string()
proton_email = random_email()
proton_user = random_proton_user(user_id=proton_user_id, email=proton_email)
sl_user_1 = create_user(proton_email)
sl_user_2 = create_user_for_partner(
proton_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user_1 = User.get(sl_user_1.id)
assert updated_user_1.partner_id == get_proton_partner().id
assert updated_user_1.partner_user_id == proton_user_id
updated_user_2 = User.get(sl_user_2.id)
assert updated_user_2.partner_id is None
assert updated_user_2.partner_user_id is None
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser1 (bar@somethingelse)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@proton
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@proton
proton_user_id = random_string()
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
sl_user_1 = create_user(random_email())
sl_user_2 = create_user_for_partner(
proton_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
assert res.redirect_to_login is False
assert res.redirect is not None
assert res.flash_category == "success"
assert res.flash_message is not None
updated_user_1 = User.get(sl_user_1.id)
assert updated_user_1.partner_id == get_proton_partner().id
assert updated_user_1.partner_user_id == proton_user_id
partner_user_1 = PartnerUser.get_by(
user_id=sl_user_1.id, partner_id=get_proton_partner().id
)
assert partner_user_1 is not None
assert partner_user_1.partner_email == proton_user.email
updated_user_2 = User.get(sl_user_2.id)
assert updated_user_2.partner_id is None
assert updated_user_2.partner_user_id is None
partner_user_2 = PartnerUser.get_by(
user_id=sl_user_2.id, partner_id=get_proton_partner().id
)
assert partner_user_2 is None
def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception):
ClientMergeStrategy(random_proton_user(), None, get_proton_partner())
expected = generate_account_not_allowed_to_log_in()
assert res == expected

View File

@ -0,0 +1,239 @@
import pytest
from arrow import Arrow
from app.account_linking import (
process_link_case,
get_login_strategy,
NewUserStrategy,
ExistingUnlinedUserStrategy,
LinkedWithAnotherPartnerUserStrategy,
SLPlan,
SLPlanType,
PartnerLinkRequest,
ClientMergeStrategy,
)
from app.proton.proton_callback_handler import get_proton_partner
from app.db import Session
from app.models import PartnerUser, User
from app.utils import random_string
from tests.utils import random_email
def random_link_request(
external_user_id: str = None,
name: str = None,
email: str = None,
plan: SLPlan = None,
) -> PartnerLinkRequest:
external_user_id = (
external_user_id if external_user_id is not None else random_string()
)
name = name if name is not None else random_string()
email = email if email is not None else random_email()
plan = plan if plan is not None else SLPlanType.Free
return PartnerLinkRequest(
name=name,
email=email,
external_user_id=external_user_id,
plan=SLPlan(type=plan, expiration=Arrow.utcnow().shift(hours=2)),
)
def create_user(email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
Session.commit()
return user
def create_user_for_partner(external_user_id: str, email: str = None) -> User:
email = email if email is not None else random_email()
user = User.create(email=email)
PartnerUser.create(
user_id=user.id,
partner_id=get_proton_partner().id,
partner_email=email,
external_user_id=external_user_id,
)
Session.commit()
return user
def test_get_strategy_unexistant_sl_user():
strategy = get_login_strategy(
link_request=random_link_request(),
user=None,
partner=get_proton_partner(),
)
assert isinstance(strategy, NewUserStrategy)
def test_get_strategy_existing_sl_user():
email = random_email()
user = User.create(email, commit=True)
strategy = get_login_strategy(
link_request=random_link_request(email=email),
user=user,
partner=get_proton_partner(),
)
assert isinstance(strategy, ExistingUnlinedUserStrategy)
def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
# In this scenario we have
# - PartnerUser1 (ID1, email1@proton)
# - PartnerUser2 (ID2, email2@proton)
# - SimpleLoginUser1 registered with email1@proton, but linked to account ID2
# We will try to log in with email1@proton
email1 = random_email()
email2 = random_email()
partner_user_id_1 = random_string()
partner_user_id_2 = random_string()
link_request_1 = random_link_request(
external_user_id=partner_user_id_1, email=email1
)
link_request_2 = random_link_request(
external_user_id=partner_user_id_2, email=email2
)
user = create_user_for_partner(
link_request_2.external_user_id, email=link_request_1.email
)
strategy = get_login_strategy(
link_request=link_request_1,
user=user,
partner=get_proton_partner(),
)
assert isinstance(strategy, LinkedWithAnotherPartnerUserStrategy)
##
# LINK
def test_link_account_with_proton_account_same_address(flask_client):
# This is the most basic scenario
# In this scenario we have:
# - PartnerUser (email1@partner)
# - SimpleLoginUser registered with email1@proton
# We will try to link both accounts
email = random_email()
partner_user_id = random_string()
link_request = random_link_request(external_user_id=partner_user_id, email=email)
user = create_user(email)
res = process_link_case(link_request, user, get_proton_partner())
assert res is not None
assert res.user is not None
assert res.user.id == user.id
assert res.user.email == email
assert res.strategy == "Link"
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=user.id
)
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
def test_link_account_with_proton_account_different_address(flask_client):
# In this scenario we have:
# - ProtonUser (foo@proton)
# - SimpleLoginUser (bar@somethingelse)
# We will try to link both accounts
partner_user_id = random_string()
link_request = random_link_request(
external_user_id=partner_user_id, email=random_email()
)
user = create_user()
res = process_link_case(link_request, user, get_proton_partner())
assert res.user.id == user.id
assert res.user.email == user.email
assert res.strategy == "Link"
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=user.id
)
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - PartnerUser (foo@partner)
# - SimpleLoginUser1 (foo@partner)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@partner
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@partner
partner_user_id = random_string()
partner_email = random_email()
link_request = random_link_request(
external_user_id=partner_user_id, email=partner_email
)
sl_user_1 = create_user(partner_email)
sl_user_2 = create_user_for_partner(
partner_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(link_request, sl_user_1, get_proton_partner())
assert res.user.id == sl_user_1.id
assert res.user.email == partner_email
assert res.strategy == "Link"
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_1.id
)
assert partner_user.partner_id == get_proton_partner().id
assert partner_user.external_user_id == partner_user_id
partner_user = PartnerUser.get_by(
partner_id=get_proton_partner().id, user_id=sl_user_2.id
)
assert partner_user is None
def test_link_account_with_proton_account_different_address_and_linked_to_other_user(
flask_client,
):
# In this scenario we have:
# - PartnerUser (foo@partner)
# - SimpleLoginUser1 (bar@somethingelse)
# - SimpleLoginUser2 (other@somethingelse) linked with foo@partner
# We will unlink SimpleLoginUser2 and link SimpleLoginUser1 with foo@partner
partner_user_id = random_string()
link_request = random_link_request(
external_user_id=partner_user_id, email=random_email()
)
sl_user_1 = create_user(random_email())
sl_user_2 = create_user_for_partner(
partner_user_id, email=random_email()
) # User already linked with the proton account
res = process_link_case(link_request, sl_user_1, get_proton_partner())
assert res.user.id == sl_user_1.id
assert res.user.email == sl_user_1.email
assert res.strategy == "Link"
partner_user_1 = PartnerUser.get_by(
user_id=sl_user_1.id, partner_id=get_proton_partner().id
)
assert partner_user_1 is not None
assert partner_user_1.partner_email == sl_user_2.email
assert partner_user_1.partner_id == get_proton_partner().id
assert partner_user_1.external_user_id == partner_user_id
partner_user_2 = PartnerUser.get_by(
user_id=sl_user_2.id, partner_id=get_proton_partner().id
)
assert partner_user_2 is None
def test_cannot_create_instance_of_base_strategy():
with pytest.raises(Exception):
ClientMergeStrategy(random_link_request(), None, get_proton_partner())

View File

@ -10,6 +10,7 @@ import jinja2
from flask import url_for
from app.models import User
from app.utils import random_string
def create_new_user() -> User:
@ -66,3 +67,7 @@ def load_eml_file(
template_values = {}
rendered = template.render(**template_values)
return email.message_from_string(rendered)
def random_email() -> str:
return "{rand}@{rand}.com".format(rand=random_string(20))