Show the default domain for creating aliases even if it's not requested by a partner (#1754)

* Show the default domain in the suffixes even if it's not allowed

* Simplify logic

* Reformat

* Simplified logic

* Remove unused function

* Added test to validate suffixes

* Ensure we catch prefixes in test

---------

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-05-29 16:40:04 +02:00 committed by GitHub
parent e43a2dd34d
commit 07bb658310
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
3 changed files with 128 additions and 23 deletions

View File

@ -6,7 +6,7 @@ from typing import Optional
import itsdangerous import itsdangerous
from app import config from app import config
from app.log import LOG from app.log import LOG
from app.models import User, AliasOptions from app.models import User, AliasOptions, SLDomain
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -105,10 +105,7 @@ def get_alias_suffixes(
for custom_domain in user_custom_domains: for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation: if custom_domain.random_prefix_generation:
suffix = ( suffix = (
"." f".{user.get_random_alias_suffix(custom_domain)}@{custom_domain.domain}"
+ user.get_random_alias_suffix(custom_domain)
+ "@"
+ custom_domain.domain
) )
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=True, is_custom=True,
@ -123,7 +120,7 @@ def get_alias_suffixes(
else: else:
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
suffix = "@" + custom_domain.domain suffix = f"@{custom_domain.domain}"
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=True, is_custom=True,
suffix=suffix, suffix=suffix,
@ -144,16 +141,13 @@ def get_alias_suffixes(
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
# then SimpleLogin domain # then SimpleLogin domain
for sl_domain in user.get_sl_domains(alias_options=alias_options): sl_domains = user.get_sl_domains(alias_options=alias_options)
suffix = ( default_domain_found = False
( for sl_domain in sl_domains:
"" prefix = (
if config.DISABLE_ALIAS_SUFFIX "" if config.DISABLE_ALIAS_SUFFIX else f".{user.get_random_alias_suffix()}"
else "." + user.get_random_alias_suffix()
)
+ "@"
+ sl_domain.domain
) )
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=False, is_custom=False,
suffix=suffix, suffix=suffix,
@ -162,11 +156,35 @@ def get_alias_suffixes(
domain=sl_domain.domain, domain=sl_domain.domain,
mx_verified=True, mx_verified=True,
) )
# No default or this is not the default
# put the default domain to top if (
if user.default_alias_public_domain_id == sl_domain.id: user.default_alias_public_domain_id is None
alias_suffixes.insert(0, alias_suffix) or user.default_alias_public_domain_id != sl_domain.id
else: ):
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
# If no default domain mark it as found
default_domain_found = user.default_alias_public_domain_id is None
else:
default_domain_found = True
alias_suffixes.insert(0, alias_suffix)
if not default_domain_found:
sl_domain = SLDomain.get(user.default_alias_public_domain_id)
if sl_domain:
prefix = (
""
if config.DISABLE_ALIAS_SUFFIX
else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
alias_suffixes.insert(0, alias_suffix)
return alias_suffixes return alias_suffixes

View File

@ -0,0 +1,87 @@
import re
from app.alias_suffix import get_alias_suffixes
from app.db import Session
from app.models import SLDomain, PartnerUser, AliasOptions, CustomDomain
from app.proton.utils import get_proton_partner
from init_app import add_sl_domains
from tests.utils import create_new_user, random_token
def setup_module():
Session.query(SLDomain).delete()
SLDomain.create(
domain="hidden", premium_only=False, flush=True, order=5, hidden=True
)
SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4)
SLDomain.create(
domain="premium_non_partner", premium_only=True, flush=True, order=3
)
SLDomain.create(
domain="free_partner",
premium_only=False,
flush=True,
partner_id=get_proton_partner().id,
order=2,
)
SLDomain.create(
domain="premium_partner",
premium_only=True,
flush=True,
partner_id=get_proton_partner().id,
order=1,
)
Session.commit()
def teardown_module():
Session.query(SLDomain).delete()
add_sl_domains()
def test_get_default_domain_even_if_is_not_allowed():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
user.trial_end = None
default_domain = SLDomain.filter_by(
hidden=False, partner_id=None, premium_only=False
).first()
user.default_alias_public_domain_id = default_domain.id
Session.flush()
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
suffixes = get_alias_suffixes(user, alias_options=options)
assert suffixes[0].domain == default_domain.domain
def test_suffixes_are_valid():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
CustomDomain.create(
user_id=user.id, domain=f"{random_token(10)}.com", verified=True
)
user.trial_end = None
Session.flush()
options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
alias_suffixes = get_alias_suffixes(user, alias_options=options)
valid_re = re.compile(r"^(\.[\w_]+)?@[\.\w]+$")
has_prefix = 0
for suffix in alias_suffixes:
match = valid_re.match(suffix.suffix)
assert match is not None
if len(match.groups()) >= 1:
has_prefix += 1
assert has_prefix > 0

View File

@ -11,7 +11,7 @@ from app.models import (
CoinbaseSubscription, CoinbaseSubscription,
ManualSubscription, ManualSubscription,
) )
from tests.utils import create_new_user from tests.utils import create_new_user, random_token
from app.subscription_webhook import execute_subscription_webhook from app.subscription_webhook import execute_subscription_webhook
@ -57,7 +57,7 @@ def test_webhook_with_subscription():
user_id=user.id, user_id=user.id,
cancel_url="", cancel_url="",
update_url="", update_url="",
subscription_id="", subscription_id=random_token(10),
event_time=arrow.now(), event_time=arrow.now(),
next_bill_date=end_at.date(), next_bill_date=end_at.date(),
plan="yearly", plan="yearly",
@ -76,7 +76,7 @@ def test_webhook_with_apple_subscription():
user_id=user.id, user_id=user.id,
receipt_data=arrow.now().date().strftime("%Y-%m-%d"), receipt_data=arrow.now().date().strftime("%Y-%m-%d"),
expires_date=end_at.date().strftime("%Y-%m-%d"), expires_date=end_at.date().strftime("%Y-%m-%d"),
original_transaction_id="", original_transaction_id=random_token(10),
plan="yearly", plan="yearly",
product_id="", product_id="",
flush=True, flush=True,