From f6708dd0b631da32f19ffcde6f91f5ed9b867c3d Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:15:10 +0200 Subject: [PATCH] chore: refactor dns to improve testability (#2224) * chore: refactor DNS client to its own class * chore: adapt code calling DNS and add tests to it * chore: refactor old dkim check not to clear flag --- app/custom_domain_validation.py | 50 +++- app/dns_utils.py | 214 ++++++++------ templates/dashboard/domain_detail/dns.html | 2 +- tests/test_custom_domain_validation.py | 325 +++++++++++++++++++++ tests/test_dns_utils.py | 22 +- 5 files changed, 497 insertions(+), 116 deletions(-) create mode 100644 tests/test_custom_domain_validation.py diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 05b6af37..6887e81a 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.constants import DMARC_RECORD from app.db import Session from app.dns_utils import ( - get_cname_record, - get_mx_domains, - get_txt_record, + DNSClient, is_mx_equivalent, - get_spf_domain, + get_network_dns_client, ) from app.models import CustomDomain from dataclasses import dataclass @@ -19,10 +17,13 @@ class DomainValidationResult: class CustomDomainValidation: - def __init__(self, dkim_domain: str): + def __init__( + self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client() + ): self.dkim_domain = dkim_domain + self._dns_client = dns_client self._dkim_records = { - (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") + f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}" for key in ("dkim", "dkim02", "dkim03") } @@ -38,15 +39,31 @@ class CustomDomainValidation: Check if dkim records are properly set for this custom domain. Returns empty list if all records are ok. Other-wise return the records that aren't properly configured """ + correct_records = {} invalid_records = {} - for prefix, expected_record in self.get_dkim_records(): + expected_records = self.get_dkim_records() + for prefix, expected_record in expected_records.items(): custom_record = f"{prefix}.{custom_domain.domain}" - dkim_record = get_cname_record(custom_record) - if dkim_record != expected_record: + dkim_record = self._dns_client.get_cname_record(custom_record) + if dkim_record == expected_record: + correct_records[prefix] = custom_record + else: invalid_records[custom_record] = dkim_record or "empty" - # HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES + + # HACK + # As initially we only had one dkim record, we want to allow users that had only the original dkim record and + # the domain validated to continue seeing it as validated (although showing them the missing records). + # However, if not even the original dkim record is right, even if the domain was dkim_verified in the past, + # we will remove the dkim_verified flag. + # This is done in order to give users with the old dkim config (only one) to update their CNAMEs if custom_domain.dkim_verified: - return invalid_records + # Check if at least the original dkim is there + if correct_records.get("dkim._domainkey") is not None: + # Original dkim record is there. Return the missing records (if any) and don't clear the flag + return invalid_records + + # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the + # rest of the code path, returning the invalid records and clearing the flag custom_domain.dkim_verified = len(invalid_records) == 0 Session.commit() return invalid_records @@ -57,7 +74,7 @@ class CustomDomainValidation: """ Check if the custom_domain has added the ownership verification records """ - txt_records = get_txt_record(custom_domain.domain) + txt_records = self._dns_client.get_txt_record(custom_domain.domain) if custom_domain.get_ownership_dns_txt_value() in txt_records: custom_domain.ownership_verified = True @@ -69,7 +86,7 @@ class CustomDomainValidation: def validate_mx_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - mx_domains = get_mx_domains(custom_domain.domain) + mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): return DomainValidationResult( @@ -84,7 +101,7 @@ class CustomDomainValidation: def validate_spf_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - spf_domains = get_spf_domain(custom_domain.domain) + spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) if EMAIL_DOMAIN in spf_domains: custom_domain.spf_verified = True Session.commit() @@ -93,13 +110,14 @@ class CustomDomainValidation: custom_domain.spf_verified = False Session.commit() return DomainValidationResult( - success=False, errors=get_txt_record(custom_domain.domain) + success=False, + errors=self._dns_client.get_txt_record(custom_domain.domain), ) def validate_dmarc_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - txt_records = get_txt_record("_dmarc." + custom_domain.domain) + txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain) if DMARC_RECORD in txt_records: custom_domain.dmarc_verified = True Session.commit() diff --git a/app/dns_utils.py b/app/dns_utils.py index 429d0aa2..2ce69934 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -1,100 +1,13 @@ -from app import config -from typing import Optional, List, Tuple +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional import dns.resolver - -def _get_dns_resolver(): - my_resolver = dns.resolver.Resolver() - my_resolver.nameservers = config.NAMESERVERS - - return my_resolver - - -def get_ns(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "NS", search=True) - except Exception: - return [] - return [a.to_text() for a in answers] - - -def get_cname_record(hostname) -> Optional[str]: - """Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end""" - try: - answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True) - except Exception: - return None - - for a in answers: - ret = a.to_text() - return ret[:-1] - - return None - - -def get_mx_domains(hostname) -> [(int, str)]: - """return list of (priority, domain name) sorted by priority (lowest priority first) - domain name ends with a "." at the end. - """ - try: - answers = _get_dns_resolver().resolve(hostname, "MX", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: - record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' - parts = record.split(" ") - - ret.append((int(parts[0]), parts[1])) - - return sorted(ret, key=lambda prio_domain: prio_domain[0]) - +from app.config import NAMESERVERS _include_spf = "include:" -def get_spf_domain(hostname) -> [str]: - """return all domains listed in *include:*""" - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - if record.startswith("v=spf1"): - parts = record.split(" ") - for part in parts: - if part.startswith(_include_spf): - ret.append(part[part.find(_include_spf) + len(_include_spf) :]) - - return ret - - -def get_txt_record(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - ret.append(record) - - return ret - - def is_mx_equivalent( mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] ) -> bool: @@ -105,16 +18,127 @@ def is_mx_equivalent( The priority order is taken into account but not the priority number. For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] """ - mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0]) - ref_mx_domains = sorted( - ref_mx_domains, key=lambda priority_domain: priority_domain[0] - ) + mx_domains = sorted(mx_domains, key=lambda x: x[0]) + ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) if len(mx_domains) < len(ref_mx_domains): return False - for i in range(0, len(ref_mx_domains)): + for i in range(len(ref_mx_domains)): if mx_domains[i][1] != ref_mx_domains[i][1]: return False return True + + +class DNSClient(ABC): + @abstractmethod + def get_cname_record(self, hostname: str) -> Optional[str]: + pass + + @abstractmethod + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + pass + + def get_spf_domain(self, hostname: str) -> List[str]: + """ + return all domains listed in *include:* + """ + try: + records = self.get_txt_record(hostname) + ret = [] + for record in records: + if record.startswith("v=spf1"): + parts = record.split(" ") + for part in parts: + if part.startswith(_include_spf): + ret.append( + part[part.find(_include_spf) + len(_include_spf) :] + ) + return ret + except Exception: + return [] + + @abstractmethod + def get_txt_record(self, hostname: str) -> List[str]: + pass + + +class NetworkDNSClient(DNSClient): + def __init__(self, nameservers: List[str]): + self._resolver = dns.resolver.Resolver() + self._resolver.nameservers = nameservers + + def get_cname_record(self, hostname: str) -> Optional[str]: + """ + Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end + """ + try: + answers = self._resolver.resolve(hostname, "CNAME", search=True) + for a in answers: + ret = a.to_text() + return ret[:-1] + except Exception: + return None + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + """ + return list of (priority, domain name) sorted by priority (lowest priority first) + domain name ends with a "." at the end. + """ + try: + answers = self._resolver.resolve(hostname, "MX", search=True) + ret = [] + for a in answers: + record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' + parts = record.split(" ") + ret.append((int(parts[0]), parts[1])) + return sorted(ret, key=lambda x: x[0]) + except Exception: + return [] + + def get_txt_record(self, hostname: str) -> List[str]: + try: + answers = self._resolver.resolve(hostname, "TXT", search=True) + ret = [] + for a in answers: # type: dns.rdtypes.ANY.TXT.TXT + for record in a.strings: + ret.append(record.decode()) + return ret + except Exception: + return [] + + +class InMemoryDNSClient(DNSClient): + def __init__(self): + self.cname_records: dict[str, Optional[str]] = {} + self.mx_records: dict[str, List[Tuple[int, str]]] = {} + self.spf_records: dict[str, List[str]] = {} + self.txt_records: dict[str, List[str]] = {} + + def set_cname_record(self, hostname: str, cname: str): + self.cname_records[hostname] = cname + + def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): + self.mx_records[hostname] = mx_list + + def set_txt_record(self, hostname: str, txt_list: List[str]): + self.txt_records[hostname] = txt_list + + def get_cname_record(self, hostname: str) -> Optional[str]: + return self.cname_records.get(hostname) + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + mx_list = self.mx_records.get(hostname, []) + return sorted(mx_list, key=lambda x: x[0]) + + def get_txt_record(self, hostname: str) -> List[str]: + return self.txt_records.get(hostname, []) + + +def get_network_dns_client() -> NetworkDNSClient: + return NetworkDNSClient(NAMESERVERS) + + +def get_mx_domains(hostname: str) -> [(int, str)]: + return get_network_dns_client().get_mx_domains(hostname) diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 15ef346f..810aa302 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -237,7 +237,7 @@ folder.
Add the following CNAME DNS records to your domain.
- {% for dkim_prefix, dkim_cname_value in dkim_records %} + {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
Record: CNAME diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py new file mode 100644 index 00000000..7a3882bd --- /dev/null +++ b/tests/test_custom_domain_validation.py @@ -0,0 +1,325 @@ +from typing import Optional + +from app import config +from app.constants import DMARC_RECORD +from app.custom_domain_validation import CustomDomainValidation +from app.db import Session +from app.models import CustomDomain, User +from app.dns_utils import InMemoryDNSClient +from app.utils import random_string +from tests.utils import create_new_user, random_domain + +user: Optional[User] = None + + +def setup_module(): + global user + config.SKIP_MX_LOOKUP_ON_CHECK = True + user = create_new_user() + user.trial_end = None + user.lifetime = True + Session.commit() + + +def create_custom_domain(domain: str) -> CustomDomain: + return CustomDomain.create(user_id=user.id, domain=domain, commit=True) + + +def test_custom_domain_validation_get_dkim_records(): + domain = random_domain() + validator = CustomDomainValidation(domain) + records = validator.get_dkim_records() + + assert len(records) == 3 + assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}" + assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}" + assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}" + + +# validate_dkim_records +def test_custom_domain_validation_validate_dkim_records_empty_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_dkim_records(domain) + + assert len(res) == 3 + for record_value in res.values(): + assert record_value == "empty" + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is False + + +def test_custom_domain_validation_validate_dkim_records_wrong_records_failure(): + dkim_domain = random_domain() + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(dkim_domain, dns_client) + + user_domain = random_domain() + + # One domain right, two domains wrong + dns_client.set_cname_record( + f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong") + dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong") + + domain = create_custom_domain(user_domain) + res = validator.validate_dkim_records(domain) + + assert len(res) == 2 + for record_value in res.values(): + assert record_value == "wrong" + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is False + + +def test_custom_domain_validation_validate_dkim_records_success_with_old_system(): + dkim_domain = random_domain() + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(dkim_domain, dns_client) + + user_domain = random_domain() + + # One domain right, other domains missing + dns_client.set_cname_record( + f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}" + ) + + domain = create_custom_domain(user_domain) + + # DKIM is verified + domain.dkim_verified = True + Session.commit() + + res = validator.validate_dkim_records(domain) + assert len(res) == 2 + assert f"dkim02._domainkey.{user_domain}" in res + assert f"dkim03._domainkey.{user_domain}" in res + + # Flag is not cleared + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is True + + +def test_custom_domain_validation_validate_dkim_records_success(): + dkim_domain = random_domain() + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(dkim_domain, dns_client) + + user_domain = random_domain() + + # One domain right, two domains wrong + dns_client.set_cname_record( + f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record( + f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}" + ) + dns_client.set_cname_record( + f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}" + ) + + domain = create_custom_domain(user_domain) + res = validator.validate_dkim_records(domain) + assert len(res) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dkim_verified is True + + +# validate_ownership +def test_custom_domain_validation_validate_ownership_empty_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_domain_ownership(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is False + + +def test_custom_domain_validation_validate_ownership_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(domain.domain, wrong_records) + res = validator.validate_domain_ownership(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is False + + +def test_custom_domain_validation_validate_ownership_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()]) + res = validator.validate_domain_ownership(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.ownership_verified is True + + +# validate_mx_records +def test_custom_domain_validation_validate_mx_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_mx_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is False + + +def test_custom_domain_validation_validate_mx_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_record_1 = random_string() + wrong_record_2 = random_string() + wrong_records = [(10, wrong_record_1), (20, wrong_record_2)] + dns_client.set_mx_records(domain.domain, wrong_records) + res = validator.validate_mx_records(domain) + + assert res.success is False + assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"] + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is False + + +def test_custom_domain_validation_validate_mx_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY) + res = validator.validate_mx_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.verified is True + + +# validate_spf_records +def test_custom_domain_validation_validate_spf_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_spf_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is False + + +def test_custom_domain_validation_validate_spf_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(domain.domain, wrong_records) + res = validator.validate_spf_records(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is False + + +def test_custom_domain_validation_validate_spf_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"]) + res = validator.validate_spf_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.spf_verified is True + + +# validate_dmarc_records +def test_custom_domain_validation_validate_dmarc_records_empty_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + res = validator.validate_dmarc_records(domain) + + assert res.success is False + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is False + + +def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + wrong_records = [random_string()] + dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records) + res = validator.validate_dmarc_records(domain) + + assert res.success is False + assert res.errors == wrong_records + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is False + + +def test_custom_domain_validation_validate_dmarc_records_success(): + dns_client = InMemoryDNSClient() + validator = CustomDomainValidation(random_domain(), dns_client) + + domain = create_custom_domain(random_domain()) + + dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD]) + res = validator.validate_dmarc_records(domain) + + assert res.success is True + assert len(res.errors) == 0 + + db_domain = CustomDomain.get_by(id=domain.id) + assert db_domain.dmarc_verified is True diff --git a/tests/test_dns_utils.py b/tests/test_dns_utils.py index 46a6ccdd..374983c8 100644 --- a/tests/test_dns_utils.py +++ b/tests/test_dns_utils.py @@ -1,10 +1,12 @@ from app.dns_utils import ( get_mx_domains, - get_spf_domain, - get_txt_record, + get_network_dns_client, is_mx_equivalent, + InMemoryDNSClient, ) +from tests.utils import random_domain + # use our own domain for test _DOMAIN = "simplelogin.io" @@ -20,12 +22,12 @@ def test_get_mx_domains(): def test_get_spf_domain(): - r = get_spf_domain(_DOMAIN) + r = get_network_dns_client().get_spf_domain(_DOMAIN) assert r == ["simplelogin.co"] def test_get_txt_record(): - r = get_txt_record(_DOMAIN) + r = get_network_dns_client().get_txt_record(_DOMAIN) assert len(r) > 0 @@ -46,3 +48,15 @@ def test_is_mx_equivalent(): [(5, "domain1"), (10, "domain2")], [(10, "domain1"), (20, "domain2"), (20, "domain3")], ) + + +def test_get_spf_record(): + client = InMemoryDNSClient() + + sl_domain = random_domain() + domain = random_domain() + + spf_record = f"v=spf1 include:{sl_domain}" + client.set_txt_record(domain, [spf_record, "another record"]) + res = client.get_spf_domain(domain) + assert res == [sl_domain]