mirror of
https://github.com/simple-login/app.git
synced 2024-09-21 01:11:29 +02:00
feat: allow to define partner records (#2225)
This commit is contained in:
parent
f6708dd0b6
commit
b5866fa779
@ -638,19 +638,22 @@ EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_
|
|||||||
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
|
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
|
||||||
|
|
||||||
|
|
||||||
def read_partner_domains() -> dict[int, str]:
|
def read_partner_dict(var: str) -> dict[int, str]:
|
||||||
partner_domains_dict = get_env_dict("PARTNER_DOMAINS")
|
partner_value = get_env_dict(var)
|
||||||
if len(partner_domains_dict) == 0:
|
if len(partner_value) == 0:
|
||||||
return {}
|
return {}
|
||||||
|
|
||||||
res: dict[int, str] = {}
|
res: dict[int, str] = {}
|
||||||
for partner_id in partner_domains_dict.keys():
|
for partner_id in partner_value.keys():
|
||||||
try:
|
try:
|
||||||
partner_id_int = int(partner_id.strip())
|
partner_id_int = int(partner_id.strip())
|
||||||
res[partner_id_int] = partner_domains_dict[partner_id]
|
res[partner_id_int] = partner_value[partner_id]
|
||||||
except ValueError:
|
except ValueError:
|
||||||
pass
|
pass
|
||||||
return res
|
return res
|
||||||
|
|
||||||
|
|
||||||
PARTNER_DOMAINS: dict[int, str] = read_partner_domains()
|
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
|
||||||
|
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
|
||||||
|
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
|
||||||
|
)
|
||||||
|
@ -1,4 +1,12 @@
|
|||||||
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
|
from dataclasses import dataclass
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
from app.config import (
|
||||||
|
EMAIL_SERVERS_WITH_PRIORITY,
|
||||||
|
EMAIL_DOMAIN,
|
||||||
|
PARTNER_DOMAINS,
|
||||||
|
PARTNER_DOMAIN_VALIDATION_PREFIXES,
|
||||||
|
)
|
||||||
from app.constants import DMARC_RECORD
|
from app.constants import DMARC_RECORD
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.dns_utils import (
|
from app.dns_utils import (
|
||||||
@ -7,7 +15,6 @@ from app.dns_utils import (
|
|||||||
get_network_dns_client,
|
get_network_dns_client,
|
||||||
)
|
)
|
||||||
from app.models import CustomDomain
|
from app.models import CustomDomain
|
||||||
from dataclasses import dataclass
|
|
||||||
|
|
||||||
|
|
||||||
@dataclass
|
@dataclass
|
||||||
@ -18,22 +25,47 @@ class DomainValidationResult:
|
|||||||
|
|
||||||
class CustomDomainValidation:
|
class CustomDomainValidation:
|
||||||
def __init__(
|
def __init__(
|
||||||
self, dkim_domain: str, dns_client: DNSClient = get_network_dns_client()
|
self,
|
||||||
|
dkim_domain: str,
|
||||||
|
dns_client: DNSClient = get_network_dns_client(),
|
||||||
|
partner_domains: Optional[dict[int, str]] = None,
|
||||||
|
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
|
||||||
):
|
):
|
||||||
self.dkim_domain = dkim_domain
|
self.dkim_domain = dkim_domain
|
||||||
self._dns_client = dns_client
|
self._dns_client = dns_client
|
||||||
self._dkim_records = {
|
self._partner_domains = partner_domains or PARTNER_DOMAINS
|
||||||
f"{key}._domainkey": f"{key}._domainkey.{self.dkim_domain}"
|
self._partner_domain_validation_prefixes = (
|
||||||
|
partner_domains_validation_prefixes or PARTNER_DOMAIN_VALIDATION_PREFIXES
|
||||||
|
)
|
||||||
|
|
||||||
|
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
|
||||||
|
prefix = "sl-verification"
|
||||||
|
if (
|
||||||
|
domain.partner_id is not None
|
||||||
|
and domain.partner_id in self._partner_domain_validation_prefixes
|
||||||
|
):
|
||||||
|
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
|
||||||
|
return f"{prefix}={domain.ownership_txt_token}"
|
||||||
|
|
||||||
|
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
|
||||||
|
"""
|
||||||
|
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
|
||||||
|
it will return the default ones or the partner ones.
|
||||||
|
"""
|
||||||
|
|
||||||
|
# By default use the default domain
|
||||||
|
dkim_domain = self.dkim_domain
|
||||||
|
if domain.partner_id is not None:
|
||||||
|
# Domain is from a partner. Retrieve the partner config and use that domain if exists
|
||||||
|
partner_domain = self._partner_domains.get(domain.partner_id)
|
||||||
|
if partner_domain is not None:
|
||||||
|
dkim_domain = partner_domain
|
||||||
|
|
||||||
|
return {
|
||||||
|
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
|
||||||
for key in ("dkim", "dkim02", "dkim03")
|
for key in ("dkim", "dkim02", "dkim03")
|
||||||
}
|
}
|
||||||
|
|
||||||
def get_dkim_records(self) -> {str: str}:
|
|
||||||
"""
|
|
||||||
Get a list of dkim records to set up. It will be
|
|
||||||
|
|
||||||
"""
|
|
||||||
return self._dkim_records
|
|
||||||
|
|
||||||
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
|
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
|
||||||
"""
|
"""
|
||||||
Check if dkim records are properly set for this custom domain.
|
Check if dkim records are properly set for this custom domain.
|
||||||
@ -41,7 +73,7 @@ class CustomDomainValidation:
|
|||||||
"""
|
"""
|
||||||
correct_records = {}
|
correct_records = {}
|
||||||
invalid_records = {}
|
invalid_records = {}
|
||||||
expected_records = self.get_dkim_records()
|
expected_records = self.get_dkim_records(custom_domain)
|
||||||
for prefix, expected_record in expected_records.items():
|
for prefix, expected_record in expected_records.items():
|
||||||
custom_record = f"{prefix}.{custom_domain.domain}"
|
custom_record = f"{prefix}.{custom_domain.domain}"
|
||||||
dkim_record = self._dns_client.get_cname_record(custom_record)
|
dkim_record = self._dns_client.get_cname_record(custom_record)
|
||||||
@ -75,8 +107,11 @@ class CustomDomainValidation:
|
|||||||
Check if the custom_domain has added the ownership verification records
|
Check if the custom_domain has added the ownership verification records
|
||||||
"""
|
"""
|
||||||
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
|
||||||
|
expected_verification_record = self.get_ownership_verification_record(
|
||||||
|
custom_domain
|
||||||
|
)
|
||||||
|
|
||||||
if custom_domain.get_ownership_dns_txt_value() in txt_records:
|
if expected_verification_record in txt_records:
|
||||||
custom_domain.ownership_verified = True
|
custom_domain.ownership_verified = True
|
||||||
Session.commit()
|
Session.commit()
|
||||||
return DomainValidationResult(success=True, errors=[])
|
return DomainValidationResult(success=True, errors=[])
|
||||||
|
@ -141,7 +141,10 @@ def domain_detail_dns(custom_domain_id):
|
|||||||
return render_template(
|
return render_template(
|
||||||
"dashboard/domain_detail/dns.html",
|
"dashboard/domain_detail/dns.html",
|
||||||
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
|
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
|
||||||
dkim_records=domain_validator.get_dkim_records(),
|
ownership_record=domain_validator.get_ownership_verification_record(
|
||||||
|
custom_domain
|
||||||
|
),
|
||||||
|
dkim_records=domain_validator.get_dkim_records(custom_domain),
|
||||||
dmarc_record=DMARC_RECORD,
|
dmarc_record=DMARC_RECORD,
|
||||||
**locals(),
|
**locals(),
|
||||||
)
|
)
|
||||||
|
@ -2451,9 +2451,6 @@ class CustomDomain(Base, ModelMixin):
|
|||||||
def get_trash_url(self):
|
def get_trash_url(self):
|
||||||
return config.URL + f"/dashboard/domains/{self.id}/trash"
|
return config.URL + f"/dashboard/domains/{self.id}/trash"
|
||||||
|
|
||||||
def get_ownership_dns_txt_value(self):
|
|
||||||
return f"sl-verification={self.ownership_txt_token}"
|
|
||||||
|
|
||||||
@classmethod
|
@classmethod
|
||||||
def create(cls, **kwargs):
|
def create(cls, **kwargs):
|
||||||
domain = kwargs.get("domain")
|
domain = kwargs.get("domain")
|
||||||
|
@ -38,7 +38,7 @@
|
|||||||
Value: <em data-toggle="tooltip"
|
Value: <em data-toggle="tooltip"
|
||||||
title="Click to copy"
|
title="Click to copy"
|
||||||
class="clipboard"
|
class="clipboard"
|
||||||
data-clipboard-text="{{ custom_domain.get_ownership_dns_txt_value() }}">{{ custom_domain.get_ownership_dns_txt_value() }}</em>
|
data-clipboard-text="{{ ownership_record }}">{{ ownership_record }}</em>
|
||||||
</div>
|
</div>
|
||||||
<form method="post" action="#ownership-form">
|
<form method="post" action="#ownership-form">
|
||||||
{{ csrf_form.csrf_token }}
|
{{ csrf_form.csrf_token }}
|
||||||
|
@ -6,6 +6,7 @@ from app.custom_domain_validation import CustomDomainValidation
|
|||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.models import CustomDomain, User
|
from app.models import CustomDomain, User
|
||||||
from app.dns_utils import InMemoryDNSClient
|
from app.dns_utils import InMemoryDNSClient
|
||||||
|
from app.proton.utils import get_proton_partner
|
||||||
from app.utils import random_string
|
from app.utils import random_string
|
||||||
from tests.utils import create_new_user, random_domain
|
from tests.utils import create_new_user, random_domain
|
||||||
|
|
||||||
@ -27,8 +28,9 @@ def create_custom_domain(domain: str) -> CustomDomain:
|
|||||||
|
|
||||||
def test_custom_domain_validation_get_dkim_records():
|
def test_custom_domain_validation_get_dkim_records():
|
||||||
domain = random_domain()
|
domain = random_domain()
|
||||||
|
custom_domain = create_custom_domain(domain)
|
||||||
validator = CustomDomainValidation(domain)
|
validator = CustomDomainValidation(domain)
|
||||||
records = validator.get_dkim_records()
|
records = validator.get_dkim_records(custom_domain)
|
||||||
|
|
||||||
assert len(records) == 3
|
assert len(records) == 3
|
||||||
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{domain}"
|
||||||
@ -36,6 +38,26 @@ def test_custom_domain_validation_get_dkim_records():
|
|||||||
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
assert records["dkim._domainkey"] == f"dkim._domainkey.{domain}"
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_get_dkim_records_for_partner():
|
||||||
|
domain = random_domain()
|
||||||
|
custom_domain = create_custom_domain(domain)
|
||||||
|
|
||||||
|
partner_id = get_proton_partner().id
|
||||||
|
custom_domain.partner_id = partner_id
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
dkim_domain = random_domain()
|
||||||
|
validator = CustomDomainValidation(
|
||||||
|
domain, partner_domains={partner_id: dkim_domain}
|
||||||
|
)
|
||||||
|
records = validator.get_dkim_records(custom_domain)
|
||||||
|
|
||||||
|
assert len(records) == 3
|
||||||
|
assert records["dkim02._domainkey"] == f"dkim02._domainkey.{dkim_domain}"
|
||||||
|
assert records["dkim03._domainkey"] == f"dkim03._domainkey.{dkim_domain}"
|
||||||
|
assert records["dkim._domainkey"] == f"dkim._domainkey.{dkim_domain}"
|
||||||
|
|
||||||
|
|
||||||
# validate_dkim_records
|
# validate_dkim_records
|
||||||
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
def test_custom_domain_validation_validate_dkim_records_empty_records_failure():
|
||||||
dns_client = InMemoryDNSClient()
|
dns_client = InMemoryDNSClient()
|
||||||
@ -169,7 +191,36 @@ def test_custom_domain_validation_validate_ownership_success():
|
|||||||
|
|
||||||
domain = create_custom_domain(random_domain())
|
domain = create_custom_domain(random_domain())
|
||||||
|
|
||||||
dns_client.set_txt_record(domain.domain, [domain.get_ownership_dns_txt_value()])
|
dns_client.set_txt_record(
|
||||||
|
domain.domain, [validator.get_ownership_verification_record(domain)]
|
||||||
|
)
|
||||||
|
res = validator.validate_domain_ownership(domain)
|
||||||
|
|
||||||
|
assert res.success is True
|
||||||
|
assert len(res.errors) == 0
|
||||||
|
|
||||||
|
db_domain = CustomDomain.get_by(id=domain.id)
|
||||||
|
assert db_domain.ownership_verified is True
|
||||||
|
|
||||||
|
|
||||||
|
def test_custom_domain_validation_validate_ownership_from_partner_success():
|
||||||
|
dns_client = InMemoryDNSClient()
|
||||||
|
partner_id = get_proton_partner().id
|
||||||
|
|
||||||
|
prefix = random_string()
|
||||||
|
validator = CustomDomainValidation(
|
||||||
|
random_domain(),
|
||||||
|
dns_client,
|
||||||
|
partner_domains_validation_prefixes={partner_id: prefix},
|
||||||
|
)
|
||||||
|
|
||||||
|
domain = create_custom_domain(random_domain())
|
||||||
|
domain.partner_id = partner_id
|
||||||
|
Session.commit()
|
||||||
|
|
||||||
|
dns_client.set_txt_record(
|
||||||
|
domain.domain, [validator.get_ownership_verification_record(domain)]
|
||||||
|
)
|
||||||
res = validator.validate_domain_ownership(domain)
|
res = validator.validate_domain_ownership(domain)
|
||||||
|
|
||||||
assert res.success is True
|
assert res.success is True
|
||||||
|
Loading…
Reference in New Issue
Block a user