refactor dns_utils and add test_dns_utils
This commit is contained in:
parent
f62a7ffe47
commit
41329782a2
|
@ -1,17 +1,21 @@
|
|||
import dns.resolver
|
||||
|
||||
|
||||
def _get_dns_resolver():
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
|
||||
# 8.8.8.8 is Google's public DNS server
|
||||
my_resolver.nameservers = ["8.8.8.8"]
|
||||
|
||||
return my_resolver
|
||||
|
||||
|
||||
def get_mx_domains(hostname) -> [(int, str)]:
|
||||
"""return list of (priority, domain name).
|
||||
domain name ends with a "." at the end.
|
||||
"""
|
||||
try:
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
|
||||
# 8.8.8.8 is Google's public DNS server
|
||||
my_resolver.nameservers = ["8.8.8.8"]
|
||||
|
||||
answers = my_resolver.query(hostname, "MX")
|
||||
answers = _get_dns_resolver().query(hostname, "MX")
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
@ -32,12 +36,7 @@ _include_spf = "include:"
|
|||
def get_spf_domain(hostname) -> [str]:
|
||||
"""return all domains listed in *include:*"""
|
||||
try:
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
|
||||
# 8.8.8.8 is Google's public DNS server
|
||||
my_resolver.nameservers = ["8.8.8.8"]
|
||||
|
||||
answers = my_resolver.query(hostname, "TXT")
|
||||
answers = _get_dns_resolver().query(hostname, "TXT")
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
|
@ -58,22 +57,14 @@ def get_spf_domain(hostname) -> [str]:
|
|||
|
||||
def get_txt_record(hostname) -> [str]:
|
||||
try:
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
|
||||
# 8.8.8.8 is Google's public DNS server
|
||||
my_resolver.nameservers = ["8.8.8.8"]
|
||||
|
||||
answers = my_resolver.query(hostname, "TXT")
|
||||
answers = _get_dns_resolver().query(hostname, "TXT")
|
||||
except Exception:
|
||||
return []
|
||||
|
||||
ret = []
|
||||
|
||||
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
|
||||
for record in a.strings:
|
||||
record = record.decode() # record is bytes
|
||||
|
||||
ret.append(a)
|
||||
ret.append(a)
|
||||
|
||||
return ret
|
||||
|
||||
|
@ -81,12 +72,7 @@ def get_txt_record(hostname) -> [str]:
|
|||
def get_dkim_record(hostname) -> str:
|
||||
"""query the dkim._domainkey.{hostname} record and returns its value"""
|
||||
try:
|
||||
my_resolver = dns.resolver.Resolver()
|
||||
|
||||
# 8.8.8.8 is Google's public DNS server
|
||||
my_resolver.nameservers = ["8.8.8.8"]
|
||||
|
||||
answers = my_resolver.query(f"dkim._domainkey.{hostname}", "TXT")
|
||||
answers = _get_dns_resolver().query(f"dkim._domainkey.{hostname}", "TXT")
|
||||
except Exception:
|
||||
return ""
|
||||
|
||||
|
|
|
@ -0,0 +1,30 @@
|
|||
from app.dns_utils import *
|
||||
|
||||
# use our own domain for test
|
||||
_DOMAIN = "simplelogin.io"
|
||||
|
||||
|
||||
def test_get_mx_domains():
|
||||
r = get_mx_domains(_DOMAIN)
|
||||
|
||||
assert len(r) > 0
|
||||
|
||||
for x in r:
|
||||
assert x[0] > 0
|
||||
assert x[1]
|
||||
|
||||
|
||||
def test_get_spf_domain():
|
||||
r = get_spf_domain(_DOMAIN)
|
||||
assert r == ["simplelogin.co"]
|
||||
|
||||
|
||||
def test_get_txt_record():
|
||||
|
||||
r = get_txt_record(_DOMAIN)
|
||||
assert len(r) > 0
|
||||
|
||||
|
||||
def test_get_dkim_record():
|
||||
r = get_dkim_record(_DOMAIN)
|
||||
assert r.startswith("v=DKIM1; k=rsa;")
|
Loading…
Reference in New Issue