From b5866fa779733b2fbd85bcec15282560740dce04 Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Wed, 18 Sep 2024 12:12:42 +0200 Subject: [PATCH] feat: allow to define partner records (#2225) --- app/config.py | 15 +++--- app/custom_domain_validation.py | 63 +++++++++++++++++----- app/dashboard/views/domain_detail.py | 5 +- app/models.py | 3 -- templates/dashboard/domain_detail/dns.html | 2 +- tests/test_custom_domain_validation.py | 55 ++++++++++++++++++- 6 files changed, 116 insertions(+), 27 deletions(-) diff --git a/app/config.py b/app/config.py index d53c502b..e55d1d8f 100644 --- a/app/config.py +++ b/app/config.py @@ -638,19 +638,22 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_ EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI) -def read_partner_domains() -> dict[int, str]: - partner_domains_dict = get_env_dict("PARTNER_DOMAINS") - if len(partner_domains_dict) == 0: +def read_partner_dict(var: str) -> dict[int, str]: + partner_value = get_env_dict(var) + if len(partner_value) == 0: return {} res: dict[int, str] = {} - for partner_id in partner_domains_dict.keys(): + for partner_id in partner_value.keys(): try: partner_id_int = int(partner_id.strip()) - res[partner_id_int] = partner_domains_dict[partner_id] + res[partner_id_int] = partner_value[partner_id] except ValueError: pass return res -PARTNER_DOMAINS: dict[int, str] = read_partner_domains() +PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS") +PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict( + "PARTNER_DOMAIN_VALIDATION_PREFIXES" +) diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 6887e81a..779f1c59 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -1,4 +1,12 @@ -from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN +from dataclasses import dataclass +from typing import Optional + +from app.config import ( + EMAIL_SERVERS_WITH_PRIORITY, + EMAIL_DOMAIN, + PARTNER_DOMAINS, + PARTNER_DOMAIN_VALIDATION_PREFIXES, +) from app.constants import DMARC_RECORD from app.db import Session from app.dns_utils import ( @@ -7,7 +15,6 @@ from app.dns_utils import ( get_network_dns_client, ) from app.models import CustomDomain -from dataclasses import dataclass @dataclass @@ -18,22 +25,47 @@ class DomainValidationResult: class CustomDomainValidation: def __init__( - self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client() + self, + dkim_domain: str, + dns_client: DNSClient = get_network_dns_client(), + partner_domains: Optional[dict[int, str]] = None, + partner_domains_validation_prefixes: Optional[dict[int, str]] = None, ): self.dkim_domain = dkim_domain self._dns_client = dns_client - self._dkim_records = { - f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}" + self._partner_domains = partner_domains or PARTNER_DOMAINS + self._partner_domain_validation_prefixes = ( + partner_domains_validation_prefixes or PARTNER_DOMAIN_VALIDATION_PREFIXES + ) + + def get_ownership_verification_record(self, domain: CustomDomain) -> str: + prefix = "sl-verification" + if ( + domain.partner_id is not None + and domain.partner_id in self._partner_domain_validation_prefixes + ): + prefix = self._partner_domain_validation_prefixes[domain.partner_id] + return f"{prefix}={domain.ownership_txt_token}" + + def get_dkim_records(self, domain: CustomDomain) -> {str: str}: + """ + Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not, + it will return the default ones or the partner ones. + """ + + # By default use the default domain + dkim_domain = self.dkim_domain + if domain.partner_id is not None: + # Domain is from a partner. Retrieve the partner config and use that domain if exists + partner_domain = self._partner_domains.get(domain.partner_id) + if partner_domain is not None: + dkim_domain = partner_domain + + return { + f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}" for key in ("dkim", "dkim02", "dkim03") } - def get_dkim_records(self) -> {str: str}: - """ - Get a list of dkim records to set up. It will be - - """ - return self._dkim_records - def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]: """ Check if dkim records are properly set for this custom domain. @@ -41,7 +73,7 @@ class CustomDomainValidation: """ correct_records = {} invalid_records = {} - expected_records = self.get_dkim_records() + expected_records = self.get_dkim_records(custom_domain) for prefix, expected_record in expected_records.items(): custom_record = f"{prefix}.{custom_domain.domain}" dkim_record = self._dns_client.get_cname_record(custom_record) @@ -75,8 +107,11 @@ class CustomDomainValidation: Check if the custom_domain has added the ownership verification records """ txt_records = self._dns_client.get_txt_record(custom_domain.domain) + expected_verification_record = self.get_ownership_verification_record( + custom_domain + ) - if custom_domain.get_ownership_dns_txt_value() in txt_records: + if expected_verification_record in txt_records: custom_domain.ownership_verified = True Session.commit() return DomainValidationResult(success=True, errors=[]) diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 1dc7683d..9c714c9f 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -141,7 +141,10 @@ def domain_detail_dns(custom_domain_id): return render_template( "dashboard/domain_detail/dns.html", EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, - dkim_records=domain_validator.get_dkim_records(), + ownership_record=domain_validator.get_ownership_verification_record( + custom_domain + ), + dkim_records=domain_validator.get_dkim_records(custom_domain), dmarc_record=DMARC_RECORD, **locals(), ) diff --git a/app/models.py b/app/models.py index f26cac29..31ce24d8 100644 --- a/app/models.py +++ b/app/models.py @@ -2451,9 +2451,6 @@ class CustomDomain(Base, ModelMixin): def get_trash_url(self): return config.URL + f"/dashboard/domains/{self.id}/trash" - def get_ownership_dns_txt_value(self): - return f"sl-verification={self.ownership_txt_token}" - @classmethod def create(cls, **kwargs): domain = kwargs.get("domain") diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 810aa302..4058f5ea 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -38,7 +38,7 @@ Value: {{ custom_domain.get_ownership_dns_txt_value() }} + data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}
{{ csrf_form.csrf_token }} diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py index 7a3882bd..b6e4386d 100644 --- a/tests/test_custom_domain_validation.py +++ b/tests/test_custom_domain_validation.py @@ -6,6 +6,7 @@ from app.custom_domain_validation import CustomDomainValidation from app.db import Session from app.models import CustomDomain, User from app.dns_utils import InMemoryDNSClient +from app.proton.utils import get_proton_partner from app.utils import random_string from tests.utils import create_new_user, random_domain @@ -27,8 +28,9 @@ def create_custom_domain(domain: str) -> CustomDomain: def test_custom_domain_validation_get_dkim_records(): domain = random_domain() + custom_domain = create_custom_domain(domain) validator = CustomDomainValidation(domain) - records = validator.get_dkim_records() + records = validator.get_dkim_records(custom_domain) assert len(records) == 3 assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}" @@ -36,6 +38,26 @@ def test_custom_domain_validation_get_dkim_records(): assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}" +def test_custom_domain_validation_get_dkim_records_for_partner(): + domain = random_domain() + custom_domain = create_custom_domain(domain) + + partner_id = get_proton_partner().id + custom_domain.partner_id = partner_id + Session.commit() + + dkim_domain = random_domain() + validator = CustomDomainValidation( + domain, partner_domains={partner_id: dkim_domain} + ) + records = validator.get_dkim_records(custom_domain) + + assert len(records) == 3 + assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}" + assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}" + assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}" + + # validate_dkim_records def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): dns_client = InMemoryDNSClient() @@ -169,7 +191,36 @@ def test_custom_domain_validation_validate_ownership_success(): domain = create_custom_domain(random_domain()) - dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()]) + dns_client.set_txt_record( + domain.domain, [validator.get_ownership_verification_record(domain)] + ) + res = validator.validate_domain_ownership(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is True + + +def test_custom_domain_validation_validate_ownership_from_partner_success(): + dns_client = InMemoryDNSClient() + partner_id = get_proton_partner().id + + prefix = random_string() + validator = CustomDomainValidation( + random_domain(), + dns_client, + partner_domains_validation_prefixes={partner_id: prefix}, + ) + + domain = create_custom_domain(random_domain()) + domain.partner_id = partner_id + Session.commit() + + dns_client.set_txt_record( + domain.domain, [validator.get_ownership_verification_record(domain)] + ) res = validator.validate_domain_ownership(domain) assert res.success is True