feat: extract custom domain utils to a service (#2215)

This commit is contained in:
Carlos Quintana 2024-09-13 14:49:48 +02:00 committed by GitHub
parent 5301d2410d
commit 647c569f99
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
5 changed files with 142 additions and 42 deletions

View File

@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value) return literal_eval(value)
def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}
components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()
return result
config_file = os.environ.get("CONFIG") config_file = os.environ.get("CONFIG")
if config_file: if config_file:
config_file = get_abs_path(config_file) config_file = get_abs_path(config_file)
@ -609,3 +636,21 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool # Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed # It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI) EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_domains() -> dict[int, str]:
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
if len(partner_domains_dict) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_domains_dict.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_domains_dict[partner_id]
except ValueError:
pass
return res
PARTNER_DOMAINS: dict[int, str] = read_partner_domains()

View File

@ -1 +1,2 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

View File

View File

@ -1,6 +1,21 @@
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.constants import DMARC_RECORD
from app.db import Session from app.db import Session
from app.dns_utils import get_cname_record from app.dns_utils import (
get_cname_record,
get_mx_domains,
get_txt_record,
is_mx_equivalent,
get_spf_domain,
)
from app.models import CustomDomain from app.models import CustomDomain
from dataclasses import dataclass
@dataclass
class DomainValidationResult:
success: bool
errors: [str]
class CustomDomainValidation: class CustomDomainValidation:
@ -35,3 +50,61 @@ class CustomDomainValidation:
custom_domain.dkim_verified = len(invalid_records) == 0 custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit() Session.commit()
return invalid_records return invalid_records
def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = get_txt_record(custom_domain.domain)
if custom_domain.get_ownership_dns_txt_value() in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False, errors=get_txt_record(custom_domain.domain)
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

View File

@ -6,16 +6,11 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.constants import DMARC_RECORD
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.custom_domain_validation import CustomDomainValidation from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
is_mx_equivalent,
)
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
CustomDomain, CustomDomain,
@ -49,8 +44,6 @@ def domain_detail_dns(custom_domain_id):
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) domain_validator = CustomDomainValidation(EMAIL_DOMAIN)
csrf_form = CSRFValidationForm() csrf_form = CSRFValidationForm()
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
@ -59,15 +52,14 @@ def domain_detail_dns(custom_domain_id):
flash("Invalid request", "warning") flash("Invalid request", "warning")
return redirect(request.url) return redirect(request.url)
if request.form.get("form-name") == "check-ownership": if request.form.get("form-name") == "check-ownership":
txt_records = get_txt_record(custom_domain.domain) ownership_validation_result = domain_validator.validate_domain_ownership(
custom_domain
if custom_domain.get_ownership_dns_txt_value() in txt_records: )
if ownership_validation_result.success:
flash( flash(
"Domain ownership is verified. Please proceed to the other records setup", "Domain ownership is verified. Please proceed to the other records setup",
"success", "success",
) )
custom_domain.ownership_verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
@ -78,36 +70,28 @@ def domain_detail_dns(custom_domain_id):
else: else:
flash("We can't find the needed TXT record", "error") flash("We can't find the needed TXT record", "error")
ownership_ok = False ownership_ok = False
ownership_errors = txt_records ownership_errors = ownership_validation_result.errors
elif request.form.get("form-name") == "check-mx": elif request.form.get("form-name") == "check-mx":
mx_domains = get_mx_domains(custom_domain.domain) mx_validation_result = domain_validator.validate_mx_records(custom_domain)
if mx_validation_result.success:
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")
mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
flash( flash(
"Your domain can start receiving emails. You can now use it to create alias", "Your domain can start receiving emails. You can now use it to create alias",
"success", "success",
) )
custom_domain.verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
) )
) )
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors
elif request.form.get("form-name") == "check-spf": elif request.form.get("form-name") == "check-spf":
spf_domains = get_spf_domain(custom_domain.domain) spf_validation_result = domain_validator.validate_spf_records(custom_domain)
if EMAIL_DOMAIN in spf_domains: if spf_validation_result.success:
custom_domain.spf_verified = True
Session.commit()
flash("SPF is setup correctly", "success") flash("SPF is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -115,14 +99,12 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.spf_verified = False
Session.commit()
flash( flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning", "warning",
) )
spf_ok = False spf_ok = False
spf_errors = get_txt_record(custom_domain.domain) spf_errors = spf_validation_result.errors
elif request.form.get("form-name") == "check-dkim": elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain) dkim_errors = domain_validator.validate_dkim_records(custom_domain)
@ -138,10 +120,10 @@ def domain_detail_dns(custom_domain_id):
flash("DKIM: the CNAME record is not correctly set", "warning") flash("DKIM: the CNAME record is not correctly set", "warning")
elif request.form.get("form-name") == "check-dmarc": elif request.form.get("form-name") == "check-dmarc":
txt_records = get_txt_record("_dmarc." + custom_domain.domain) dmarc_validation_result = domain_validator.validate_dmarc_records(
if dmarc_record in txt_records: custom_domain
custom_domain.dmarc_verified = True )
Session.commit() if dmarc_validation_result.success:
flash("DMARC is setup correctly", "success") flash("DMARC is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -149,19 +131,18 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.dmarc_verified = False
Session.commit()
flash( flash(
"DMARC: The TXT record is not correctly set", "DMARC: The TXT record is not correctly set",
"warning", "warning",
) )
dmarc_ok = False dmarc_ok = False
dmarc_errors = txt_records dmarc_errors = dmarc_validation_result.errors
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(), dkim_records=domain_validator.get_dkim_records(),
dmarc_record=DMARC_RECORD,
**locals(), **locals(),
) )