diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py
index 05b6af37..6887e81a 100644
--- a/app/custom_domain_validation.py
+++ b/app/custom_domain_validation.py
@@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
- get_cname_record,
- get_mx_domains,
- get_txt_record,
+ DNSClient,
is_mx_equivalent,
- get_spf_domain,
+ get_network_dns_client,
)
from app.models import CustomDomain
from dataclasses import dataclass
@@ -19,10 +17,13 @@ class DomainValidationResult:
class CustomDomainValidation:
- def __init__(self, dkim_domain: str):
+ def __init__(
+ self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
+ ):
self.dkim_domain = dkim_domain
+ self._dns_client = dns_client
self._dkim_records = {
- (f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}")
+ f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
@@ -38,15 +39,31 @@ class CustomDomainValidation:
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
+ correct_records = {}
invalid_records = {}
- for prefix, expected_record in self.get_dkim_records():
+ expected_records = self.get_dkim_records()
+ for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
- dkim_record = get_cname_record(custom_record)
- if dkim_record != expected_record:
+ dkim_record = self._dns_client.get_cname_record(custom_record)
+ if dkim_record == expected_record:
+ correct_records[prefix] = custom_record
+ else:
invalid_records[custom_record] = dkim_record or "empty"
- # HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
+
+ # HACK
+ # As initially we only had one dkim record, we want to allow users that had only the original dkim record and
+ # the domain validated to continue seeing it as validated (although showing them the missing records).
+ # However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
+ # we will remove the dkim_verified flag.
+ # This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified:
- return invalid_records
+ # Check if at least the original dkim is there
+ if correct_records.get("dkim._domainkey") is not None:
+ # Original dkim record is there. Return the missing records (if any) and don't clear the flag
+ return invalid_records
+
+ # Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
+ # rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records
@@ -57,7 +74,7 @@ class CustomDomainValidation:
"""
Check if the custom_domain has added the ownership verification records
"""
- txt_records = get_txt_record(custom_domain.domain)
+ txt_records = self._dns_client.get_txt_record(custom_domain.domain)
if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True
@@ -69,7 +86,7 @@ class CustomDomainValidation:
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
- mx_domains = get_mx_domains(custom_domain.domain)
+ mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
@@ -84,7 +101,7 @@ class CustomDomainValidation:
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
- spf_domains = get_spf_domain(custom_domain.domain)
+ spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
@@ -93,13 +110,14 @@ class CustomDomainValidation:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
- success=False, errors=get_txt_record(custom_domain.domain)
+ success=False,
+ errors=self._dns_client.get_txt_record(custom_domain.domain),
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
- txt_records = get_txt_record("_dmarc." + custom_domain.domain)
+ txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
diff --git a/app/dns_utils.py b/app/dns_utils.py
index 429d0aa2..2ce69934 100644
--- a/app/dns_utils.py
+++ b/app/dns_utils.py
@@ -1,100 +1,13 @@
-from app import config
-from typing import Optional, List, Tuple
+from abc import ABC, abstractmethod
+from typing import List, Tuple, Optional
import dns.resolver
-
-def _get_dns_resolver():
- my_resolver = dns.resolver.Resolver()
- my_resolver.nameservers = config.NAMESERVERS
-
- return my_resolver
-
-
-def get_ns(hostname) -> [str]:
- try:
- answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
- except Exception:
- return []
- return [a.to_text() for a in answers]
-
-
-def get_cname_record(hostname) -> Optional[str]:
- """Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
- try:
- answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
- except Exception:
- return None
-
- for a in answers:
- ret = a.to_text()
- return ret[:-1]
-
- return None
-
-
-def get_mx_domains(hostname) -> [(int, str)]:
- """return list of (priority, domain name) sorted by priority (lowest priority first)
- domain name ends with a "." at the end.
- """
- try:
- answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
- except Exception:
- return []
-
- ret = []
-
- for a in answers:
- record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
- parts = record.split(" ")
-
- ret.append((int(parts[0]), parts[1]))
-
- return sorted(ret, key=lambda prio_domain: prio_domain[0])
-
+from app.config import NAMESERVERS
_include_spf = "include:"
-def get_spf_domain(hostname) -> [str]:
- """return all domains listed in *include:*"""
- try:
- answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
- except Exception:
- return []
-
- ret = []
-
- for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
- for record in a.strings:
- record = record.decode() # record is bytes
-
- if record.startswith("v=spf1"):
- parts = record.split(" ")
- for part in parts:
- if part.startswith(_include_spf):
- ret.append(part[part.find(_include_spf) + len(_include_spf) :])
-
- return ret
-
-
-def get_txt_record(hostname) -> [str]:
- try:
- answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
- except Exception:
- return []
-
- ret = []
-
- for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
- for record in a.strings:
- record = record.decode() # record is bytes
-
- ret.append(record)
-
- return ret
-
-
def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool:
@@ -105,16 +18,127 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
"""
- mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
- ref_mx_domains = sorted(
- ref_mx_domains, key=lambda priority_domain: priority_domain[0]
- )
+ mx_domains = sorted(mx_domains, key=lambda x: x[0])
+ ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
if len(mx_domains) < len(ref_mx_domains):
return False
- for i in range(0, len(ref_mx_domains)):
+ for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]:
return False
return True
+
+
+class DNSClient(ABC):
+ @abstractmethod
+ def get_cname_record(self, hostname: str) -> Optional[str]:
+ pass
+
+ @abstractmethod
+ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ pass
+
+ def get_spf_domain(self, hostname: str) -> List[str]:
+ """
+ return all domains listed in *include:*
+ """
+ try:
+ records = self.get_txt_record(hostname)
+ ret = []
+ for record in records:
+ if record.startswith("v=spf1"):
+ parts = record.split(" ")
+ for part in parts:
+ if part.startswith(_include_spf):
+ ret.append(
+ part[part.find(_include_spf) + len(_include_spf) :]
+ )
+ return ret
+ except Exception:
+ return []
+
+ @abstractmethod
+ def get_txt_record(self, hostname: str) -> List[str]:
+ pass
+
+
+class NetworkDNSClient(DNSClient):
+ def __init__(self, nameservers: List[str]):
+ self._resolver = dns.resolver.Resolver()
+ self._resolver.nameservers = nameservers
+
+ def get_cname_record(self, hostname: str) -> Optional[str]:
+ """
+ Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
+ """
+ try:
+ answers = self._resolver.resolve(hostname, "CNAME", search=True)
+ for a in answers:
+ ret = a.to_text()
+ return ret[:-1]
+ except Exception:
+ return None
+
+ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ """
+ return list of (priority, domain name) sorted by priority (lowest priority first)
+ domain name ends with a "." at the end.
+ """
+ try:
+ answers = self._resolver.resolve(hostname, "MX", search=True)
+ ret = []
+ for a in answers:
+ record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
+ parts = record.split(" ")
+ ret.append((int(parts[0]), parts[1]))
+ return sorted(ret, key=lambda x: x[0])
+ except Exception:
+ return []
+
+ def get_txt_record(self, hostname: str) -> List[str]:
+ try:
+ answers = self._resolver.resolve(hostname, "TXT", search=True)
+ ret = []
+ for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
+ for record in a.strings:
+ ret.append(record.decode())
+ return ret
+ except Exception:
+ return []
+
+
+class InMemoryDNSClient(DNSClient):
+ def __init__(self):
+ self.cname_records: dict[str, Optional[str]] = {}
+ self.mx_records: dict[str, List[Tuple[int, str]]] = {}
+ self.spf_records: dict[str, List[str]] = {}
+ self.txt_records: dict[str, List[str]] = {}
+
+ def set_cname_record(self, hostname: str, cname: str):
+ self.cname_records[hostname] = cname
+
+ def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
+ self.mx_records[hostname] = mx_list
+
+ def set_txt_record(self, hostname: str, txt_list: List[str]):
+ self.txt_records[hostname] = txt_list
+
+ def get_cname_record(self, hostname: str) -> Optional[str]:
+ return self.cname_records.get(hostname)
+
+ def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
+ mx_list = self.mx_records.get(hostname, [])
+ return sorted(mx_list, key=lambda x: x[0])
+
+ def get_txt_record(self, hostname: str) -> List[str]:
+ return self.txt_records.get(hostname, [])
+
+
+def get_network_dns_client() -> NetworkDNSClient:
+ return NetworkDNSClient(NAMESERVERS)
+
+
+def get_mx_domains(hostname: str) -> [(int, str)]:
+ return get_network_dns_client().get_mx_domains(hostname)
diff --git a/templates/dashboard/domain_detail/dns.html b/templates/dashboard/domain_detail/dns.html
index 15ef346f..810aa302 100644
--- a/templates/dashboard/domain_detail/dns.html
+++ b/templates/dashboard/domain_detail/dns.html
@@ -237,7 +237,7 @@
folder.
Add the following CNAME DNS records to your domain.
- {% for dkim_prefix, dkim_cname_value in dkim_records %}
+ {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
Record: CNAME
diff --git a/tests/test_custom_domain_validation.py b/tests/test_custom_domain_validation.py
new file mode 100644
index 00000000..7a3882bd
--- /dev/null
+++ b/tests/test_custom_domain_validation.py
@@ -0,0 +1,325 @@
+from typing import Optional
+
+from app import config
+from app.constants import DMARC_RECORD
+from app.custom_domain_validation import CustomDomainValidation
+from app.db import Session
+from app.models import CustomDomain, User
+from app.dns_utils import InMemoryDNSClient
+from app.utils import random_string
+from tests.utils import create_new_user, random_domain
+
+user: Optional[User] = None
+
+
+def setup_module():
+ global user
+ config.SKIP_MX_LOOKUP_ON_CHECK = True
+ user = create_new_user()
+ user.trial_end = None
+ user.lifetime = True
+ Session.commit()
+
+
+def create_custom_domain(domain: str) -> CustomDomain:
+ return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
+
+
+def test_custom_domain_validation_get_dkim_records():
+ domain = random_domain()
+ validator = CustomDomainValidation(domain)
+ records = validator.get_dkim_records()
+
+ assert len(records) == 3
+ assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
+ assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
+ assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
+
+
+# validate_dkim_records
+def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_dkim_records(domain)
+
+ assert len(res) == 3
+ for record_value in res.values():
+ assert record_value == "empty"
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is False
+
+
+def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
+ dkim_domain = random_domain()
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(dkim_domain, dns_client)
+
+ user_domain = random_domain()
+
+ # One domain right, two domains wrong
+ dns_client.set_cname_record(
+ f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
+ dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
+
+ domain = create_custom_domain(user_domain)
+ res = validator.validate_dkim_records(domain)
+
+ assert len(res) == 2
+ for record_value in res.values():
+ assert record_value == "wrong"
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is False
+
+
+def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
+ dkim_domain = random_domain()
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(dkim_domain, dns_client)
+
+ user_domain = random_domain()
+
+ # One domain right, other domains missing
+ dns_client.set_cname_record(
+ f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
+ )
+
+ domain = create_custom_domain(user_domain)
+
+ # DKIM is verified
+ domain.dkim_verified = True
+ Session.commit()
+
+ res = validator.validate_dkim_records(domain)
+ assert len(res) == 2
+ assert f"dkim02._domainkey.{user_domain}" in res
+ assert f"dkim03._domainkey.{user_domain}" in res
+
+ # Flag is not cleared
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is True
+
+
+def test_custom_domain_validation_validate_dkim_records_success():
+ dkim_domain = random_domain()
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(dkim_domain, dns_client)
+
+ user_domain = random_domain()
+
+ # One domain right, two domains wrong
+ dns_client.set_cname_record(
+ f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(
+ f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
+ )
+ dns_client.set_cname_record(
+ f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
+ )
+
+ domain = create_custom_domain(user_domain)
+ res = validator.validate_dkim_records(domain)
+ assert len(res) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dkim_verified is True
+
+
+# validate_ownership
+def test_custom_domain_validation_validate_ownership_empty_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is False
+
+
+def test_custom_domain_validation_validate_ownership_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(domain.domain, wrong_records)
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is False
+
+
+def test_custom_domain_validation_validate_ownership_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
+ res = validator.validate_domain_ownership(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.ownership_verified is True
+
+
+# validate_mx_records
+def test_custom_domain_validation_validate_mx_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is False
+
+
+def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_record_1 = random_string()
+ wrong_record_2 = random_string()
+ wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
+ dns_client.set_mx_records(domain.domain, wrong_records)
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is False
+ assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is False
+
+
+def test_custom_domain_validation_validate_mx_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
+ res = validator.validate_mx_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.verified is True
+
+
+# validate_spf_records
+def test_custom_domain_validation_validate_spf_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is False
+
+
+def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(domain.domain, wrong_records)
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is False
+
+
+def test_custom_domain_validation_validate_spf_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
+ res = validator.validate_spf_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.spf_verified is True
+
+
+# validate_dmarc_records
+def test_custom_domain_validation_validate_dmarc_records_empty_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is False
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is False
+
+
+def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ wrong_records = [random_string()]
+ dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is False
+ assert res.errors == wrong_records
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is False
+
+
+def test_custom_domain_validation_validate_dmarc_records_success():
+ dns_client = InMemoryDNSClient()
+ validator = CustomDomainValidation(random_domain(), dns_client)
+
+ domain = create_custom_domain(random_domain())
+
+ dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
+ res = validator.validate_dmarc_records(domain)
+
+ assert res.success is True
+ assert len(res.errors) == 0
+
+ db_domain = CustomDomain.get_by(id=domain.id)
+ assert db_domain.dmarc_verified is True
diff --git a/tests/test_dns_utils.py b/tests/test_dns_utils.py
index 46a6ccdd..374983c8 100644
--- a/tests/test_dns_utils.py
+++ b/tests/test_dns_utils.py
@@ -1,10 +1,12 @@
from app.dns_utils import (
get_mx_domains,
- get_spf_domain,
- get_txt_record,
+ get_network_dns_client,
is_mx_equivalent,
+ InMemoryDNSClient,
)
+from tests.utils import random_domain
+
# use our own domain for test
_DOMAIN = "simplelogin.io"
@@ -20,12 +22,12 @@ def test_get_mx_domains():
def test_get_spf_domain():
- r = get_spf_domain(_DOMAIN)
+ r = get_network_dns_client().get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"]
def test_get_txt_record():
- r = get_txt_record(_DOMAIN)
+ r = get_network_dns_client().get_txt_record(_DOMAIN)
assert len(r) > 0
@@ -46,3 +48,15 @@ def test_is_mx_equivalent():
[(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")],
)
+
+
+def test_get_spf_record():
+ client = InMemoryDNSClient()
+
+ sl_domain = random_domain()
+ domain = random_domain()
+
+ spf_record = f"v=spf1 include:{sl_domain}"
+ client.set_txt_record(domain, [spf_record, "another record"])
+ res = client.get_spf_domain(domain)
+ assert res == [sl_domain]