chore: refactor dns to improve testability (#2224)

* chore: refactor DNS client to its own class

* chore: adapt code calling DNS and add tests to it

* chore: refactor old dkim check not to clear flag
This commit is contained in:
Carlos Quintana 2024-09-17 16:15:10 +02:00 committed by GitHub
parent 065cc3db92
commit f6708dd0b6
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 497 additions and 116 deletions

View File

@ -2,11 +2,9 @@ from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
get_cname_record, DNSClient,
get_mx_domains,
get_txt_record,
is_mx_equivalent, is_mx_equivalent,
get_spf_domain, get_network_dns_client,
) )
from app.models import CustomDomain from app.models import CustomDomain
from dataclasses import dataclass from dataclasses import dataclass
@ -19,10 +17,13 @@ class DomainValidationResult:
class CustomDomainValidation: class CustomDomainValidation:
def __init__(self, dkim_domain: str): def __init__(
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
):
self.dkim_domain = dkim_domain self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = { self._dkim_records = {
(f"{key}._domainkey", f"{key}._domainkey.{self.dkim_domain}") f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
for key in ("dkim", "dkim02", "dkim03") for key in ("dkim", "dkim02", "dkim03")
} }
@ -38,15 +39,31 @@ class CustomDomainValidation:
Check if dkim records are properly set for this custom domain. Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
""" """
correct_records = {}
invalid_records = {} invalid_records = {}
for prefix, expected_record in self.get_dkim_records(): expected_records = self.get_dkim_records()
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}" custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = get_cname_record(custom_record) dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record != expected_record: if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty" invalid_records[custom_record] = dkim_record or "empty"
# HACK: If dkim is enabled, don't disable it to give users time to update their CNAMES
# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified: if custom_domain.dkim_verified:
return invalid_records # Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0 custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit() Session.commit()
return invalid_records return invalid_records
@ -57,7 +74,7 @@ class CustomDomainValidation:
""" """
Check if the custom_domain has added the ownership verification records Check if the custom_domain has added the ownership verification records
""" """
txt_records = get_txt_record(custom_domain.domain) txt_records = self._dns_client.get_txt_record(custom_domain.domain)
if custom_domain.get_ownership_dns_txt_value() in txt_records: if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
@ -69,7 +86,7 @@ class CustomDomainValidation:
def validate_mx_records( def validate_mx_records(
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
mx_domains = get_mx_domains(custom_domain.domain) mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult( return DomainValidationResult(
@ -84,7 +101,7 @@ class CustomDomainValidation:
def validate_spf_records( def validate_spf_records(
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
spf_domains = get_spf_domain(custom_domain.domain) spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains: if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
Session.commit() Session.commit()
@ -93,13 +110,14 @@ class CustomDomainValidation:
custom_domain.spf_verified = False custom_domain.spf_verified = False
Session.commit() Session.commit()
return DomainValidationResult( return DomainValidationResult(
success=False, errors=get_txt_record(custom_domain.domain) success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
) )
def validate_dmarc_records( def validate_dmarc_records(
self, custom_domain: CustomDomain self, custom_domain: CustomDomain
) -> DomainValidationResult: ) -> DomainValidationResult:
txt_records = get_txt_record("_dmarc." + custom_domain.domain) txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records: if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True custom_domain.dmarc_verified = True
Session.commit() Session.commit()

View File

@ -1,100 +1,13 @@
from app import config from abc import ABC, abstractmethod
from typing import Optional, List, Tuple from typing import List, Tuple, Optional
import dns.resolver import dns.resolver
from app.config import NAMESERVERS
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS
return my_resolver
def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]
def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None
for a in answers:
ret = a.to_text()
return ret[:-1]
return None
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda prio_domain: prio_domain[0])
_include_spf = "include:" _include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
return ret
def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(record)
return ret
def is_mx_equivalent( def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool: ) -> bool:
@ -105,16 +18,127 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number. The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
""" """
mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0]) mx_domains = sorted(mx_domains, key=lambda x: x[0])
ref_mx_domains = sorted( ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0])
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)
if len(mx_domains) < len(ref_mx_domains): if len(mx_domains) < len(ref_mx_domains):
return False return False
for i in range(0, len(ref_mx_domains)): for i in range(len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]: if mx_domains[i][1] != ref_mx_domains[i][1]:
return False return False
return True return True
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass
@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass
class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers
def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
except Exception:
return []
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]:
return get_network_dns_client().get_mx_domains(hostname)

View File

@ -237,7 +237,7 @@
folder. folder.
</div> </div>
<div class="mb-2">Add the following CNAME DNS records to your domain.</div> <div class="mb-2">Add the following CNAME DNS records to your domain.</div>
{% for dkim_prefix, dkim_cname_value in dkim_records %} {% for dkim_prefix, dkim_cname_value in dkim_records.items() %}
<div class="mb-2 p-3 dns-record"> <div class="mb-2 p-3 dns-record">
Record: CNAME Record: CNAME

View File

