From a7aec0c37a784984f8f9f281700d842b5cf04360 Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Casaj=C3=BAs?= Date: Tue, 23 Jul 2024 16:17:23 +0200 Subject: [PATCH] Move set default domain for alias to an external function (#2158) * Move set default alias to a separate method to reuse it * Add tests * Find domains by domain not by id * Revert models and setting changes * Remove non required function --- app/alias_suffix.py | 10 +- app/dashboard/views/setting.py | 41 +----- app/models.py | 13 +- app/user_settings.py | 47 +++++++ tests/api/test_setting.py | 9 ++ tests/user_settings/__init__.py | 0 .../test_set_default_alias_domain.py | 128 ++++++++++++++++++ 7 files changed, 204 insertions(+), 44 deletions(-) create mode 100644 app/user_settings.py create mode 100644 tests/user_settings/__init__.py create mode 100644 tests/user_settings/test_set_default_alias_domain.py diff --git a/app/alias_suffix.py b/app/alias_suffix.py index 8cfc98aa..fbcbff20 100644 --- a/app/alias_suffix.py +++ b/app/alias_suffix.py @@ -64,8 +64,12 @@ def verify_prefix_suffix( # SimpleLogin domain case: # 1) alias_suffix must start with "." and # 2) alias_domain_prefix must come from the word list + available_sl_domains = [ + sl_domain.domain + for sl_domain in user.get_sl_domains(alias_options=alias_options) + ] if ( - alias_domain in user.available_sl_domains(alias_options=alias_options) + alias_domain in available_sl_domains and alias_domain not in user_custom_domains # when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty and not config.DISABLE_ALIAS_SUFFIX @@ -80,9 +84,7 @@ def verify_prefix_suffix( LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False - if alias_domain not in user.available_sl_domains( - alias_options=alias_options - ): + if alias_domain not in available_sl_domains: LOG.e("wrong alias suffix %s, user %s", alias_suffix, user) return False diff --git a/app/dashboard/views/setting.py b/app/dashboard/views/setting.py index f07ebb17..1b52109e 100644 --- a/app/dashboard/views/setting.py +++ b/app/dashboard/views/setting.py @@ -14,7 +14,7 @@ from flask_wtf import FlaskForm from flask_wtf.file import FileField from wtforms import StringField, validators -from app import s3 +from app import s3, user_settings from app.config import ( FIRST_ALIAS_DOMAIN, ALIAS_RANDOM_SUFFIX_LENGTH, @@ -31,12 +31,10 @@ from app.models import ( PlanEnum, File, EmailChange, - CustomDomain, AliasGeneratorEnum, AliasSuffixEnum, ManualSubscription, SenderFormatEnum, - SLDomain, CoinbaseSubscription, AppleSubscription, PartnerUser, @@ -166,38 +164,11 @@ def setting(): return redirect(url_for("dashboard.setting")) elif request.form.get("form-name") == "change-random-alias-default-domain": default_domain = request.form.get("random-alias-default-domain") - - if default_domain: - sl_domain: SLDomain = SLDomain.get_by(domain=default_domain) - if sl_domain: - if sl_domain.premium_only and not current_user.is_premium(): - flash("You cannot use this domain", "error") - return redirect(url_for("dashboard.setting")) - - current_user.default_alias_public_domain_id = sl_domain.id - current_user.default_alias_custom_domain_id = None - else: - custom_domain = CustomDomain.get_by(domain=default_domain) - if custom_domain: - # sanity check - if ( - custom_domain.user_id != current_user.id - or not custom_domain.verified - ): - LOG.w( - "%s cannot use domain %s", current_user, custom_domain - ) - flash(f"Domain {default_domain} can't be used", "error") - return redirect(request.url) - else: - current_user.default_alias_custom_domain_id = ( - custom_domain.id - ) - current_user.default_alias_public_domain_id = None - - else: - current_user.default_alias_custom_domain_id = None - current_user.default_alias_public_domain_id = None + try: + user_settings.set_default_alias_id(current_user, default_domain) + except user_settings.CannotSetAlias as e: + flash(e.msg, "error") + return redirect(url_for("dashboard.setting")) Session.commit() flash("Your preference has been updated", "success") diff --git a/app/models.py b/app/models.py index 4701d483..e5297aa7 100644 --- a/app/models.py +++ b/app/models.py @@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): - the domain """ res = [] - for domain in self.available_sl_domains(alias_options=alias_options): - res.append((True, domain)) + for domain in self.get_sl_domains(alias_options=alias_options): + res.append((True, domain.domain)) for custom_domain in self.verified_custom_domains(): res.append((False, custom_domain.domain)) @@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): - Verified custom domains """ - domains = self.available_sl_domains(alias_options=alias_options) + domains = [ + sl_domain.domain + for sl_domain in self.get_sl_domains(alias_options=alias_options) + ] for custom_domain in self.verified_custom_domains(): domains.append(custom_domain.domain) @@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin): return sorted(self._auto_create_rules, key=lambda rule: rule.order) def __repr__(self): - return f"" + return f"" class AutoCreateRule(Base, ModelMixin): @@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin): ) def __repr__(self): - return f"" class Monitoring(Base, ModelMixin): diff --git a/app/user_settings.py b/app/user_settings.py new file mode 100644 index 00000000..8e97d582 --- /dev/null +++ b/app/user_settings.py @@ -0,0 +1,47 @@ +from typing import Optional + +from app.db import Session +from app.log import LOG +from app.models import User, SLDomain, CustomDomain + + +class CannotSetAlias(Exception): + def __init__(self, msg: str): + self.msg = msg + + +def set_default_alias_id(user: User, domain_name: Optional[str]): + if domain_name is None: + LOG.i(f"User {user} has set no domain as default domain") + user.default_alias_public_domain_id = None + user.default_alias_custom_domain_id = None + Session.flush() + return + sl_domain: SLDomain = SLDomain.get_by(domain=domain_name) + if sl_domain: + if sl_domain.hidden: + LOG.i(f"User {user} has tried to set up a hidden domain as default domain") + raise CannotSetAlias("Domain does not exist") + if sl_domain.premium_only and not user.is_premium(): + LOG.i(f"User {user} has tried to set up a premium domain as default domain") + raise CannotSetAlias("You cannot use this domain") + LOG.i(f"User {user} has set public {sl_domain} as default domain") + user.default_alias_public_domain_id = sl_domain.id + user.default_alias_custom_domain_id = None + Session.flush() + return + custom_domain = CustomDomain.get_by(domain=domain_name) + if not custom_domain: + LOG.i( + f"User {user} has tried to set up an non existing domain as default domain" + ) + raise CannotSetAlias("Domain does not exist or it hasn't been verified") + if custom_domain.user_id != user.id or not custom_domain.verified: + LOG.i( + f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified" + ) + raise CannotSetAlias("Domain does not exist or it hasn't been verified") + LOG.i(f"User {user} has set custom {custom_domain} as default domain") + user.default_alias_public_domain_id = None + user.default_alias_custom_domain_id = custom_domain.id + Session.flush() diff --git a/tests/api/test_setting.py b/tests/api/test_setting.py index 199a8409..698596ee 100644 --- a/tests/api/test_setting.py +++ b/tests/api/test_setting.py @@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client): def test_update_settings_random_alias_default_domain(flask_client): user = login(flask_client) + custom_domain = CustomDomain.create( + domain=random_domain(), verified=True, user_id=user.id, flush=True + ) assert user.default_random_alias_domain() == "sl.local" r = flask_client.patch( @@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client): assert r.status_code == 200 assert user.default_random_alias_domain() == "d1.test" + r = flask_client.patch( + "/api/setting", json={"random_alias_default_domain": custom_domain.domain} + ) + assert r.status_code == 200 + assert user.default_random_alias_domain() == custom_domain.domain + def test_update_settings_sender_format(flask_client): user = login(flask_client) diff --git a/tests/user_settings/__init__.py b/tests/user_settings/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/user_settings/test_set_default_alias_domain.py b/tests/user_settings/test_set_default_alias_domain.py new file mode 100644 index 00000000..043be863 --- /dev/null +++ b/tests/user_settings/test_set_default_alias_domain.py @@ -0,0 +1,128 @@ +import pytest + +from app import user_settings +from app.db import Session +from app.models import User, CustomDomain, SLDomain +from tests.utils import random_token, create_new_user + +user_id: int = 0 +custom_domain_name: str = "" +sl_domain_name: str = "" + + +def setup_module(): + global user_id, custom_domain_name, sl_domain_name + user = create_new_user() + user.trial_end = None + user_id = user.id + custom_domain_name = CustomDomain.create( + user_id=user_id, + catch_all=True, + domain=random_token() + ".com", + verified=True, + flush=True, + ).domain + sl_domain_name = SLDomain.create( + domain=random_token() + ".com", + premium_only=False, + flush=True, + order=5, + hidden=False, + ).domain + + +def test_set_default_no_domain(): + user = User.get(user_id) + user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id + user.default_alias_private_domain_id = CustomDomain.get_by( + domain=custom_domain_name + ).id + Session.flush() + user_settings.set_default_alias_id(user, None) + assert user.default_alias_public_domain_id is None + assert user.default_alias_custom_domain_id is None + + +def test_set_premium_sl_domain_with_non_premium_user(): + user = User.get(user_id) + user.lifetime = False + domain = SLDomain.get_by(domain=sl_domain_name) + domain.premium_only = True + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_id(user, sl_domain_name) + + +def test_set_hidden_sl_domain(): + user = User.get(user_id) + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = True + domain.premium_only = False + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_id(user, sl_domain_name) + + +def test_set_sl_domain(): + user = User.get(user_id) + user.lifetime = False + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = False + domain.premium_only = False + Session.flush() + user_settings.set_default_alias_id(user, sl_domain_name) + assert user.default_alias_public_domain_id == domain.id + assert user.default_alias_custom_domain_id is None + + +def test_set_sl_premium_domain(): + user = User.get(user_id) + user.lifetime = True + domain = SLDomain.get_by(domain=sl_domain_name) + domain.hidden = False + domain.premium_only = True + Session.flush() + user_settings.set_default_alias_id(user, sl_domain_name) + assert user.default_alias_public_domain_id == domain.id + assert user.default_alias_custom_domain_id is None + + +def test_set_other_user_custom_domain(): + user = User.get(user_id) + user.lifetime = True + other_user_domain_name = CustomDomain.create( + user_id=create_new_user().id, + catch_all=True, + domain=random_token() + ".com", + verified=True, + ).domain + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_id(user, other_user_domain_name) + + +def test_set_unverified_custom_domain(): + user = User.get(user_id) + user.lifetime = True + domain = CustomDomain.get_by(domain=custom_domain_name) + domain.verified = False + Session.flush() + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_id(user, custom_domain_name) + + +def test_set_custom_domain(): + user = User.get(user_id) + user.lifetime = True + domain = CustomDomain.get_by(domain=custom_domain_name) + domain.verified = True + Session.flush() + user_settings.set_default_alias_id(user, custom_domain_name) + assert user.default_alias_public_domain_id is None + assert user.default_alias_custom_domain_id == domain.id + + +def test_set_invalid_custom_domain(): + user = User.get(user_id) + with pytest.raises(user_settings.CannotSetAlias): + user_settings.set_default_alias_id(user, "invalid_nop" + random_token())