From 065cc3db9200a74f4896c95c7b02cc65c0a29187 Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Tue, 17 Sep 2024 10:30:55 +0200 Subject: [PATCH] chore: refactor create custom domain (#2221) * fix: scripts/new-migration to use poetry again * chore: add migration to add custom_domain.partner_id * chore: refactor create_custom_domain * chore: allow to specify partner_id to custom_domain * refactor: can_use_domain return cause * refactor: remove intermediate result class --- app/custom_domain_utils.py | 128 +++++++++++++++ app/dashboard/views/custom_domain.py | 91 ++--------- app/models.py | 10 +- ...5_2441b7ff5da9_custom_domain_partner_id.py | 30 ++++ scripts/new-migration.sh | 4 +- templates/dashboard/custom_domain.html | 2 +- tests/test_custom_domain_utils.py | 149 ++++++++++++++++++ 7 files changed, 331 insertions(+), 83 deletions(-) create mode 100644 migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py create mode 100644 tests/test_custom_domain_utils.py diff --git a/app/custom_domain_utils.py b/app/custom_domain_utils.py index e69de29b..9275622d 100644 --- a/app/custom_domain_utils.py +++ b/app/custom_domain_utils.py @@ -0,0 +1,128 @@ +import re + +from dataclasses import dataclass +from enum import Enum +from typing import Optional + +from app.db import Session +from app.email_utils import get_email_domain_part +from app.log import LOG +from app.models import User, CustomDomain, SLDomain, Mailbox + +_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(? str: + if self == CannotUseDomainReason.InvalidDomain: + return "This is not a valid domain" + elif self == CannotUseDomainReason.BuiltinDomain: + return "A custom domain cannot be a built-in domain." + elif self == CannotUseDomainReason.DomainAlreadyUsed: + return f"{domain} already used" + elif self == CannotUseDomainReason.DomainPartOfUserEmail: + return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email" + elif self == CannotUseDomainReason.DomainUserInMailbox: + return f"{domain} already used in a SimpleLogin mailbox" + else: + raise Exception("Invalid CannotUseDomainReason") + + +def is_valid_domain(domain: str) -> bool: + """ + Checks that a domain is valid according to RFC 1035 + """ + if len(domain) > 255: + return False + if domain.endswith("."): + domain = domain[:-1] # Strip the trailing dot + labels = domain.split(".") + if not labels: + return False + for label in labels: + if not _ALLOWED_DOMAIN_REGEX.match(label): + return False + return True + + +def sanitize_domain(domain: str) -> str: + new_domain = domain.lower().strip() + if new_domain.startswith("http://"): + new_domain = new_domain[len("http://") :] + + if new_domain.startswith("https://"): + new_domain = new_domain[len("https://") :] + + return new_domain + + +def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]: + if not is_valid_domain(domain): + return CannotUseDomainReason.InvalidDomain + elif SLDomain.get_by(domain=domain): + return CannotUseDomainReason.BuiltinDomain + elif CustomDomain.get_by(domain=domain): + return CannotUseDomainReason.DomainAlreadyUsed + elif get_email_domain_part(user.email) == domain: + return CannotUseDomainReason.DomainPartOfUserEmail + elif Mailbox.filter( + Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}") + ).first(): + return CannotUseDomainReason.DomainUserInMailbox + else: + return None + + +def create_custom_domain( + user: User, domain: str, partner_id: Optional[int] = None +) -> CreateCustomDomainResult: + if not user.is_premium(): + return CreateCustomDomainResult( + message="Only premium plan can add custom domain", + message_category="warning", + ) + + new_domain = sanitize_domain(domain) + domain_forbidden_cause = can_domain_be_used(user, new_domain) + if domain_forbidden_cause: + return CreateCustomDomainResult( + message=domain_forbidden_cause.message(new_domain), message_category="error" + ) + + new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id) + + # new domain has ownership verified if its parent has the ownership verified + for root_cd in user.custom_domains: + if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified: + LOG.i( + "%s ownership verified thanks to %s", + new_custom_domain, + root_cd, + ) + new_custom_domain.ownership_verified = True + + # Add the partner_id in case it's passed + if partner_id is not None: + new_custom_domain.partner_id = partner_id + + Session.commit() + + return CreateCustomDomainResult( + success=True, + instance=new_custom_domain, + ) diff --git a/app/dashboard/views/custom_domain.py b/app/dashboard/views/custom_domain.py index 875b5dbf..b410b306 100644 --- a/app/dashboard/views/custom_domain.py +++ b/app/dashboard/views/custom_domain.py @@ -5,11 +5,9 @@ from wtforms import StringField, validators from app import parallel_limiter from app.config import EMAIL_SERVERS_WITH_PRIORITY +from app.custom_domain_utils import create_custom_domain from app.dashboard.base import dashboard_bp -from app.db import Session -from app.email_utils import get_email_domain_part -from app.log import LOG -from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain +from app.models import CustomDomain class NewCustomDomainForm(FlaskForm): @@ -25,11 +23,8 @@ def custom_domain(): custom_domains = CustomDomain.filter_by( user_id=current_user.id, is_sl_subdomain=False ).all() - mailboxes = current_user.mailboxes() new_custom_domain_form = NewCustomDomainForm() - errors = {} - if request.method == "POST": if request.form.get("form-name") == "create": if not current_user.is_premium(): @@ -37,87 +32,25 @@ def custom_domain(): return redirect(url_for("dashboard.custom_domain")) if new_custom_domain_form.validate(): - new_domain = new_custom_domain_form.domain.data.lower().strip() - - if new_domain.startswith("http://"): - new_domain = new_domain[len("http://") :] - - if new_domain.startswith("https://"): - new_domain = new_domain[len("https://") :] - - if SLDomain.get_by(domain=new_domain): - flash("A custom domain cannot be a built-in domain.", "error") - elif CustomDomain.get_by(domain=new_domain): - flash(f"{new_domain} already used", "error") - elif get_email_domain_part(current_user.email) == new_domain: - flash( - "You cannot add a domain that you are currently using for your personal email. " - "Please change your personal email to your real email", - "error", - ) - elif Mailbox.filter( - Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}") - ).first(): - flash( - f"{new_domain} already used in a SimpleLogin mailbox", "error" - ) - else: - new_custom_domain = CustomDomain.create( - domain=new_domain, user_id=current_user.id - ) - # new domain has ownership verified if its parent has the ownership verified - for root_cd in current_user.custom_domains: - if ( - new_domain.endswith("." + root_cd.domain) - and root_cd.ownership_verified - ): - LOG.i( - "%s ownership verified thanks to %s", - new_custom_domain, - root_cd, - ) - new_custom_domain.ownership_verified = True - - Session.commit() - - mailbox_ids = request.form.getlist("mailbox_ids") - if mailbox_ids: - # check if mailbox is not tempered with - mailboxes = [] - for mailbox_id in mailbox_ids: - mailbox = Mailbox.get(mailbox_id) - if ( - not mailbox - or mailbox.user_id != current_user.id - or not mailbox.verified - ): - flash("Something went wrong, please retry", "warning") - return redirect(url_for("dashboard.custom_domain")) - mailboxes.append(mailbox) - - for mailbox in mailboxes: - DomainMailbox.create( - domain_id=new_custom_domain.id, mailbox_id=mailbox.id - ) - - Session.commit() - - flash( - f"New domain {new_custom_domain.domain} is created", "success" - ) - + res = create_custom_domain( + user=current_user, domain=new_custom_domain_form.domain.data + ) + if res.success: + flash(f"New domain {res.instance.domain} is created", "success") return redirect( url_for( "dashboard.domain_detail_dns", - custom_domain_id=new_custom_domain.id, + custom_domain_id=res.instance.id, ) ) + else: + flash(res.message, res.message_category) + if res.redirect: + return redirect(url_for(res.redirect)) return render_template( "dashboard/custom_domain.html", custom_domains=custom_domains, new_custom_domain_form=new_custom_domain_form, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, - errors=errors, - mailboxes=mailboxes, ) diff --git a/app/models.py b/app/models.py index f2a3f9fc..f26cac29 100644 --- a/app/models.py +++ b/app/models.py @@ -973,7 +973,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): def has_custom_domain(self): return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0 - def custom_domains(self): + def custom_domains(self) -> List["CustomDomain"]: return CustomDomain.filter_by(user_id=self.id, verified=True).all() def available_domains_for_random_alias( @@ -2419,6 +2419,14 @@ class CustomDomain(Base, ModelMixin): sa.Boolean, nullable=False, default=False, server_default="0" ) + partner_id = sa.Column( + sa.Integer, + sa.ForeignKey("partner.id"), + nullable=True, + default=None, + server_default=None, + ) + __table_args__ = ( Index( "ix_unique_domain", # Index name diff --git a/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py b/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py new file mode 100644 index 00000000..ba33bea0 --- /dev/null +++ b/migrations/versions/2024_091315_2441b7ff5da9_custom_domain_partner_id.py @@ -0,0 +1,30 @@ +"""Custom Domain partner id + +Revision ID: 2441b7ff5da9 +Revises: 1c14339aae90 +Create Date: 2024-09-13 15:43:02.425964 + +""" +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '2441b7ff5da9' +down_revision = '1c14339aae90' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('custom_domain', sa.Column('partner_id', sa.Integer(), nullable=True, default=None, server_default=None)) + op.create_foreign_key(None, 'custom_domain', 'partner', ['partner_id'], ['id']) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'custom_domain', type_='foreignkey') + op.drop_column('custom_domain', 'partner_id') + # ### end Alembic commands ### diff --git a/scripts/new-migration.sh b/scripts/new-migration.sh index 54568579..da11a756 100755 --- a/scripts/new-migration.sh +++ b/scripts/new-migration.sh @@ -12,10 +12,10 @@ docker run -p 25432:5432 --name ${container_name} -e POSTGRES_PASSWORD=postgres sleep 3 # upgrade the DB to the latest stage and -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic upgrade head +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic upgrade head # generate the migration script. -env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl rye run alembic revision --autogenerate $@ +env DB_URI=postgresql://postgres:postgres@127.0.0.1:25432/sl poetry run alembic revision --autogenerate $@ # remove the db docker rm -f ${container_name} diff --git a/templates/dashboard/custom_domain.html b/templates/dashboard/custom_domain.html index 62a12b0e..2982867d 100644 --- a/templates/dashboard/custom_domain.html +++ b/templates/dashboard/custom_domain.html @@ -94,4 +94,4 @@ {% endblock %} -{% block script %}{% endblock %} + diff --git a/tests/test_custom_domain_utils.py b/tests/test_custom_domain_utils.py new file mode 100644 index 00000000..af6ab70e --- /dev/null +++ b/tests/test_custom_domain_utils.py @@ -0,0 +1,149 @@ +from typing import Optional + +from app import config +from app.config import ALIAS_DOMAINS +from app.custom_domain_utils import ( + can_domain_be_used, + create_custom_domain, + is_valid_domain, + sanitize_domain, + CannotUseDomainReason, +) +from app.db import Session +from app.models import User, CustomDomain, Mailbox +from tests.utils import get_proton_partner +from tests.utils import create_new_user, random_string, random_domain + +user: Optional[User] = None + + +def setup_module(): + global user + config.SKIP_MX_LOOKUP_ON_CHECK = True + user = create_new_user() + user.trial_end = None + user.lifetime = True + Session.commit() + + +# is_valid_domain +def test_is_valid_domain(): + assert is_valid_domain("example.com") is True + assert is_valid_domain("sub.example.com") is True + assert is_valid_domain("ex-ample.com") is True + + assert is_valid_domain("-example.com") is False + assert is_valid_domain("example-.com") is False + assert is_valid_domain("exa_mple.com") is False + assert is_valid_domain("example..com") is False + assert is_valid_domain("") is False + assert is_valid_domain("a" * 64 + ".com") is False + assert is_valid_domain("a" * 63 + ".com") is True + assert is_valid_domain("example.com.") is True + assert is_valid_domain(".example.com") is False + assert is_valid_domain("example..com") is False + assert is_valid_domain("example.com-") is False + + +# can_domain_be_used +def test_can_domain_be_used(): + domain = f"{random_string(10)}.com" + res = can_domain_be_used(user, domain) + assert res is None + + +def test_can_domain_be_used_existing_domain(): + domain = random_domain() + CustomDomain.create(user_id=user.id, domain=domain, commit=True) + res = can_domain_be_used(user, domain) + assert res is CannotUseDomainReason.DomainAlreadyUsed + + +def test_can_domain_be_used_sl_domain(): + domain = ALIAS_DOMAINS[0] + res = can_domain_be_used(user, domain) + assert res is CannotUseDomainReason.BuiltinDomain + + +def test_can_domain_be_used_domain_of_user_email(): + domain = user.email.split("@")[1] + res = can_domain_be_used(user, domain) + assert res is CannotUseDomainReason.DomainPartOfUserEmail + + +def test_can_domain_be_used_domain_of_existing_mailbox(): + domain = random_domain() + Mailbox.create(user_id=user.id, email=f"email@{domain}", verified=True, commit=True) + res = can_domain_be_used(user, domain) + assert res is CannotUseDomainReason.DomainUserInMailbox + + +def test_can_domain_be_used_invalid_domain(): + domain = f"{random_string(10)}@lol.com" + res = can_domain_be_used(user, domain) + assert res is CannotUseDomainReason.InvalidDomain + + +# sanitize_domain +def test_can_sanitize_domain_empty(): + assert sanitize_domain("") == "" + + +def test_can_sanitize_domain_starting_with_http(): + domain = "test.domain" + assert sanitize_domain(f"http://{domain}") == domain + + +def test_can_sanitize_domain_starting_with_https(): + domain = "test.domain" + assert sanitize_domain(f"https://{domain}") == domain + + +def test_can_sanitize_domain_correct_domain(): + domain = "test.domain" + assert sanitize_domain(domain) == domain + + +# create_custom_domain +def test_can_create_custom_domain(): + domain = random_domain() + res = create_custom_domain(user=user, domain=domain) + assert res.success is True + assert res.redirect is None + assert res.message == "" + assert res.message_category == "" + assert res.instance is not None + + assert res.instance.domain == domain + assert res.instance.user_id == user.id + + +def test_can_create_custom_domain_validates_if_parent_is_validated(): + root_domain = random_domain() + subdomain = f"{random_string(10)}.{root_domain}" + + # Create custom domain with the root domain + CustomDomain.create( + user_id=user.id, + domain=root_domain, + verified=True, + ownership_verified=True, + commit=True, + ) + + # Create custom domain with subdomain. Should automatically be verified + res = create_custom_domain(user=user, domain=subdomain) + assert res.success is True + assert res.instance.domain == subdomain + assert res.instance.user_id == user.id + assert res.instance.ownership_verified is True + + +def test_creates_custom_domain_with_partner_id(): + domain = random_domain() + proton_partner = get_proton_partner() + res = create_custom_domain(user=user, domain=domain, partner_id=proton_partner.id) + assert res.success is True + assert res.instance.domain == domain + assert res.instance.user_id == user.id + assert res.instance.partner_id == proton_partner.id