feat: allow to define partner records (#2225)

This commit is contained in:
Carlos Quintana 2024-09-18 12:12:42 +02:00 committed by GitHub
parent f6708dd0b6
commit b5866fa779
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
6 changed files with 116 additions and 27 deletions

View file

@ -638,19 +638,22 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_domains() -> dict[int, str]:
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
if len(partner_domains_dict) == 0:
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_domains_dict.keys():
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_domains_dict[partner_id]
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res
PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
)

View file

@ -1,4 +1,12 @@
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from dataclasses import dataclass
from typing import Optional
from app.config import (
EMAIL_SERVERS_WITH_PRIORITY,
EMAIL_DOMAIN,
PARTNER_DOMAINS,
PARTNER_DOMAIN_VALIDATION_PREFIXES,
)
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
@ -7,7 +15,6 @@ from app.dns_utils import (
get_network_dns_client,
)
from app.models import CustomDomain
from dataclasses import dataclass
@dataclass
@ -18,22 +25,47 @@ class DomainValidationResult:
class CustomDomainValidation:
def __init__(
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._dkim_records = {
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
self._partner_domains = partner_domains or PARTNER_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes or PARTNER_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl-verification"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}={domain.ownership_txt_token}"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
partner_domain = self._partner_domains.get(domain.partner_id)
if partner_domain is not None:
dkim_domain = partner_domain
return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
def get_dkim_records(self) -> {str: str}:
"""
Get a list of dkim records to set up. It will be
"""
return self._dkim_records
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Check if dkim records are properly set for this custom domain.
@ -41,7 +73,7 @@ class CustomDomainValidation:
"""
correct_records = {}
invalid_records = {}
expected_records = self.get_dkim_records()
expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record)
@ -75,8 +107,11 @@ class CustomDomainValidation:
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)
if custom_domain.get_ownership_dns_txt_value() in txt_records:
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])

View file

@ -141,7 +141,10 @@ def domain_detail_dns(custom_domain_id):
return render_template(
"dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
dkim_records=domain_validator.get_dkim_records(),
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
dkim_records=domain_validator.get_dkim_records(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(),
)

View file

@ -2451,9 +2451,6 @@ class CustomDomain(Base, ModelMixin):
def get_trash_url(self):
return config.URL + f"/dashboard/domains/{self.id}/trash"
def get_ownership_dns_txt_value(self):
return f"sl-verification={self.ownership_txt_token}"
@classmethod
def create(cls, **kwargs):
domain = kwargs.get("domain")

View file

@ -38,7 +38,7 @@
Value: <em data-toggle="tooltip"
title="Click to copy"
class="clipboard"
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
</div>
<form method="post" action="#ownership-form">
{{ csrf_form.csrf_token }}

View file

@ -6,6 +6,7 @@ from app.custom_domain_validation import CustomDomainValidation
from app.db import Session
from app.models import CustomDomain, User
from app.dns_utils import InMemoryDNSClient
from app.proton.utils import get_proton_partner
from app.utils import random_string
from tests.utils import create_new_user, random_domain
@ -27,8 +28,9 @@ def create_custom_domain(domain: str) -> CustomDomain:
def test_custom_domain_validation_get_dkim_records():
domain = random_domain()
custom_domain = create_custom_domain(domain)
validator = CustomDomainValidation(domain)
records = validator.get_dkim_records()
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
@ -36,6 +38,26 @@ def test_custom_domain_validation_get_dkim_records():
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
def test_custom_domain_validation_get_dkim_records_for_partner():
domain = random_domain()
custom_domain = create_custom_domain(domain)
partner_id = get_proton_partner().id
custom_domain.partner_id = partner_id
Session.commit()
dkim_domain = random_domain()
validator = CustomDomainValidation(
domain, partner_domains={partner_id: dkim_domain}
)
records = validator.get_dkim_records(custom_domain)
assert len(records) == 3
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
# validate_dkim_records
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
dns_client = InMemoryDNSClient()
@ -169,7 +191,36 @@ def test_custom_domain_validation_validate_ownership_success():
domain = create_custom_domain(random_domain())
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True
assert len(res.errors) == 0
db_domain = CustomDomain.get_by(id=domain.id)
assert db_domain.ownership_verified is True
def test_custom_domain_validation_validate_ownership_from_partner_success():
dns_client = InMemoryDNSClient()
partner_id = get_proton_partner().id
prefix = random_string()
validator = CustomDomainValidation(
random_domain(),
dns_client,
partner_domains_validation_prefixes={partner_id: prefix},
)
domain = create_custom_domain(random_domain())
domain.partner_id = partner_id
Session.commit()
dns_client.set_txt_record(
domain.domain, [validator.get_ownership_verification_record(domain)]
)
res = validator.validate_domain_ownership(domain)
assert res.success is True