diff --git a/app/alias_suffix.py b/app/alias_suffix.py index 83d6392f..8f010fa8 100644 --- a/app/alias_suffix.py +++ b/app/alias_suffix.py @@ -6,7 +6,7 @@ from typing import Optional import itsdangerous from app import config from app.log import LOG -from app.models import User, AliasOptions +from app.models import User, AliasOptions, SLDomain signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) @@ -105,10 +105,7 @@ def get_alias_suffixes( for custom_domain in user_custom_domains: if custom_domain.random_prefix_generation: suffix = ( - "." - + user.get_random_alias_suffix(custom_domain) - + "@" - + custom_domain.domain + f".{user.get_random_alias_suffix(custom_domain)}@{custom_domain.domain}" ) alias_suffix = AliasSuffix( is_custom=True, @@ -123,7 +120,7 @@ def get_alias_suffixes( else: alias_suffixes.append(alias_suffix) - suffix = "@" + custom_domain.domain + suffix = f"@{custom_domain.domain}" alias_suffix = AliasSuffix( is_custom=True, suffix=suffix, @@ -144,16 +141,13 @@ def get_alias_suffixes( alias_suffixes.append(alias_suffix) # then SimpleLogin domain - for sl_domain in user.get_sl_domains(alias_options=alias_options): - suffix = ( - ( - "" - if config.DISABLE_ALIAS_SUFFIX - else "." + user.get_random_alias_suffix() - ) - + "@" - + sl_domain.domain + sl_domains = user.get_sl_domains(alias_options=alias_options) + default_domain_found = False + for sl_domain in sl_domains: + prefix = ( + "" if config.DISABLE_ALIAS_SUFFIX else f".{user.get_random_alias_suffix()}" ) + suffix = f"{prefix}@{sl_domain.domain}" alias_suffix = AliasSuffix( is_custom=False, suffix=suffix, @@ -162,11 +156,35 @@ def get_alias_suffixes( domain=sl_domain.domain, mx_verified=True, ) - - # put the default domain to top - if user.default_alias_public_domain_id == sl_domain.id: - alias_suffixes.insert(0, alias_suffix) - else: + # No default or this is not the default + if ( + user.default_alias_public_domain_id is None + or user.default_alias_public_domain_id != sl_domain.id + ): alias_suffixes.append(alias_suffix) + # If no default domain mark it as found + default_domain_found = user.default_alias_public_domain_id is None + else: + default_domain_found = True + alias_suffixes.insert(0, alias_suffix) + + if not default_domain_found: + sl_domain = SLDomain.get(user.default_alias_public_domain_id) + if sl_domain: + prefix = ( + "" + if config.DISABLE_ALIAS_SUFFIX + else f".{user.get_random_alias_suffix()}" + ) + suffix = f"{prefix}@{sl_domain.domain}" + alias_suffix = AliasSuffix( + is_custom=False, + suffix=suffix, + signed_suffix=signer.sign(suffix).decode(), + is_premium=sl_domain.premium_only, + domain=sl_domain.domain, + mx_verified=True, + ) + alias_suffixes.insert(0, alias_suffix) return alias_suffixes diff --git a/tests/test_alias_suffixes.py b/tests/test_alias_suffixes.py new file mode 100644 index 00000000..4ab4d09f --- /dev/null +++ b/tests/test_alias_suffixes.py @@ -0,0 +1,87 @@ +import re + +from app.alias_suffix import get_alias_suffixes +from app.db import Session +from app.models import SLDomain, PartnerUser, AliasOptions, CustomDomain +from app.proton.utils import get_proton_partner +from init_app import add_sl_domains +from tests.utils import create_new_user, random_token + + +def setup_module(): + Session.query(SLDomain).delete() + SLDomain.create( + domain="hidden", premium_only=False, flush=True, order=5, hidden=True + ) + SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4) + SLDomain.create( + domain="premium_non_partner", premium_only=True, flush=True, order=3 + ) + SLDomain.create( + domain="free_partner", + premium_only=False, + flush=True, + partner_id=get_proton_partner().id, + order=2, + ) + SLDomain.create( + domain="premium_partner", + premium_only=True, + flush=True, + partner_id=get_proton_partner().id, + order=1, + ) + Session.commit() + + +def teardown_module(): + Session.query(SLDomain).delete() + add_sl_domains() + + +def test_get_default_domain_even_if_is_not_allowed(): + user = create_new_user() + PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + user.trial_end = None + default_domain = SLDomain.filter_by( + hidden=False, partner_id=None, premium_only=False + ).first() + user.default_alias_public_domain_id = default_domain.id + Session.flush() + options = AliasOptions( + show_sl_domains=False, show_partner_domains=get_proton_partner() + ) + suffixes = get_alias_suffixes(user, alias_options=options) + assert suffixes[0].domain == default_domain.domain + + +def test_suffixes_are_valid(): + user = create_new_user() + PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + CustomDomain.create( + user_id=user.id, domain=f"{random_token(10)}.com", verified=True + ) + user.trial_end = None + Session.flush() + options = AliasOptions( + show_sl_domains=True, show_partner_domains=get_proton_partner() + ) + alias_suffixes = get_alias_suffixes(user, alias_options=options) + valid_re = re.compile(r"^(\.[\w_]+)?@[\.\w]+$") + has_prefix = 0 + for suffix in alias_suffixes: + match = valid_re.match(suffix.suffix) + assert match is not None + if len(match.groups()) >= 1: + has_prefix += 1 + assert has_prefix > 0 diff --git a/tests/test_subscription_webhook.py b/tests/test_subscription_webhook.py index 6b09200c..7e7ed074 100644 --- a/tests/test_subscription_webhook.py +++ b/tests/test_subscription_webhook.py @@ -11,7 +11,7 @@ from app.models import ( CoinbaseSubscription, ManualSubscription, ) -from tests.utils import create_new_user +from tests.utils import create_new_user, random_token from app.subscription_webhook import execute_subscription_webhook @@ -57,7 +57,7 @@ def test_webhook_with_subscription(): user_id=user.id, cancel_url="", update_url="", - subscription_id="", + subscription_id=random_token(10), event_time=arrow.now(), next_bill_date=end_at.date(), plan="yearly", @@ -76,7 +76,7 @@ def test_webhook_with_apple_subscription(): user_id=user.id, receipt_data=arrow.now().date().strftime("%Y-%m-%d"), expires_date=end_at.date().strftime("%Y-%m-%d"), - original_transaction_id="", + original_transaction_id=random_token(10), plan="yearly", product_id="", flush=True,