mirror of
https://github.com/simple-login/app.git
synced 2024-09-21 01:11:29 +02:00
chore: refactor create custom domain (#2221)
* fix: scripts/new-migration to use poetry again * chore: add migration to add custom_domain.partner_id * chore: refactor create_custom_domain * chore: allow to specify partner_id to custom_domain * refactor: can_use_domain return cause * refactor: remove intermediate result class
This commit is contained in:
parent
647c569f99
commit
065cc3db92
@ -0,0 +1,128 @@
|
||||
import re
|
||||
|
||||
from dataclasses import dataclass
|
||||
from enum import Enum
|
||||
from typing import Optional
|
||||
|
||||
from app.db import Session
|
||||
from app.email_utils import get_email_domain_part
|
||||
from app.log import LOG
|
||||
from app.models import User, CustomDomain, SLDomain, Mailbox
|
||||
|
||||
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
|
||||
|
||||
|
||||
@dataclass
|
||||
class CreateCustomDomainResult:
|
||||
message: str = ""
|
||||
message_category: str = ""
|
||||
success: bool = False
|
||||
instance: Optional[CustomDomain] = None
|
||||
redirect: Optional[str] = None
|
||||
|
||||
|
||||
class CannotUseDomainReason(Enum):
|
||||
InvalidDomain = 1
|
||||
BuiltinDomain = 2
|
||||
DomainAlreadyUsed = 3
|
||||
DomainPartOfUserEmail = 4
|
||||
DomainUserInMailbox = 5
|
||||
|
||||
def message(self, domain: str) -> str:
|
||||
if self == CannotUseDomainReason.InvalidDomain:
|
||||
return "This is not a valid domain"
|
||||
elif self == CannotUseDomainReason.BuiltinDomain:
|
||||
return "A custom domain cannot be a built-in domain."
|
||||
elif self == CannotUseDomainReason.DomainAlreadyUsed:
|
||||
return f"{domain} already used"
|
||||
elif self == CannotUseDomainReason.DomainPartOfUserEmail:
|
||||
return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email"
|
||||
elif self == CannotUseDomainReason.DomainUserInMailbox:
|
||||
return f"{domain} already used in a SimpleLogin mailbox"
|
||||
else:
|
||||
raise Exception("Invalid CannotUseDomainReason")
|
||||
|
||||
|
||||
def is_valid_domain(domain: str) -> bool:
|
||||
"""
|
||||
Checks that a domain is valid according to RFC 1035
|
||||
"""
|
||||
if len(domain) > 255:
|
||||
return False
|
||||
if domain.endswith("."):
|
||||
domain = domain[:-1] # Strip the trailing dot
|
||||
labels = domain.split(".")
|
||||
if not labels:
|
||||
return False
|
||||
for label in labels:
|
||||
if not _ALLOWED_DOMAIN_REGEX.match(label):
|
||||
return False
|
||||
return True
|
||||
|
||||
|
||||
def sanitize_domain(domain: str) -> str:
|
||||
new_domain = domain.lower().strip()
|
||||
if new_domain.startswith("http://"):
|
||||
new_domain = new_domain[len("http://") :]
|
||||
|
||||
if new_domain.startswith("https://"):
|
||||
new_domain = new_domain[len("https://") :]
|
||||
|
||||
return new_domain
|
||||
|
||||
|
||||
def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]:
|
||||
if not is_valid_domain(domain):
|
||||
return CannotUseDomainReason.InvalidDomain
|
||||
elif SLDomain.get_by(domain=domain):
|
||||
return CannotUseDomainReason.BuiltinDomain
|
||||
elif CustomDomain.get_by(domain=domain):
|
||||
return CannotUseDomainReason.DomainAlreadyUsed
|
||||
elif get_email_domain_part(user.email) == domain:
|
||||
return CannotUseDomainReason.DomainPartOfUserEmail
|
||||
elif Mailbox.filter(
|
||||
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}")
|
||||
).first():
|
||||
return CannotUseDomainReason.DomainUserInMailbox
|
||||
else:
|
||||
return None
|
||||
|
||||
|
||||
def create_custom_domain(
|
||||
user: User, domain: str, partner_id: Optional[int] = None
|
||||
) -> CreateCustomDomainResult:
|
||||
if not user.is_premium():
|
||||
return CreateCustomDomainResult(
|
||||
message="Only premium plan can add custom domain",
|
||||
message_category="warning",
|
||||
)
|
||||
|
||||
new_domain = sanitize_domain(domain)
|
||||
domain_forbidden_cause = can_domain_be_used(user, new_domain)
|
||||
if domain_forbidden_cause:
|
||||
return CreateCustomDomainResult(
|
||||
message=domain_forbidden_cause.message(new_domain), message_category="error"
|
||||
)
|
||||
|
||||
new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id)
|
||||
|
||||
# new domain has ownership verified if its parent has the ownership verified
|
||||
for root_cd in user.custom_domains:
|
||||
if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified:
|
||||
LOG.i(
|
||||
"%s ownership verified thanks to %s",
|
||||
new_custom_domain,
|
||||
root_cd,
|
||||
)
|
||||
new_custom_domain.ownership_verified = True
|
||||
|
||||
# Add the partner_id in case it's passed
|
||||
if partner_id is not None:
|
||||
new_custom_domain.partner_id = partner_id
|
||||
|
||||
Session.commit()
|
||||
|
||||
return CreateCustomDomainResult(
|
||||
success=True,
|
||||
instance=new_custom_domain,
|
||||
)
|
@ -5,11 +5,9 @@ from wtforms import StringField, validators
|
||||
|
||||
from app import parallel_limiter
|
||||
from app.config import EMAIL_SERVERS_WITH_PRIORITY
|
||||
from app.custom_domain_utils import create_custom_domain
|
||||
from app.dashboard.base import dashboard_bp
|
||||
from app.db import Session
|
||||
from app.email_utils import get_email_domain_part
|
||||
from app.log import LOG
|
||||
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
|
||||
from app.models import CustomDomain
|
||||
|
||||
|
||||
class NewCustomDomainForm(FlaskForm):
|
||||
@ -25,11 +23,8 @@ def custom_domain():
|
||||
custom_domains = CustomDomain.filter_by(
|
||||
user_id=current_user.id, is_sl_subdomain=False
|
||||
).all()
|
||||
mailboxes = current_user.mailboxes()
|
||||
new_custom_domain_form = NewCustomDomainForm()
|
||||
|
||||
errors = {}
|
||||
|
||||
if request.method == "POST":
|
||||
if request.form.get("form-name") == "create":
|
||||
if not current_user.is_premium():
|
||||
@ -37,87 +32,25 @@ def custom_domain():
|
||||
return redirect(url_for("dashboard.custom_domain"))
|
||||
|
||||
if new_custom_domain_form.validate():
|
||||
new_domain = new_custom_domain_form.domain.data.lower().strip()
|
||||
|
||||
if new_domain.startswith("http://"):
|
||||
new_domain = new_domain[len("http://") :]
|
||||
|
||||
if new_domain.startswith("https://"):
|
||||
new_domain = new_domain[len("https://") :]
|
||||
|
||||
if SLDomain.get_by(domain=new_domain):
|
||||
flash("A custom domain cannot be a built-in domain.", "error")
|
||||
elif CustomDomain.get_by(domain=new_domain):
|
||||
flash(f"{new_domain} already used", "error")
|
||||
elif get_email_domain_part(current_user.email) == new_domain:
|
||||
flash(
|
||||
"You cannot add a domain that you are currently using for your personal email. "
|
||||
"Please change your personal email to your real email",
|
||||
"error",
|
||||
res = create_custom_domain(
|
||||
user=current_user, domain=new_custom_domain_form.domain.data
|
||||
)
|
||||
elif Mailbox.filter(
|
||||
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}")
|
||||
).first():
|
||||
flash(
|
||||
f"{new_domain} already used in a SimpleLogin mailbox", "error"
|
||||
)
|
||||
else:
|
||||
new_custom_domain = CustomDomain.create(
|
||||
domain=new_domain, user_id=current_user.id
|
||||
)
|
||||
# new domain has ownership verified if its parent has the ownership verified
|
||||
for root_cd in current_user.custom_domains:
|
||||
if (
|
||||
new_domain.endswith("." + root_cd.domain)
|
||||
and root_cd.ownership_verified
|
||||
):
|
||||
LOG.i(
|
||||
"%s ownership verified thanks to %s",
|
||||
new_custom_domain,
|
||||
root_cd,
|
||||
)
|
||||
new_custom_domain.ownership_verified = True
|
||||
|
||||
Session.commit()
|
||||
|
||||
mailbox_ids = request.form.getlist("mailbox_ids")
|
||||
if mailbox_ids:
|
||||
# check if mailbox is not tempered with
|
||||
mailboxes = []
|
||||
for mailbox_id in mailbox_ids:
|
||||
mailbox = Mailbox.get(mailbox_id)
|
||||
if (
|
||||
not mailbox
|
||||
or mailbox.user_id != current_user.id
|
||||
or not mailbox.verified
|
||||
):
|
||||
flash("Something went wrong, please retry", "warning")
|
||||
return redirect(url_for("dashboard.custom_domain"))
|
||||
mailboxes.append(mailbox)
|
||||
|
||||
for mailbox in mailboxes:
|
||||
DomainMailbox.create(
|
||||
domain_id=new_custom_domain.id, mailbox_id=mailbox.id
|
||||
)
|
||||
|
||||
Session.commit()
|
||||
|
||||
flash(
|
||||
f"New domain {new_custom_domain.domain} is created", "success"
|
||||
)
|
||||
|
||||
if res.success:
|
||||
flash(f"New domain {res.instance.domain} is created", "success")
|
||||
return redirect(
|
||||
url_for(
|
||||
"dashboard.domain_detail_dns",
|
||||
custom_domain_id=new_custom_domain.id,
|
||||
custom_domain_id=res.instance.id,
|
||||
)
|
||||
)
|
||||
else:
|
||||
flash(res.message, res.message_category)
|
||||
if res.redirect:
|
||||
return redirect(url_for(res.redirect))
|
||||
|
||||
return render_template(
|
||||
"dashboard/custom_domain.html",
|
||||
custom_domains=custom_domains,
|
||||
new_custom_domain_form=new_custom_domain_form,
|
||||
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
|
||||
errors=errors,
|
||||
mailboxes=mailboxes,
|
||||
)
|
||||
|
@ -973,7 +973,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
|
||||
def has_custom_domain(self):
|
||||
return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0
|
||||
|
||||
def custom_domains(self):
|
||||
def custom_domains(self) -> List["CustomDomain"]:
|
||||
return CustomDomain.filter_by(user_id=self.id, verified=True).all()
|
||||
|
||||
def available_domains_for_random_alias(
|
||||
@ -2419,6 +2419,14 @@ class CustomDomain(Base, ModelMixin):
|
||||
sa.Boolean, nullable=False, default=False, server_default="0"
|
||||
)
|
||||
|
||||
partner_id = sa.Column(
|
||||
sa.Integer,
|
||||
sa.ForeignKey("partner.id"),
|
||||
nullable=True,
|
||||
default=None,
|
||||
server_default=None,
|
||||
)
|
||||
|
||||
__table_args__ = (
|
||||
Index(
|
||||
"ix_unique_domain", # Index name
|
||||
|
@ -0,0 +1,30 @@
|
||||
"""Custom Domain partner id
|
||||
|
||||
Revision ID: 2441b7ff5da9
|
||||
Revises: 1c14339aae90
|
||||
Create Date: 2024-09-13 15:43:02.425964
|
||||
|
||||
"""
|
||||
from alembic import op
|
||||
import sqlalchemy as sa
|
||||
|
||||
|
||||
# revision identifiers, used by Alembic.
|
||||
revision = '2441b7ff5da9'
|
||||
down_revision = '1c14339aae90'
|
||||
branch_labels = None
|
||||
depends_on = None
|
||||
|
||||
|
||||
def upgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None))
|
||||
op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id'])
|
||||
# ### end Alembic commands ###
|
||||
|
||||
|
||||
def downgrade():
|
||||
# ### commands auto generated by Alembic - please adjust! ###
|
||||
op.drop_constraint(None, 'custom_domain', type_='foreignkey')
|
||||
op.drop_column('custom_domain', 'partner_id')
|
||||
# ### end Alembic commands ###
|
@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres
|
||||
sleep 3
|
||||
|
||||
# upgrade the DB to the latest stage and
|
||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head
|
||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head
|
||||
|
||||
# generate the migration script.
|
||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@
|
||||
env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@
|
||||
|
||||
# remove the db
|
||||
docker rm -f ${container_name}
|
||||
|
@ -94,4 +94,4 @@
|
||||
</div>
|
||||
</div>
|
||||
{% endblock %}
|
||||
{% block script %}<script>$('.mailbox-select').multipleSelect();</script>{% endblock %}
|
||||
|
||||
|
149
tests/test_custom_domain_utils.py
Normal file
149
tests/test_custom_domain_utils.py
Normal file
@ -0,0 +1,149 @@
|
||||
from typing import Optional
|
||||
|
||||
from app import config
|
||||
from app.config import ALIAS_DOMAINS
|
||||
from app.custom_domain_utils import (
|
||||
can_domain_be_used,
|
||||
create_custom_domain,
|
||||
is_valid_domain,
|
||||
sanitize_domain,
|
||||
CannotUseDomainReason,
|
||||
)
|
||||
from app.db import Session
|
||||
from app.models import User, CustomDomain, Mailbox
|
||||
from tests.utils import get_proton_partner
|
||||
from tests.utils import create_new_user, random_string, random_domain
|
||||
|
||||
user: Optional[User] = None
|
||||
|
||||
|
||||
def setup_module():
|
||||
global user
|
||||
config.SKIP_MX_LOOKUP_ON_CHECK = True
|
||||
user = create_new_user()
|
||||
user.trial_end = None
|
||||
user.lifetime = True
|
||||
Session.commit()
|
||||
|
||||
|
||||
# is_valid_domain
|
||||
def test_is_valid_domain():
|
||||
assert is_valid_domain("example.com") is True
|
||||
assert is_valid_domain("sub.example.com") is True
|
||||
assert is_valid_domain("ex-ample.com") is True
|
||||
|
||||
assert is_valid_domain("-example.com") is False
|
||||
assert is_valid_domain("example-.com") is False
|
||||
assert is_valid_domain("exa_mple.com") is False
|
||||
assert is_valid_domain("example..com") is False
|
||||
assert is_valid_domain("") is False
|
||||
assert is_valid_domain("a" * 64 + ".com") is False
|
||||
assert is_valid_domain("a" * 63 + ".com") is True
|
||||
assert is_valid_domain("example.com.") is True
|
||||
assert is_valid_domain(".example.com") is False
|
||||
assert is_valid_domain("example..com") is False
|
||||
assert is_valid_domain("example.com-") is False
|
||||
|
||||
|
||||
# can_domain_be_used
|
||||
def test_can_domain_be_used():
|
||||
domain = f"{random_string(10)}.com"
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is None
|
||||
|
||||
|
||||
def test_can_domain_be_used_existing_domain():
|
||||
domain = random_domain()
|
||||
CustomDomain.create(user_id=user.id, domain=domain, commit=True)
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is CannotUseDomainReason.DomainAlreadyUsed
|
||||
|
||||
|
||||
def test_can_domain_be_used_sl_domain():
|
||||
domain = ALIAS_DOMAINS[0]
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is CannotUseDomainReason.BuiltinDomain
|
||||
|
||||
|
||||
def test_can_domain_be_used_domain_of_user_email():
|
||||
domain = user.email.split("@")[1]
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is CannotUseDomainReason.DomainPartOfUserEmail
|
||||
|
||||
|
||||
def test_can_domain_be_used_domain_of_existing_mailbox():
|
||||
domain = random_domain()
|
||||
Mailbox.create(user_id=user.id, email=f"email@{domain}", verified=True, commit=True)
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is CannotUseDomainReason.DomainUserInMailbox
|
||||
|
||||
|
||||
def test_can_domain_be_used_invalid_domain():
|
||||
domain = f"{random_string(10)}@lol.com"
|
||||
res = can_domain_be_used(user, domain)
|
||||
assert res is CannotUseDomainReason.InvalidDomain
|
||||
|
||||
|
||||
# sanitize_domain
|
||||
def test_can_sanitize_domain_empty():
|
||||
assert sanitize_domain("") == ""
|
||||
|
||||
|
||||
def test_can_sanitize_domain_starting_with_http():
|
||||
domain = "test.domain"
|
||||
assert sanitize_domain(f"http://{domain}") == domain
|
||||
|
||||
|
||||
def test_can_sanitize_domain_starting_with_https():
|
||||
domain = "test.domain"
|
||||
assert sanitize_domain(f"https://{domain}") == domain
|
||||
|
||||
|
||||
def test_can_sanitize_domain_correct_domain():
|
||||
domain = "test.domain"
|
||||
assert sanitize_domain(domain) == domain
|
||||
|
||||
|
||||
# create_custom_domain
|
||||
def test_can_create_custom_domain():
|
||||
domain = random_domain()
|
||||
res = create_custom_domain(user=user, domain=domain)
|
||||
assert res.success is True
|
||||
assert res.redirect is None
|
||||
assert res.message == ""
|
||||
assert res.message_category == ""
|
||||
assert res.instance is not None
|
||||
|
||||
assert res.instance.domain == domain
|
||||
assert res.instance.user_id == user.id
|
||||
|
||||
|
||||
def test_can_create_custom_domain_validates_if_parent_is_validated():
|
||||
root_domain = random_domain()
|
||||
subdomain = f"{random_string(10)}.{root_domain}"
|
||||
|
||||
# Create custom domain with the root domain
|
||||
CustomDomain.create(
|
||||
user_id=user.id,
|
||||
domain=root_domain,
|
||||
verified=True,
|
||||
ownership_verified=True,
|
||||
commit=True,
|
||||
)
|
||||
|
||||
# Create custom domain with subdomain. Should automatically be verified
|
||||
res = create_custom_domain(user=user, domain=subdomain)
|
||||
assert res.success is True
|
||||
assert res.instance.domain == subdomain
|
||||
assert res.instance.user_id == user.id
|
||||
assert res.instance.ownership_verified is True
|
||||
|
||||
|
||||
def test_creates_custom_domain_with_partner_id():
|
||||
domain = random_domain()
|
||||
proton_partner = get_proton_partner()
|
||||
res = create_custom_domain(user=user, domain=domain, partner_id=proton_partner.id)
|
||||
assert res.success is True
|
||||
assert res.instance.domain == domain
|
||||
assert res.instance.user_id == user.id
|
||||
assert res.instance.partner_id == proton_partner.id
|
Loading…
Reference in New Issue
Block a user