mirror of
https://github.com/simple-login/app.git
synced 2024-11-13 07:31:12 +01:00
feat: allow to define partner records (#2225)
This commit is contained in:
parent
f6708dd0b6
commit
b5866fa779
6 changed files with 116 additions and 27 deletions
|
@ -638,19 +638,22 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
|
|||
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
|
||||
|
||||
|
||||
def read_partner_domains() -> dict[int, str]:
|
||||
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
|
||||
if len(partner_domains_dict) == 0:
|
||||
def read_partner_dict(var: str) -> dict[int, str]:
|
||||
partner_value = get_env_dict(var)
|
||||
if len(partner_value) == 0:
|
||||
return {}
|
||||
|
||||
res: dict[int, str] = {}
|
||||
for partner_id in partner_domains_dict.keys():
|
||||
for partner_id in partner_value.keys():
|
||||
try:
|
||||
partner_id_int = int(partner_id.strip())
|
||||
res[partner_id_int] = partner_domains_dict[partner_id]
|
||||
res[partner_id_int] = partner_value[partner_id]
|
||||
except ValueError:
|
||||
pass
|
||||
return res
|
||||
|
||||
|
||||
PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
|
||||
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
|
||||
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
|
||||
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
|
||||
)
|
||||
|
|
|
@ -1,4 +1,12 @@
|
|||
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
|
||||
from dataclasses import dataclass
|
||||
from typing import Optional
|
||||
|
||||
from app.config import (
|
||||
EMAIL_SERVERS_WITH_PRIORITY,
|
||||
EMAIL_DOMAIN,
|
||||
PARTNER_DOMAINS,
|
||||
PARTNER_DOMAIN_VALIDATION_PREFIXES,
|
||||
)
|
||||
from app.constants import DMARC_RECORD
|
||||
from app.db import Session
|
||||
from app.dns_utils import (
|
||||
|
@ -7,7 +15,6 @@ from app.dns_utils import (
|
|||
get_network_dns_client,
|
||||
)
|
||||
from app.models import CustomDomain
|
||||
from dataclasses import dataclass
|
||||
|
||||
|
||||
@dataclass
|
||||
|
@ -18,22 +25,47 @@ class DomainValidationResult:
|
|||
|
||||
class CustomDomainValidation:
|
||||
def __init__(
|
||||
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
|
||||
self,
|
||||
dkim_domain: str,
|
||||
dns_client: DNSClient = get_network_dns_client(),
|
||||
partner_domains: Optional[dict[int, str]] = None,
|
||||
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
|
||||
):
|
||||
self.dkim_domain = dkim_domain
|
||||
self._dns_client = dns_client
|
||||
self._dkim_records = {
|
||||
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
|
||||
self._partner_domains = partner_domains or PARTNER_DOMAINS
|
||||
self._partner_domain_validation_prefixes = (
|
||||
partner_domains_validation_prefixes or PARTNER_DOMAIN_VALIDATION_PREFIXES
|
||||
)
|
||||
|
||||
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
|
||||
prefix = "sl-verification"
|
||||
if (
|
||||
domain.partner_id is not None
|
||||
and domain.partner_id in self._partner_domain_validation_prefixes
|
||||
):
|
||||
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
|
||||
return f"{prefix}={domain.ownership_txt_token}"
|
||||
|
||||
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
|
||||
"""
|
||||
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
|
||||
it will return the default ones or the partner ones.
|
||||
"""
|
||||
|
||||
# By default use the default domain
|
||||
dkim_domain = self.dkim_domain
|
||||
if domain.partner_id is not None:
|
||||
# Domain is from a partner. Retrieve the partner config and use that domain if exists
|
||||
partner_domain = self._partner_domains.get(domain.partner_id)
|
||||
if partner_domain is not None:
|
||||
dkim_domain = partner_domain
|
||||
|
||||
return {
|
||||
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
|
||||
for key in ("dkim", "dkim02", "dkim03")
|
||||
}
|
||||
|
||||
def get_dkim_records(self) -> {str: str}:
|
||||
"""
|
||||
Get a list of dkim records to set up. It will be
|
||||
|
||||
"""
|
||||
return self._dkim_records
|
||||
|
||||
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
|
||||
"""
|
||||
Check if dkim records are properly set for this custom domain.
|
||||
|
@ -41,7 +73,7 @@ class CustomDomainValidation:
|
|||
"""
|
||||
correct_records = {}
|
||||
invalid_records = {}
|
||||
expected_records = self.get_dkim_records()
|
||||
expected_records = self.get_dkim_records(custom_domain)
|
||||
for prefix, expected_record in expected_records.items():
|
||||
custom_record = f"{prefix}.{custom_domain.domain}"
|
||||
dkim_record = self._dns_client.get_cname_record(custom_record)
|
||||
|
@ -75,8 +107,11 @@ class CustomDomainValidation:
|
|||
Check if the custom_domain has added the ownership verification records
|
||||
"""
|
||||
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
||||
expected_verification_record = self.get_ownership_verification_record(
|
||||
custom_domain
|
||||
)
|
||||
|
||||
if custom_domain.get_ownership_dns_txt_value() in txt_records:
|
||||
if expected_verification_record in txt_records:
|
||||
custom_domain.ownership_verified = True
|
||||
Session.commit()
|
||||
return DomainValidationResult(success=True, errors=[])
|
||||
|
|
|
@ -141,7 +141,10 @@ def domain_detail_dns(custom_domain_id):
|
|||
return render_template(
|
||||
"dashboard/domain_detail/dns.html",
|
||||
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
|
||||
dkim_records=domain_validator.get_dkim_records(),
|
||||
ownership_record=domain_validator.get_ownership_verification_record(
|
||||
custom_domain
|
||||
),
|
||||
dkim_records=domain_validator.get_dkim_records(custom_domain),
|
||||
dmarc_record=DMARC_RECORD,
|
||||
**locals(),
|
||||
)
|
||||
|
|
|
@ -2451,9 +2451,6 @@ class CustomDomain(Base, ModelMixin):
|
|||
def get_trash_url(self):
|
||||
return config.URL + f"/dashboard/domains/{self.id}/trash"
|
||||
|
||||
def get_ownership_dns_txt_value(self):
|
||||
return f"sl-verification={self.ownership_txt_token}"
|
||||
|
||||
@classmethod
|
||||
def create(cls, **kwargs):
|
||||
domain = kwargs.get("domain")
|
||||
|
|
|
@ -38,7 +38,7 @@
|
|||
Value: <em data-toggle="tooltip"
|
||||
title="Click to copy"
|
||||
class="clipboard"
|
||||
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
|
||||
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
|
||||
</div>
|
||||
<form method="post" action="#ownership-form">
|
||||
{{ csrf_form.csrf_token }}
|
||||
|
|
|
@ -6,6 +6,7 @@ from app.custom_domain_validation import CustomDomainValidation
|
|||
from app.db import Session
|
||||
from app.models import CustomDomain, User
|
||||
from app.dns_utils import InMemoryDNSClient
|
||||
from app.proton.utils import get_proton_partner
|
||||
from app.utils import random_string
|
||||
from tests.utils import create_new_user, random_domain
|
||||
|
||||
|
@ -27,8 +28,9 @@ def create_custom_domain(domain: str) -> CustomDomain:
|
|||
|
||||
def test_custom_domain_validation_get_dkim_records():
|
||||
domain = random_domain()
|
||||
custom_domain = create_custom_domain(domain)
|
||||
validator = CustomDomainValidation(domain)
|
||||
records = validator.get_dkim_records()
|
||||
records = validator.get_dkim_records(custom_domain)
|
||||
|
||||
assert len(records) == 3
|
||||
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
||||
|
@ -36,6 +38,26 @@ def test_custom_domain_validation_get_dkim_records():
|
|||
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
||||
|
||||
|
||||
def test_custom_domain_validation_get_dkim_records_for_partner():
|
||||
domain = random_domain()
|
||||
custom_domain = create_custom_domain(domain)
|
||||
|
||||
partner_id = get_proton_partner().id
|
||||
custom_domain.partner_id = partner_id
|
||||
Session.commit()
|
||||
|
||||
dkim_domain = random_domain()
|
||||
validator = CustomDomainValidation(
|
||||
domain, partner_domains={partner_id: dkim_domain}
|
||||
)
|
||||
records = validator.get_dkim_records(custom_domain)
|
||||
|
||||
assert len(records) == 3
|
||||
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
|
||||
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
|
||||
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
|
||||
|
||||
|
||||
# validate_dkim_records
|
||||
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
||||
dns_client = InMemoryDNSClient()
|
||||
|
@ -169,7 +191,36 @@ def test_custom_domain_validation_validate_ownership_success():
|
|||
|
||||
domain = create_custom_domain(random_domain())
|
||||
|
||||
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
|
||||
dns_client.set_txt_record(
|
||||
domain.domain, [validator.get_ownership_verification_record(domain)]
|
||||
)
|
||||
res = validator.validate_domain_ownership(domain)
|
||||
|
||||
assert res.success is True
|
||||
assert len(res.errors) == 0
|
||||
|
||||
db_domain = CustomDomain.get_by(id=domain.id)
|
||||
assert db_domain.ownership_verified is True
|
||||
|
||||
|
||||
def test_custom_domain_validation_validate_ownership_from_partner_success():
|
||||
dns_client = InMemoryDNSClient()
|
||||
partner_id = get_proton_partner().id
|
||||
|
||||
prefix = random_string()
|
||||
validator = CustomDomainValidation(
|
||||
random_domain(),
|
||||
dns_client,
|
||||
partner_domains_validation_prefixes={partner_id: prefix},
|
||||
)
|
||||
|
||||
domain = create_custom_domain(random_domain())
|
||||
domain.partner_id = partner_id
|
||||
Session.commit()
|
||||
|
||||
dns_client.set_txt_record(
|
||||
domain.domain, [validator.get_ownership_verification_record(domain)]
|
||||
)
|
||||
res = validator.validate_domain_ownership(domain)
|
||||
|
||||
assert res.success is True
|
||||
|
|
Loading…
Reference in a new issue