diff --git a/app/models.py b/app/models.py index 76a71b58..7a1ad0a8 100644 --- a/app/models.py +++ b/app/models.py @@ -580,19 +580,6 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): Session.flush() user.default_mailbox_id = mb.id - # create a first alias mail to show user how to use when they login - alias = Alias.create_new( - user, - prefix="simplelogin-newsletter", - mailbox_id=mb.id, - note="This is your first alias. It's used to receive SimpleLogin communications " - "like new features announcements, newsletters.", - ) - Session.flush() - - user.newsletter_alias_id = alias.id - Session.flush() - # generate an alternative_id if needed if "alternative_id" not in kwargs: user.alternative_id = str(uuid.uuid4()) @@ -611,6 +598,19 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): Session.flush() return user + # create a first alias mail to show user how to use when they login + alias = Alias.create_new( + user, + prefix="simplelogin-newsletter", + mailbox_id=mb.id, + note="This is your first alias. It's used to receive SimpleLogin communications " + "like new features announcements, newsletters.", + ) + Session.flush() + + user.newsletter_alias_id = alias.id + Session.flush() + if config.DISABLE_ONBOARDING: LOG.d("Disable onboarding emails") return user @@ -636,7 +636,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): return user def get_active_subscription( - self, + self, include_partner_subscription: bool = True ) -> Optional[ Union[ Subscription @@ -664,19 +664,24 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): if coinbase_subscription and coinbase_subscription.is_active(): return coinbase_subscription - partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id(self.id) - if partner_sub and partner_sub.is_active(): - return partner_sub + if include_partner_subscription: + partner_sub: PartnerSubscription = PartnerSubscription.find_by_user_id( + self.id + ) + if partner_sub and partner_sub.is_active(): + return partner_sub return None # region Billing - def lifetime_or_active_subscription(self) -> bool: + def lifetime_or_active_subscription( + self, include_partner_subscription: bool = True + ) -> bool: """True if user has lifetime licence or active subscription""" if self.lifetime: return True - return self.get_active_subscription() is not None + return self.get_active_subscription(include_partner_subscription) is not None def is_paid(self) -> bool: """same as _lifetime_or_active_subscription but not include free manual subscription""" @@ -705,14 +710,14 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): return True - def is_premium(self) -> bool: + def is_premium(self, include_partner_subscription: bool = True) -> bool: """ user is premium if they: - have a lifetime deal or - in trial period or - active subscription """ - if self.lifetime_or_active_subscription(): + if self.lifetime_or_active_subscription(include_partner_subscription): return True if self.trial_end and arrow.now() < self.trial_end: diff --git a/tests/auth/test_reset_password.py b/tests/auth/test_reset_password.py index 20cd695b..adabc58f 100644 --- a/tests/auth/test_reset_password.py +++ b/tests/auth/test_reset_password.py @@ -5,7 +5,7 @@ from app.models import User, ResetPasswordCode from tests.utils import create_new_user, random_token -def test_reset_password(flask_client): +def test_successful_reset_password(flask_client): user = create_new_user() original_pass_hash = user.password user_id = user.id @@ -19,8 +19,8 @@ def test_reset_password(flask_client): data={"password": "1231idsfjaads"}, ) - assert r.status_code == 200 + assert r.status_code == 302 - assert len(ResetPasswordCode.get_by(user_id=user_id).all()) == 0 + assert ResetPasswordCode.get_by(user_id=user_id) is None user = User.get(user_id) assert user.password != original_pass_hash diff --git a/tests/models/test_user.py b/tests/models/test_user.py index d6293b6b..211cb3f8 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -1,7 +1,9 @@ +import arrow from app import config from app.db import Session -from app.models import User, Job -from tests.utils import random_email +from app.models import User, Job, PartnerSubscription, PartnerUser, ManualSubscription +from app.proton.utils import get_proton_partner +from tests.utils import random_email, random_token def test_create_from_partner(flask_client): @@ -11,6 +13,7 @@ def test_create_from_partner(flask_client): ) assert user.notification is False assert user.trial_end is None + assert user.newsletter_alias_id is None job = Session.query(Job).order_by(Job.id.desc()).first() assert job is not None assert job.name == config.JOB_SEND_PROTON_WELCOME_1 @@ -23,3 +26,23 @@ def test_user_created_by_partner(flask_client): regular_user = User.create(email=random_email()) assert regular_user.created_by_partner is False + + +def test_user_is_premium(flask_client): + user = User.create(email=random_email(), from_partner=True) + assert not user.is_premium() + partner_user = PartnerUser.create( + user_id=user.id, + partner_id=get_proton_partner().id, + partner_email=user.email, + external_user_id=random_token(), + flush=True, + ) + ps = PartnerSubscription.create( + partner_user_id=partner_user.id, end_at=arrow.now().shift(years=1), flush=True + ) + assert user.is_premium() + assert not user.is_premium(include_partner_subscription=False) + ManualSubscription.create(user_id=user.id, end_at=ps.end_at) + assert user.is_premium() + assert user.is_premium(include_partner_subscription=False)