mirror of
https://github.com/simple-login/app.git
synced 2024-11-16 17:08:30 +01:00
9d5697b624
* chore: DNS validation improvements * fix: do not show domains pending deletion * fix: generate verification token if null * revert: dmarc cleanup
201 lines
8.2 KiB
Python
201 lines
8.2 KiB
Python
from dataclasses import dataclass
|
|
from typing import List, Optional
|
|
|
|
from app import config
|
|
from app.constants import DMARC_RECORD
|
|
from app.db import Session
|
|
from app.dns_utils import (
|
|
MxRecord,
|
|
DNSClient,
|
|
is_mx_equivalent,
|
|
get_network_dns_client,
|
|
)
|
|
from app.models import CustomDomain
|
|
from app.utils import random_string
|
|
|
|
|
|
@dataclass
|
|
class DomainValidationResult:
|
|
success: bool
|
|
errors: [str]
|
|
|
|
|
|
class CustomDomainValidation:
|
|
def __init__(
|
|
self,
|
|
dkim_domain: str,
|
|
dns_client: DNSClient = get_network_dns_client(),
|
|
partner_domains: Optional[dict[int, str]] = None,
|
|
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
|
|
):
|
|
self.dkim_domain = dkim_domain
|
|
self._dns_client = dns_client
|
|
self._partner_domains = partner_domains or config.PARTNER_DNS_CUSTOM_DOMAINS
|
|
self._partner_domain_validation_prefixes = (
|
|
partner_domains_validation_prefixes
|
|
or config.PARTNER_CUSTOM_DOMAIN_VALIDATION_PREFIXES
|
|
)
|
|
|
|
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
|
|
prefix = "sl"
|
|
if (
|
|
domain.partner_id is not None
|
|
and domain.partner_id in self._partner_domain_validation_prefixes
|
|
):
|
|
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
|
|
|
|
if not domain.ownership_txt_token:
|
|
domain.ownership_txt_token = random_string(30)
|
|
Session.commit()
|
|
|
|
return f"{prefix}-verification={domain.ownership_txt_token}"
|
|
|
|
def get_expected_mx_records(self, domain: CustomDomain) -> list[MxRecord]:
|
|
records = []
|
|
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
|
|
domain = self._partner_domains[domain.partner_id]
|
|
records.append(MxRecord(10, f"mx1.{domain}."))
|
|
records.append(MxRecord(20, f"mx2.{domain}."))
|
|
else:
|
|
# Default ones
|
|
for priority, domain in config.EMAIL_SERVERS_WITH_PRIORITY:
|
|
records.append(MxRecord(priority, domain))
|
|
|
|
return records
|
|
|
|
def get_expected_spf_domain(self, domain: CustomDomain) -> str:
|
|
if domain.partner_id is not None and domain.partner_id in self._partner_domains:
|
|
return self._partner_domains[domain.partner_id]
|
|
else:
|
|
return config.EMAIL_DOMAIN
|
|
|
|
def get_expected_spf_record(self, domain: CustomDomain) -> str:
|
|
spf_domain = self.get_expected_spf_domain(domain)
|
|
return f"v=spf1 include:{spf_domain} ~all"
|
|
|
|
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
|
|
"""
|
|
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
|
|
it will return the default ones or the partner ones.
|
|
"""
|
|
|
|
# By default use the default domain
|
|
dkim_domain = self.dkim_domain
|
|
if domain.partner_id is not None:
|
|
# Domain is from a partner. Retrieve the partner config and use that domain if exists
|
|
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
|
|
|
|
return {
|
|
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
|
|
for key in ("dkim", "dkim02", "dkim03")
|
|
}
|
|
|
|
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
|
|
"""
|
|
Check if dkim records are properly set for this custom domain.
|
|
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
|
|
"""
|
|
correct_records = {}
|
|
invalid_records = {}
|
|
expected_records = self.get_dkim_records(custom_domain)
|
|
for prefix, expected_record in expected_records.items():
|
|
custom_record = f"{prefix}.{custom_domain.domain}"
|
|
dkim_record = self._dns_client.get_cname_record(custom_record)
|
|
if dkim_record == expected_record:
|
|
correct_records[prefix] = custom_record
|
|
else:
|
|
invalid_records[custom_record] = dkim_record or "empty"
|
|
|
|
# HACK
|
|
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
|
|
# the domain validated to continue seeing it as validated (although showing them the missing records).
|
|
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
|
|
# we will remove the dkim_verified flag.
|
|
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
|
|
if custom_domain.dkim_verified:
|
|
# Check if at least the original dkim is there
|
|
if correct_records.get("dkim._domainkey") is not None:
|
|
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
|
|
return invalid_records
|
|
|
|
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
|
|
# rest of the code path, returning the invalid records and clearing the flag
|
|
custom_domain.dkim_verified = len(invalid_records) == 0
|
|
Session.commit()
|
|
return invalid_records
|
|
|
|
def validate_domain_ownership(
|
|
self, custom_domain: CustomDomain
|
|
) -> DomainValidationResult:
|
|
"""
|
|
Check if the custom_domain has added the ownership verification records
|
|
"""
|
|
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
|
expected_verification_record = self.get_ownership_verification_record(
|
|
custom_domain
|
|
)
|
|
|
|
if expected_verification_record in txt_records:
|
|
custom_domain.ownership_verified = True
|
|
Session.commit()
|
|
return DomainValidationResult(success=True, errors=[])
|
|
else:
|
|
return DomainValidationResult(success=False, errors=txt_records)
|
|
|
|
def validate_mx_records(
|
|
self, custom_domain: CustomDomain
|
|
) -> DomainValidationResult:
|
|
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
|
|
expected_mx_records = self.get_expected_mx_records(custom_domain)
|
|
|
|
if not is_mx_equivalent(mx_domains, expected_mx_records):
|
|
return DomainValidationResult(
|
|
success=False,
|
|
errors=[f"{record.priority} {record.domain}" for record in mx_domains],
|
|
)
|
|
else:
|
|
custom_domain.verified = True
|
|
Session.commit()
|
|
return DomainValidationResult(success=True, errors=[])
|
|
|
|
def validate_spf_records(
|
|
self, custom_domain: CustomDomain
|
|
) -> DomainValidationResult:
|
|
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
|
|
expected_spf_domain = self.get_expected_spf_domain(custom_domain)
|
|
if expected_spf_domain in spf_domains:
|
|
custom_domain.spf_verified = True
|
|
Session.commit()
|
|
return DomainValidationResult(success=True, errors=[])
|
|
else:
|
|
custom_domain.spf_verified = False
|
|
Session.commit()
|
|
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
|
cleaned_records = self.__clean_spf_records(txt_records, custom_domain)
|
|
return DomainValidationResult(
|
|
success=False,
|
|
errors=cleaned_records,
|
|
)
|
|
|
|
def validate_dmarc_records(
|
|
self, custom_domain: CustomDomain
|
|
) -> DomainValidationResult:
|
|
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
|
|
if DMARC_RECORD in txt_records:
|
|
custom_domain.dmarc_verified = True
|
|
Session.commit()
|
|
return DomainValidationResult(success=True, errors=[])
|
|
else:
|
|
custom_domain.dmarc_verified = False
|
|
Session.commit()
|
|
return DomainValidationResult(success=False, errors=txt_records)
|
|
|
|
def __clean_spf_records(
|
|
self, txt_records: List[str], custom_domain: CustomDomain
|
|
) -> List[str]:
|
|
final_records = []
|
|
verification_record = self.get_ownership_verification_record(custom_domain)
|
|
for record in txt_records:
|
|
if record != verification_record:
|
|
final_records.append(record)
|
|
return final_records
|