From 989a577db608345e49a81f4fdc900a33d828a8e2 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Casaj=C3=BAs?= Date: Wed, 13 Sep 2023 18:12:47 +0200 Subject: [PATCH] Allow to get premium partner domains without premium sl domains (#1880) * Allow to get premium partner domains without premium sl domains * Set condition on domains --- app/models.py | 39 +++++++++++++++++++++++---------------- tests/test_domains.py | 28 ++++++++++++++++++++++++++++ 2 files changed, 51 insertions(+), 16 deletions(-) diff --git a/app/models.py b/app/models.py index 61935be7..94c4ad44 100644 --- a/app/models.py +++ b/app/models.py @@ -280,6 +280,7 @@ class IntEnumType(sa.types.TypeDecorator): class AliasOptions: show_sl_domains: bool = True show_partner_domains: Optional[Partner] = None + show_partner_premium: Optional[bool] = None class Hibp(Base, ModelMixin): @@ -1038,29 +1039,35 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): ) -> list["SLDomain"]: if alias_options is None: alias_options = AliasOptions() - conditions = [SLDomain.hidden == False] # noqa: E712 - if not self.is_premium(): - conditions.append(SLDomain.premium_only == False) # noqa: E712 - partner_domain_cond = [] # noqa:E711 + top_conds = [SLDomain.hidden == False] # noqa: E712 + or_conds = [] # noqa:E711 if self.default_alias_public_domain_id is not None: - partner_domain_cond.append( - SLDomain.id == self.default_alias_public_domain_id - ) + default_domain_conds = [SLDomain.id == self.default_alias_public_domain_id] + if not self.is_premium(): + default_domain_conds.append( + SLDomain.premium_only == False # noqa: E712 + ) + or_conds.append(and_(*default_domain_conds).self_group()) if alias_options.show_partner_domains is not None: partner_user = PartnerUser.filter_by( user_id=self.id, partner_id=alias_options.show_partner_domains.id ).first() if partner_user is not None: - partner_domain_cond.append( - SLDomain.partner_id == partner_user.partner_id - ) + partner_domain_cond = [SLDomain.partner_id == partner_user.partner_id] + if alias_options.show_partner_premium is None: + alias_options.show_partner_premium = self.is_premium() + if not alias_options.show_partner_premium: + partner_domain_cond.append( + SLDomain.premium_only == False # noqa: E712 + ) + or_conds.append(and_(*partner_domain_cond).self_group()) if alias_options.show_sl_domains: - partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711 - if len(partner_domain_cond) == 1: - conditions.append(partner_domain_cond[0]) - else: - conditions.append(or_(*partner_domain_cond)) - query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order) + sl_conds = [SLDomain.partner_id == None] # noqa: E711 + if not self.is_premium(): + sl_conds.append(SLDomain.premium_only == False) # noqa: E712 + or_conds.append(and_(*sl_conds).self_group()) + top_conds.append(or_(*or_conds)) + query = Session.query(SLDomain).filter(*top_conds).order_by(SLDomain.order) return query.all() def available_alias_domains( diff --git a/tests/test_domains.py b/tests/test_domains.py index 298363a2..5783aa03 100644 --- a/tests/test_domains.py +++ b/tests/test_domains.py @@ -199,3 +199,31 @@ def test_get_free_partner_and_hidden_default_domain(): assert [d.domain for d in domains] == user.available_sl_domains( alias_options=options ) + + +def test_get_free_partner_and_premium_partner(): + user = create_new_user() + user.trial_end = None + PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + user.default_alias_public_domain_id = ( + SLDomain.filter_by(hidden=False, premium_only=False).first().id + ) + Session.flush() + options = AliasOptions( + show_sl_domains=False, + show_partner_domains=get_proton_partner(), + show_partner_premium=True, + ) + domains = user.get_sl_domains(alias_options=options) + assert len(domains) == 3 + assert domains[0].domain == "premium_partner" + assert domains[1].domain == "free_partner" + assert domains[2].domain == "free_non_partner" + assert [d.domain for d in domains] == user.available_sl_domains( + alias_options=options + )