From f6708dd0b631da32f19ffcde6f91f5ed9b867c3d Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Tue, 17 Sep 2024 16:15:10 +0200 Subject: [PATCH] chore: refactor dns to improve testability (#2224) * chore: refactor DNS client to its own class * chore: adapt code calling DNS and add tests to it * chore: refactor old dkim check not to clear flag --- app/custom_domain_validation.py | 50 +++- app/dns_utils.py | 214 ++++++++------ templates/dashboard/domain_detail/dns.html | 2 +- tests/test_custom_domain_validation.py | 325 +++++++++++++++++++++ tests/test_dns_utils.py | 22 +- 5 files changed, 497 insertions(+), 116 deletions(-) create mode 100644 tests/test_custom_domain_validation.py diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 05b6af37..6887e81a 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.constants import DMARC_RECORD from app.db import Session from app.dns_utils import ( - get_cname_record, - get_mx_domains, - get_txt_record, + DNSClient, is_mx_equivalent, - get_spf_domain, + get_network_dns_client, ) from app.models import CustomDomain from dataclasses import dataclass @@ -19,10 +17,13 @@ class DomainValidationResult: class CustomDomainValidation: - def __init__(self, dkim_domain: str): + def __init__( + self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client() + ): self.dkim_domain = dkim_domain + self._dns_client = dns_client self._dkim_records = { - (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") + f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}" for key in ("dkim", "dkim02", "dkim03") } @@ -38,15 +39,31 @@ class CustomDomainValidation: Check if dkim records are properly set for this custom domain. Returns empty list if all records are ok. Other-wise return the records that aren't properly configured """ + correct_records = {} invalid_records = {} - for prefix, expected_record in self.get_dkim_records(): + expected_records = self.get_dkim_records() + for prefix, expected_record in expected_records.items(): custom_record = f"{prefix}.{custom_domain.domain}" - dkim_record = get_cname_record(custom_record) - if dkim_record != expected_record: + dkim_record = self._dns_client.get_cname_record(custom_record) + if dkim_record == expected_record: + correct_records[prefix] = custom_record + else: invalid_records[custom_record] = dkim_record or "empty" - # HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES + + # HACK + # As initially we only had one dkim record, we want to allow users that had only the original dkim record and + # the domain validated to continue seeing it as validated (although showing them the missing records). + # However, if not even the original dkim record is right, even if the domain was dkim_verified in the past, + # we will remove the dkim_verified flag. + # This is done in order to give users with the old dkim config (only one) to update their CNAMEs if custom_domain.dkim_verified: - return invalid_records + # Check if at least the original dkim is there + if correct_records.get("dkim._domainkey") is not None: + # Original dkim record is there. Return the missing records (if any) and don't clear the flag + return invalid_records + + # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the + # rest of the code path, returning the invalid records and clearing the flag custom_domain.dkim_verified = len(invalid_records) == 0 Session.commit() return invalid_records @@ -57,7 +74,7 @@ class CustomDomainValidation: """ Check if the custom_domain has added the ownership verification records """ - txt_records = get_txt_record(custom_domain.domain) + txt_records = self._dns_client.get_txt_record(custom_domain.domain) if custom_domain.get_ownership_dns_txt_value() in txt_records: custom_domain.ownership_verified = True @@ -69,7 +86,7 @@ class CustomDomainValidation: def validate_mx_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - mx_domains = get_mx_domains(custom_domain.domain) + mx_domains = self._dns_client.get_mx_domains(custom_domain.domain) if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): return DomainValidationResult( @@ -84,7 +101,7 @@ class CustomDomainValidation: def validate_spf_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - spf_domains = get_spf_domain(custom_domain.domain) + spf_domains = self._dns_client.get_spf_domain(custom_domain.domain) if EMAIL_DOMAIN in spf_domains: custom_domain.spf_verified = True Session.commit() @@ -93,13 +110,14 @@ class CustomDomainValidation: custom_domain.spf_verified = False Session.commit() return DomainValidationResult( - success=False, errors=get_txt_record(custom_domain.domain) + success=False, + errors=self._dns_client.get_txt_record(custom_domain.domain), ) def validate_dmarc_records( self, custom_domain: CustomDomain ) -> DomainValidationResult: - txt_records = get_txt_record("_dmarc." + custom_domain.domain) + txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain) if DMARC_RECORD in txt_records: custom_domain.dmarc_verified = True Session.commit() diff --git a/app/dns_utils.py b/app/dns_utils.py index 429d0aa2..2ce69934 100644 --- a/app/dns_utils.py +++ b/app/dns_utils.py @@ -1,100 +1,13 @@ -from app import config -from typing import Optional, List, Tuple +from abc import ABC, abstractmethod +from typing import List, Tuple, Optional import dns.resolver - -def _get_dns_resolver(): - my_resolver = dns.resolver.Resolver() - my_resolver.nameservers = config.NAMESERVERS - - return my_resolver - - -def get_ns(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "NS", search=True) - except Exception: - return [] - return [a.to_text() for a in answers] - - -def get_cname_record(hostname) -> Optional[str]: - """Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end""" - try: - answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True) - except Exception: - return None - - for a in answers: - ret = a.to_text() - return ret[:-1] - - return None - - -def get_mx_domains(hostname) -> [(int, str)]: - """return list of (priority, domain name) sorted by priority (lowest priority first) - domain name ends with a "." at the end. - """ - try: - answers = _get_dns_resolver().resolve(hostname, "MX", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: - record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' - parts = record.split(" ") - - ret.append((int(parts[0]), parts[1])) - - return sorted(ret, key=lambda prio_domain: prio_domain[0]) - +from app.config import NAMESERVERS _include_spf = "include:" -def get_spf_domain(hostname) -> [str]: - """return all domains listed in *include:*""" - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - if record.startswith("v=spf1"): - parts = record.split(" ") - for part in parts: - if part.startswith(_include_spf): - ret.append(part[part.find(_include_spf) + len(_include_spf) :]) - - return ret - - -def get_txt_record(hostname) -> [str]: - try: - answers = _get_dns_resolver().resolve(hostname, "TXT", search=True) - except Exception: - return [] - - ret = [] - - for a in answers: # type: dns.rdtypes.ANY.TXT.TXT - for record in a.strings: - record = record.decode() # record is bytes - - ret.append(record) - - return ret - - def is_mx_equivalent( mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] ) -> bool: @@ -105,16 +18,127 @@ def is_mx_equivalent( The priority order is taken into account but not the priority number. For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] """ - mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0]) - ref_mx_domains = sorted( - ref_mx_domains, key=lambda priority_domain: priority_domain[0] - ) + mx_domains = sorted(mx_domains, key=lambda x: x[0]) + ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) if len(mx_domains) < len(ref_mx_domains): return False - for i in range(0, len(ref_mx_domains)): + for i in range(len(ref_mx_domains)): if mx_domains[i][1] != ref_mx_domains[i][1]: return False return True + + +class DNSClient(ABC): + @abstractmethod + def get_cname_record(self, hostname: str) -> Optional[str]: + pass + + @abstractmethod + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + pass + + def get_spf_domain(self, hostname: str) -> List[str]: + """ + return all domains listed in *include:* + """ + try: + records = self.get_txt_record(hostname) + ret = [] + for record in records: + if record.startswith("v=spf1"): + parts = record.split(" ") + for part in parts: + if part.startswith(_include_spf): + ret.append( + part[part.find(_include_spf) + len(_include_spf) :] + ) + return ret + except Exception: + return [] + + @abstractmethod + def get_txt_record(self, hostname: str) -> List[str]: + pass + + +class NetworkDNSClient(DNSClient): + def __init__(self, nameservers: List[str]): + self._resolver = dns.resolver.Resolver() + self._resolver.nameservers = nameservers + + def get_cname_record(self, hostname: str) -> Optional[str]: + """ + Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end + """ + try: + answers = self._resolver.resolve(hostname, "CNAME", search=True) + for a in answers: + ret = a.to_text() + return ret[:-1] + except Exception: + return None + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + """ + return list of (priority, domain name) sorted by priority (lowest priority first) + domain name ends with a "." at the end. + """ + try: + answers = self._resolver.resolve(hostname, "MX", search=True) + ret = [] + for a in answers: + record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.' + parts = record.split(" ") + ret.append((int(parts[0]), parts[1])) + return sorted(ret, key=lambda x: x[0]) + except Exception: + return [] + + def get_txt_record(self, hostname: str) -> List[str]: + try: + answers = self._resolver.resolve(hostname, "TXT", search=True) + ret = [] + for a in answers: # type: dns.rdtypes.ANY.TXT.TXT + for record in a.strings: + ret.append(record.decode()) + return ret + except Exception: + return [] + + +class InMemoryDNSClient(DNSClient): + def __init__(self): + self.cname_records: dict[str, Optional[str]] = {} + self.mx_records: dict[str, List[Tuple[int, str]]] = {} + self.spf_records: dict[str, List[str]] = {} + self.txt_records: dict[str, List[str]] = {} + + def set_cname_record(self, hostname: str, cname: str): + self.cname_records[hostname] = cname + + def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]): + self.mx_records[hostname] = mx_list + + def set_txt_record(self, hostname: str, txt_list: List[str]): + self.txt_records[hostname] = txt_list + + def get_cname_record(self, hostname: str) -> Optional[str]: + return self.cname_records.get(hostname) + + def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]: + mx_list = self.mx_records.get(hostname, []) + return sorted(mx_list, key=lambda x: x[0]) + + def get_txt_record(self, hostname: str) -> List[str]: + return self.txt_records.get(hostname, []) + + +def get_network_dns_client() -> NetworkDNSClient: + return NetworkDNSClient(NAMESERVERS) + + +def get_mx_domains(hostname: str) -> [(int, str)]: + return get_network_dns_client().get_mx_domains(hostname) diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html index 15ef346f..810aa302 100644 --- a/templates/dashboard/domain_detail/dns.html +++ b/templates/dashboard/domain_detail/dns.html @@ -237,7 +237,7 @@ folder.