mirror of
https://github.com/simple-login/app.git
synced 2024-11-10 21:27:10 +01:00
chore: refactor dns to improve testability (#2224)
* chore: refactor DNS client to its own class * chore: adapt code calling DNS and add tests to it * chore: refactor old dkim check not to clear flag
This commit is contained in:
parent
065cc3db92
commit
f6708dd0b6
@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
|
|||||||
from app.constants import DMARC_RECORD
|
from app.constants import DMARC_RECORD
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.dns_utils import (
|
from app.dns_utils import (
|
||||||
get_cname_record,
|
DNSClient,
|
||||||
get_mx_domains,
|
|
||||||
get_txt_record,
|
|
||||||
is_mx_equivalent,
|
is_mx_equivalent,
|
||||||
get_spf_domain,
|
get_network_dns_client,
|
||||||
)
|
)
|
||||||
from app.models import CustomDomain
|
from app.models import CustomDomain
|
||||||
from dataclasses import dataclass
|
from dataclasses import dataclass
|
||||||
@ -19,10 +17,13 @@ class DomainValidationResult:
|
|||||||
|
|
||||||
|
|
||||||
class CustomDomainValidation:
|
class CustomDomainValidation:
|
||||||
def __init__(self, dkim_domain: str):
|
def __init__(
|
||||||
|
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
|
||||||
|
):
|
||||||
self.dkim_domain = dkim_domain
|
self.dkim_domain = dkim_domain
|
||||||
|
self._dns_client = dns_client
|
||||||
self._dkim_records = {
|
self._dkim_records = {
|
||||||
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
|
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
|
||||||
for key in ("dkim", "dkim02", "dkim03")
|
for key in ("dkim", "dkim02", "dkim03")
|
||||||
}
|
}
|
||||||
|
|
||||||
@ -38,15 +39,31 @@ class CustomDomainValidation:
|
|||||||
Check if dkim records are properly set for this custom domain.
|
Check if dkim records are properly set for this custom domain.
|
||||||
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
|
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
|
||||||
"""
|
"""
|
||||||
|
correct_records = {}
|
||||||
invalid_records = {}
|
invalid_records = {}
|
||||||
for prefix, expected_record in self.get_dkim_records():
|
expected_records = self.get_dkim_records()
|
||||||
|
for prefix, expected_record in expected_records.items():
|
||||||
custom_record = f"{prefix}.{custom_domain.domain}"
|
custom_record = f"{prefix}.{custom_domain.domain}"
|
||||||
dkim_record = get_cname_record(custom_record)
|
dkim_record = self._dns_client.get_cname_record(custom_record)
|
||||||
if dkim_record != expected_record:
|
if dkim_record == expected_record:
|
||||||
|
correct_records[prefix] = custom_record
|
||||||
|
else:
|
||||||
invalid_records[custom_record] = dkim_record or "empty"
|
invalid_records[custom_record] = dkim_record or "empty"
|
||||||
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
|
|
||||||
|
# HACK
|
||||||
|
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
|
||||||
|
# the domain validated to continue seeing it as validated (although showing them the missing records).
|
||||||
|
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
|
||||||
|
# we will remove the dkim_verified flag.
|
||||||
|
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
|
||||||
if custom_domain.dkim_verified:
|
if custom_domain.dkim_verified:
|
||||||
return invalid_records
|
# Check if at least the original dkim is there
|
||||||
|
if correct_records.get("dkim._domainkey") is not None:
|
||||||
|
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
|
||||||
|
return invalid_records
|
||||||
|
|
||||||
|
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
|
||||||
|
# rest of the code path, returning the invalid records and clearing the flag
|
||||||
custom_domain.dkim_verified = len(invalid_records) == 0
|
custom_domain.dkim_verified = len(invalid_records) == 0
|
||||||
Session.commit()
|
Session.commit()
|
||||||
return invalid_records
|
return invalid_records
|
||||||
@ -57,7 +74,7 @@ class CustomDomainValidation:
|
|||||||
"""
|
"""
|
||||||
Check if the custom_domain has added the ownership verification records
|
Check if the custom_domain has added the ownership verification records
|
||||||
"""
|
"""
|
||||||
txt_records = get_txt_record(custom_domain.domain)
|
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
||||||
|
|
||||||
if custom_domain.get_ownership_dns_txt_value() in txt_records:
|
if custom_domain.get_ownership_dns_txt_value() in txt_records:
|
||||||
custom_domain.ownership_verified = True
|
custom_domain.ownership_verified = True
|
||||||
@ -69,7 +86,7 @@ class CustomDomainValidation:
|
|||||||
def validate_mx_records(
|
def validate_mx_records(
|
||||||
self, custom_domain: CustomDomain
|
self, custom_domain: CustomDomain
|
||||||
) -> DomainValidationResult:
|
) -> DomainValidationResult:
|
||||||
mx_domains = get_mx_domains(custom_domain.domain)
|
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
|
||||||
|
|
||||||
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
|
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
|
||||||
return DomainValidationResult(
|
return DomainValidationResult(
|
||||||
@ -84,7 +101,7 @@ class CustomDomainValidation:
|
|||||||
def validate_spf_records(
|
def validate_spf_records(
|
||||||
self, custom_domain: CustomDomain
|
self, custom_domain: CustomDomain
|
||||||
) -> DomainValidationResult:
|
) -> DomainValidationResult:
|
||||||
spf_domains = get_spf_domain(custom_domain.domain)
|
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
|
||||||
if EMAIL_DOMAIN in spf_domains:
|
if EMAIL_DOMAIN in spf_domains:
|
||||||
custom_domain.spf_verified = True
|
custom_domain.spf_verified = True
|
||||||
Session.commit()
|
Session.commit()
|
||||||
@ -93,13 +110,14 @@ class CustomDomainValidation:
|
|||||||
custom_domain.spf_verified = False
|
custom_domain.spf_verified = False
|
||||||
Session.commit()
|
Session.commit()
|
||||||
return DomainValidationResult(
|
return DomainValidationResult(
|
||||||
success=False, errors=get_txt_record(custom_domain.domain)
|
success=False,
|
||||||
|
errors=self._dns_client.get_txt_record(custom_domain.domain),
|
||||||
)
|
)
|
||||||
|
|
||||||
def validate_dmarc_records(
|
def validate_dmarc_records(
|
||||||
self, custom_domain: CustomDomain
|
self, custom_domain: CustomDomain
|
||||||
) -> DomainValidationResult:
|
) -> DomainValidationResult:
|
||||||
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
|
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
|
||||||
if DMARC_RECORD in txt_records:
|
if DMARC_RECORD in txt_records:
|
||||||
custom_domain.dmarc_verified = True
|
custom_domain.dmarc_verified = True
|
||||||
Session.commit()
|
Session.commit()
|
||||||
|
214
app/dns_utils.py
214
app/dns_utils.py
@ -1,100 +1,13 @@
|
|||||||
from app import config
|
from abc import ABC, abstractmethod
|
||||||
from typing import Optional, List, Tuple
|
from typing import List, Tuple, Optional
|
||||||
|
|
||||||
import dns.resolver
|
import dns.resolver
|
||||||
|
|
||||||
|
from app.config import NAMESERVERS
|
||||||
def _get_dns_resolver():
|
|
||||||
my_resolver = dns.resolver.Resolver()
|
|
||||||
my_resolver.nameservers = config.NAMESERVERS
|
|
||||||
|
|
||||||
return my_resolver
|
|
||||||
|
|
||||||
|
|
||||||
def get_ns(hostname) -> [str]:
|
|
||||||
try:
|
|
||||||
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
return [a.to_text() for a in answers]
|
|
||||||
|
|
||||||
|
|
||||||
def get_cname_record(hostname) -> Optional[str]:
|
|
||||||
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
|
|
||||||
try:
|
|
||||||
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
|
|
||||||
except Exception:
|
|
||||||
return None
|
|
||||||
|
|
||||||
for a in answers:
|
|
||||||
ret = a.to_text()
|
|
||||||
return ret[:-1]
|
|
||||||
|
|
||||||
return None
|
|
||||||
|
|
||||||
|
|
||||||
def get_mx_domains(hostname) -> [(int, str)]:
|
|
||||||
"""return list of (priority, domain name) sorted by priority (lowest priority first)
|
|
||||||
domain name ends with a "." at the end.
|
|
||||||
"""
|
|
||||||
try:
|
|
||||||
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for a in answers:
|
|
||||||
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
|
|
||||||
parts = record.split(" ")
|
|
||||||
|
|
||||||
ret.append((int(parts[0]), parts[1]))
|
|
||||||
|
|
||||||
return sorted(ret, key=lambda prio_domain: prio_domain[0])
|
|
||||||
|
|
||||||
|
|
||||||
_include_spf = "include:"
|
_include_spf = "include:"
|
||||||
|
|
||||||
|
|
||||||
def get_spf_domain(hostname) -> [str]:
|
|
||||||
"""return all domains listed in *include:*"""
|
|
||||||
try:
|
|
||||||
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
|
||||||
for record in a.strings:
|
|
||||||
record = record.decode() # record is bytes
|
|
||||||
|
|
||||||
if record.startswith("v=spf1"):
|
|
||||||
parts = record.split(" ")
|
|
||||||
for part in parts:
|
|
||||||
if part.startswith(_include_spf):
|
|
||||||
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def get_txt_record(hostname) -> [str]:
|
|
||||||
try:
|
|
||||||
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
|
|
||||||
except Exception:
|
|
||||||
return []
|
|
||||||
|
|
||||||
ret = []
|
|
||||||
|
|
||||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
|
||||||
for record in a.strings:
|
|
||||||
record = record.decode() # record is bytes
|
|
||||||
|
|
||||||
ret.append(record)
|
|
||||||
|
|
||||||
return ret
|
|
||||||
|
|
||||||
|
|
||||||
def is_mx_equivalent(
|
def is_mx_equivalent(
|
||||||
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
|
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
|
||||||
) -> bool:
|
) -> bool:
|
||||||
@ -105,16 +18,127 @@ def is_mx_equivalent(
|
|||||||
The priority order is taken into account but not the priority number.
|
The priority order is taken into account but not the priority number.
|
||||||
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
|
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
|
||||||
"""
|
"""
|
||||||
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
|
mx_domains = sorted(mx_domains, key=lambda x: x[0])
|
||||||
ref_mx_domains = sorted(
|
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
|
||||||
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
|
|
||||||
)
|
|
||||||
|
|
||||||
if len(mx_domains) < len(ref_mx_domains):
|
if len(mx_domains) < len(ref_mx_domains):
|
||||||
return False
|
return False
|
||||||
|
|
||||||
for i in range(0, len(ref_mx_domains)):
|
for i in range(len(ref_mx_domains)):
|
||||||
if mx_domains[i][1] != ref_mx_domains[i][1]:
|
if mx_domains[i][1] != ref_mx_domains[i][1]:
|
||||||
return False
|
return False
|
||||||
|
|
||||||
return True
|
return True
|
||||||
|
|
||||||
|
|
||||||
|
class DNSClient(ABC):
|
||||||
|
@abstractmethod
|
||||||
|
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def get_spf_domain(self, hostname: str) -> List[str]:
|
||||||
|
"""
|
||||||
|
return all domains listed in *include:*
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
records = self.get_txt_record(hostname)
|
||||||
|
ret = []
|
||||||
|
for record in records:
|
||||||
|
if record.startswith("v=spf1"):
|
||||||
|
parts = record.split(" ")
|
||||||
|
for part in parts:
|
||||||
|
if part.startswith(_include_spf):
|
||||||
|
ret.append(
|
||||||
|
part[part.find(_include_spf) + len(_include_spf) :]
|
||||||
|
)
|
||||||
|
return ret
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
@abstractmethod
|
||||||
|
def get_txt_record(self, hostname: str) -> List[str]:
|
||||||
|
pass
|
||||||
|
|
||||||
|
|
||||||
|
class NetworkDNSClient(DNSClient):
|
||||||
|
def __init__(self, nameservers: List[str]):
|
||||||
|
self._resolver = dns.resolver.Resolver()
|
||||||
|
self._resolver.nameservers = nameservers
|
||||||
|
|
||||||
|
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||||
|
"""
|
||||||
|
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
answers = self._resolver.resolve(hostname, "CNAME", search=True)
|
||||||
|
for a in answers:
|
||||||
|
ret = a.to_text()
|
||||||
|
return ret[:-1]
|
||||||
|
except Exception:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||||
|
"""
|
||||||
|
return list of (priority, domain name) sorted by priority (lowest priority first)
|
||||||
|
domain name ends with a "." at the end.
|
||||||
|
"""
|
||||||
|
try:
|
||||||
|
answers = self._resolver.resolve(hostname, "MX", search=True)
|
||||||
|
ret = []
|
||||||
|
for a in answers:
|
||||||
|
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
|
||||||
|
parts = record.split(" ")
|
||||||
|
ret.append((int(parts[0]), parts[1]))
|
||||||
|
return sorted(ret, key=lambda x: x[0])
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
def get_txt_record(self, hostname: str) -> List[str]:
|
||||||
|
try:
|
||||||
|
answers = self._resolver.resolve(hostname, "TXT", search=True)
|
||||||
|
ret = []
|
||||||
|
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||||
|
for record in a.strings:
|
||||||
|
ret.append(record.decode())
|
||||||
|
return ret
|
||||||
|
except Exception:
|
||||||
|
return []
|
||||||
|
|
||||||
|
|
||||||
|
class InMemoryDNSClient(DNSClient):
|
||||||
|
def __init__(self):
|
||||||
|
self.cname_records: dict[str, Optional[str]] = {}
|
||||||
|
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
|
||||||
|
self.spf_records: dict[str, List[str]] = {}
|
||||||
|
self.txt_records: dict[str, List[str]] = {}
|
||||||
|
|
||||||
|
def set_cname_record(self, hostname: str, cname: str):
|
||||||
|
self.cname_records[hostname] = cname
|
||||||
|
|
||||||
|
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
|
||||||
|
self.mx_records[hostname] = mx_list
|
||||||
|
|
||||||
|
def set_txt_record(self, hostname: str, txt_list: List[str]):
|
||||||
|
self.txt_records[hostname] = txt_list
|
||||||
|
|
||||||
|
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||||
|
return self.cname_records.get(hostname)
|
||||||
|
|
||||||
|
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||||
|
mx_list = self.mx_records.get(hostname, [])
|
||||||
|
return sorted(mx_list, key=lambda x: x[0])
|
||||||
|
|
||||||
|
def get_txt_record(self, hostname: str) -> List[str]:
|
||||||
|
return self.txt_records.get(hostname, [])
|
||||||
|
|
||||||
|
|
||||||
|
def get_network_dns_client() -> NetworkDNSClient:
|
||||||
|
return NetworkDNSClient(NAMESERVERS)
|
||||||
|
|
||||||
|
|
||||||
|
def get_mx_domains(hostname: str) -> [(int, str)]:
|
||||||
|
return get_network_dns_client().get_mx_domains(hostname)
|
||||||
|
@ -237,7 +237,7 @@
|
|||||||
folder.
|
folder.
|
||||||
</div>
|
</div>
|
||||||
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
|
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
|
||||||
{% for dkim_prefix, dkim_cname_value in dkim_records %}
|
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
|
||||||
|
|
||||||
<div class="mb-2 p-3 dns-record">
|
<div class="mb-2 p-3 dns-record">
|
||||||
Record: CNAME
|
Record: CNAME
|
||||||
|
325
tests/test_custom_domain_validation.py
Normal file
325
tests/test_custom_domain_validation.py
Normal file
@ -0,0 +1,325 @@
|
|||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app import config
|
||||||
|
from app.constants import DMARC_RECORD
|
||||||
|
from app.custom_domain_validation import CustomDomainValidation
|
||||||
|
from app.db import Session
|
||||||
|
from app.models import CustomDomain, User
|
||||||
|
from app.dns_utils import InMemoryDNSClient
|
||||||
|
from app.utils import random_string
|
||||||
|
from tests.utils import create_new_user, random_domain
|
||||||
|
|
||||||
|
user: Optional[User] = None
|
||||||
|
|
||||||
|
|
||||||
|
def setup_module():
|
||||||
|
global user
|
||||||
|
config.SKIP_MX_LOOKUP_ON_CHECK = True
|
||||||
|
user = create_new_user()
|
||||||
|
user.trial_end = None
|
||||||
|
user.lifetime = True
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
|
||||||
|
def create_custom_domain(domain: str) -> CustomDomain:
|
||||||
|
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_get_dkim_records():
|
||||||
|
domain = random_domain()
|
||||||
|
validator = CustomDomainValidation(domain)
|
||||||
|
records = validator.get_dkim_records()
|
||||||
|
|
||||||
|
assert len(records) == 3
|
||||||
|
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
||||||
|
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
|
||||||
|
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
||||||
|
|
||||||
|
|
||||||
|
# validate_dkim_records
|
||||||
|
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
res = validator.validate_dkim_records(domain)
|
||||||
|
|
||||||
|
assert len(res) == 3
|
||||||
|
for record_value in res.values():
|
||||||
|
assert record_value == "empty"
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dkim_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
|
||||||
|
dkim_domain = random_domain()
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||||
|
|
||||||
|
user_domain = random_domain()
|
||||||
|
|
||||||
|
# One domain right, two domains wrong
|
||||||
|
dns_client.set_cname_record(
|
||||||
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||||
|
)
|
||||||
|
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
|
||||||
|
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
|
||||||
|
|
||||||
|
domain = create_custom_domain(user_domain)
|
||||||
|
res = validator.validate_dkim_records(domain)
|
||||||
|
|
||||||
|
assert len(res) == 2
|
||||||
|
for record_value in res.values():
|
||||||
|
assert record_value == "wrong"
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dkim_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
|
||||||
|
dkim_domain = random_domain()
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||||
|
|
||||||
|
user_domain = random_domain()
|
||||||
|
|
||||||
|
# One domain right, other domains missing
|
||||||
|
dns_client.set_cname_record(
|
||||||
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||||
|
)
|
||||||
|
|
||||||
|
domain = create_custom_domain(user_domain)
|
||||||
|
|
||||||
|
# DKIM is verified
|
||||||
|
domain.dkim_verified = True
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
res = validator.validate_dkim_records(domain)
|
||||||
|
assert len(res) == 2
|
||||||
|
assert f"dkim02._domainkey.{user_domain}" in res
|
||||||
|
assert f"dkim03._domainkey.{user_domain}" in res
|
||||||
|
|
||||||
|
# Flag is not cleared
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dkim_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_dkim_records_success():
|
||||||
|
dkim_domain = random_domain()
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||||
|
|
||||||
|
user_domain = random_domain()
|
||||||
|
|
||||||
|
# One domain right, two domains wrong
|
||||||
|
dns_client.set_cname_record(
|
||||||
|
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||||
|
)
|
||||||
|
dns_client.set_cname_record(
|
||||||
|
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
|
||||||
|
)
|
||||||
|
dns_client.set_cname_record(
|
||||||
|
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
|
||||||
|
)
|
||||||
|
|
||||||
|
domain = create_custom_domain(user_domain)
|
||||||
|
res = validator.validate_dkim_records(domain)
|
||||||
|
assert len(res) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dkim_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
# validate_ownership
|
||||||
|
def test_custom_domain_validation_validate_ownership_empty_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
res = validator.validate_domain_ownership(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.ownership_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
wrong_records = [random_string()]
|
||||||
|
dns_client.set_txt_record(domain.domain, wrong_records)
|
||||||
|
res = validator.validate_domain_ownership(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert res.errors == wrong_records
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.ownership_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_ownership_success():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
|
||||||
|
res = validator.validate_domain_ownership(domain)
|
||||||
|
|
||||||
|
assert res.success is True
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.ownership_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
# validate_mx_records
|
||||||
|
def test_custom_domain_validation_validate_mx_records_empty_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
res = validator.validate_mx_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
wrong_record_1 = random_string()
|
||||||
|
wrong_record_2 = random_string()
|
||||||
|
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
|
||||||
|
dns_client.set_mx_records(domain.domain, wrong_records)
|
||||||
|
res = validator.validate_mx_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_mx_records_success():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
|
||||||
|
res = validator.validate_mx_records(domain)
|
||||||
|
|
||||||
|
assert res.success is True
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.verified is True
|
||||||
|
|
||||||
|
|
||||||
|
# validate_spf_records
|
||||||
|
def test_custom_domain_validation_validate_spf_records_empty_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
res = validator.validate_spf_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.spf_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
wrong_records = [random_string()]
|
||||||
|
dns_client.set_txt_record(domain.domain, wrong_records)
|
||||||
|
res = validator.validate_spf_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert res.errors == wrong_records
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.spf_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_spf_records_success():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
|
||||||
|
res = validator.validate_spf_records(domain)
|
||||||
|
|
||||||
|
assert res.success is True
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.spf_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
# validate_dmarc_records
|
||||||
|
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
res = validator.validate_dmarc_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dmarc_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
wrong_records = [random_string()]
|
||||||
|
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
|
||||||
|
res = validator.validate_dmarc_records(domain)
|
||||||
|
|
||||||
|
assert res.success is False
|
||||||
|
assert res.errors == wrong_records
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dmarc_verified is False
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_dmarc_records_success():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
|
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
|
||||||
|
res = validator.validate_dmarc_records(domain)
|
||||||
|
|
||||||
|
assert res.success is True
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.dmarc_verified is True
|
@ -1,10 +1,12 @@
|
|||||||
from app.dns_utils import (
|
from app.dns_utils import (
|
||||||
get_mx_domains,
|
get_mx_domains,
|
||||||
get_spf_domain,
|
get_network_dns_client,
|
||||||
get_txt_record,
|
|
||||||
is_mx_equivalent,
|
is_mx_equivalent,
|
||||||
|
InMemoryDNSClient,
|
||||||
)
|
)
|
||||||
|
|
||||||
|
from tests.utils import random_domain
|
||||||
|
|
||||||
# use our own domain for test
|
# use our own domain for test
|
||||||
_DOMAIN = "simplelogin.io"
|
_DOMAIN = "simplelogin.io"
|
||||||
|
|
||||||
@ -20,12 +22,12 @@ def test_get_mx_domains():
|
|||||||
|
|
||||||
|
|
||||||
def test_get_spf_domain():
|
def test_get_spf_domain():
|
||||||
r = get_spf_domain(_DOMAIN)
|
r = get_network_dns_client().get_spf_domain(_DOMAIN)
|
||||||
assert r == ["simplelogin.co"]
|
assert r == ["simplelogin.co"]
|
||||||
|
|
||||||
|
|
||||||
def test_get_txt_record():
|
def test_get_txt_record():
|
||||||
r = get_txt_record(_DOMAIN)
|
r = get_network_dns_client().get_txt_record(_DOMAIN)
|
||||||
assert len(r) > 0
|
assert len(r) > 0
|
||||||
|
|
||||||
|
|
||||||
@ -46,3 +48,15 @@ def test_is_mx_equivalent():
|
|||||||
[(5, "domain1"), (10, "domain2")],
|
[(5, "domain1"), (10, "domain2")],
|
||||||
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
|
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
|
||||||
)
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def test_get_spf_record():
|
||||||
|
client = InMemoryDNSClient()
|
||||||
|
|
||||||
|
sl_domain = random_domain()
|
||||||
|
domain = random_domain()
|
||||||
|
|
||||||
|
spf_record = f"v=spf1 include:{sl_domain}"
|
||||||
|
client.set_txt_record(domain, [spf_record, "another record"])
|
||||||
|
res = client.get_spf_domain(domain)
|
||||||
|
assert res == [sl_domain]
|
||||||
|
Loading…
Reference in New Issue
Block a user