@ -0,0 +1,325 @@
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.utils import random_string
from tests.utils import create_new_user, random_domain
user: Optional[User] = None
def setup_module():
global user
config.SKIP_MX_LOOKUP_ON_CHECK = True
user = create_new_user()
user.trial_end = None
user.lifetime = True
Session.commit()
def create_custom_domain(domain: str) -> CustomDomain:
return CustomDomain.create(user_id=user.id, domain=domain, commit=True)
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records()
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dkim_records(domain)
assert len(res) == 3
for record_value in res.values():
assert record_value == "empty"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_wrong_records_failure():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(f"dkim02._domainkey.{user_domain}", "wrong")
dns_client.set_cname_record(f"dkim03._domainkey.{user_domain}", "wrong")
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 2
for record_value in res.values():
assert record_value == "wrong"
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is False
def test_custom_domain_validation_validate_dkim_records_success_with_old_system():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, other domains missing
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
# DKIM is verified
domain.dkim_verified = True
Session.commit()
res = validator.validate_dkim_records(domain)
assert len(res) == 2
assert f"dkim02._domainkey.{user_domain}" in res
assert f"dkim03._domainkey.{user_domain}" in res
# Flag is not cleared
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
def test_custom_domain_validation_validate_dkim_records_success():
dkim_domain = random_domain()
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(dkim_domain, dns_client)
user_domain = random_domain()
# One domain right, two domains wrong
dns_client.set_cname_record(
f"dkim._domainkey.{user_domain}", f"dkim._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim02._domainkey.{user_domain}", f"dkim02._domainkey.{dkim_domain}"
)
dns_client.set_cname_record(
f"dkim03._domainkey.{user_domain}", f"dkim03._domainkey.{dkim_domain}"
)
domain = create_custom_domain(user_domain)
res = validator.validate_dkim_records(domain)
assert len(res) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dkim_verified is True
# validate_ownership
def test_custom_domain_validation_validate_ownership_empty_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_domain_ownership(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is False
def test_custom_domain_validation_validate_ownership_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
# validate_mx_records
def test_custom_domain_validation_validate_mx_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_mx_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_record_1 = random_string()
wrong_record_2 = random_string()
wrong_records = [(10, wrong_record_1), (20, wrong_record_2)]
dns_client.set_mx_records(domain.domain, wrong_records)
res = validator.validate_mx_records(domain)
assert res.success is False
assert res.errors == [f"10 {wrong_record_1}", f"20 {wrong_record_2}"]
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is False
def test_custom_domain_validation_validate_mx_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_mx_records(domain.domain, config.EMAIL_SERVERS_WITH_PRIORITY)
res = validator.validate_mx_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.verified is True
# validate_spf_records
def test_custom_domain_validation_validate_spf_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_spf_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(domain.domain, wrong_records)
res = validator.validate_spf_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is False
def test_custom_domain_validation_validate_spf_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [f"v=spf1 include:{config.EMAIL_DOMAIN}"])
res = validator.validate_spf_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.spf_verified is True
# validate_dmarc_records
def test_custom_domain_validation_validate_dmarc_records_empty_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_wrong_records_failure():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
wrong_records = [random_string()]
dns_client.set_txt_record(f"_dmarc.{domain.domain}", wrong_records)
res = validator.validate_dmarc_records(domain)
assert res.success is False
assert res.errors == wrong_records
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is False
def test_custom_domain_validation_validate_dmarc_records_success():
dns_client = InMemoryDNSClient()
validator = CustomDomainValidation(random_domain(), dns_client)
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(f"_dmarc.{domain.domain}", [DMARC_RECORD])
res = validator.validate_dmarc_records(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.dmarc_verified is True

View File

@ -1,10 +1,12 @@
from app.dns_utils import ( from app.dns_utils import (
get_mx_domains, get_mx_domains,
get_spf_domain, get_network_dns_client,
get_txt_record,
is_mx_equivalent, is_mx_equivalent,
InMemoryDNSClient,
) )
from tests.utils import random_domain
# use our own domain for test # use our own domain for test
_DOMAIN = "simplelogin.io" _DOMAIN = "simplelogin.io"
@ -20,12 +22,12 @@ def test_get_mx_domains():
def test_get_spf_domain(): def test_get_spf_domain():
r = get_spf_domain(_DOMAIN) r = get_network_dns_client().get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"] assert r == ["simplelogin.co"]
def test_get_txt_record(): def test_get_txt_record():
r = get_txt_record(_DOMAIN) r = get_network_dns_client().get_txt_record(_DOMAIN)
assert len(r) > 0 assert len(r) > 0
@ -46,3 +48,15 @@ def test_is_mx_equivalent():
[(5, "domain1"), (10, "domain2")], [(5, "domain1"), (10, "domain2")],
[(10, "domain1"), (20, "domain2"), (20, "domain3")], [(10, "domain1"), (20, "domain2"), (20, "domain3")],
) )
def test_get_spf_record():
client = InMemoryDNSClient()
sl_domain = random_domain()
domain = random_domain()
spf_record = f"v=spf1 include:{sl_domain}"
client.set_txt_record(domain, [spf_record, "another record"])
res = client.get_spf_domain(domain)
assert res == [sl_domain]