Show the default domain for creating aliases even if it's not requested by a partner (#1754)

* Show the default domain in the suffixes even if it's not allowed

* Simplify logic

* Reformat

* Simplified logic

* Remove unused function

* Added test to validate suffixes

* Ensure we catch prefixes in test

---------

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-05-29 16:40:04 +02:00 committed by GitHub
parent e43a2dd34d
commit 07bb658310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Optional
import itsdangerous
from app import config
from app.log import LOG
from app.models import User, AliasOptions
from app.models import User, AliasOptions, SLDomain
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -105,10 +105,7 @@ def get_alias_suffixes(
for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation:
suffix = (
"."
+ user.get_random_alias_suffix(custom_domain)
+ "@"
+ custom_domain.domain
f".{user.get_random_alias_suffix(custom_domain)}@{custom_domain.domain}"
)
alias_suffix = AliasSuffix(
is_custom=True,
@ -123,7 +120,7 @@ def get_alias_suffixes(
else:
alias_suffixes.append(alias_suffix)
suffix = "@" + custom_domain.domain
suffix = f"@{custom_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=True,
suffix=suffix,
@ -144,16 +141,13 @@ def get_alias_suffixes(
alias_suffixes.append(alias_suffix)
# then SimpleLogin domain
for sl_domain in user.get_sl_domains(alias_options=alias_options):
suffix = (
(
""
if config.DISABLE_ALIAS_SUFFIX
else "." + user.get_random_alias_suffix()
)
+ "@"
+ sl_domain.domain
sl_domains = user.get_sl_domains(alias_options=alias_options)
default_domain_found = False
for sl_domain in sl_domains:
prefix = (
"" if config.DISABLE_ALIAS_SUFFIX else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
@ -162,11 +156,35 @@ def get_alias_suffixes(
domain=sl_domain.domain,
mx_verified=True,
)
# put the default domain to top
if user.default_alias_public_domain_id == sl_domain.id:
alias_suffixes.insert(0, alias_suffix)
else:
# No default or this is not the default
if (
user.default_alias_public_domain_id is None
or user.default_alias_public_domain_id != sl_domain.id
):
alias_suffixes.append(alias_suffix)
# If no default domain mark it as found
default_domain_found = user.default_alias_public_domain_id is None
else:
default_domain_found = True
alias_suffixes.insert(0, alias_suffix)
if not default_domain_found:
sl_domain = SLDomain.get(user.default_alias_public_domain_id)
if sl_domain:
prefix = (
""
if config.DISABLE_ALIAS_SUFFIX
else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
alias_suffixes.insert(0, alias_suffix)
return alias_suffixes

View File

@ -0,0 +1,87 @@
import re
from app.alias_suffix import get_alias_suffixes
from app.db import Session
from app.models import SLDomain, PartnerUser, AliasOptions, CustomDomain
from app.proton.utils import get_proton_partner
from init_app import add_sl_domains
from tests.utils import create_new_user, random_token
def setup_module():
Session.query(SLDomain).delete()
SLDomain.create(
domain="hidden", premium_only=False, flush=True, order=5, hidden=True
)
SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4)
SLDomain.create(
domain="premium_non_partner", premium_only=True, flush=True, order=3
)
SLDomain.create(
domain="free_partner",
premium_only=False,
flush=True,
partner_id=get_proton_partner().id,
order=2,
)
SLDomain.create(
domain="premium_partner",
premium_only=True,
flush=True,
partner_id=get_proton_partner().id,
order=1,
)
Session.commit()
def teardown_module():
Session.query(SLDomain).delete()
add_sl_domains()
def test_get_default_domain_even_if_is_not_allowed():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
user.trial_end = None
default_domain = SLDomain.filter_by(
hidden=False, partner_id=None, premium_only=False
).first()
user.default_alias_public_domain_id = default_domain.id
Session.flush()
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
suffixes = get_alias_suffixes(user, alias_options=options)
assert suffixes[0].domain == default_domain.domain
def test_suffixes_are_valid():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
CustomDomain.create(
user_id=user.id, domain=f"{random_token(10)}.com", verified=True
)
user.trial_end = None
Session.flush()
options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
alias_suffixes = get_alias_suffixes(user, alias_options=options)
valid_re = re.compile(r"^(\.[\w_]+)?@[\.\w]+$")
has_prefix = 0
for suffix in alias_suffixes:
match = valid_re.match(suffix.suffix)
assert match is not None
if len(match.groups()) >= 1:
has_prefix += 1
assert has_prefix > 0

View File

@ -11,7 +11,7 @@ from app.models import (
CoinbaseSubscription,
ManualSubscription,
)
from tests.utils import create_new_user
from tests.utils import create_new_user, random_token
from app.subscription_webhook import execute_subscription_webhook
@ -57,7 +57,7 @@ def test_webhook_with_subscription():
user_id=user.id,
cancel_url="",
update_url="",
subscription_id="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=end_at.date(),
plan="yearly",
@ -76,7 +76,7 @@ def test_webhook_with_apple_subscription():
user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id="",
original_transaction_id=random_token(10),
plan="yearly",
product_id="",
flush=True,