diff --git a/app/custom_domain_utils.py b/app/custom_domain_utils.py index 66a7ded8..d54760b5 100644 --- a/app/custom_domain_utils.py +++ b/app/custom_domain_utils.py @@ -3,15 +3,16 @@ import re from dataclasses import dataclass from enum import Enum -from typing import Optional +from typing import List, Optional from app.config import JOB_DELETE_DOMAIN from app.db import Session from app.email_utils import get_email_domain_part from app.log import LOG -from app.models import User, CustomDomain, SLDomain, Mailbox, Job +from app.models import User, CustomDomain, SLDomain, Mailbox, Job, DomainMailbox _ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(? bool: """ Checks that a domain is valid according to RFC 1035 @@ -140,3 +155,40 @@ def delete_custom_domain(domain: CustomDomain): run_at=arrow.now(), commit=True, ) + + +def set_custom_domain_mailboxes( + user_id: int, custom_domain: CustomDomain, mailbox_ids: List[int] +) -> SetCustomDomainMailboxesResult: + if len(mailbox_ids) == 0: + return SetCustomDomainMailboxesResult( + success=False, reason=CannotSetCustomDomainMailboxesCause.NoMailboxes + ) + elif len(mailbox_ids) > _MAX_MAILBOXES_PER_DOMAIN: + return SetCustomDomainMailboxesResult( + success=False, reason=CannotSetCustomDomainMailboxesCause.TooManyMailboxes + ) + + mailboxes = ( + Session.query(Mailbox) + .filter( + Mailbox.id.in_(mailbox_ids), + Mailbox.user_id == user_id, + Mailbox.verified == True, # noqa: E712 + ) + .all() + ) + if len(mailboxes) != len(mailbox_ids): + return SetCustomDomainMailboxesResult( + success=False, reason=CannotSetCustomDomainMailboxesCause.InvalidMailbox + ) + + # first remove all existing domain-mailboxes links + DomainMailbox.filter_by(domain_id=custom_domain.id).delete() + Session.flush() + + for mailbox in mailboxes: + DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) + + Session.commit() + return SetCustomDomainMailboxesResult(success=True) diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 73909812..0911a748 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -7,7 +7,7 @@ from wtforms import StringField, validators, IntegerField from app.constants import DMARC_RECORD from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN -from app.custom_domain_utils import delete_custom_domain +from app.custom_domain_utils import delete_custom_domain, set_custom_domain_mailboxes from app.custom_domain_validation import CustomDomainValidation from app.dashboard.base import dashboard_bp from app.db import Session @@ -16,7 +16,6 @@ from app.models import ( Alias, DomainDeletedAlias, Mailbox, - DomainMailbox, AutoCreateRule, AutoCreateRuleMailbox, ) @@ -220,40 +219,16 @@ def domain_detail(custom_domain_id): ) elif request.form.get("form-name") == "update": mailbox_ids = request.form.getlist("mailbox_ids") - # check if mailbox is not tempered with - mailboxes = [] - for mailbox_id in mailbox_ids: - mailbox = Mailbox.get(mailbox_id) - if ( - not mailbox - or mailbox.user_id != current_user.id - or not mailbox.verified - ): - flash("Something went wrong, please retry", "warning") - return redirect( - url_for( - "dashboard.domain_detail", custom_domain_id=custom_domain.id - ) - ) - mailboxes.append(mailbox) + result = set_custom_domain_mailboxes( + user_id=current_user.id, + custom_domain=custom_domain, + mailbox_ids=mailbox_ids, + ) - if not mailboxes: - flash("You must select at least 1 mailbox", "warning") - return redirect( - url_for( - "dashboard.domain_detail", custom_domain_id=custom_domain.id - ) - ) - - # first remove all existing domain-mailboxes links - DomainMailbox.filter_by(domain_id=custom_domain.id).delete() - Session.flush() - - for mailbox in mailboxes: - DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) - - Session.commit() - flash(f"{custom_domain.domain} mailboxes has been updated", "success") + if result.success: + flash(f"{custom_domain.domain} mailboxes has been updated", "success") + else: + flash(result.reason.value, "warning") return redirect( url_for("dashboard.domain_detail", custom_domain_id=custom_domain.id) diff --git a/tests/test_custom_domain_utils.py b/tests/test_custom_domain_utils.py index af6ab70e..ec60f896 100644 --- a/tests/test_custom_domain_utils.py +++ b/tests/test_custom_domain_utils.py @@ -7,11 +7,13 @@ from app.custom_domain_utils import ( create_custom_domain, is_valid_domain, sanitize_domain, + set_custom_domain_mailboxes, CannotUseDomainReason, + CannotSetCustomDomainMailboxesCause, ) from app.db import Session -from app.models import User, CustomDomain, Mailbox -from tests.utils import get_proton_partner +from app.models import User, CustomDomain, Mailbox, DomainMailbox +from tests.utils import get_proton_partner, random_email from tests.utils import create_new_user, random_string, random_domain user: Optional[User] = None @@ -147,3 +149,119 @@ def test_creates_custom_domain_with_partner_id(): assert res.instance.domain == domain assert res.instance.user_id == user.id assert res.instance.partner_id == proton_partner.id + + +# set_custom_domain_mailboxes +def test_set_custom_domain_mailboxes_empty_list(): + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + res = set_custom_domain_mailboxes(user.id, domain, []) + assert res.success is False + assert res.reason == CannotSetCustomDomainMailboxesCause.NoMailboxes + + +def test_set_custom_domain_mailboxes_mailbox_from_another_user(): + other_user = create_new_user() + other_mailbox = Mailbox.create( + user_id=other_user.id, email=random_email(), verified=True + ) + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + + res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id]) + assert res.success is False + assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox + + +def test_set_custom_domain_mailboxes_mailbox_from_current_user_and_another_user(): + other_user = create_new_user() + other_mailbox = Mailbox.create( + user_id=other_user.id, email=random_email(), verified=True + ) + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + + res = set_custom_domain_mailboxes( + user.id, domain, [user.default_mailbox_id, other_mailbox.id] + ) + assert res.success is False + assert res.reason == CannotSetCustomDomainMailboxesCause.InvalidMailbox + + +def test_set_custom_domain_mailboxes_success(): + other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True) + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + + res = set_custom_domain_mailboxes( + user.id, domain, [user.default_mailbox_id, other_mailbox.id] + ) + assert res.success is True + assert res.reason is None + + domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all() + assert len(domain_mailboxes) == 2 + assert domain_mailboxes[0].domain_id == domain.id + assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id + assert domain_mailboxes[1].domain_id == domain.id + assert domain_mailboxes[1].mailbox_id == other_mailbox.id + + +def test_set_custom_domain_mailboxes_set_twice(): + other_mailbox = Mailbox.create(user_id=user.id, email=random_email(), verified=True) + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + + res = set_custom_domain_mailboxes( + user.id, domain, [user.default_mailbox_id, other_mailbox.id] + ) + assert res.success is True + assert res.reason is None + + res = set_custom_domain_mailboxes( + user.id, domain, [user.default_mailbox_id, other_mailbox.id] + ) + assert res.success is True + assert res.reason is None + + domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all() + assert len(domain_mailboxes) == 2 + assert domain_mailboxes[0].domain_id == domain.id + assert domain_mailboxes[0].mailbox_id == user.default_mailbox_id + assert domain_mailboxes[1].domain_id == domain.id + assert domain_mailboxes[1].mailbox_id == other_mailbox.id + + +def test_set_custom_domain_mailboxes_removes_old_association(): + domain = CustomDomain.create(user_id=user.id, domain=random_domain(), commit=True) + + res = set_custom_domain_mailboxes(user.id, domain, [user.default_mailbox_id]) + assert res.success is True + assert res.reason is None + + other_mailbox = Mailbox.create( + user_id=user.id, email=random_email(), verified=True, commit=True + ) + res = set_custom_domain_mailboxes(user.id, domain, [other_mailbox.id]) + assert res.success is True + assert res.reason is None + + domain_mailboxes = DomainMailbox.filter_by(domain_id=domain.id).all() + assert len(domain_mailboxes) == 1 + assert domain_mailboxes[0].domain_id == domain.id + assert domain_mailboxes[0].mailbox_id == other_mailbox.id + + +def test_set_custom_domain_mailboxes_with_unverified_mailbox(): + domain = CustomDomain.create(user_id=user.id, domain=random_domain()) + verified_mailbox = Mailbox.create( + user_id=user.id, + email=random_email(), + verified=True, + ) + unverified_mailbox = Mailbox.create( + user_id=user.id, + email=random_email(), + verified=False, + ) + + res = set_custom_domain_mailboxes( + user.id, domain, [verified_mailbox.id, unverified_mailbox.id] + ) + assert res.success is False + assert res.reason is CannotSetCustomDomainMailboxesCause.InvalidMailbox