Create Partner only domains (#1665)

* Add Partner only domains

* Add hidden domain to the test and revert to default domains after the tests

* Send what to show in each call

* Fix: Pass none instead of false

* Removed flag from partnerusr

---------

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
Adrià Casajús 2023-04-04 15:21:51 +02:00 committed by GitHub
parent 03e5083d97
commit 43b91cd197
No known key found for this signature in database
GPG Key ID: 4AEE18F83AFDEB23
4 changed files with 201 additions and 36 deletions

View File

@ -6,8 +6,7 @@ from typing import Optional
import itsdangerous import itsdangerous
from app import config from app import config
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User, Partner
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@ -87,7 +86,11 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
return True return True
def get_alias_suffixes(user: User) -> [AliasSuffix]: def get_alias_suffixes(
user: User,
show_domains_for_partner: Optional[Partner] = None,
show_sl_domains: bool = True,
) -> [AliasSuffix]:
""" """
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up. Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
""" """
@ -139,7 +142,7 @@ def get_alias_suffixes(user: User) -> [AliasSuffix]:
alias_suffixes.append(alias_suffix) alias_suffixes.append(alias_suffix)
# then SimpleLogin domain # then SimpleLogin domain
for sl_domain in user.get_sl_domains(): for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains):
suffix = ( suffix = (
( (
"" ""

View File

@ -18,7 +18,7 @@ from flanker.addresslib import address
from flask import url_for from flask import url_for
from flask_login import UserMixin from flask_login import UserMixin
from jinja2 import FileSystemLoader, Environment from jinja2 import FileSystemLoader, Environment
from sqlalchemy import orm from sqlalchemy import orm, or_
from sqlalchemy import text, desc, CheckConstraint, Index, Column from sqlalchemy import text, desc, CheckConstraint, Index, Column
from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.dialects.postgresql import TSVECTOR
from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.ext.declarative import declarative_base
@ -967,13 +967,31 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
""" """
return [sl_domain.domain for sl_domain in self.get_sl_domains()] return [sl_domain.domain for sl_domain in self.get_sl_domains()]
def get_sl_domains(self) -> List["SLDomain"]: def get_sl_domains(
query = SLDomain.filter_by(hidden=False).order_by(SLDomain.order) self,
show_domains_for_partner: Optional[Partner] = None,
if self.is_premium(): show_sl_domains: bool = True,
return query.all() ) -> list["SLDomain"]:
conditions = [SLDomain.hidden == False] # noqa: E712
if not self.is_premium():
conditions.append(SLDomain.premium_only == False) # noqa: E712
partner_domain_cond = [] # noqa:E711
if show_domains_for_partner is not None:
partner_user = PartnerUser.filter_by(
user_id=self.id, partner_id=show_domains_for_partner.id
).first()
if partner_user is not None:
partner_domain_cond.append(
SLDomain.partner_id == partner_user.partner_id
)
if show_sl_domains:
partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711
if len(partner_domain_cond) == 1:
conditions.append(partner_domain_cond[0])
else: else:
return query.filter_by(premium_only=False).all() conditions.append(or_(*partner_domain_cond))
query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order)
return query.all()
def available_alias_domains(self) -> [str]: def available_alias_domains(self) -> [str]:
"""return all domains that user can use when creating a new alias, including: """return all domains that user can use when creating a new alias, including:
@ -2768,6 +2786,31 @@ class Notification(Base, ModelMixin):
) )
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class SLDomain(Base, ModelMixin): class SLDomain(Base, ModelMixin):
"""SimpleLogin domains""" """SimpleLogin domains"""
@ -2785,6 +2828,13 @@ class SLDomain(Base, ModelMixin):
sa.Boolean, nullable=False, default=False, server_default="0" sa.Boolean, nullable=False, default=False, server_default="0"
) )
partner_id = sa.Column(
sa.ForeignKey(Partner.id, ondelete="cascade"),
nullable=True,
default=None,
sever_default="NULL",
)
# if enabled, do not show this domain when user creates a custom alias # if enabled, do not show this domain when user creates a custom alias
hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0")
@ -3231,31 +3281,6 @@ class ProviderComplaint(Base, ModelMixin):
refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id]) refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id])
class Partner(Base, ModelMixin):
__tablename__ = "partner"
name = sa.Column(sa.String(128), unique=True, nullable=False)
contact_email = sa.Column(sa.String(128), unique=True, nullable=False)
@staticmethod
def find_by_token(token: str) -> Optional[Partner]:
hmaced = PartnerApiToken.hmac_token(token)
res = (
Session.query(Partner, PartnerApiToken)
.filter(
and_(
PartnerApiToken.token == hmaced,
Partner.id == PartnerApiToken.partner_id,
)
)
.first()
)
if res:
partner, partner_api_token = res
return partner
return None
class PartnerApiToken(Base, ModelMixin): class PartnerApiToken(Base, ModelMixin):
__tablename__ = "partner_api_token" __tablename__ = "partner_api_token"

