mirror of
https://github.com/simple-login/app.git
synced 2024-11-13 07:31:12 +01:00
Move set default domain for alias to an external function (#2158)
* Move set default alias to a separate method to reuse it * Add tests * Find domains by domain not by id * Revert models and setting changes * Remove non required function
This commit is contained in:
parent
71ce0f6253
commit
a7aec0c37a
7 changed files with 204 additions and 44 deletions
|
@ -64,8 +64,12 @@ def verify_prefix_suffix(
|
|||
# SimpleLogin domain case:
|
||||
# 1) alias_suffix must start with "." and
|
||||
# 2) alias_domain_prefix must come from the word list
|
||||
available_sl_domains = [
|
||||
sl_domain.domain
|
||||
for sl_domain in user.get_sl_domains(alias_options=alias_options)
|
||||
]
|
||||
if (
|
||||
alias_domain in user.available_sl_domains(alias_options=alias_options)
|
||||
alias_domain in available_sl_domains
|
||||
and alias_domain not in user_custom_domains
|
||||
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
||||
and not config.DISABLE_ALIAS_SUFFIX
|
||||
|
@ -80,9 +84,7 @@ def verify_prefix_suffix(
|
|||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||
return False
|
||||
|
||||
if alias_domain not in user.available_sl_domains(
|
||||
alias_options=alias_options
|
||||
):
|
||||
if alias_domain not in available_sl_domains:
|
||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||
return False
|
||||
|
||||
|
|
|
@ -14,7 +14,7 @@ from flask_wtf import FlaskForm
|
|||
from flask_wtf.file import FileField
|
||||
from wtforms import StringField, validators
|
||||
|
||||
from app import s3
|
||||
from app import s3, user_settings
|
||||
from app.config import (
|
||||
FIRST_ALIAS_DOMAIN,
|
||||
ALIAS_RANDOM_SUFFIX_LENGTH,
|
||||
|
@ -31,12 +31,10 @@ from app.models import (
|
|||
PlanEnum,
|
||||
File,
|
||||
EmailChange,
|
||||
CustomDomain,
|
||||
AliasGeneratorEnum,
|
||||
AliasSuffixEnum,
|
||||
ManualSubscription,
|
||||
SenderFormatEnum,
|
||||
SLDomain,
|
||||
CoinbaseSubscription,
|
||||
AppleSubscription,
|
||||
PartnerUser,
|
||||
|
@ -166,38 +164,11 @@ def setting():
|
|||
return redirect(url_for("dashboard.setting"))
|
||||
elif request.form.get("form-name") == "change-random-alias-default-domain":
|
||||
default_domain = request.form.get("random-alias-default-domain")
|
||||
|
||||
if default_domain:
|
||||
sl_domain: SLDomain = SLDomain.get_by(domain=default_domain)
|
||||
if sl_domain:
|
||||
if sl_domain.premium_only and not current_user.is_premium():
|
||||
flash("You cannot use this domain", "error")
|
||||
return redirect(url_for("dashboard.setting"))
|
||||
|
||||
current_user.default_alias_public_domain_id = sl_domain.id
|
||||
current_user.default_alias_custom_domain_id = None
|
||||
else:
|
||||
custom_domain = CustomDomain.get_by(domain=default_domain)
|
||||
if custom_domain:
|
||||
# sanity check
|
||||
if (
|
||||
custom_domain.user_id != current_user.id
|
||||
or not custom_domain.verified
|
||||
):
|
||||
LOG.w(
|
||||
"%s cannot use domain %s", current_user, custom_domain
|
||||
)
|
||||
flash(f"Domain {default_domain} can't be used", "error")
|
||||
return redirect(request.url)
|
||||
else:
|
||||
current_user.default_alias_custom_domain_id = (
|
||||
custom_domain.id
|
||||
)
|
||||
current_user.default_alias_public_domain_id = None
|
||||
|
||||
else:
|
||||
current_user.default_alias_custom_domain_id = None
|
||||
current_user.default_alias_public_domain_id = None
|
||||
try:
|
||||
user_settings.set_default_alias_id(current_user, default_domain)
|
||||
except user_settings.CannotSetAlias as e:
|
||||
flash(e.msg, "error")
|
||||
return redirect(url_for("dashboard.setting"))
|
||||
|
||||
Session.commit()
|
||||
flash("Your preference has been updated", "success")
|
||||
|
|
|
@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
- the domain
|
||||
"""
|
||||
res = []
|
||||
for domain in self.available_sl_domains(alias_options=alias_options):
|
||||
res.append((True, domain))
|
||||
for domain in self.get_sl_domains(alias_options=alias_options):
|
||||
res.append((True, domain.domain))
|
||||
|
||||
for custom_domain in self.verified_custom_domains():
|
||||
res.append((False, custom_domain.domain))
|
||||
|
@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||
- Verified custom domains
|
||||
|
||||
"""
|
||||
domains = self.available_sl_domains(alias_options=alias_options)
|
||||
domains = [
|
||||
sl_domain.domain
|
||||
for sl_domain in self.get_sl_domains(alias_options=alias_options)
|
||||
]
|
||||
|
||||
for custom_domain in self.verified_custom_domains():
|
||||
domains.append(custom_domain.domain)
|
||||
|
@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin):
|
|||
return sorted(self._auto_create_rules, key=lambda rule: rule.order)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<Custom Domain {self.domain}>"
|
||||
return f"<Custom Domain {self.id} {self.domain}>"
|
||||
|
||||
|
||||
class AutoCreateRule(Base, ModelMixin):
|
||||
|
@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin):
|
|||
)
|
||||
|
||||
def __repr__(self):
|
||||
return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}"
|
||||
return f"<SLDomain {self.id} {self.domain} {'Premium' if self.premium_only else 'Free'}>"
|
||||
|
||||
|
||||
class Monitoring(Base, ModelMixin):
|
||||
|
|
47
app/user_settings.py
Normal file
47
app/user_settings.py
Normal file
|
@ -0,0 +1,47 @@
|
|||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.log import LOG
|
||||
from app.models import User, SLDomain, CustomDomain
|
||||
|
||||
|
||||
class CannotSetAlias(Exception):
|
||||
def __init__(self, msg: str):
|
||||
self.msg = msg
|
||||
|
||||
|
||||
def set_default_alias_id(user: User, domain_name: Optional[str]):
|
||||
if domain_name is None:
|
||||
LOG.i(f"User {user} has set no domain as default domain")
|
||||
user.default_alias_public_domain_id = None
|
||||
user.default_alias_custom_domain_id = None
|
||||
Session.flush()
|
||||
return
|
||||
sl_domain: SLDomain = SLDomain.get_by(domain=domain_name)
|
||||
if sl_domain:
|
||||
if sl_domain.hidden:
|
||||
LOG.i(f"User {user} has tried to set up a hidden domain as default domain")
|
||||
raise CannotSetAlias("Domain does not exist")
|
||||
if sl_domain.premium_only and not user.is_premium():
|
||||
LOG.i(f"User {user} has tried to set up a premium domain as default domain")
|
||||
raise CannotSetAlias("You cannot use this domain")
|
||||
LOG.i(f"User {user} has set public {sl_domain} as default domain")
|
||||
user.default_alias_public_domain_id = sl_domain.id
|
||||
user.default_alias_custom_domain_id = None
|
||||
Session.flush()
|
||||
return
|
||||
custom_domain = CustomDomain.get_by(domain=domain_name)
|
||||
if not custom_domain:
|
||||
LOG.i(
|
||||
f"User {user} has tried to set up an non existing domain as default domain"
|
||||
)
|
||||
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||
if custom_domain.user_id != user.id or not custom_domain.verified:
|
||||
LOG.i(
|
||||
f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified"
|
||||
)
|
||||
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||
LOG.i(f"User {user} has set custom {custom_domain} as default domain")
|
||||
user.default_alias_public_domain_id = None
|
||||
user.default_alias_custom_domain_id = custom_domain.id
|
||||
Session.flush()
|
|
@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client):
|
|||
|
||||
def test_update_settings_random_alias_default_domain(flask_client):
|
||||
user = login(flask_client)
|
||||
custom_domain = CustomDomain.create(
|
||||
domain=random_domain(), verified=True, user_id=user.id, flush=True
|
||||
)
|
||||
assert user.default_random_alias_domain() == "sl.local"
|
||||
|
||||
r = flask_client.patch(
|
||||
|
@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client):
|
|||
assert r.status_code == 200
|
||||
assert user.default_random_alias_domain() == "d1.test"
|
||||
|
||||
r = flask_client.patch(
|
||||
"/api/setting", json={"random_alias_default_domain": custom_domain.domain}
|
||||
)
|
||||
assert r.status_code == 200
|
||||
assert user.default_random_alias_domain() == custom_domain.domain
|
||||
|
||||
|
||||
def test_update_settings_sender_format(flask_client):
|
||||
user = login(flask_client)
|
||||
|
|
0
tests/user_settings/__init__.py
Normal file
0
tests/user_settings/__init__.py
Normal file
128
tests/user_settings/test_set_default_alias_domain.py
Normal file
128
tests/user_settings/test_set_default_alias_domain.py
Normal file
|
@ -0,0 +1,128 @@
|
|||
import pytest
|
||||
|
||||
from app import user_settings
|
||||
from app.db import Session
|
||||
from app.models import User, CustomDomain, SLDomain
|
||||
from tests.utils import random_token, create_new_user
|
||||
|
||||
user_id: int = 0
|
||||
custom_domain_name: str = ""
|
||||
sl_domain_name: str = ""
|
||||
|
||||
|
||||
def setup_module():
|
||||
global user_id, custom_domain_name, sl_domain_name
|
||||
user = create_new_user()
|
||||
user.trial_end = None
|
||||
user_id = user.id
|
||||
custom_domain_name = CustomDomain.create(
|
||||
user_id=user_id,
|
||||
catch_all=True,
|
||||
domain=random_token() + ".com",
|
||||
verified=True,
|
||||
flush=True,
|
||||
).domain
|
||||
sl_domain_name = SLDomain.create(
|
||||
domain=random_token() + ".com",
|
||||
premium_only=False,
|
||||
flush=True,
|
||||
order=5,
|
||||
hidden=False,
|
||||
).domain
|
||||
|
||||
|
||||
def test_set_default_no_domain():
|
||||
user = User.get(user_id)
|
||||
user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id
|
||||
user.default_alias_private_domain_id = CustomDomain.get_by(
|
||||
domain=custom_domain_name
|
||||
).id
|
||||
Session.flush()
|
||||
user_settings.set_default_alias_id(user, None)
|
||||
assert user.default_alias_public_domain_id is None
|
||||
assert user.default_alias_custom_domain_id is None
|
||||
|
||||
|
||||
def test_set_premium_sl_domain_with_non_premium_user():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = False
|
||||
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||
domain.premium_only = True
|
||||
Session.flush()
|
||||
with pytest.raises(user_settings.CannotSetAlias):
|
||||
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||
|
||||
|
||||
def test_set_hidden_sl_domain():
|
||||
user = User.get(user_id)
|
||||
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||
domain.hidden = True
|
||||
domain.premium_only = False
|
||||
Session.flush()
|
||||
with pytest.raises(user_settings.CannotSetAlias):
|
||||
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||
|
||||
|
||||
def test_set_sl_domain():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = False
|
||||
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||
domain.hidden = False
|
||||
domain.premium_only = False
|
||||
Session.flush()
|
||||
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||
assert user.default_alias_public_domain_id == domain.id
|
||||
assert user.default_alias_custom_domain_id is None
|
||||
|
||||
|
||||
def test_set_sl_premium_domain():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = True
|
||||
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||
domain.hidden = False
|
||||
domain.premium_only = True
|
||||
Session.flush()
|
||||
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||
assert user.default_alias_public_domain_id == domain.id
|
||||
assert user.default_alias_custom_domain_id is None
|
||||
|
||||
|
||||
def test_set_other_user_custom_domain():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = True
|
||||
other_user_domain_name = CustomDomain.create(
|
||||
user_id=create_new_user().id,
|
||||
catch_all=True,
|
||||
domain=random_token() + ".com",
|
||||
verified=True,
|
||||
).domain
|
||||
Session.flush()
|
||||
with pytest.raises(user_settings.CannotSetAlias):
|
||||
user_settings.set_default_alias_id(user, other_user_domain_name)
|
||||
|
||||
|
||||
def test_set_unverified_custom_domain():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = True
|
||||
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||
domain.verified = False
|
||||
Session.flush()
|
||||
with pytest.raises(user_settings.CannotSetAlias):
|
||||
user_settings.set_default_alias_id(user, custom_domain_name)
|
||||
|
||||
|
||||
def test_set_custom_domain():
|
||||
user = User.get(user_id)
|
||||
user.lifetime = True
|
||||
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||
domain.verified = True
|
||||
Session.flush()
|
||||
user_settings.set_default_alias_id(user, custom_domain_name)
|
||||
assert user.default_alias_public_domain_id is None
|
||||
assert user.default_alias_custom_domain_id == domain.id
|
||||
|
||||
|
||||
def test_set_invalid_custom_domain():
|
||||
user = User.get(user_id)
|
||||
with pytest.raises(user_settings.CannotSetAlias):
|
||||
user_settings.set_default_alias_id(user, "invalid_nop" + random_token())
|
Loading…
Reference in a new issue