diff --git a/app/alias_suffix.py b/app/alias_suffix.py index 8d21884a..f5e5b03d 100644 --- a/app/alias_suffix.py +++ b/app/alias_suffix.py @@ -6,8 +6,7 @@ from typing import Optional import itsdangerous from app import config from app.log import LOG -from app.models import User - +from app.models import User, Partner signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET) @@ -87,7 +86,11 @@ def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool: return True -def get_alias_suffixes(user: User) -> [AliasSuffix]: +def get_alias_suffixes( + user: User, + show_domains_for_partner: Optional[Partner] = None, + show_sl_domains: bool = True, +) -> [AliasSuffix]: """ Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up. """ @@ -139,7 +142,7 @@ def get_alias_suffixes(user: User) -> [AliasSuffix]: alias_suffixes.append(alias_suffix) # then SimpleLogin domain - for sl_domain in user.get_sl_domains(): + for sl_domain in user.get_sl_domains(show_domains_for_partner, show_sl_domains): suffix = ( ( "" diff --git a/app/models.py b/app/models.py index 525118fa..e4021dcc 100644 --- a/app/models.py +++ b/app/models.py @@ -18,7 +18,7 @@ from flanker.addresslib import address from flask import url_for from flask_login import UserMixin from jinja2 import FileSystemLoader, Environment -from sqlalchemy import orm +from sqlalchemy import orm, or_ from sqlalchemy import text, desc, CheckConstraint, Index, Column from sqlalchemy.dialects.postgresql import TSVECTOR from sqlalchemy.ext.declarative import declarative_base @@ -967,13 +967,31 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle): """ return [sl_domain.domain for sl_domain in self.get_sl_domains()] - def get_sl_domains(self) -> List["SLDomain"]: - query = SLDomain.filter_by(hidden=False).order_by(SLDomain.order) - - if self.is_premium(): - return query.all() + def get_sl_domains( + self, + show_domains_for_partner: Optional[Partner] = None, + show_sl_domains: bool = True, + ) -> list["SLDomain"]: + conditions = [SLDomain.hidden == False] # noqa: E712 + if not self.is_premium(): + conditions.append(SLDomain.premium_only == False) # noqa: E712 + partner_domain_cond = [] # noqa:E711 + if show_domains_for_partner is not None: + partner_user = PartnerUser.filter_by( + user_id=self.id, partner_id=show_domains_for_partner.id + ).first() + if partner_user is not None: + partner_domain_cond.append( + SLDomain.partner_id == partner_user.partner_id + ) + if show_sl_domains: + partner_domain_cond.append(SLDomain.partner_id == None) # noqa:E711 + if len(partner_domain_cond) == 1: + conditions.append(partner_domain_cond[0]) else: - return query.filter_by(premium_only=False).all() + conditions.append(or_(*partner_domain_cond)) + query = Session.query(SLDomain).filter(*conditions).order_by(SLDomain.order) + return query.all() def available_alias_domains(self) -> [str]: """return all domains that user can use when creating a new alias, including: @@ -2768,6 +2786,31 @@ class Notification(Base, ModelMixin): ) +class Partner(Base, ModelMixin): + __tablename__ = "partner" + + name = sa.Column(sa.String(128), unique=True, nullable=False) + contact_email = sa.Column(sa.String(128), unique=True, nullable=False) + + @staticmethod + def find_by_token(token: str) -> Optional[Partner]: + hmaced = PartnerApiToken.hmac_token(token) + res = ( + Session.query(Partner, PartnerApiToken) + .filter( + and_( + PartnerApiToken.token == hmaced, + Partner.id == PartnerApiToken.partner_id, + ) + ) + .first() + ) + if res: + partner, partner_api_token = res + return partner + return None + + class SLDomain(Base, ModelMixin): """SimpleLogin domains""" @@ -2785,6 +2828,13 @@ class SLDomain(Base, ModelMixin): sa.Boolean, nullable=False, default=False, server_default="0" ) + partner_id = sa.Column( + sa.ForeignKey(Partner.id, ondelete="cascade"), + nullable=True, + default=None, + sever_default="NULL", + ) + # if enabled, do not show this domain when user creates a custom alias hidden = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") @@ -3231,31 +3281,6 @@ class ProviderComplaint(Base, ModelMixin): refused_email = orm.relationship(RefusedEmail, foreign_keys=[refused_email_id]) -class Partner(Base, ModelMixin): - __tablename__ = "partner" - - name = sa.Column(sa.String(128), unique=True, nullable=False) - contact_email = sa.Column(sa.String(128), unique=True, nullable=False) - - @staticmethod - def find_by_token(token: str) -> Optional[Partner]: - hmaced = PartnerApiToken.hmac_token(token) - res = ( - Session.query(Partner, PartnerApiToken) - .filter( - and_( - PartnerApiToken.token == hmaced, - Partner.id == PartnerApiToken.partner_id, - ) - ) - .first() - ) - if res: - partner, partner_api_token = res - return partner - return None - - class PartnerApiToken(Base, ModelMixin): __tablename__ = "partner_api_token" diff --git a/migrations/versions/2023_040318_5f4a5625da66_.py b/migrations/versions/2023_040318_5f4a5625da66_.py new file mode 100644 index 00000000..b912bea5 --- /dev/null +++ b/migrations/versions/2023_040318_5f4a5625da66_.py @@ -0,0 +1,31 @@ +"""empty message + +Revision ID: 5f4a5625da66 +Revises: 2c2093c82bc0 +Create Date: 2023-04-03 18:30:46.488231 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = '5f4a5625da66' +down_revision = '2c2093c82bc0' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('public_domain', sa.Column('partner_id', sa.Integer(), nullable=True, sever_default='NULL')) + op.create_foreign_key(None, 'public_domain', 'partner', ['partner_id'], ['id'], ondelete='cascade') + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_constraint(None, 'public_domain', type_='foreignkey') + op.drop_column('public_domain', 'partner_id') + # ### end Alembic commands ### diff --git a/tests/test_domains.py b/tests/test_domains.py new file mode 100644 index 00000000..07e73cc4 --- /dev/null +++ b/tests/test_domains.py @@ -0,0 +1,106 @@ +from app.db import Session +from app.models import SLDomain, PartnerUser +from app.proton.utils import get_proton_partner +from init_app import add_sl_domains +from tests.utils import create_new_user, random_token + + +def setup_module(): + Session.query(SLDomain).delete() + SLDomain.create( + domain="hidden", premium_only=False, flush=True, order=5, hidden=True + ) + SLDomain.create(domain="free_non_partner", premium_only=False, flush=True, order=4) + SLDomain.create( + domain="premium_non_partner", premium_only=True, flush=True, order=3 + ) + SLDomain.create( + domain="free_partner", + premium_only=False, + flush=True, + partner_id=get_proton_partner().id, + order=2, + ) + SLDomain.create( + domain="premium_partner", + premium_only=True, + flush=True, + partner_id=get_proton_partner().id, + order=1, + ) + Session.commit() + + +def teardown_module(): + Session.query(SLDomain).delete() + add_sl_domains() + + +def test_get_non_partner_domains(): + user = create_new_user() + domains = user.get_sl_domains() + # Premium + assert len(domains) == 2 + assert domains[0].domain == "premium_non_partner" + assert domains[1].domain == "free_non_partner" + # Free + user.trial_end = None + Session.flush() + domains = user.get_sl_domains() + assert len(domains) == 1 + assert domains[0].domain == "free_non_partner" + + +def test_get_free_with_partner_domains(): + user = create_new_user() + user.trial_end = None + PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + domains = user.get_sl_domains() + # Default + assert len(domains) == 1 + assert domains[0].domain == "free_non_partner" + # Show partner domains + domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) + assert len(domains) == 2 + assert domains[0].domain == "free_partner" + assert domains[1].domain == "free_non_partner" + # Only partner domains + domains = user.get_sl_domains( + show_domains_for_partner=get_proton_partner(), show_sl_domains=False + ) + assert len(domains) == 1 + assert domains[0].domain == "free_partner" + + +def test_get_premium_with_partner_domains(): + user = create_new_user() + PartnerUser.create( + partner_id=get_proton_partner().id, + user_id=user.id, + external_user_id=random_token(10), + flush=True, + ) + domains = user.get_sl_domains() + # Default + assert len(domains) == 2 + assert domains[0].domain == "premium_non_partner" + assert domains[1].domain == "free_non_partner" + # Show partner domains + domains = user.get_sl_domains(show_domains_for_partner=get_proton_partner()) + assert len(domains) == 4 + assert domains[0].domain == "premium_partner" + assert domains[1].domain == "free_partner" + assert domains[2].domain == "premium_non_partner" + assert domains[3].domain == "free_non_partner" + # Only partner domains + domains = user.get_sl_domains( + show_domains_for_partner=get_proton_partner(), show_sl_domains=False + ) + assert len(domains) == 2 + assert domains[0].domain == "premium_partner" + assert domains[1].domain == "free_partner"