View File

@ -0,0 +1,31 @@
"""empty message
Revision ID: 5f4a5625da66
Revises: 2c2093c82bc0
Create Date: 2023-04-03 18:30:46.488231
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '5f4a5625da66'
down_revision = '2c2093c82bc0'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('public_domain', sa.Column('partner_id', sa.Integer(), nullable=True, sever_default='NULL'))
op.create_foreign_key(None, 'public_domain', 'partner', ['partner_id'], ['id'], ondelete='cascade')
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_constraint(None, 'public_domain', type_='foreignkey')
op.drop_column('public_domain', 'partner_id')
# ### end Alembic commands ###

106
tests/test_domains.py Normal file
View File

@ -0,0 +1,106 @@
from app.db import Session
from app.models import SLDomain, PartnerUser
from app.proton.utils import get_proton_partner
from init_app import add_sl_domains
from tests.utils import create_new_user, random_token
def setup_module():
Session.query(SLDomain).delete()
SLDomain.create(
domain="hidden", premium_only=False, flush=True, order=5, hidden=True
)
SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4)
SLDomain.create(
domain="premium_non_partner", premium_only=True, flush=True, order=3
)
SLDomain.create(
domain="free_partner",
premium_only=False,
flush=True,
partner_id=get_proton_partner().id,
order=2,
)
SLDomain.create(
domain="premium_partner",
premium_only=True,
flush=True,
partner_id=get_proton_partner().id,
order=1,
)
Session.commit()
def teardown_module():
Session.query(SLDomain).delete()
add_sl_domains()
def test_get_non_partner_domains():
user = create_new_user()
domains = user.get_sl_domains()
# Premium
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
# Free
user.trial_end = None
Session.flush()
domains = user.get_sl_domains()
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
def test_get_free_with_partner_domains():
user = create_new_user()
user.trial_end = None
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
domains = user.get_sl_domains()
# Default
assert len(domains) == 1
assert domains[0].domain == "free_non_partner"
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
assert len(domains) == 2
assert domains[0].domain == "free_partner"
assert domains[1].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
)
assert len(domains) == 1
assert domains[0].domain == "free_partner"
def test_get_premium_with_partner_domains():
user = create_new_user()
PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
domains = user.get_sl_domains()
# Default
assert len(domains) == 2
assert domains[0].domain == "premium_non_partner"
assert domains[1].domain == "free_non_partner"
# Show partner domains
domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner())
assert len(domains) == 4
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"
assert domains[2].domain == "premium_non_partner"
assert domains[3].domain == "free_non_partner"
# Only partner domains
domains = user.get_sl_domains(
show_domains_for_partner=get_proton_partner(), show_sl_domains=False
)
assert len(domains) == 2
assert domains[0].domain == "premium_partner"
assert domains[1].domain == "free_partner"