refactor dns_utils and add test_dns_utils

This commit is contained in:
Son NK 2020-01-05 19:01:38 +01:00
parent f62a7ffe47
commit 41329782a2
2 changed files with 44 additions and 28 deletions

View File

@ -1,17 +1,21 @@
import dns.resolver
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
# 8.8.8.8 is Google's public DNS server
my_resolver.nameservers = ["8.8.8.8"]
return my_resolver
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name).
domain name ends with a "." at the end.
"""
try:
my_resolver = dns.resolver.Resolver()
# 8.8.8.8 is Google's public DNS server
my_resolver.nameservers = ["8.8.8.8"]
answers = my_resolver.query(hostname, "MX")
answers = _get_dns_resolver().query(hostname, "MX")
except Exception:
return []
@ -32,12 +36,7 @@ _include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
my_resolver = dns.resolver.Resolver()
# 8.8.8.8 is Google's public DNS server
my_resolver.nameservers = ["8.8.8.8"]
answers = my_resolver.query(hostname, "TXT")
answers = _get_dns_resolver().query(hostname, "TXT")
except Exception:
return []
@ -58,22 +57,14 @@ def get_spf_domain(hostname) -> [str]:
def get_txt_record(hostname) -> [str]:
try:
my_resolver = dns.resolver.Resolver()
# 8.8.8.8 is Google's public DNS server
my_resolver.nameservers = ["8.8.8.8"]
answers = my_resolver.query(hostname, "TXT")
answers = _get_dns_resolver().query(hostname, "TXT")
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(a)
ret.append(a)
return ret
@ -81,12 +72,7 @@ def get_txt_record(hostname) -> [str]:
def get_dkim_record(hostname) -> str:
"""query the dkim._domainkey.{hostname} record and returns its value"""
try:
my_resolver = dns.resolver.Resolver()
# 8.8.8.8 is Google's public DNS server
my_resolver.nameservers = ["8.8.8.8"]
answers = my_resolver.query(f"dkim._domainkey.{hostname}", "TXT")
answers = _get_dns_resolver().query(f"dkim._domainkey.{hostname}", "TXT")
except Exception:
return ""

30
tests/test_dns_utils.py Normal file
View File

@ -0,0 +1,30 @@
from app.dns_utils import *
# use our own domain for test
_DOMAIN = "simplelogin.io"
def test_get_mx_domains():
r = get_mx_domains(_DOMAIN)
assert len(r) > 0
for x in r:
assert x[0] > 0
assert x[1]
def test_get_spf_domain():
r = get_spf_domain(_DOMAIN)
assert r == ["simplelogin.co"]
def test_get_txt_record():
r = get_txt_record(_DOMAIN)
assert len(r) > 0
def test_get_dkim_record():
r = get_dkim_record(_DOMAIN)
assert r.startswith("v=DKIM1; k=rsa;")