mirror of
https://github.com/simple-login/app.git
synced 2024-11-10 21:27:10 +01:00
Move set default domain for alias to an external function (#2158)
* Move set default alias to a separate method to reuse it * Add tests * Find domains by domain not by id * Revert models and setting changes * Remove non required function
This commit is contained in:
parent
71ce0f6253
commit
a7aec0c37a
@ -64,8 +64,12 @@ def verify_prefix_suffix(
|
|||||||
# SimpleLogin domain case:
|
# SimpleLogin domain case:
|
||||||
# 1) alias_suffix must start with "." and
|
# 1) alias_suffix must start with "." and
|
||||||
# 2) alias_domain_prefix must come from the word list
|
# 2) alias_domain_prefix must come from the word list
|
||||||
|
available_sl_domains = [
|
||||||
|
sl_domain.domain
|
||||||
|
for sl_domain in user.get_sl_domains(alias_options=alias_options)
|
||||||
|
]
|
||||||
if (
|
if (
|
||||||
alias_domain in user.available_sl_domains(alias_options=alias_options)
|
alias_domain in available_sl_domains
|
||||||
and alias_domain not in user_custom_domains
|
and alias_domain not in user_custom_domains
|
||||||
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
|
||||||
and not config.DISABLE_ALIAS_SUFFIX
|
and not config.DISABLE_ALIAS_SUFFIX
|
||||||
@ -80,9 +84,7 @@ def verify_prefix_suffix(
|
|||||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
if alias_domain not in user.available_sl_domains(
|
if alias_domain not in available_sl_domains:
|
||||||
alias_options=alias_options
|
|
||||||
):
|
|
||||||
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
|
||||||
return False
|
return False
|
||||||
|
|
||||||
|
@ -14,7 +14,7 @@ from flask_wtf import FlaskForm
|
|||||||
from flask_wtf.file import FileField
|
from flask_wtf.file import FileField
|
||||||
from wtforms import StringField, validators
|
from wtforms import StringField, validators
|
||||||
|
|
||||||
from app import s3
|
from app import s3, user_settings
|
||||||
from app.config import (
|
from app.config import (
|
||||||
FIRST_ALIAS_DOMAIN,
|
FIRST_ALIAS_DOMAIN,
|
||||||
ALIAS_RANDOM_SUFFIX_LENGTH,
|
ALIAS_RANDOM_SUFFIX_LENGTH,
|
||||||
@ -31,12 +31,10 @@ from app.models import (
|
|||||||
PlanEnum,
|
PlanEnum,
|
||||||
File,
|
File,
|
||||||
EmailChange,
|
EmailChange,
|
||||||
CustomDomain,
|
|
||||||
AliasGeneratorEnum,
|
AliasGeneratorEnum,
|
||||||
AliasSuffixEnum,
|
AliasSuffixEnum,
|
||||||
ManualSubscription,
|
ManualSubscription,
|
||||||
SenderFormatEnum,
|
SenderFormatEnum,
|
||||||
SLDomain,
|
|
||||||
CoinbaseSubscription,
|
CoinbaseSubscription,
|
||||||
AppleSubscription,
|
AppleSubscription,
|
||||||
PartnerUser,
|
PartnerUser,
|
||||||
@ -166,38 +164,11 @@ def setting():
|
|||||||
return redirect(url_for("dashboard.setting"))
|
return redirect(url_for("dashboard.setting"))
|
||||||
elif request.form.get("form-name") == "change-random-alias-default-domain":
|
elif request.form.get("form-name") == "change-random-alias-default-domain":
|
||||||
default_domain = request.form.get("random-alias-default-domain")
|
default_domain = request.form.get("random-alias-default-domain")
|
||||||
|
try:
|
||||||
if default_domain:
|
user_settings.set_default_alias_id(current_user, default_domain)
|
||||||
sl_domain: SLDomain = SLDomain.get_by(domain=default_domain)
|
except user_settings.CannotSetAlias as e:
|
||||||
if sl_domain:
|
flash(e.msg, "error")
|
||||||
if sl_domain.premium_only and not current_user.is_premium():
|
return redirect(url_for("dashboard.setting"))
|
||||||
flash("You cannot use this domain", "error")
|
|
||||||
return redirect(url_for("dashboard.setting"))
|
|
||||||
|
|
||||||
current_user.default_alias_public_domain_id = sl_domain.id
|
|
||||||
current_user.default_alias_custom_domain_id = None
|
|
||||||
else:
|
|
||||||
custom_domain = CustomDomain.get_by(domain=default_domain)
|
|
||||||
if custom_domain:
|
|
||||||
# sanity check
|
|
||||||
if (
|
|
||||||
custom_domain.user_id != current_user.id
|
|
||||||
or not custom_domain.verified
|
|
||||||
):
|
|
||||||
LOG.w(
|
|
||||||
"%s cannot use domain %s", current_user, custom_domain
|
|
||||||
)
|
|
||||||
flash(f"Domain {default_domain} can't be used", "error")
|
|
||||||
return redirect(request.url)
|
|
||||||
else:
|
|
||||||
current_user.default_alias_custom_domain_id = (
|
|
||||||
custom_domain.id
|
|
||||||
)
|
|
||||||
current_user.default_alias_public_domain_id = None
|
|
||||||
|
|
||||||
else:
|
|
||||||
current_user.default_alias_custom_domain_id = None
|
|
||||||
current_user.default_alias_public_domain_id = None
|
|
||||||
|
|
||||||
Session.commit()
|
Session.commit()
|
||||||
flash("Your preference has been updated", "success")
|
flash("Your preference has been updated", "success")
|
||||||
|
@ -985,8 +985,8 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||||||
- the domain
|
- the domain
|
||||||
"""
|
"""
|
||||||
res = []
|
res = []
|
||||||
for domain in self.available_sl_domains(alias_options=alias_options):
|
for domain in self.get_sl_domains(alias_options=alias_options):
|
||||||
res.append((True, domain))
|
res.append((True, domain.domain))
|
||||||
|
|
||||||
for custom_domain in self.verified_custom_domains():
|
for custom_domain in self.verified_custom_domains():
|
||||||
res.append((False, custom_domain.domain))
|
res.append((False, custom_domain.domain))
|
||||||
@ -1128,7 +1128,10 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
|||||||
- Verified custom domains
|
- Verified custom domains
|
||||||
|
|
||||||
"""
|
"""
|
||||||
domains = self.available_sl_domains(alias_options=alias_options)
|
domains = [
|
||||||
|
sl_domain.domain
|
||||||
|
for sl_domain in self.get_sl_domains(alias_options=alias_options)
|
||||||
|
]
|
||||||
|
|
||||||
for custom_domain in self.verified_custom_domains():
|
for custom_domain in self.verified_custom_domains():
|
||||||
domains.append(custom_domain.domain)
|
domains.append(custom_domain.domain)
|
||||||
@ -2483,7 +2486,7 @@ class CustomDomain(Base, ModelMixin):
|
|||||||
return sorted(self._auto_create_rules, key=lambda rule: rule.order)
|
return sorted(self._auto_create_rules, key=lambda rule: rule.order)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<Custom Domain {self.domain}>"
|
return f"<Custom Domain {self.id} {self.domain}>"
|
||||||
|
|
||||||
|
|
||||||
class AutoCreateRule(Base, ModelMixin):
|
class AutoCreateRule(Base, ModelMixin):
|
||||||
@ -3114,7 +3117,7 @@ class SLDomain(Base, ModelMixin):
|
|||||||
)
|
)
|
||||||
|
|
||||||
def __repr__(self):
|
def __repr__(self):
|
||||||
return f"<SLDomain {self.domain} {'Premium' if self.premium_only else 'Free'}"
|
return f"<SLDomain {self.id} {self.domain} {'Premium' if self.premium_only else 'Free'}>"
|
||||||
|
|
||||||
|
|
||||||
class Monitoring(Base, ModelMixin):
|
class Monitoring(Base, ModelMixin):
|
||||||
|
47
app/user_settings.py
Normal file
47
app/user_settings.py
Normal file
@ -0,0 +1,47 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.db import Session
|
||||||
|
from app.log import LOG
|
||||||
|
from app.models import User, SLDomain, CustomDomain
|
||||||
|
|
||||||
|
|
||||||
|
class CannotSetAlias(Exception):
|
||||||
|
def __init__(self, msg: str):
|
||||||
|
self.msg = msg
|
||||||
|
|
||||||
|
|
||||||
|
def set_default_alias_id(user: User, domain_name: Optional[str]):
|
||||||
|
if domain_name is None:
|
||||||
|
LOG.i(f"User {user} has set no domain as default domain")
|
||||||
|
user.default_alias_public_domain_id = None
|
||||||
|
user.default_alias_custom_domain_id = None
|
||||||
|
Session.flush()
|
||||||
|
return
|
||||||
|
sl_domain: SLDomain = SLDomain.get_by(domain=domain_name)
|
||||||
|
if sl_domain:
|
||||||
|
if sl_domain.hidden:
|
||||||
|
LOG.i(f"User {user} has tried to set up a hidden domain as default domain")
|
||||||
|
raise CannotSetAlias("Domain does not exist")
|
||||||
|
if sl_domain.premium_only and not user.is_premium():
|
||||||
|
LOG.i(f"User {user} has tried to set up a premium domain as default domain")
|
||||||
|
raise CannotSetAlias("You cannot use this domain")
|
||||||
|
LOG.i(f"User {user} has set public {sl_domain} as default domain")
|
||||||
|
user.default_alias_public_domain_id = sl_domain.id
|
||||||
|
user.default_alias_custom_domain_id = None
|
||||||
|
Session.flush()
|
||||||
|
return
|
||||||
|
custom_domain = CustomDomain.get_by(domain=domain_name)
|
||||||
|
if not custom_domain:
|
||||||
|
LOG.i(
|
||||||
|
f"User {user} has tried to set up an non existing domain as default domain"
|
||||||
|
)
|
||||||
|
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||||
|
if custom_domain.user_id != user.id or not custom_domain.verified:
|
||||||
|
LOG.i(
|
||||||
|
f"User {user} has tried to set domain {custom_domain} as default domain that does not belong to the user or that is not verified"
|
||||||
|
)
|
||||||
|
raise CannotSetAlias("Domain does not exist or it hasn't been verified")
|
||||||
|
LOG.i(f"User {user} has set custom {custom_domain} as default domain")
|
||||||
|
user.default_alias_public_domain_id = None
|
||||||
|
user.default_alias_custom_domain_id = custom_domain.id
|
||||||
|
Session.flush()
|
@ -44,6 +44,9 @@ def test_update_settings_alias_generator(flask_client):
|
|||||||
|
|
||||||
def test_update_settings_random_alias_default_domain(flask_client):
|
def test_update_settings_random_alias_default_domain(flask_client):
|
||||||
user = login(flask_client)
|
user = login(flask_client)
|
||||||
|
custom_domain = CustomDomain.create(
|
||||||
|
domain=random_domain(), verified=True, user_id=user.id, flush=True
|
||||||
|
)
|
||||||
assert user.default_random_alias_domain() == "sl.local"
|
assert user.default_random_alias_domain() == "sl.local"
|
||||||
|
|
||||||
r = flask_client.patch(
|
r = flask_client.patch(
|
||||||
@ -57,6 +60,12 @@ def test_update_settings_random_alias_default_domain(flask_client):
|
|||||||
assert r.status_code == 200
|
assert r.status_code == 200
|
||||||
assert user.default_random_alias_domain() == "d1.test"
|
assert user.default_random_alias_domain() == "d1.test"
|
||||||
|
|
||||||
|
r = flask_client.patch(
|
||||||
|
"/api/setting", json={"random_alias_default_domain": custom_domain.domain}
|
||||||
|
)
|
||||||
|
assert r.status_code == 200
|
||||||
|
assert user.default_random_alias_domain() == custom_domain.domain
|
||||||
|
|
||||||
|
|
||||||
def test_update_settings_sender_format(flask_client):
|
def test_update_settings_sender_format(flask_client):
|
||||||
user = login(flask_client)
|
user = login(flask_client)
|
||||||
|
0
tests/user_settings/__init__.py
Normal file
0
tests/user_settings/__init__.py
Normal file
128
tests/user_settings/test_set_default_alias_domain.py
Normal file
128
tests/user_settings/test_set_default_alias_domain.py
Normal file
@ -0,0 +1,128 @@
|
|||||||
|
import pytest
|
||||||
|
|
||||||
|
from app import user_settings
|
||||||
|
from app.db import Session
|
||||||
|
from app.models import User, CustomDomain, SLDomain
|
||||||
|
from tests.utils import random_token, create_new_user
|
||||||
|
|
||||||
|
user_id: int = 0
|
||||||
|
custom_domain_name: str = ""
|
||||||
|
sl_domain_name: str = ""
|
||||||
|
|
||||||
|
|
||||||
|
def setup_module():
|
||||||
|
global user_id, custom_domain_name, sl_domain_name
|
||||||
|
user = create_new_user()
|
||||||
|
user.trial_end = None
|
||||||
|
user_id = user.id
|
||||||
|
custom_domain_name = CustomDomain.create(
|
||||||
|
user_id=user_id,
|
||||||
|
catch_all=True,
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
verified=True,
|
||||||
|
flush=True,
|
||||||
|
).domain
|
||||||
|
sl_domain_name = SLDomain.create(
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
premium_only=False,
|
||||||
|
flush=True,
|
||||||
|
order=5,
|
||||||
|
hidden=False,
|
||||||
|
).domain
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_default_no_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.default_alias_public_domain_id = SLDomain.get_by(domain=sl_domain_name).id
|
||||||
|
user.default_alias_private_domain_id = CustomDomain.get_by(
|
||||||
|
domain=custom_domain_name
|
||||||
|
).id
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_id(user, None)
|
||||||
|
assert user.default_alias_public_domain_id is None
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_premium_sl_domain_with_non_premium_user():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = False
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.premium_only = True
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_hidden_sl_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = True
|
||||||
|
domain.premium_only = False
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_sl_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = False
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = False
|
||||||
|
domain.premium_only = False
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id == domain.id
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_sl_premium_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = SLDomain.get_by(domain=sl_domain_name)
|
||||||
|
domain.hidden = False
|
||||||
|
domain.premium_only = True
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_id(user, sl_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id == domain.id
|
||||||
|
assert user.default_alias_custom_domain_id is None
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_other_user_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
other_user_domain_name = CustomDomain.create(
|
||||||
|
user_id=create_new_user().id,
|
||||||
|
catch_all=True,
|
||||||
|
domain=random_token() + ".com",
|
||||||
|
verified=True,
|
||||||
|
).domain
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_id(user, other_user_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_unverified_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||||
|
domain.verified = False
|
||||||
|
Session.flush()
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_id(user, custom_domain_name)
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
user.lifetime = True
|
||||||
|
domain = CustomDomain.get_by(domain=custom_domain_name)
|
||||||
|
domain.verified = True
|
||||||
|
Session.flush()
|
||||||
|
user_settings.set_default_alias_id(user, custom_domain_name)
|
||||||
|
assert user.default_alias_public_domain_id is None
|
||||||
|
assert user.default_alias_custom_domain_id == domain.id
|
||||||
|
|
||||||
|
|
||||||
|
def test_set_invalid_custom_domain():
|
||||||
|
user = User.get(user_id)
|
||||||
|
with pytest.raises(user_settings.CannotSetAlias):
|
||||||
|
user_settings.set_default_alias_id(user, "invalid_nop" + random_token())
|
Loading…
Reference in New Issue
Block a user