From b6f79ea3a6ab9443c606902f476cf9082ec67576 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Casaj=C3=BAs?= Date: Thu, 6 Apr 2023 11:07:13 +0200 Subject: [PATCH] Refactor alias options and add it to more methods (#1681) MIME-Version: 1.0 Content-Type: text/plain; charset=UTF-8 Content-Transfer-Encoding: 8bit Co-authored-by: Adrià Casajús --- app/alias_suffix.py | 18 +++++++++-------- app/models.py | 40 ++++++++++++++++++++++++++----------- tests/models/test_user.py | 8 +------- tests/test_domains.py | 42 ++++++++++++++++++++++++++++++--------- 4 files changed, 72 insertions(+), 36 deletions(-) diff --git a/app/alias_suffix.py b/app/alias_suffix.py index f5e5b03d..af5e2562 100644 --- a/app/alias_suffix.py +++ b/app/alias_suffix.py @@ -6,7 +6,7 @@ from typing import Optional import itsdangerous from app import config from app.log import LOG -from app.models import User, Partner +from app.models import User, AliasOptions signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) @@ -42,7 +42,9 @@ def check_suffix_signature(signed_suffix: str) -> Optional[str]: return None -def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: +def verify_prefix_suffix( + user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None +) -> bool: """verify if user could create an alias with the given prefix and suffix""" if not alias_prefix or not alias_suffix: # should be caught on frontend return False @@ -63,7 +65,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: # 1) alias_suffix must start with "." and # 2) alias_domain_prefix must come from the word list if ( - alias_domain in user.available_sl_domains() + alias_domain in user.available_sl_domains(alias_options=alias_options) and alias_domain not in user_custom_domains # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty and not config.DISABLE_ALIAS_SUFFIX @@ -79,7 +81,9 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False - if alias_domain not in user.available_sl_domains(): + if alias_domain not in user.available_sl_domains( + alias_options=alias_options + ): LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False @@ -87,9 +91,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: def get_alias_suffixes( - user: User, - show_domains_for_partner: Optional[Partner] = None, - show_sl_domains: bool = True, + user: User, alias_options: Optional[AliasOptions] = None ) -> [AliasSuffix]: """ Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up. @@ -142,7 +144,7 @@ def get_alias_suffixes( alias_suffixes.append(alias_suffix) # then SimpleLogin domain - for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains): + for sl_domain in user.get_sl_domains(alias_options=alias_options): suffix = ( ( "" diff --git a/app/models.py b/app/models.py index e4021dcc..e5b07ce2 100644 --- a/app/models.py +++ b/app/models.py @@ -1,6 +1,7 @@ from __future__ import annotations import base64 +import dataclasses import enum import hashlib import hmac @@ -273,6 +274,12 @@ class IntEnumType(sa.types.TypeDecorator): return self._enum_type(enum_value) +@dataclasses.dataclass +class AliasOptions: + show_sl_domains: bool = True + show_partner_domains: Optional[Partner] = None + + class Hibp(Base, ModelMixin): __tablename__ = "hibp" name = sa.Column(sa.String(), nullable=False, unique=True, index=True) @@ -867,14 +874,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): def custom_domains(self): return CustomDomain.filter_by(user_id=self.id, verified=True).all() - def available_domains_for_random_alias(self) -> List[Tuple[bool, str]]: + def available_domains_for_random_alias( + self, alias_options: Optional[AliasOptions] = None + ) -> List[Tuple[bool, str]]: """Return available domains for user to create random aliases Each result record contains: - whether the domain belongs to SimpleLogin - the domain """ res = [] - for domain in self.available_sl_domains(): + for domain in self.available_sl_domains(alias_options=alias_options): res.append((True, domain)) for custom_domain in self.verified_custom_domains(): @@ -959,32 +968,37 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): return None, "", False - def available_sl_domains(self) -> [str]: + def available_sl_domains( + self, alias_options: Optional[AliasOptions] = None + ) -> [str]: """ Return all SimpleLogin domains that user can use when creating a new alias, including: - SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) """ - return [sl_domain.domain for sl_domain in self.get_sl_domains()] + return [ + sl_domain.domain + for sl_domain in self.get_sl_domains(alias_options=alias_options) + ] def get_sl_domains( - self, - show_domains_for_partner: Optional[Partner] = None, - show_sl_domains: bool = True, + self, alias_options: Optional[AliasOptions] = None ) -> list["SLDomain"]: + if alias_options is None: + alias_options = AliasOptions() conditions = [SLDomain.hidden == False] # noqa: E712 if not self.is_premium(): conditions.append(SLDomain.premium_only == False) # noqa: E712 partner_domain_cond = [] # noqa:E711 - if show_domains_for_partner is not None: + if alias_options.show_partner_domains is not None: partner_user = PartnerUser.filter_by( - user_id=self.id, partner_id=show_domains_for_partner.id + user_id=self.id, partner_id=alias_options.show_partner_domains.id ).first() if partner_user is not None: partner_domain_cond.append( SLDomain.partner_id == partner_user.partner_id ) - if show_sl_domains: + if alias_options.show_sl_domains: partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711 if len(partner_domain_cond) == 1: conditions.append(partner_domain_cond[0]) @@ -993,14 +1007,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order) return query.all() - def available_alias_domains(self) -> [str]: + def available_alias_domains( + self, alias_options: Optional[AliasOptions] = None + ) -> [str]: """return all domains that user can use when creating a new alias, including: - SimpleLogin public domains, available for all users (ALIAS_DOMAIN) - SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN) - Verified custom domains """ - domains = self.available_sl_domains() + domains = self.available_sl_domains(alias_options=alias_options) for custom_domain in self.verified_custom_domains(): domains.append(custom_domain.domain) diff --git a/tests/models/test_user.py b/tests/models/test_user.py index 3e39804f..d6293b6b 100644 --- a/tests/models/test_user.py +++ b/tests/models/test_user.py @@ -1,13 +1,7 @@ from app import config from app.db import Session from app.models import User, Job -from tests.utils import create_new_user, random_email - - -def test_available_sl_domains(flask_client): - user = create_new_user() - - assert set(user.available_sl_domains()) == {"d1.test", "d2.test", "sl.local"} +from tests.utils import random_email def test_create_from_partner(flask_client): diff --git a/tests/test_domains.py b/tests/test_domains.py index 07e73cc4..124b3559 100644 --- a/tests/test_domains.py +++ b/tests/test_domains.py @@ -1,5 +1,5 @@ from app.db import Session -from app.models import SLDomain, PartnerUser +from app.models import SLDomain, PartnerUser, AliasOptions from app.proton.utils import get_proton_partner from init_app import add_sl_domains from tests.utils import create_new_user, random_token @@ -43,12 +43,14 @@ def test_get_non_partner_domains(): assert len(domains) == 2 assert domains[0].domain == "premium_non_partner" assert domains[1].domain == "free_non_partner" + assert [d.domain for d in domains] == user.available_sl_domains() # Free user.trial_end = None Session.flush() domains = user.get_sl_domains() assert len(domains) == 1 assert domains[0].domain == "free_non_partner" + assert [d.domain for d in domains] == user.available_sl_domains() def test_get_free_with_partner_domains(): @@ -64,17 +66,28 @@ def test_get_free_with_partner_domains(): # Default assert len(domains) == 1 assert domains[0].domain == "free_non_partner" + assert [d.domain for d in domains] == user.available_sl_domains() # Show partner domains - domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) + options = AliasOptions( + show_sl_domains=True, show_partner_domains=get_proton_partner() + ) + domains = user.get_sl_domains(alias_options=options) assert len(domains) == 2 assert domains[0].domain == "free_partner" assert domains[1].domain == "free_non_partner" - # Only partner domains - domains = user.get_sl_domains( - show_domains_for_partner=get_proton_partner(), show_sl_domains=False + assert [d.domain for d in domains] == user.available_sl_domains( + alias_options=options ) + # Only partner domains + options = AliasOptions( + show_sl_domains=False, show_partner_domains=get_proton_partner() + ) + domains = user.get_sl_domains(alias_options=options) assert len(domains) == 1 assert domains[0].domain == "free_partner" + assert [d.domain for d in domains] == user.available_sl_domains( + alias_options=options + ) def test_get_premium_with_partner_domains(): @@ -90,17 +103,28 @@ def test_get_premium_with_partner_domains(): assert len(domains) == 2 assert domains[0].domain == "premium_non_partner" assert domains[1].domain == "free_non_partner" + assert [d.domain for d in domains] == user.available_sl_domains() # Show partner domains - domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) + options = AliasOptions( + show_sl_domains=True, show_partner_domains=get_proton_partner() + ) + domains = user.get_sl_domains(alias_options=options) assert len(domains) == 4 assert domains[0].domain == "premium_partner" assert domains[1].domain == "free_partner" assert domains[2].domain == "premium_non_partner" assert domains[3].domain == "free_non_partner" - # Only partner domains - domains = user.get_sl_domains( - show_domains_for_partner=get_proton_partner(), show_sl_domains=False + assert [d.domain for d in domains] == user.available_sl_domains( + alias_options=options ) + # Only partner domains + options = AliasOptions( + show_sl_domains=False, show_partner_domains=get_proton_partner() + ) + domains = user.get_sl_domains(alias_options=options) assert len(domains) == 2 assert domains[0].domain == "premium_partner" assert domains[1].domain == "free_partner" + assert [d.domain for d in domains] == user.available_sl_domains( + alias_options=options + )