Refactor alias options and add it to more methods (#1681)

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-04-06 11:07:13 +02:00 committed by GitHub
parent 43b91cd197
commit b6f79ea3a6
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 72 additions and 36 deletions

View File

@ -6,7 +6,7 @@ from typing import Optional
import itsdangerous
from app import config
from app.log import LOG
from app.models import User, Partner
from app.models import User, AliasOptions
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -42,7 +42,9 @@ def check_suffix_signature(signed_suffix: str) -> Optional[str]:
return None
def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
def verify_prefix_suffix(
user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None
) -> bool:
"""verify if user could create an alias with the given prefix and suffix"""
if not alias_prefix or not alias_suffix: # should be caught on frontend
return False
@ -63,7 +65,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
# 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list
if (
alias_domain in user.available_sl_domains()
alias_domain in user.available_sl_domains(alias_options=alias_options)
and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not config.DISABLE_ALIAS_SUFFIX
@ -79,7 +81,9 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
if alias_domain not in user.available_sl_domains():
if alias_domain not in user.available_sl_domains(
alias_options=alias_options
):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
@ -87,9 +91,7 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
def get_alias_suffixes(
user: User,
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
user: User, alias_options: Optional[AliasOptions] = None
) -> [AliasSuffix]:
"""
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
@ -142,7 +144,7 @@ def get_alias_suffixes(
alias_suffixes.append(alias_suffix)
# then SimpleLogin domain
for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains):
for sl_domain in user.get_sl_domains(alias_options=alias_options):
suffix = (
(
""

View File

@ -1,6 +1,7 @@
from __future__ import annotations
import base64
import dataclasses
import enum
import hashlib
import hmac
@ -273,6 +274,12 @@ class IntEnumType(sa.types.TypeDecorator):
return self._enum_type(enum_value)
@dataclasses.dataclass
class AliasOptions:
show_sl_domains: bool = True
show_partner_domains: Optional[Partner] = None
class Hibp(Base, ModelMixin):
__tablename__ = "hibp"
name = sa.Column(sa.String(), nullable=False, unique=True, index=True)
@ -867,14 +874,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def custom_domains(self):
return CustomDomain.filter_by(user_id=self.id, verified=True).all()
def available_domains_for_random_alias(self) -> List[Tuple[bool, str]]:
def available_domains_for_random_alias(
self, alias_options: Optional[AliasOptions] = None
) -> List[Tuple[bool, str]]:
"""Return available domains for user to create random aliases
Each result record contains:
- whether the domain belongs to SimpleLogin
- the domain
"""
res = []
for domain in self.available_sl_domains():
for domain in self.available_sl_domains(alias_options=alias_options):
res.append((True, domain))
for custom_domain in self.verified_custom_domains():
@ -959,32 +968,37 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return None, "", False
def available_sl_domains(self) -> [str]:
def available_sl_domains(
self, alias_options: Optional[AliasOptions] = None
) -> [str]:
"""
Return all SimpleLogin domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
"""
return [sl_domain.domain for sl_domain in self.get_sl_domains()]
return [
sl_domain.domain
for sl_domain in self.get_sl_domains(alias_options=alias_options)
]
def get_sl_domains(
self,
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
self, alias_options: Optional[AliasOptions] = None
) -> list["SLDomain"]:
if alias_options is None:
alias_options = AliasOptions()
conditions = [SLDomain.hidden == False] # noqa: E712
if not self.is_premium():
conditions.append(SLDomain.premium_only == False) # noqa: E712
partner_domain_cond = [] # noqa:E711
if show_domains_for_partner is not None:
if alias_options.show_partner_domains is not None:
partner_user = PartnerUser.filter_by(
user_id=self.id, partner_id=show_domains_for_partner.id
user_id=self.id, partner_id=alias_options.show_partner_domains.id
).first()
if partner_user is not None:
partner_domain_cond.append(
SLDomain.partner_id == partner_user.partner_id
)
if show_sl_domains:
if alias_options.show_sl_domains:
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
if len(partner_domain_cond) == 1:
conditions.append(partner_domain_cond[0])
@ -993,14 +1007,16 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
return query.all()
def available_alias_domains(self) -> [str]:
def available_alias_domains(
self, alias_options: Optional[AliasOptions] = None
) -> [str]:
"""return all domains that user can use when creating a new alias, including:
- SimpleLogin public domains, available for all users (ALIAS_DOMAIN)
- SimpleLogin premium domains, only available for Premium accounts (PREMIUM_ALIAS_DOMAIN)
- Verified custom domains
"""
domains = self.available_sl_domains()
domains = self.available_sl_domains(alias_options=alias_options)
for custom_domain in self.verified_custom_domains():
domains.append(custom_domain.domain)

View File

@ -1,13 +1,7 @@
from app import config
from app.db import Session
from app.models import User, Job
from tests.utils import create_new_user, random_email
def test_available_sl_domains(flask_client):
user = create_new_user()
assert set(user.available_sl_domains()) == {"d1.test", "d2.test", "sl.local"}
from tests.utils import random_email
def test_create_from_partner(flask_client):

View File

@ -1,5 +1,5 @@
from app.db import Session
from app.models import SLDomain, PartnerUser
from app.models import SLDomain, PartnerUser, AliasOptions
from app.proton.utils import get_proton_partner
from init_app import add_sl_domains
from tests.utils import create_new_user, random_token
@ -43,12 +43,14 @@ def test_get_non_partner_domains():
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Free
user.trial_end = None
Session.flush()
domains = user.get_sl_domains()
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
def test_get_free_with_partner_domains():
@ -64,17 +66,28 @@ def test_get_free_with_partner_domains():
# Default
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 2
assert domains[0].domain == "free_partner"
assert domains[1].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)
# Only partner domains
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 1
assert domains[0].domain == "free_partner"
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)
def test_get_premium_with_partner_domains():
@ -90,17 +103,28 @@ def test_get_premium_with_partner_domains():
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
assert [d.domain for d in domains] == user.available_sl_domains()
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
options = AliasOptions(
show_sl_domains=True, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 4
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"
assert domains[2].domain == "premium_non_partner"
assert domains[3].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)
# Only partner domains
options = AliasOptions(
show_sl_domains=False, show_partner_domains=get_proton_partner()
)
domains = user.get_sl_domains(alias_options=options)
assert len(domains) == 2
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"
assert [d.domain for d in domains] == user.available_sl_domains(
alias_options=options
)