From 647c569f9912e8df90e076185efed080040ae73d Mon Sep 17 00:00:00 2001 From: Carlos Quintana <74399022+cquintana92@users.noreply.github.com> Date: Fri, 13 Sep 2024 14:49:48 +0200 Subject: [PATCH] feat: extract custom domain utils to a service (#2215) --- app/config.py | 45 +++++++++++++++++ app/constants.py | 1 + app/custom_domain_utils.py | 0 app/custom_domain_validation.py | 75 +++++++++++++++++++++++++++- app/dashboard/views/domain_detail.py | 63 ++++++++--------------- 5 files changed, 142 insertions(+), 42 deletions(-) create mode 100644 app/custom_domain_utils.py diff --git a/app/config.py b/app/config.py index 38f6e401..d53c502b 100644 --- a/app/config.py +++ b/app/config.py @@ -35,6 +35,33 @@ def sl_getenv(env_var: str, default_factory: Callable = None): return literal_eval(value) +def get_env_dict(env_var: str) -> dict[str, str]: + """ + Get an env variable and convert it into a python dictionary with keys and values as strings. + Args: + env_var (str): env var, example: SL_DB + + Syntax is: key1=value1;key2=value2 + Components separated by ; + key and value separated by = + """ + value = os.getenv(env_var) + if not value: + return {} + + components = value.split(";") + result = {} + for component in components: + if component == "": + continue + parts = component.split("=") + if len(parts) != 2: + raise Exception(f"Invalid config for env var {env_var}") + result[parts[0].strip()] = parts[1].strip() + + return result + + config_file = os.environ.get("CONFIG") if config_file: config_file = get_abs_path(config_file) @@ -609,3 +636,21 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_ # Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool # It defaults to the regular DB_URI in case it's needed EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI) + + +def read_partner_domains() -> dict[int, str]: + partner_domains_dict = get_env_dict("PARTNER_DOMAINS") + if len(partner_domains_dict) == 0: + return {} + + res: dict[int, str] = {} + for partner_id in partner_domains_dict.keys(): + try: + partner_id_int = int(partner_id.strip()) + res[partner_id_int] = partner_domains_dict[partner_id] + except ValueError: + pass + return res + + +PARTNER_DOMAINS: dict[int, str] = read_partner_domains() diff --git a/app/constants.py b/app/constants.py index b20bc6af..837e7fa7 100644 --- a/app/constants.py +++ b/app/constants.py @@ -1 +1,2 @@ HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies" +DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" diff --git a/app/custom_domain_utils.py b/app/custom_domain_utils.py new file mode 100644 index 00000000..e69de29b diff --git a/app/custom_domain_validation.py b/app/custom_domain_validation.py index 3a2145a8..05b6af37 100644 --- a/app/custom_domain_validation.py +++ b/app/custom_domain_validation.py @@ -1,6 +1,21 @@ +from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN +from app.constants import DMARC_RECORD from app.db import Session -from app.dns_utils import get_cname_record +from app.dns_utils import ( + get_cname_record, + get_mx_domains, + get_txt_record, + is_mx_equivalent, + get_spf_domain, +) from app.models import CustomDomain +from dataclasses import dataclass + + +@dataclass +class DomainValidationResult: + success: bool + errors: [str] class CustomDomainValidation: @@ -35,3 +50,61 @@ class CustomDomainValidation: custom_domain.dkim_verified = len(invalid_records) == 0 Session.commit() return invalid_records + + def validate_domain_ownership( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + """ + Check if the custom_domain has added the ownership verification records + """ + txt_records = get_txt_record(custom_domain.domain) + + if custom_domain.get_ownership_dns_txt_value() in txt_records: + custom_domain.ownership_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + return DomainValidationResult(success=False, errors=txt_records) + + def validate_mx_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + mx_domains = get_mx_domains(custom_domain.domain) + + if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): + return DomainValidationResult( + success=False, + errors=[f"{priority} {domain}" for (priority, domain) in mx_domains], + ) + else: + custom_domain.verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + + def validate_spf_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + spf_domains = get_spf_domain(custom_domain.domain) + if EMAIL_DOMAIN in spf_domains: + custom_domain.spf_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + custom_domain.spf_verified = False + Session.commit() + return DomainValidationResult( + success=False, errors=get_txt_record(custom_domain.domain) + ) + + def validate_dmarc_records( + self, custom_domain: CustomDomain + ) -> DomainValidationResult: + txt_records = get_txt_record("_dmarc." + custom_domain.domain) + if DMARC_RECORD in txt_records: + custom_domain.dmarc_verified = True + Session.commit() + return DomainValidationResult(success=True, errors=[]) + else: + custom_domain.dmarc_verified = False + Session.commit() + return DomainValidationResult(success=False, errors=txt_records) diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 29089a38..1dc7683d 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -6,16 +6,11 @@ from flask_login import login_required, current_user from flask_wtf import FlaskForm from wtforms import StringField, validators, IntegerField +from app.constants import DMARC_RECORD from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN from app.custom_domain_validation import CustomDomainValidation from app.dashboard.base import dashboard_bp from app.db import Session -from app.dns_utils import ( - get_mx_domains, - get_spf_domain, - get_txt_record, - is_mx_equivalent, -) from app.log import LOG from app.models import ( CustomDomain, @@ -49,8 +44,6 @@ def domain_detail_dns(custom_domain_id): domain_validator = CustomDomainValidation(EMAIL_DOMAIN) csrf_form = CSRFValidationForm() - dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s" - mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] @@ -59,15 +52,14 @@ def domain_detail_dns(custom_domain_id): flash("Invalid request", "warning") return redirect(request.url) if request.form.get("form-name") == "check-ownership": - txt_records = get_txt_record(custom_domain.domain) - - if custom_domain.get_ownership_dns_txt_value() in txt_records: + ownership_validation_result = domain_validator.validate_domain_ownership( + custom_domain + ) + if ownership_validation_result.success: flash( "Domain ownership is verified. Please proceed to the other records setup", "success", ) - custom_domain.ownership_verified = True - Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", @@ -78,36 +70,28 @@ def domain_detail_dns(custom_domain_id): else: flash("We can't find the needed TXT record", "error") ownership_ok = False - ownership_errors = txt_records + ownership_errors = ownership_validation_result.errors elif request.form.get("form-name") == "check-mx": - mx_domains = get_mx_domains(custom_domain.domain) - - if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY): - flash("The MX record is not correctly set", "warning") - - mx_ok = False - # build mx_errors to show to user - mx_errors = [ - f"{priority} {domain}" for (priority, domain) in mx_domains - ] - else: + mx_validation_result = domain_validator.validate_mx_records(custom_domain) + if mx_validation_result.success: flash( "Your domain can start receiving emails. You can now use it to create alias", "success", ) - custom_domain.verified = True - Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id ) ) + else: + flash("The MX record is not correctly set", "warning") + mx_ok = False + mx_errors = mx_validation_result.errors + elif request.form.get("form-name") == "check-spf": - spf_domains = get_spf_domain(custom_domain.domain) - if EMAIL_DOMAIN in spf_domains: - custom_domain.spf_verified = True - Session.commit() + spf_validation_result = domain_validator.validate_spf_records(custom_domain) + if spf_validation_result.success: flash("SPF is setup correctly", "success") return redirect( url_for( @@ -115,14 +99,12 @@ def domain_detail_dns(custom_domain_id): ) ) else: - custom_domain.spf_verified = False - Session.commit() flash( f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", "warning", ) spf_ok = False - spf_errors = get_txt_record(custom_domain.domain) + spf_errors = spf_validation_result.errors elif request.form.get("form-name") == "check-dkim": dkim_errors = domain_validator.validate_dkim_records(custom_domain) @@ -138,10 +120,10 @@ def domain_detail_dns(custom_domain_id): flash("DKIM: the CNAME record is not correctly set", "warning") elif request.form.get("form-name") == "check-dmarc": - txt_records = get_txt_record("_dmarc." + custom_domain.domain) - if dmarc_record in txt_records: - custom_domain.dmarc_verified = True - Session.commit() + dmarc_validation_result = domain_validator.validate_dmarc_records( + custom_domain + ) + if dmarc_validation_result.success: flash("DMARC is setup correctly", "success") return redirect( url_for( @@ -149,19 +131,18 @@ def domain_detail_dns(custom_domain_id): ) ) else: - custom_domain.dmarc_verified = False - Session.commit() flash( "DMARC: The TXT record is not correctly set", "warning", ) dmarc_ok = False - dmarc_errors = txt_records + dmarc_errors = dmarc_validation_result.errors return render_template( "dashboard/domain_detail/dns.html", EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, dkim_records=domain_validator.get_dkim_records(), + dmarc_record=DMARC_RECORD, **locals(), )