mirror of
https://github.com/simple-login/app.git
synced 2024-09-21 01:11:29 +02:00
chore: refactor dns to improve testability (#2224)
* chore: refactor DNS client to its own class * chore: adapt code calling DNS and add tests to it * chore: refactor old dkim check not to clear flag
This commit is contained in:
parent
065cc3db92
commit
f6708dd0b6
@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
|
||||
from app.constants import DMARC_RECORD
|
||||
from app.db import Session
|
||||
from app.dns_utils import (
|
||||
get_cname_record,
|
||||
get_mx_domains,
|
||||
get_txt_record,
|
||||
DNSClient,
|
||||
is_mx_equivalent,
|
||||
get_spf_domain,
|
||||
get_network_dns_client,
|
||||
)
|
||||
from app.models import CustomDomain
|
||||
from dataclasses import dataclass
|
||||
@ -19,10 +17,13 @@ class DomainValidationResult:
|
||||
|
||||
|
||||
class CustomDomainValidation:
|
||||
def __init__(self, dkim_domain: str):
|
||||
def __init__(
|
||||
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
|
||||
):
|
||||
self.dkim_domain = dkim_domain
|
||||
self._dns_client = dns_client
|
||||
self._dkim_records = {
|
||||
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
|
||||
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
|
||||
for key in ("dkim", "dkim02", "dkim03")
|
||||
}
|
||||
|
||||
@ -38,15 +39,31 @@ class CustomDomainValidation:
|
||||
Check if dkim records are properly set for this custom domain.
|
||||
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
|
||||
"""
|
||||
correct_records = {}
|
||||
invalid_records = {}
|
||||
for prefix, expected_record in self.get_dkim_records():
|
||||
expected_records = self.get_dkim_records()
|
||||
for prefix, expected_record in expected_records.items():
|
||||
custom_record = f"{prefix}.{custom_domain.domain}"
|
||||
dkim_record = get_cname_record(custom_record)
|
||||
if dkim_record != expected_record:
|
||||
dkim_record = self._dns_client.get_cname_record(custom_record)
|
||||
if dkim_record == expected_record:
|
||||
correct_records[prefix] = custom_record
|
||||
else:
|
||||
invalid_records[custom_record] = dkim_record or "empty"
|
||||
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
|
||||
|
||||
# HACK
|
||||
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
|
||||
# the domain validated to continue seeing it as validated (although showing them the missing records).
|
||||
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
|
||||
# we will remove the dkim_verified flag.
|
||||
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
|
||||
if custom_domain.dkim_verified:
|
||||
return invalid_records
|
||||
# Check if at least the original dkim is there
|
||||
if correct_records.get("dkim._domainkey") is not None:
|
||||
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
|
||||
return invalid_records
|
||||
|
||||
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
|
||||
# rest of the code path, returning the invalid records and clearing the flag
|
||||
custom_domain.dkim_verified = len(invalid_records) == 0
|
||||
Session.commit()
|
||||
return invalid_records
|
||||
@ -57,7 +74,7 @@ class CustomDomainValidation:
|
||||
"""
|
||||
Check if the custom_domain has added the ownership verification records
|
||||
"""
|
||||
txt_records = get_txt_record(custom_domain.domain)
|
||||
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
||||
|
||||
if custom_domain.get_ownership_dns_txt_value() in txt_records:
|
||||
custom_domain.ownership_verified = True
|
||||
@ -69,7 +86,7 @@ class CustomDomainValidation:
|
||||
def validate_mx_records(
|
||||
self, custom_domain: CustomDomain
|
||||
) -> DomainValidationResult:
|
||||
mx_domains = get_mx_domains(custom_domain.domain)
|
||||
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
|
||||
|
||||
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
|
||||
return DomainValidationResult(
|
||||
@ -84,7 +101,7 @@ class CustomDomainValidation:
|
||||
def validate_spf_records(
|
||||
self, custom_domain: CustomDomain
|
||||
) -> DomainValidationResult:
|
||||
spf_domains = get_spf_domain(custom_domain.domain)
|
||||
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
|
||||
if EMAIL_DOMAIN in spf_domains:
|
||||
custom_domain.spf_verified = True
|
||||
Session.commit()
|
||||
@ -93,13 +110,14 @@ class CustomDomainValidation:
|
||||
custom_domain.spf_verified = False
|
||||
Session.commit()
|
||||
return DomainValidationResult(
|
||||
success=False, errors=get_txt_record(custom_domain.domain)
|
||||
success=False,
|
||||
errors=self._dns_client.get_txt_record(custom_domain.domain),
|
||||
)
|
||||
|
||||
def validate_dmarc_records(
|
||||
self, custom_domain: CustomDomain
|
||||
) -> DomainValidationResult:
|
||||
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
|
||||
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
|
||||
if DMARC_RECORD in txt_records:
|
||||
custom_domain.dmarc_verified = True
|
||||
Session.commit()
|
||||
|
214
app/dns_utils.py
214
app/dns_utils.py
@ -1,100 +1,13 @@
|
||||
from app import config
|
||||
from typing import Optional, List, Tuple
|
||||
from abc import ABC, abstractmethod
|
||||
from typing import List, Tuple, Optional
|
||||
|
||||
import dns.resolver
|
||||
|
||||
|
||||
def _get_dns_resolver():
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
my_resolver.nameservers = config.NAMESERVERS
|
||||
|
||||
return my_resolver
|
||||
|
||||
|
||||
def get_ns(hostname) -> [str]:
|
||||
try:
|
||||
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
|
||||
except Exception:
|
||||
return []
|
||||
return [a.to_text() for a in answers]
|
||||
|
||||
|
||||
def get_cname_record(hostname) -> Optional[str]:
|
||||
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
|
||||
try:
|
||||
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
for a in answers:
|
||||
ret = a.to_text()
|
||||
return ret[:-1]
|
||||
|
||||
return None
|
||||
|
||||
|
||||
def get_mx_domains(hostname) -> [(int, str)]:
|
||||
"""return list of (priority, domain name) sorted by priority (lowest priority first)
|
||||
domain name ends with a "." at the end.
|
||||
"""
|
||||
try:
|
||||
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
ret = []
|
||||
|
||||
for a in answers:
|
||||
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
|
||||
parts = record.split(" ")
|
||||
|
||||
ret.append((int(parts[0]), parts[1]))
|
||||
|
||||
return sorted(ret, key=lambda prio_domain: prio_domain[0])
|
||||
|
||||
from app.config import NAMESERVERS
|
||||
|
||||
_include_spf = "include:"
|
||||
|
||||
|
||||
def get_spf_domain(hostname) -> [str]:
|
||||
"""return all domains listed in *include:*"""
|
||||
try:
|
||||
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
ret = []
|
||||
|
||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||
for record in a.strings:
|
||||
record = record.decode() # record is bytes
|
||||
|
||||
if record.startswith("v=spf1"):
|
||||
parts = record.split(" ")
|
||||
for part in parts:
|
||||
if part.startswith(_include_spf):
|
||||
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def get_txt_record(hostname) -> [str]:
|
||||
try:
|
||||
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
ret = []
|
||||
|
||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||
for record in a.strings:
|
||||
record = record.decode() # record is bytes
|
||||
|
||||
ret.append(record)
|
||||
|
||||
return ret
|
||||
|
||||
|
||||
def is_mx_equivalent(
|
||||
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
|
||||
) -> bool:
|
||||
@ -105,16 +18,127 @@ def is_mx_equivalent(
|
||||
The priority order is taken into account but not the priority number.
|
||||
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
|
||||
"""
|
||||
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
|
||||
ref_mx_domains = sorted(
|
||||
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
|
||||
)
|
||||
mx_domains = sorted(mx_domains, key=lambda x: x[0])
|
||||
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
|
||||
|
||||
if len(mx_domains) < len(ref_mx_domains):
|
||||
return False
|
||||
|
||||
for i in range(0, len(ref_mx_domains)):
|
||||
for i in range(len(ref_mx_domains)):
|
||||
if mx_domains[i][1] != ref_mx_domains[i][1]:
|
||||
return False
|
||||
|
||||
return True
|
||||
|
||||
|
||||
class DNSClient(ABC):
|
||||
@abstractmethod
|
||||
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||
pass
|
||||
|
||||
@abstractmethod
|
||||
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||
pass
|
||||
|
||||
def get_spf_domain(self, hostname: str) -> List[str]:
|
||||
"""
|
||||
return all domains listed in *include:*
|
||||
"""
|
||||
try:
|
||||
records = self.get_txt_record(hostname)
|
||||
ret = []
|
||||
for record in records:
|
||||
if record.startswith("v=spf1"):
|
||||
parts = record.split(" ")
|
||||
for part in parts:
|
||||
if part.startswith(_include_spf):
|
||||
ret.append(
|
||||
part[part.find(_include_spf) + len(_include_spf) :]
|
||||
)
|
||||
return ret
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
@abstractmethod
|
||||
def get_txt_record(self, hostname: str) -> List[str]:
|
||||
pass
|
||||
|
||||
|
||||
class NetworkDNSClient(DNSClient):
|
||||
def __init__(self, nameservers: List[str]):
|
||||
self._resolver = dns.resolver.Resolver()
|
||||
self._resolver.nameservers = nameservers
|
||||
|
||||
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||
"""
|
||||
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
|
||||
"""
|
||||
try:
|
||||
answers = self._resolver.resolve(hostname, "CNAME", search=True)
|
||||
for a in answers:
|
||||
ret = a.to_text()
|
||||
return ret[:-1]
|
||||
except Exception:
|
||||
return None
|
||||
|
||||
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||
"""
|
||||
return list of (priority, domain name) sorted by priority (lowest priority first)
|
||||
domain name ends with a "." at the end.
|
||||
"""
|
||||
try:
|
||||
answers = self._resolver.resolve(hostname, "MX", search=True)
|
||||
ret = []
|
||||
for a in answers:
|
||||
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
|
||||
parts = record.split(" ")
|
||||
ret.append((int(parts[0]), parts[1]))
|
||||
return sorted(ret, key=lambda x: x[0])
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
def get_txt_record(self, hostname: str) -> List[str]:
|
||||
try:
|
||||
answers = self._resolver.resolve(hostname, "TXT", search=True)
|
||||
ret = []
|
||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||
for record in a.strings:
|
||||
ret.append(record.decode())
|
||||
return ret
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
||||
class InMemoryDNSClient(DNSClient):
|
||||
def __init__(self):
|
||||
self.cname_records: dict[str, Optional[str]] = {}
|
||||
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
|
||||
self.spf_records: dict[str, List[str]] = {}
|
||||
self.txt_records: dict[str, List[str]] = {}
|
||||
|
||||
def set_cname_record(self, hostname: str, cname: str):
|
||||
self.cname_records[hostname] = cname
|
||||
|
||||
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
|
||||
self.mx_records[hostname] = mx_list
|
||||
|
||||
def set_txt_record(self, hostname: str, txt_list: List[str]):
|
||||
self.txt_records[hostname] = txt_list
|
||||
|
||||
def get_cname_record(self, hostname: str) -> Optional[str]:
|
||||
return self.cname_records.get(hostname)
|
||||
|
||||
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
|
||||
mx_list = self.mx_records.get(hostname, [])
|
||||
return sorted(mx_list, key=lambda x: x[0])
|
||||
|
||||
def get_txt_record(self, hostname: str) -> List[str]:
|
||||
return self.txt_records.get(hostname, [])
|
||||
|
||||
|
||||
def get_network_dns_client() -> NetworkDNSClient:
|
||||
return NetworkDNSClient(NAMESERVERS)
|
||||
|
||||
|
||||
def get_mx_domains(hostname: str) -> [(int, str)]:
|
||||
return get_network_dns_client().get_mx_domains(hostname)
|
||||
|
@ -237,7 +237,7 @@
|
||||
folder.
|
||||
</div>
|
||||
<div class="mb-2">Add the following CNAME DNS records to your domain.</div>
|
||||
{% for dkim_prefix, dkim_cname_value in dkim_records %}
|
||||
{% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
|
||||
|
||||
<div class="mb-2 p-3 dns-record">
|
||||
Record: CNAME
|
||||
|
325
tests/test_custom_domain_validation.py
Normal file
325
tests/test_custom_domain_validation.py
Normal file
@ -0,0 +1,325 @@
|
||||
from typing import Optional
|
||||
|
||||
from app import config
|
||||
from app.constants import DMARC_RECORD
|
||||
from app.custom_domain_validation import CustomDomainValidation
|
||||
from app.db import Session
|
||||
from app.models import CustomDomain, User
|
||||
from app.dns_utils import InMemoryDNSClient
|
||||
from app.utils import random_string
|
||||
from tests.utils import create_new_user, random_domain
|
||||
|
||||
user: Optional[User] = None
|
||||
|
||||
|
||||
def setup_module():
|
||||
global user
|
||||
config.SKIP_MX_LOOKUP_ON_CHECK = True
|
||||
user = create_new_user()
|
||||
user.trial_end = None
|
||||
user.lifetime = True
|
||||
Session.commit()
|
||||
|
||||
|
||||
def create_custom_domain(domain: str) -> CustomDomain:
|
||||
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
|
||||
|
||||
|
||||
def test_custom_domain_validation_get_dkim_records():
|
||||
domain = random_domain()
|
||||
validator = CustomDomainValidation(domain)
|
||||
records = validator.get_dkim_records()
|
||||
|
||||
assert len(records) == 3
|
||||
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
||||
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
|
||||
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
||||
|
||||
|
||||
# validate_dkim_records
|
||||
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
res = validator.validate_dkim_records(domain)
|
||||
|
||||
assert len(res) == 3
|
||||
for record_value in res.values():
|
||||
assert record_value == "empty"
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dkim_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
|
||||
dkim_domain = random_domain()
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||
|
||||
user_domain = random_domain()
|
||||
|
||||
# One domain right, two domains wrong
|
||||
dns_client.set_cname_record(
|
||||
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||
)
|
||||
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
|
||||
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
|
||||
|
||||
domain = create_custom_domain(user_domain)
|
||||
res = validator.validate_dkim_records(domain)
|
||||
|
||||
assert len(res) == 2
|
||||
for record_value in res.values():
|
||||
assert record_value == "wrong"
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dkim_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
|
||||
dkim_domain = random_domain()
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||
|
||||
user_domain = random_domain()
|
||||
|
||||
# One domain right, other domains missing
|
||||
dns_client.set_cname_record(
|
||||
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||
)
|
||||
|
||||
domain = create_custom_domain(user_domain)
|
||||
|
||||
# DKIM is verified
|
||||
domain.dkim_verified = True
|
||||
Session.commit()
|
||||
|
||||
res = validator.validate_dkim_records(domain)
|
||||
assert len(res) == 2
|
||||
assert f"dkim02._domainkey.{user_domain}" in res
|
||||
assert f"dkim03._domainkey.{user_domain}" in res
|
||||
|
||||
# Flag is not cleared
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dkim_verified is True
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_dkim_records_success():
|
||||
dkim_domain = random_domain()
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(dkim_domain, dns_client)
|
||||
|
||||
user_domain = random_domain()
|
||||
|
||||
# One domain right, two domains wrong
|
||||
dns_client.set_cname_record(
|
||||
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
|
||||
)
|
||||
dns_client.set_cname_record(
|
||||
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
|
||||
)
|
||||
dns_client.set_cname_record(
|
||||
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
|
||||
)
|
||||
|
||||
domain = create_custom_domain(user_domain)
|
||||
res = validator.validate_dkim_records(domain)
|
||||
assert len(res) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dkim_verified is True
|
||||
|
||||
|
||||
# validate_ownership
|
||||
def test_custom_domain_validation_validate_ownership_empty_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
res = validator.validate_domain_ownership(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.ownership_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
wrong_records = [random_string()]
|
||||
dns_client.set_txt_record(domain.domain, wrong_records)
|
||||
res = validator.validate_domain_ownership(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert res.errors == wrong_records
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.ownership_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_ownership_success():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
|
||||
res = validator.validate_domain_ownership(domain)
|
||||
|
||||
assert res.success is True
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.ownership_verified is True
|
||||
|
||||
|
||||
# validate_mx_records
|
||||
def test_custom_domain_validation_validate_mx_records_empty_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
res = validator.validate_mx_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
wrong_record_1 = random_string()
|
||||
wrong_record_2 = random_string()
|
||||
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
|
||||
dns_client.set_mx_records(domain.domain, wrong_records)
|
||||
res = validator.validate_mx_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_mx_records_success():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
|
||||
res = validator.validate_mx_records(domain)
|
||||
|
||||
assert res.success is True
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.verified is True
|
||||
|
||||
|
||||
# validate_spf_records
|
||||
def test_custom_domain_validation_validate_spf_records_empty_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
res = validator.validate_spf_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.spf_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
wrong_records = [random_string()]
|
||||
dns_client.set_txt_record(domain.domain, wrong_records)
|
||||
res = validator.validate_spf_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert res.errors == wrong_records
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.spf_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_spf_records_success():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
|
||||
res = validator.validate_spf_records(domain)
|
||||
|
||||
assert res.success is True
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.spf_verified is True
|
||||
|
||||
|
||||
# validate_dmarc_records
|
||||
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
res = validator.validate_dmarc_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dmarc_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
wrong_records = [random_string()]
|
||||
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
|
||||
res = validator.validate_dmarc_records(domain)
|
||||
|
||||
assert res.success is False
|
||||
assert res.errors == wrong_records
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dmarc_verified is False
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_dmarc_records_success():
|
||||
dns_client = InMemoryDNSClient()
|
||||
validator = CustomDomainValidation(random_domain(), dns_client)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
|
||||
res = validator.validate_dmarc_records(domain)
|
||||
|
||||
assert res.success is True
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.dmarc_verified is True
|
@ -1,10 +1,12 @@
|
||||
from app.dns_utils import (
|
||||
get_mx_domains,
|
||||
get_spf_domain,
|
||||
get_txt_record,
|
||||
get_network_dns_client,
|
||||
is_mx_equivalent,
|
||||
InMemoryDNSClient,
|
||||
)
|
||||
|
||||
from tests.utils import random_domain
|
||||
|
||||
# use our own domain for test
|
||||
_DOMAIN = "simplelogin.io"
|
||||
|
||||
@ -20,12 +22,12 @@ def test_get_mx_domains():
|
||||
|
||||
|
||||
def test_get_spf_domain():
|
||||
r = get_spf_domain(_DOMAIN)
|
||||
r = get_network_dns_client().get_spf_domain(_DOMAIN)
|
||||
assert r == ["simplelogin.co"]
|
||||
|
||||
|
||||
def test_get_txt_record():
|
||||
r = get_txt_record(_DOMAIN)
|
||||
r = get_network_dns_client().get_txt_record(_DOMAIN)
|
||||
assert len(r) > 0
|
||||
|
||||
|
||||
@ -46,3 +48,15 @@ def test_is_mx_equivalent():
|
||||
[(5, "domain1"), (10, "domain2")],
|
||||
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
|
||||
)
|
||||
|
||||
|
||||
def test_get_spf_record():
|
||||
client = InMemoryDNSClient()
|
||||
|
||||
sl_domain = random_domain()
|
||||
domain = random_domain()
|
||||
|
||||
spf_record = f"v=spf1 include:{sl_domain}"
|
||||
client.set_txt_record(domain, [spf_record, "another record"])
|
||||
res = client.get_spf_domain(domain)
|
||||
assert res == [sl_domain]
|
||||
|
Loading…
Reference in New Issue
Block a user