diff --git a/app/admin_model.py b/app/admin_model.py index 17c0e01e..d1437352 100644 --- a/app/admin_model.py +++ b/app/admin_model.py @@ -5,7 +5,7 @@ from flask_admin.actions import action from flask_admin.contrib import sqla from flask_login import current_user, login_user -from app.extensions import db +from app.db import Session from app.models import User, ManualSubscription @@ -99,7 +99,7 @@ class UserAdmin(SLModelView): "Extend trial for 1 week more?", ) def extend_trial_1w(self, ids): - for user in User.query.filter(User.id.in_(ids)): + for user in User.filter(User.id.in_(ids)): if user.trial_end and user.trial_end > arrow.now(): user.trial_end = user.trial_end.shift(weeks=1) else: @@ -107,7 +107,7 @@ class UserAdmin(SLModelView): flash(f"Extend trial for {user} to {user.trial_end}", "success") - db.session.commit() + Session.commit() @action( "disable_otp", @@ -115,12 +115,12 @@ class UserAdmin(SLModelView): "Disable OTP?", ) def disable_otp(self, ids): - for user in User.query.filter(User.id.in_(ids)): + for user in User.filter(User.id.in_(ids)): if user.enable_otp: user.enable_otp = False flash(f"Disable OTP for {user}", "info") - db.session.commit() + Session.commit() @action( "login_as", @@ -132,16 +132,14 @@ class UserAdmin(SLModelView): flash("only 1 user can be selected", "error") return - for user in User.query.filter(User.id.in_(ids)): + for user in User.filter(User.id.in_(ids)): login_user(user) flash(f"Login as user {user}", "success") return redirect("/") def manual_upgrade(way: str, ids: [int], is_giveaway: bool): - query = User.query.filter(User.id.in_(ids)) - - for user in query.all(): + for user in User.filter(User.id.in_(ids)).all(): manual_sub: ManualSubscription = ManualSubscription.get_by(user_id=user.id) if manual_sub: # renew existing subscription @@ -149,7 +147,7 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool): manual_sub.end_at = manual_sub.end_at.shift(years=1) else: manual_sub.end_at = arrow.now().shift(years=1, days=1) - db.session.commit() + Session.commit() flash(f"Subscription extended to {manual_sub.end_at.humanize()}", "success") continue @@ -211,11 +209,11 @@ class ManualSubscriptionAdmin(SLModelView): "Extend 1 year more?", ) def extend_1y(self, ids): - for ms in ManualSubscription.query.filter(ManualSubscription.id.in_(ids)): + for ms in ManualSubscription.filter(ManualSubscription.id.in_(ids)): ms.end_at = ms.end_at.shift(years=1) flash(f"Extend subscription for {ms.user}", "success") - db.session.commit() + Session.commit() class ClientAdmin(SLModelView): diff --git a/app/alias_utils.py b/app/alias_utils.py index a5f1a7d5..0c9e4501 100644 --- a/app/alias_utils.py +++ b/app/alias_utils.py @@ -1,10 +1,11 @@ -import re2 as re from typing import Optional +import re2 as re from email_validator import validate_email, EmailNotValidError from sqlalchemy.exc import IntegrityError, DataError from app.config import BOUNCE_PREFIX_FOR_REPLY_PHASE +from app.db import Session from app.email_utils import ( get_email_domain_part, send_cannot_create_directory_alias, @@ -14,7 +15,6 @@ from app.email_utils import ( get_email_local_part, ) from app.errors import AliasInTrashError -from app.extensions import db from app.log import LOG from app.models import ( Alias, @@ -97,14 +97,14 @@ def try_auto_create_directory(address: str) -> Optional[Alias]: mailbox_id=mailboxes[0].id, note=f"Created by directory {directory.name}", ) - db.session.flush() + Session.flush() for i in range(1, len(mailboxes)): AliasMailbox.create( alias_id=alias.id, mailbox_id=mailboxes[i].id, ) - db.session.commit() + Session.commit() return alias except AliasInTrashError: LOG.w( @@ -116,7 +116,7 @@ def try_auto_create_directory(address: str) -> Optional[Alias]: return None except IntegrityError: LOG.w("Alias %s already exists", address) - db.session.rollback() + Session.rollback() alias = Alias.get_by(email=address) return alias @@ -173,13 +173,13 @@ def try_auto_create_via_domain(address: str) -> Optional[Alias]: mailbox_id=mailboxes[0].id, note=alias_note, ) - db.session.flush() + Session.flush() for i in range(1, len(mailboxes)): AliasMailbox.create( alias_id=alias.id, mailbox_id=mailboxes[i].id, ) - db.session.commit() + Session.commit() return alias except AliasInTrashError: LOG.w( @@ -191,12 +191,12 @@ def try_auto_create_via_domain(address: str) -> Optional[Alias]: return None except IntegrityError: LOG.w("Alias %s already exists", address) - db.session.rollback() + Session.rollback() alias = Alias.get_by(email=address) return alias except DataError: LOG.w("Cannot create alias %s", address) - db.session.rollback() + Session.rollback() return None @@ -211,30 +211,30 @@ def delete_alias(alias: Alias, user: User): email=alias.email, domain_id=alias.custom_domain_id ): LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id) - db.session.add( + Session.add( DomainDeletedAlias( user_id=user.id, email=alias.email, domain_id=alias.custom_domain_id ) ) - db.session.commit() + Session.commit() else: if not DeletedAlias.get_by(email=alias.email): LOG.d("add %s to global trash", alias) - db.session.add(DeletedAlias(email=alias.email)) - db.session.commit() + Session.add(DeletedAlias(email=alias.email)) + Session.commit() - Alias.query.filter(Alias.id == alias.id).delete() - db.session.commit() + Alias.filter(Alias.id == alias.id).delete() + Session.commit() def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]: """ get list of aliases for a given mailbox """ - ret = set(Alias.query.filter(Alias.mailbox_id == mailbox.id).all()) + ret = set(Alias.filter(Alias.mailbox_id == mailbox.id).all()) for alias in ( - db.session.query(Alias) + Session.query(Alias) .join(AliasMailbox, Alias.id == AliasMailbox.alias_id) .filter(AliasMailbox.mailbox_id == mailbox.id) ): @@ -247,7 +247,7 @@ def nb_email_log_for_mailbox(mailbox: Mailbox): aliases = aliases_for_mailbox(mailbox) alias_ids = [alias.id for alias in aliases] return ( - db.session.query(EmailLog) + Session.query(EmailLog) .join(Contact, EmailLog.contact_id == Contact.id) .filter(Contact.alias_id.in_(alias_ids)) .count() diff --git a/app/api/base.py b/app/api/base.py index 6138464f..dc1c86af 100644 --- a/app/api/base.py +++ b/app/api/base.py @@ -4,7 +4,7 @@ import arrow from flask import Blueprint, request, jsonify, g from flask_login import current_user -from app.extensions import db +from app.db import Session from app.models import ApiKey api_bp = Blueprint(name="api", import_name=__name__, url_prefix="/api") @@ -26,7 +26,7 @@ def require_api_auth(f): # Update api key stats api_key.last_used = arrow.now() api_key.times += 1 - db.session.commit() + Session.commit() g.user = api_key.user diff --git a/app/api/serializer.py b/app/api/serializer.py index 2128a05e..e59776b7 100644 --- a/app/api/serializer.py +++ b/app/api/serializer.py @@ -6,7 +6,7 @@ from sqlalchemy import or_, func, case, and_ from sqlalchemy.orm import joinedload from app.config import PAGE_LIMIT -from app.extensions import db +from app.db import Session from app.models import ( Alias, Contact, @@ -117,7 +117,7 @@ def serialize_contact(contact: Contact, existed=False) -> dict: def get_alias_infos_with_pagination(user, page_id=0, query=None) -> [AliasInfo]: ret = [] q = ( - db.session.query(Alias) + Session.query(Alias) .options(joinedload(Alias.mailbox)) .filter(Alias.user_id == user.id) .order_by(Alias.created_at.desc()) @@ -221,7 +221,7 @@ def get_alias_infos_with_pagination_v3( def get_alias_info(alias: Alias) -> AliasInfo: q = ( - db.session.query(Contact, EmailLog) + Session.query(Contact, EmailLog) .filter(Contact.alias_id == alias.id) .filter(EmailLog.contact_id == Contact.id) ) @@ -251,7 +251,7 @@ def get_alias_info_v2(alias: Alias, mailbox=None) -> AliasInfo: mailbox = alias.mailbox q = ( - db.session.query(Contact, EmailLog) + Session.query(Contact, EmailLog) .filter(Contact.alias_id == alias.id) .filter(EmailLog.contact_id == Contact.id) ) @@ -297,7 +297,7 @@ def get_alias_info_v2(alias: Alias, mailbox=None) -> AliasInfo: def get_alias_contacts(alias, page_id: int) -> [dict]: q = ( - Contact.query.filter_by(alias_id=alias.id) + Contact.filter_by(alias_id=alias.id) .order_by(Contact.id.desc()) .limit(PAGE_LIMIT) .offset(page_id * PAGE_LIMIT) @@ -332,7 +332,7 @@ def get_alias_info_v3(user: User, alias_id: int) -> AliasInfo: def construct_alias_query(user: User): # subquery on alias annotated with nb_reply, nb_blocked, nb_forward, max_created_at, latest_email_log_created_at alias_activity_subquery = ( - db.session.query( + Session.query( Alias.id, func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"), func.sum( @@ -364,7 +364,7 @@ def construct_alias_query(user: User): ) alias_contact_subquery = ( - db.session.query(Alias.id, func.max(Contact.id).label("max_contact_id")) + Session.query(Alias.id, func.max(Contact.id).label("max_contact_id")) .join(Contact, Alias.id == Contact.alias_id, isouter=True) .filter(Alias.user_id == user.id) .group_by(Alias.id) @@ -372,7 +372,7 @@ def construct_alias_query(user: User): ) return ( - db.session.query( + Session.query( Alias, Contact, EmailLog, diff --git a/app/api/views/alias.py b/app/api/views/alias.py index 64151bb7..666cafb2 100644 --- a/app/api/views/alias.py +++ b/app/api/views/alias.py @@ -17,10 +17,10 @@ from app.api.serializer import ( get_alias_infos_with_pagination_v3, ) from app.dashboard.views.alias_log import get_alias_log +from app.db import Session from app.email_utils import ( generate_reply_email, ) -from app.extensions import db from app.log import LOG from app.models import Alias, Contact, Mailbox, AliasMailbox from app.utils import sanitize_email @@ -164,7 +164,7 @@ def toggle_alias(alias_id): return jsonify(error="Forbidden"), 403 alias.enabled = not alias.enabled - db.session.commit() + Session.commit() return jsonify(enabled=alias.enabled), 200 @@ -280,8 +280,8 @@ def update_alias(alias_id): # <<< update alias mailboxes >>> # first remove all existing alias-mailboxes links - AliasMailbox.query.filter_by(alias_id=alias.id).delete() - db.session.flush() + AliasMailbox.filter_by(alias_id=alias.id).delete() + Session.flush() # then add all new mailboxes for i, mailbox in enumerate(mailboxes): @@ -310,7 +310,7 @@ def update_alias(alias_id): changed = True if changed: - db.session.commit() + Session.commit() return jsonify(ok=True), 200 @@ -422,7 +422,7 @@ def create_contact_route(alias_id): ) LOG.d("create reverse-alias for %s %s", contact_addr, alias) - db.session.commit() + Session.commit() return jsonify(**serialize_contact(contact)), 201 @@ -444,6 +444,6 @@ def delete_contact(contact_id): return jsonify(error="Forbidden"), 403 Contact.delete(contact_id) - db.session.commit() + Session.commit() return jsonify(deleted=True), 200 diff --git a/app/api/views/alias_options.py b/app/api/views/alias_options.py index a140b309..3312922a 100644 --- a/app/api/views/alias_options.py +++ b/app/api/views/alias_options.py @@ -5,7 +5,7 @@ from app.api.base import api_bp, require_api_auth from app.dashboard.views.custom_alias import ( get_available_suffixes, ) -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import AliasUsedOn, Alias, User from app.utils import convert_to_id @@ -43,7 +43,7 @@ def options_v4(): if hostname: # put the latest used alias first q = ( - db.session.query(AliasUsedOn, Alias, User) + Session.query(AliasUsedOn, Alias, User) .filter( AliasUsedOn.alias_id == Alias.id, Alias.user_id == user.id, @@ -114,7 +114,7 @@ def options_v5(): if hostname: # put the latest used alias first q = ( - db.session.query(AliasUsedOn, Alias, User) + Session.query(AliasUsedOn, Alias, User) .filter( AliasUsedOn.alias_id == Alias.id, Alias.user_id == user.id, diff --git a/app/api/views/apple.py b/app/api/views/apple.py index 563c9d0b..a8e6c589 100644 --- a/app/api/views/apple.py +++ b/app/api/views/apple.py @@ -9,7 +9,7 @@ from requests import RequestException from app.api.base import api_bp, require_api_auth from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import PlanEnum, AppleSubscription @@ -279,7 +279,7 @@ def apple_update_notification(): apple_sub.receipt_data = data["unified_receipt"]["latest_receipt"] apple_sub.expires_date = expires_date apple_sub.plan = plan - db.session.commit() + Session.commit() return jsonify(ok=True), 200 else: LOG.w( @@ -544,6 +544,6 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]: plan=plan, ) - db.session.commit() + Session.commit() return apple_sub diff --git a/app/api/views/auth.py b/app/api/views/auth.py index 4b7a072f..a29f85b0 100644 --- a/app/api/views/auth.py +++ b/app/api/views/auth.py @@ -11,13 +11,14 @@ from app import email_utils from app.api.base import api_bp from app.config import FLASK_SECRET, DISABLE_REGISTRATION from app.dashboard.views.setting import send_reset_password_email +from app.db import Session from app.email_utils import ( email_can_be_used_as_mailbox, personal_email_already_used, send_email, render, ) -from app.extensions import db, limiter +from app.extensions import limiter from app.log import LOG from app.models import User, ApiKey, SocialAuth, AccountActivation from app.utils import sanitize_email @@ -98,12 +99,12 @@ def auth_register(): LOG.d("create user %s", email) user = User.create(email=email, name="", password=password) - db.session.flush() + Session.flush() # create activation code code = "".join([str(random.randint(0, 9)) for _ in range(6)]) AccountActivation.create(user_id=user.id, code=code) - db.session.commit() + Session.commit() send_email( email, @@ -155,13 +156,13 @@ def auth_activate(): if account_activation.code != code: # decrement nb tries account_activation.tries -= 1 - db.session.commit() + Session.commit() # Trigger rate limiter g.deduct_limit = True if account_activation.tries == 0: AccountActivation.delete(account_activation.id) - db.session.commit() + Session.commit() return jsonify(error="Too many wrong tries"), 410 return jsonify(error="Wrong email or code"), 400 @@ -169,7 +170,7 @@ def auth_activate(): LOG.d("activate user %s", user) user.activated = True AccountActivation.delete(account_activation.id) - db.session.commit() + Session.commit() return jsonify(msg="Account is activated, user can login now"), 200 @@ -198,12 +199,12 @@ def auth_reactivate(): account_activation = AccountActivation.get_by(user_id=user.id) if account_activation: AccountActivation.delete(account_activation.id) - db.session.commit() + Session.commit() # create activation code code = "".join([str(random.randint(0, 9)) for _ in range(6)]) AccountActivation.create(user_id=user.id, code=code) - db.session.commit() + Session.commit() send_email( email, @@ -255,12 +256,12 @@ def auth_facebook(): LOG.d("create facebook user with %s", user_info) user = User.create(email=email, name=user_info["name"], activated=True) - db.session.commit() + Session.commit() email_utils.send_welcome_email(user) if not SocialAuth.get_by(user_id=user.id, social="facebook"): SocialAuth.create(user_id=user.id, social="facebook") - db.session.commit() + Session.commit() return jsonify(**auth_payload(user, device)), 200 @@ -308,12 +309,12 @@ def auth_google(): LOG.d("create Google user with %s", user_info) user = User.create(email=email, name="", activated=True) - db.session.commit() + Session.commit() email_utils.send_welcome_email(user) if not SocialAuth.get_by(user_id=user.id, social="google"): SocialAuth.create(user_id=user.id, social="google") - db.session.commit() + Session.commit() return jsonify(**auth_payload(user, device)), 200 @@ -331,7 +332,7 @@ def auth_payload(user, device) -> dict: if not api_key: LOG.d("create new api key for %s and %s", user, device) api_key = ApiKey.create(user.id, device) - db.session.commit() + Session.commit() ret["mfa_key"] = None ret["api_key"] = api_key.code diff --git a/app/api/views/auth_mfa.py b/app/api/views/auth_mfa.py index d0297970..ebefc992 100644 --- a/app/api/views/auth_mfa.py +++ b/app/api/views/auth_mfa.py @@ -5,7 +5,7 @@ from itsdangerous import Signer from app.api.base import api_bp from app.config import FLASK_SECRET -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import User, ApiKey @@ -61,7 +61,7 @@ def auth_mfa(): if not api_key: LOG.d("create new api key for %s and %s", user, device) api_key = ApiKey.create(user.id, device) - db.session.commit() + Session.commit() ret["api_key"] = api_key.code diff --git a/app/api/views/custom_domain.py b/app/api/views/custom_domain.py index 7eab2638..d7a73a86 100644 --- a/app/api/views/custom_domain.py +++ b/app/api/views/custom_domain.py @@ -2,7 +2,7 @@ from flask import g, request from flask import jsonify from app.api.base import api_bp, require_api_auth -from app.extensions import db +from app.db import Session from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox @@ -108,8 +108,8 @@ def update_custom_domain(custom_domain_id): mailboxes.append(mailbox) # first remove all existing domain-mailboxes links - DomainMailbox.query.filter_by(domain_id=custom_domain.id).delete() - db.session.flush() + DomainMailbox.filter_by(domain_id=custom_domain.id).delete() + Session.flush() for mailbox in mailboxes: DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) @@ -117,6 +117,6 @@ def update_custom_domain(custom_domain_id): changed = True if changed: - db.session.commit() + Session.commit() return jsonify(ok=True), 200 diff --git a/app/api/views/mailbox.py b/app/api/views/mailbox.py index 82d75b0d..87f9dc79 100644 --- a/app/api/views/mailbox.py +++ b/app/api/views/mailbox.py @@ -7,12 +7,12 @@ from flask import request from app.api.base import api_bp, require_api_auth from app.dashboard.views.mailbox import send_verification_email from app.dashboard.views.mailbox_detail import verify_mailbox_change +from app.db import Session from app.email_utils import ( mailbox_already_used, email_can_be_used_as_mailbox, is_valid_email, ) -from app.extensions import db from app.models import Mailbox from app.utils import sanitize_email @@ -58,7 +58,7 @@ def create_mailbox(): ) else: new_mailbox = Mailbox.create(email=mailbox_email, user_id=user.id) - db.session.commit() + Session.commit() send_verification_email(user, new_mailbox) @@ -89,7 +89,7 @@ def delete_mailbox(mailbox_id): return jsonify(error="You cannot delete the default mailbox"), 400 Mailbox.delete(mailbox_id) - db.session.commit() + Session.commit() return jsonify(deleted=True), 200 @@ -158,7 +158,7 @@ def update_mailbox(mailbox_id): changed = True if changed: - db.session.commit() + Session.commit() return jsonify(updated=True), 200 @@ -190,7 +190,7 @@ def get_mailboxes_v2(): user = g.user mailboxes = [] - for mailbox in Mailbox.query.filter_by(user_id=user.id): + for mailbox in Mailbox.filter_by(user_id=user.id): mailboxes.append(mailbox) return ( diff --git a/app/api/views/new_custom_alias.py b/app/api/views/new_custom_alias.py index 5148da72..24b1c328 100644 --- a/app/api/views/new_custom_alias.py +++ b/app/api/views/new_custom_alias.py @@ -10,7 +10,8 @@ from app.api.serializer import ( ) from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT from app.dashboard.views.custom_alias import verify_prefix_suffix, signer -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import ( Alias, @@ -108,11 +109,11 @@ def new_custom_alias_v2(): custom_domain_id=custom_domain_id, ) - db.session.commit() + Session.commit() if hostname: AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) - db.session.commit() + Session.commit() return ( jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))), @@ -217,7 +218,7 @@ def new_custom_alias_v3(): mailbox_id=mailboxes[0].id, custom_domain_id=custom_domain_id, ) - db.session.flush() + Session.flush() for i in range(1, len(mailboxes)): AliasMailbox.create( @@ -225,11 +226,11 @@ def new_custom_alias_v3(): mailbox_id=mailboxes[i].id, ) - db.session.commit() + Session.commit() if hostname: AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) - db.session.commit() + Session.commit() return ( jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))), diff --git a/app/api/views/new_random_alias.py b/app/api/views/new_random_alias.py index 19d2e3bd..aef5fd08 100644 --- a/app/api/views/new_random_alias.py +++ b/app/api/views/new_random_alias.py @@ -7,7 +7,8 @@ from app.api.serializer import ( serialize_alias_info_v2, ) from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import Alias, AliasUsedOn, AliasGeneratorEnum @@ -51,12 +52,12 @@ def new_random_alias(): return jsonify(error=f"{mode} must be either word or uuid"), 400 alias = Alias.create_new_random(user=user, scheme=scheme, note=note) - db.session.commit() + Session.commit() hostname = request.args.get("hostname") if hostname: AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) - db.session.commit() + Session.commit() return ( jsonify(alias=alias.email, **serialize_alias_info_v2(get_alias_info_v2(alias))), diff --git a/app/api/views/notification.py b/app/api/views/notification.py index b1cddae5..67013147 100644 --- a/app/api/views/notification.py +++ b/app/api/views/notification.py @@ -4,7 +4,7 @@ from flask import request from app.api.base import api_bp, require_api_auth from app.config import PAGE_LIMIT -from app.extensions import db +from app.db import Session from app.models import Notification @@ -32,7 +32,7 @@ def get_notifications(): return jsonify(error="page must be provided in request query"), 400 notifications = ( - Notification.query.filter_by(user_id=user.id) + Notification.filter_by(user_id=user.id) .order_by(Notification.read, Notification.created_at.desc()) .limit(PAGE_LIMIT + 1) # load a record more to know whether there's more .offset(page * PAGE_LIMIT) @@ -76,6 +76,6 @@ def mark_as_read(notification_id): return jsonify(error="Forbidden"), 403 notification.read = True - db.session.commit() + Session.commit() return jsonify(done=True), 200 diff --git a/app/api/views/setting.py b/app/api/views/setting.py index a209265e..33eb3a27 100644 --- a/app/api/views/setting.py +++ b/app/api/views/setting.py @@ -2,7 +2,7 @@ import arrow from flask import jsonify, g, request from app.api.base import api_bp, require_api_auth -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import ( User, @@ -93,7 +93,7 @@ def update_setting(): user.default_alias_custom_domain_id = custom_domain.id user.default_alias_public_domain_id = None - db.session.commit() + Session.commit() return jsonify(setting_to_dict(user)) diff --git a/app/api/views/user_info.py b/app/api/views/user_info.py index 1ac69585..1a2a5053 100644 --- a/app/api/views/user_info.py +++ b/app/api/views/user_info.py @@ -7,7 +7,7 @@ from flask_login import logout_user from app import s3 from app.api.base import api_bp, require_api_auth from app.config import SESSION_COOKIE_NAME -from app.extensions import db +from app.db import Session from app.models import ApiKey, File, User from app.utils import random_string @@ -56,24 +56,24 @@ def update_user_info(): if user.profile_picture_id: file = user.profile_picture user.profile_picture_id = None - db.session.flush() + Session.flush() if file: File.delete(file.id) s3.delete(file.path) - db.session.flush() + Session.flush() else: raw_data = base64.decodebytes(data["profile_picture"].encode()) file_path = random_string(30) file = File.create(user_id=user.id, path=file_path) - db.session.flush() + Session.flush() s3.upload_from_bytesio(file_path, BytesIO(raw_data)) user.profile_picture_id = file.id - db.session.flush() + Session.flush() if "name" in data: user.name = data["name"] - db.session.commit() + Session.commit() return jsonify(user_to_dict(user)) @@ -95,7 +95,7 @@ def create_api_key(): device = data.get("device") api_key = ApiKey.create(user_id=g.user.id, name=device) - db.session.commit() + Session.commit() return jsonify(api_key=api_key.code), 201 diff --git a/app/auth/views/activate.py b/app/auth/views/activate.py index e9943f99..96eb18e9 100644 --- a/app/auth/views/activate.py +++ b/app/auth/views/activate.py @@ -3,7 +3,8 @@ from flask_login import login_user, current_user from app import email_utils from app.auth.base import auth_bp -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import ActivationCode @@ -50,7 +51,7 @@ def activate(): # activation code is to be used only once ActivationCode.delete(activation_code.id) - db.session.commit() + Session.commit() flash("Your account has been activated", "success") diff --git a/app/auth/views/change_email.py b/app/auth/views/change_email.py index 803c26f9..544b6fc9 100644 --- a/app/auth/views/change_email.py +++ b/app/auth/views/change_email.py @@ -2,7 +2,7 @@ from flask import request, flash, render_template, redirect, url_for from flask_login import login_user from app.auth.base import auth_bp -from app.extensions import db +from app.db import Session from app.models import EmailChange @@ -18,14 +18,14 @@ def change_email(): if email_change.is_expired(): # delete the expired email EmailChange.delete(email_change.id) - db.session.commit() + Session.commit() return render_template("auth/change_email.html") user = email_change.user user.email = email_change.new_email EmailChange.delete(email_change.id) - db.session.commit() + Session.commit() flash("Your new email has been updated", "success") diff --git a/app/auth/views/facebook.py b/app/auth/views/facebook.py index 66d2d695..9ea5eb57 100644 --- a/app/auth/views/facebook.py +++ b/app/auth/views/facebook.py @@ -9,7 +9,7 @@ from app.config import ( FACEBOOK_CLIENT_ID, FACEBOOK_CLIENT_SECRET, ) -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import User, SocialAuth from .login_utils import after_login @@ -102,7 +102,7 @@ def facebook_callback(): LOG.d("set user profile picture to %s", picture_url) file = create_file_from_url(user, picture_url) user.profile_picture_id = file.id - db.session.commit() + Session.commit() else: flash( @@ -122,6 +122,6 @@ def facebook_callback(): if not SocialAuth.get_by(user_id=user.id, social="facebook"): SocialAuth.create(user_id=user.id, social="facebook") - db.session.commit() + Session.commit() return after_login(user, next_url) diff --git a/app/auth/views/fido.py b/app/auth/views/fido.py index aa1ed063..4258870f 100644 --- a/app/auth/views/fido.py +++ b/app/auth/views/fido.py @@ -19,7 +19,8 @@ from wtforms import HiddenField, validators, BooleanField from app.auth.base import auth_bp from app.config import MFA_USER_ID from app.config import RP_ID, URL -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import User, Fido, MfaBrowser @@ -102,7 +103,7 @@ def fido(): auto_activate = False else: user.fido_sign_count = new_sign_count - db.session.commit() + Session.commit() del session[MFA_USER_ID] login_user(user) @@ -113,7 +114,7 @@ def fido(): if fido_token_form.remember.data: browser = MfaBrowser.create_new(user=user) - db.session.commit() + Session.commit() response.set_cookie( "mfa", value=browser.token, diff --git a/app/auth/views/github.py b/app/auth/views/github.py index 99a93dde..abc2df2b 100644 --- a/app/auth/views/github.py +++ b/app/auth/views/github.py @@ -4,7 +4,7 @@ from requests_oauthlib import OAuth2Session from app.auth.base import auth_bp from app.auth.views.login_utils import after_login from app.config import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET, URL -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import User, SocialAuth from app.utils import encode_url, sanitize_email @@ -94,7 +94,7 @@ def github_callback(): if not SocialAuth.get_by(user_id=user.id, social="github"): SocialAuth.create(user_id=user.id, social="github") - db.session.commit() + Session.commit() # The activation link contains the original page, for ex authorize page next_url = request.args.get("next") if request.args else None diff --git a/app/auth/views/google.py b/app/auth/views/google.py index 08dddc3d..25f45d3c 100644 --- a/app/auth/views/google.py +++ b/app/auth/views/google.py @@ -4,7 +4,7 @@ from requests_oauthlib import OAuth2Session from app import s3 from app.auth.base import auth_bp from app.config import URL, GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import User, File, SocialAuth from app.utils import random_string, sanitize_email @@ -89,7 +89,7 @@ def google_callback(): LOG.d("set user profile picture to %s", picture_url) file = create_file_from_url(user, picture_url) user.profile_picture_id = file.id - db.session.commit() + Session.commit() else: flash( "Sorry you cannot sign up via Google, please use email/password sign-up instead", @@ -108,7 +108,7 @@ def google_callback(): if not SocialAuth.get_by(user_id=user.id, social="google"): SocialAuth.create(user_id=user.id, social="google") - db.session.commit() + Session.commit() return after_login(user, next_url) @@ -119,7 +119,7 @@ def create_file_from_url(user, url) -> File: s3.upload_from_url(url, file_path) - db.session.flush() + Session.flush() LOG.d("upload file %s to s3", file) return file diff --git a/app/auth/views/mfa.py b/app/auth/views/mfa.py index a9bacec1..b6430d44 100644 --- a/app/auth/views/mfa.py +++ b/app/auth/views/mfa.py @@ -15,7 +15,8 @@ from wtforms import BooleanField, StringField, validators from app.auth.base import auth_bp from app.config import MFA_USER_ID, URL -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.models import User, MfaBrowser @@ -67,7 +68,7 @@ def mfa(): if totp.verify(token) and user.last_otp != token: del session[MFA_USER_ID] user.last_otp = token - db.session.commit() + Session.commit() login_user(user) flash(f"Welcome back!", "success") @@ -77,7 +78,7 @@ def mfa(): if otp_token_form.remember.data: browser = MfaBrowser.create_new(user=user) - db.session.commit() + Session.commit() response.set_cookie( "mfa", value=browser.token, diff --git a/app/auth/views/recovery.py b/app/auth/views/recovery.py index fac50932..3d2a4249 100644 --- a/app/auth/views/recovery.py +++ b/app/auth/views/recovery.py @@ -6,7 +6,8 @@ from wtforms import StringField, validators from app.auth.base import auth_bp from app.config import MFA_USER_ID -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import User, RecoveryCode @@ -54,7 +55,7 @@ def recovery_route(): recovery_code.used = True recovery_code.used_at = arrow.now() - db.session.commit() + Session.commit() # User comes to login page from another page if next_url: diff --git a/app/auth/views/register.py b/app/auth/views/register.py index 30714cdf..07c2ff1d 100644 --- a/app/auth/views/register.py +++ b/app/auth/views/register.py @@ -8,11 +8,11 @@ from app import email_utils, config from app.auth.base import auth_bp from app.auth.views.login_utils import get_referral from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY +from app.db import Session from app.email_utils import ( email_can_be_used_as_mailbox, personal_email_already_used, ) -from app.extensions import db from app.log import LOG from app.models import User, ActivationCode from app.utils import random_string, encode_url, sanitize_email @@ -81,7 +81,7 @@ def register(): password=form.password.data, referral=get_referral(), ) - db.session.commit() + Session.commit() try: send_activation_email(user, next_url) @@ -102,7 +102,7 @@ def register(): def send_activation_email(user, next_url): # the activation code is valid for 1h activation = ActivationCode.create(user_id=user.id, code=random_string(30)) - db.session.commit() + Session.commit() # Send user activation email activation_link = f"{URL}/auth/activate?code={activation.code}" diff --git a/app/auth/views/reset_password.py b/app/auth/views/reset_password.py index acec963d..9aa201bb 100644 --- a/app/auth/views/reset_password.py +++ b/app/auth/views/reset_password.py @@ -6,7 +6,8 @@ from wtforms import StringField, validators from app.auth.base import auth_bp from app.auth.views.login_utils import after_login -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.models import ResetPasswordCode @@ -64,7 +65,7 @@ def reset_password(): # change the alternative_id to log user out on other browsers user.alternative_id = str(uuid.uuid4()) - db.session.commit() + Session.commit() # do not use login_user(user) here # to make sure user needs to go through MFA if enabled diff --git a/app/dashboard/views/alias_contact_manager.py b/app/dashboard/views/alias_contact_manager.py index c9e644b1..1f4aeafe 100644 --- a/app/dashboard/views/alias_contact_manager.py +++ b/app/dashboard/views/alias_contact_manager.py @@ -10,12 +10,12 @@ from wtforms import StringField, validators, ValidationError from app.config import PAGE_LIMIT from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import ( is_valid_email, generate_reply_email, parse_full_address, ) -from app.extensions import db from app.log import LOG from app.models import Alias, Contact, EmailLog @@ -64,7 +64,7 @@ def get_contact_infos( ) -> [ContactInfo]: """if contact_id is set, only return the contact info for this contact""" sub = ( - db.session.query( + Session.query( Contact.id, func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"), func.sum( @@ -94,7 +94,7 @@ def get_contact_infos( ) q = ( - db.session.query( + Session.query( Contact, EmailLog, sub.c.nb_reply, @@ -221,7 +221,7 @@ def alias_contact_manager(alias_id): ) LOG.d("create reverse-alias for %s", contact_addr) - db.session.commit() + Session.commit() flash(f"Reverse alias for {contact_addr} is created", "success") return redirect( @@ -248,7 +248,7 @@ def alias_contact_manager(alias_id): delete_contact_email = contact.website_email Contact.delete(contact_id) - db.session.commit() + Session.commit() flash( f"Reverse-alias for {delete_contact_email} has been deleted", "success" diff --git a/app/dashboard/views/alias_log.py b/app/dashboard/views/alias_log.py index bf06c1ea..06708446 100644 --- a/app/dashboard/views/alias_log.py +++ b/app/dashboard/views/alias_log.py @@ -4,7 +4,7 @@ from flask_login import login_required, current_user from app.config import PAGE_LIMIT from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.models import Alias, EmailLog, Contact @@ -43,7 +43,7 @@ def alias_log(alias_id, page_id): logs = get_alias_log(alias, page_id) base = ( - db.session.query(Contact, EmailLog) + Session.query(Contact, EmailLog) .filter(Contact.id == EmailLog.contact_id) .filter(Contact.alias_id == alias.id) ) @@ -66,7 +66,7 @@ def get_alias_log(alias: Alias, page_id=0) -> [AliasLog]: logs: [AliasLog] = [] q = ( - db.session.query(Contact, EmailLog) + Session.query(Contact, EmailLog) .filter(Contact.id == EmailLog.contact_id) .filter(Contact.alias_id == alias.id) .order_by(EmailLog.id.desc()) diff --git a/app/dashboard/views/alias_transfer.py b/app/dashboard/views/alias_transfer.py index 800243e3..aec5aae5 100644 --- a/app/dashboard/views/alias_transfer.py +++ b/app/dashboard/views/alias_transfer.py @@ -5,8 +5,9 @@ from flask_login import login_required, current_user from app.config import URL from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import send_email, render -from app.extensions import db, limiter +from app.extensions import limiter from app.log import LOG from app.models import ( Alias, @@ -25,20 +26,20 @@ def transfer(alias, new_user, new_mailboxes: [Mailbox]): raise Exception("Cannot transfer alias that's used to receive newsletter") # update user_id - db.session.query(Contact).filter(Contact.alias_id == alias.id).update( + Session.query(Contact).filter(Contact.alias_id == alias.id).update( {"user_id": new_user.id} ) - db.session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update( + Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update( {"user_id": new_user.id} ) - db.session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update( + Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update( {"user_id": new_user.id} ) # remove existing mailboxes from the alias - db.session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete() + Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete() # set mailboxes alias.mailbox_id = new_mailboxes.pop().id @@ -71,7 +72,7 @@ def transfer(alias, new_user, new_mailboxes: [Mailbox]): alias.disable_pgp = False alias.pinned = False - db.session.commit() + Session.commit() @dashboard_bp.route("/alias_transfer/send//", methods=["GET", "POST"]) @@ -100,7 +101,7 @@ def alias_transfer_send_route(alias_id): if request.method == "POST": if request.form.get("form-name") == "create": alias.transfer_token = str(uuid4()) - db.session.commit() + Session.commit() alias_transfer_url = ( URL + "/dashboard/alias_transfer/receive" @@ -111,7 +112,7 @@ def alias_transfer_send_route(alias_id): # request.form.get("form-name") == "remove" else: alias.transfer_token = None - db.session.commit() + Session.commit() alias_transfer_url = None flash("Share URL deleted", "success") return redirect(request.url) diff --git a/app/dashboard/views/api_key.py b/app/dashboard/views/api_key.py index e1628b9d..4850c2db 100644 --- a/app/dashboard/views/api_key.py +++ b/app/dashboard/views/api_key.py @@ -4,7 +4,7 @@ from flask_wtf import FlaskForm from wtforms import StringField, validators from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.models import ApiKey @@ -16,7 +16,7 @@ class NewApiKeyForm(FlaskForm): @login_required def api_key(): api_keys = ( - ApiKey.query.filter(ApiKey.user_id == current_user.id) + ApiKey.filter(ApiKey.user_id == current_user.id) .order_by(ApiKey.created_at.desc()) .all() ) @@ -38,7 +38,7 @@ def api_key(): name = api_key.name ApiKey.delete(api_key_id) - db.session.commit() + Session.commit() flash(f"API Key {name} has been deleted", "success") return redirect(url_for("dashboard.api_key")) @@ -48,7 +48,7 @@ def api_key(): new_api_key = ApiKey.create( name=new_api_key_form.name.data, user_id=current_user.id ) - db.session.commit() + Session.commit() flash(f"New API Key {new_api_key.name} has been created", "success") return redirect(url_for("dashboard.api_key")) diff --git a/app/dashboard/views/app.py b/app/dashboard/views/app.py index 0c881ee5..6d4f913a 100644 --- a/app/dashboard/views/app.py +++ b/app/dashboard/views/app.py @@ -1,3 +1,5 @@ +from app.db import Session + """ List of apps that user has used via the "Sign in with SimpleLogin" """ @@ -7,7 +9,6 @@ from flask_login import login_required, current_user from sqlalchemy.orm import joinedload from app.dashboard.base import dashboard_bp -from app.extensions import db from app.models import ( ClientUser, ) @@ -36,7 +37,7 @@ def app_route(): client = client_user.client ClientUser.delete(client_user_id) - db.session.commit() + Session.commit() flash(f"Link with {client.name} has been removed", "success") return redirect(request.url) diff --git a/app/dashboard/views/batch_import.py b/app/dashboard/views/batch_import.py index 343a05f4..3a3cfb13 100644 --- a/app/dashboard/views/batch_import.py +++ b/app/dashboard/views/batch_import.py @@ -5,7 +5,7 @@ from flask_login import login_required, current_user from app import s3 from app.config import JOB_BATCH_IMPORT from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import File, BatchImport, Job from app.utils import random_string @@ -18,7 +18,7 @@ def batch_import_route(): if not current_user.verified_custom_domains(): flash("Alias batch import is only available for custom domains", "warning") - batch_imports = BatchImport.query.filter_by(user_id=current_user.id).all() + batch_imports = BatchImport.filter_by(user_id=current_user.id).all() if request.method == "POST": alias_file = request.files["alias-file"] @@ -26,11 +26,11 @@ def batch_import_route(): file_path = random_string(20) + ".csv" file = File.create(user_id=current_user.id, path=file_path) s3.upload_from_bytesio(file_path, alias_file) - db.session.flush() + Session.flush() LOG.d("upload file %s to s3 at %s", file, file_path) bi = BatchImport.create(user_id=current_user.id, file_id=file.id) - db.session.flush() + Session.flush() LOG.d("Add a batch import job %s for %s", bi, current_user) # Schedule batch import job @@ -39,7 +39,7 @@ def batch_import_route(): payload={"batch_import_id": bi.id}, run_at=arrow.now(), ) - db.session.commit() + Session.commit() flash( "The file has been uploaded successfully and the import will start shortly", diff --git a/app/dashboard/views/billing.py b/app/dashboard/views/billing.py index d4cb9297..a0fa87ee 100644 --- a/app/dashboard/views/billing.py +++ b/app/dashboard/views/billing.py @@ -3,7 +3,7 @@ from flask_login import login_required, current_user from app.config import PADDLE_MONTHLY_PRODUCT_ID, PADDLE_YEARLY_PRODUCT_ID from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import Subscription, PlanEnum from app.paddle_utils import cancel_subscription, change_plan @@ -26,7 +26,7 @@ def billing(): if success: sub.cancelled = True - db.session.commit() + Session.commit() flash("Your subscription has been canceled successfully", "success") else: flash( @@ -44,7 +44,7 @@ def billing(): if success: sub.plan = PlanEnum.monthly - db.session.commit() + Session.commit() flash("Your subscription has been updated", "success") else: if msg: @@ -65,7 +65,7 @@ def billing(): if success: sub.plan = PlanEnum.yearly - db.session.commit() + Session.commit() flash("Your subscription has been updated", "success") else: if msg: diff --git a/app/dashboard/views/contact_detail.py b/app/dashboard/views/contact_detail.py index 919b5723..5fc40fb1 100644 --- a/app/dashboard/views/contact_detail.py +++ b/app/dashboard/views/contact_detail.py @@ -2,7 +2,7 @@ from flask import render_template, request, redirect, url_for, flash from flask_login import login_required, current_user from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.models import Contact from app.pgp_utils import PGPException, load_public_key_and_check @@ -34,7 +34,7 @@ def contact_detail_route(contact_id): except PGPException: flash("Cannot add the public key, please verify it", "error") else: - db.session.commit() + Session.commit() flash( f"PGP public key for {contact.email} is saved successfully", "success", @@ -46,7 +46,7 @@ def contact_detail_route(contact_id): # Free user can decide to remove contact PGP key contact.pgp_public_key = None contact.pgp_finger_print = None - db.session.commit() + Session.commit() flash(f"PGP public key for {contact.email} is removed", "success") return redirect( url_for("dashboard.contact_detail_route", contact_id=contact_id) diff --git a/app/dashboard/views/coupon.py b/app/dashboard/views/coupon.py index 7b3228fd..0dde5a4d 100644 --- a/app/dashboard/views/coupon.py +++ b/app/dashboard/views/coupon.py @@ -6,8 +6,8 @@ from wtforms import StringField, validators from app.config import ADMIN_EMAIL from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import send_email -from app.extensions import db from app.models import ( ManualSubscription, Coupon, @@ -57,7 +57,7 @@ def coupon_route(): if coupon and not coupon.used: coupon.used_by_user_id = current_user.id coupon.used = True - db.session.commit() + Session.commit() manual_sub: ManualSubscription = ManualSubscription.get_by( user_id=current_user.id @@ -68,7 +68,7 @@ def coupon_route(): manual_sub.end_at = manual_sub.end_at.shift(years=coupon.nb_year) else: manual_sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1) - db.session.commit() + Session.commit() flash( f"Your current subscription is extended to {manual_sub.end_at.humanize()}", "success", diff --git a/app/dashboard/views/custom_alias.py b/app/dashboard/views/custom_alias.py index a6eb5066..d869004f 100644 --- a/app/dashboard/views/custom_alias.py +++ b/app/dashboard/views/custom_alias.py @@ -13,7 +13,8 @@ from app.config import ( ALIAS_LIMIT, ) from app.dashboard.base import dashboard_bp -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import ( Alias, @@ -307,10 +308,10 @@ def custom_alias(): mailbox_id=mailboxes[0].id, custom_domain_id=custom_domain_id, ) - db.session.flush() + Session.flush() except IntegrityError: LOG.w("Alias %s already exists", full_alias) - db.session.rollback() + Session.rollback() flash("Unknown error, please retry", "error") return redirect(url_for("dashboard.custom_alias")) @@ -320,7 +321,7 @@ def custom_alias(): mailbox_id=mailboxes[i].id, ) - db.session.commit() + Session.commit() flash(f"Alias {full_alias} has been created", "success") return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) diff --git a/app/dashboard/views/custom_domain.py b/app/dashboard/views/custom_domain.py index a386b50e..cf8a8edb 100644 --- a/app/dashboard/views/custom_domain.py +++ b/app/dashboard/views/custom_domain.py @@ -5,8 +5,8 @@ from wtforms import StringField, validators from app.config import EMAIL_SERVERS_WITH_PRIORITY from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import get_email_domain_part -from app.extensions import db from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain @@ -19,7 +19,7 @@ class NewCustomDomainForm(FlaskForm): @dashboard_bp.route("/custom_domain", methods=["GET", "POST"]) @login_required def custom_domain(): - custom_domains = CustomDomain.query.filter_by(user_id=current_user.id).all() + custom_domains = CustomDomain.filter_by(user_id=current_user.id).all() mailboxes = current_user.mailboxes() new_custom_domain_form = NewCustomDomainForm() @@ -54,7 +54,7 @@ def custom_domain(): new_custom_domain = CustomDomain.create( domain=new_domain, user_id=current_user.id ) - db.session.commit() + Session.commit() mailbox_ids = request.form.getlist("mailbox_ids") if mailbox_ids: @@ -76,7 +76,7 @@ def custom_domain(): domain_id=new_custom_domain.id, mailbox_id=mailbox.id ) - db.session.commit() + Session.commit() flash( f"New domain {new_custom_domain.domain} is created", "success" diff --git a/app/dashboard/views/directory.py b/app/dashboard/views/directory.py index b4a8d94b..41d26b67 100644 --- a/app/dashboard/views/directory.py +++ b/app/dashboard/views/directory.py @@ -10,7 +10,7 @@ from app.config import ( BOUNCE_PREFIX_FOR_REPLY_PHASE, ) from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.models import Directory, Mailbox, DirectoryMailbox @@ -24,7 +24,7 @@ class NewDirForm(FlaskForm): @login_required def directory(): dirs = ( - Directory.query.filter_by(user_id=current_user.id) + Directory.filter_by(user_id=current_user.id) .order_by(Directory.created_at.desc()) .all() ) @@ -47,7 +47,7 @@ def directory(): name = dir.name Directory.delete(dir_id) - db.session.commit() + Session.commit() flash(f"Directory {name} has been deleted", "success") return redirect(url_for("dashboard.directory")) @@ -67,7 +67,7 @@ def directory(): dir.disabled = True flash(f"On-the-fly is disabled for {dir.name}", "warning") - db.session.commit() + Session.commit() return redirect(url_for("dashboard.directory")) @@ -98,13 +98,13 @@ def directory(): return redirect(url_for("dashboard.directory")) # first remove all existing directory-mailboxes links - DirectoryMailbox.query.filter_by(directory_id=dir.id).delete() - db.session.flush() + DirectoryMailbox.filter_by(directory_id=dir.id).delete() + Session.flush() for mailbox in mailboxes: DirectoryMailbox.create(directory_id=dir.id, mailbox_id=mailbox.id) - db.session.commit() + Session.commit() flash(f"Directory {dir.name} has been updated", "success") return redirect(url_for("dashboard.directory")) @@ -141,7 +141,7 @@ def directory(): new_dir = Directory.create( name=new_dir_name, user_id=current_user.id ) - db.session.commit() + Session.commit() mailbox_ids = request.form.getlist("mailbox_ids") if mailbox_ids: # check if mailbox is not tempered with @@ -162,7 +162,7 @@ def directory(): directory_id=new_dir.id, mailbox_id=mailbox.id ) - db.session.commit() + Session.commit() flash(f"Directory {new_dir.name} is created", "success") diff --git a/app/dashboard/views/domain_detail.py b/app/dashboard/views/domain_detail.py index 3d40e686..0bdc45da 100644 --- a/app/dashboard/views/domain_detail.py +++ b/app/dashboard/views/domain_detail.py @@ -1,6 +1,6 @@ -import re2 as re from threading import Thread +import re2 as re from flask import render_template, request, redirect, url_for, flash from flask_login import login_required, current_user from flask_wtf import FlaskForm @@ -8,6 +8,7 @@ from wtforms import StringField, validators, IntegerField from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.dashboard.base import dashboard_bp +from app.db import Session from app.dns_utils import ( get_mx_domains, get_spf_domain, @@ -15,7 +16,6 @@ from app.dns_utils import ( get_cname_record, ) from app.email_utils import send_email -from app.extensions import db from app.log import LOG from app.models import ( CustomDomain, @@ -40,7 +40,7 @@ def domain_detail_dns(custom_domain_id): # generate a domain ownership txt token if needed if not custom_domain.ownership_verified and not custom_domain.ownership_txt_token: custom_domain.ownership_txt_token = random_string(30) - db.session.commit() + Session.commit() spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" @@ -62,7 +62,7 @@ def domain_detail_dns(custom_domain_id): "success", ) custom_domain.ownership_verified = True - db.session.commit() + Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", @@ -92,7 +92,7 @@ def domain_detail_dns(custom_domain_id): "success", ) custom_domain.verified = True - db.session.commit() + Session.commit() return redirect( url_for( "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id @@ -102,7 +102,7 @@ def domain_detail_dns(custom_domain_id): spf_domains = get_spf_domain(custom_domain.domain) if EMAIL_DOMAIN in spf_domains: custom_domain.spf_verified = True - db.session.commit() + Session.commit() flash("SPF is setup correctly", "success") return redirect( url_for( @@ -111,7 +111,7 @@ def domain_detail_dns(custom_domain_id): ) else: custom_domain.spf_verified = False - db.session.commit() + Session.commit() flash( f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", "warning", @@ -124,7 +124,7 @@ def domain_detail_dns(custom_domain_id): if dkim_record == dkim_cname: flash("DKIM is setup correctly.", "success") custom_domain.dkim_verified = True - db.session.commit() + Session.commit() return redirect( url_for( @@ -133,7 +133,7 @@ def domain_detail_dns(custom_domain_id): ) else: custom_domain.dkim_verified = False - db.session.commit() + Session.commit() flash("DKIM: the CNAME record is not correctly set", "warning") dkim_ok = False dkim_errors = [dkim_record or "[Empty]"] @@ -142,7 +142,7 @@ def domain_detail_dns(custom_domain_id): txt_records = get_txt_record("_dmarc." + custom_domain.domain) if dmarc_record in txt_records: custom_domain.dmarc_verified = True - db.session.commit() + Session.commit() flash("DMARC is setup correctly", "success") return redirect( url_for( @@ -151,7 +151,7 @@ def domain_detail_dns(custom_domain_id): ) else: custom_domain.dmarc_verified = False - db.session.commit() + Session.commit() flash( "DMARC: The TXT record is not correctly set", "warning", @@ -179,7 +179,7 @@ def domain_detail(custom_domain_id): if request.method == "POST": if request.form.get("form-name") == "switch-catch-all": custom_domain.catch_all = not custom_domain.catch_all - db.session.commit() + Session.commit() if custom_domain.catch_all: flash( @@ -197,14 +197,14 @@ def domain_detail(custom_domain_id): elif request.form.get("form-name") == "set-name": if request.form.get("action") == "save": custom_domain.name = request.form.get("alias-name").replace("\n", "") - db.session.commit() + Session.commit() flash( f"Default alias name for Domain {custom_domain.domain} has been set", "success", ) else: custom_domain.name = None - db.session.commit() + Session.commit() flash( f"Default alias name for Domain {custom_domain.domain} has been removed", "info", @@ -217,7 +217,7 @@ def domain_detail(custom_domain_id): custom_domain.random_prefix_generation = ( not custom_domain.random_prefix_generation ) - db.session.commit() + Session.commit() if custom_domain.random_prefix_generation: flash( @@ -260,13 +260,13 @@ def domain_detail(custom_domain_id): ) # first remove all existing domain-mailboxes links - DomainMailbox.query.filter_by(domain_id=custom_domain.id).delete() - db.session.flush() + DomainMailbox.filter_by(domain_id=custom_domain.id).delete() + Session.flush() for mailbox in mailboxes: DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) - db.session.commit() + Session.commit() flash(f"{custom_domain.domain} mailboxes has been updated", "success") return redirect( @@ -302,7 +302,7 @@ def delete_domain(custom_domain_id: int): user = custom_domain.user CustomDomain.delete(custom_domain.id) - db.session.commit() + Session.commit() LOG.d("Domain %s deleted", domain_name) @@ -328,7 +328,7 @@ def domain_detail_trash(custom_domain_id): if request.method == "POST": if request.form.get("form-name") == "empty-all": DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete() - db.session.commit() + Session.commit() flash("All deleted aliases can now be re-created", "success") return redirect( @@ -349,7 +349,7 @@ def domain_detail_trash(custom_domain_id): ) DomainDeletedAlias.delete(deleted_alias.id) - db.session.commit() + Session.commit() flash( f"{deleted_alias.email} can now be re-created", "success", @@ -477,7 +477,7 @@ def domain_detail_auto_create(custom_domain_id): auto_create_rule_id=rule.id, mailbox_id=mailbox.id ) - db.session.commit() + Session.commit() flash("New auto create rule has been created", "success") @@ -502,7 +502,7 @@ def domain_detail_auto_create(custom_domain_id): rule_order = rule.order AutoCreateRule.delete(rule_id) - db.session.commit() + Session.commit() flash(f"Rule #{rule_order} has been deleted", "success") return redirect( url_for( diff --git a/app/dashboard/views/fido_manage.py b/app/dashboard/views/fido_manage.py index 6a089666..1a0aa44b 100644 --- a/app/dashboard/views/fido_manage.py +++ b/app/dashboard/views/fido_manage.py @@ -5,7 +5,7 @@ from wtforms import HiddenField, validators from app.dashboard.base import dashboard_bp from app.dashboard.views.enter_sudo import sudo_required -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import RecoveryCode, Fido @@ -34,7 +34,7 @@ def fido_manage(): return redirect(url_for("dashboard.fido_manage")) Fido.delete(fido_key.id) - db.session.commit() + Session.commit() LOG.d(f"FIDO Key ID={fido_key.id} Removed") flash(f"Key {fido_key.name} successfully unlinked", "success") @@ -42,7 +42,7 @@ def fido_manage(): # Disable FIDO for the user if all keys have been deleted if not Fido.filter_by(uuid=current_user.fido_uuid).all(): current_user.fido_uuid = None - db.session.commit() + Session.commit() # user does not have any 2FA enabled left, delete all recovery codes if not current_user.two_factor_authentication_enabled(): diff --git a/app/dashboard/views/fido_setup.py b/app/dashboard/views/fido_setup.py index 282336f8..76f17fb6 100644 --- a/app/dashboard/views/fido_setup.py +++ b/app/dashboard/views/fido_setup.py @@ -11,7 +11,7 @@ from wtforms import StringField, HiddenField, validators from app.config import RP_ID, URL from app.dashboard.base import dashboard_bp from app.dashboard.views.enter_sudo import sudo_required -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import Fido, RecoveryCode @@ -61,7 +61,7 @@ def fido_setup(): if current_user.fido_uuid is None: current_user.fido_uuid = fido_uuid - db.session.flush() + Session.flush() Fido.create( credential_id=str(fido_credential.credential_id, "utf-8"), @@ -70,14 +70,14 @@ def fido_setup(): sign_count=fido_credential.sign_count, name=fido_token_form.key_name.data, ) - db.session.commit() + Session.commit() LOG.d( f"credential_id={str(fido_credential.credential_id, 'utf-8')} added for {fido_uuid}" ) flash("Security key has been activated", "success") - if not RecoveryCode.query.filter_by(user_id=current_user.id).all(): + if not RecoveryCode.filter_by(user_id=current_user.id).all(): return redirect(url_for("dashboard.recovery_code_route")) else: return redirect(url_for("dashboard.fido_manage")) diff --git a/app/dashboard/views/index.py b/app/dashboard/views/index.py index d119ef2d..60ca050e 100644 --- a/app/dashboard/views/index.py +++ b/app/dashboard/views/index.py @@ -7,7 +7,8 @@ from app import alias_utils from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3 from app.config import PAGE_LIMIT, ALIAS_LIMIT from app.dashboard.base import dashboard_bp -from app.extensions import db, limiter +from app.db import Session +from app.extensions import limiter from app.log import LOG from app.models import ( Alias, @@ -26,19 +27,19 @@ class Stats: def get_stats(user: User) -> Stats: - nb_alias = Alias.query.filter_by(user_id=user.id).count() + nb_alias = Alias.filter_by(user_id=user.id).count() nb_forward = ( - db.session.query(EmailLog) + Session.query(EmailLog) .filter_by(user_id=user.id, is_reply=False, blocked=False, bounced=False) .count() ) nb_reply = ( - db.session.query(EmailLog) + Session.query(EmailLog) .filter_by(user_id=user.id, is_reply=True, blocked=False, bounced=False) .count() ) nb_block = ( - db.session.query(EmailLog) + Session.query(EmailLog) .filter_by(user_id=user.id, is_reply=False, blocked=True, bounced=False) .count() ) @@ -92,7 +93,7 @@ def index(): alias.mailbox_id = current_user.default_mailbox_id - db.session.commit() + Session.commit() LOG.d("create new random alias %s for user %s", alias, current_user) flash(f"Alias {alias.email} has been created", "success") @@ -130,7 +131,7 @@ def index(): flash(f"Alias {email} has been deleted", "success") elif request.form.get("form-name") == "disable-alias": alias.enabled = False - db.session.commit() + Session.commit() flash(f"Alias {alias.email} has been disabled", "success") return redirect( @@ -146,7 +147,7 @@ def index(): # to make sure not showing intro to user again current_user.intro_shown = True - db.session.commit() + Session.commit() stats = get_stats(current_user) diff --git a/app/dashboard/views/lifetime_licence.py b/app/dashboard/views/lifetime_licence.py index 2a2a8130..eac3dc15 100644 --- a/app/dashboard/views/lifetime_licence.py +++ b/app/dashboard/views/lifetime_licence.py @@ -5,8 +5,8 @@ from wtforms import StringField, validators from app.config import ADMIN_EMAIL from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import send_email -from app.extensions import db from app.models import LifetimeCoupon @@ -40,7 +40,7 @@ def lifetime_licence(): current_user.lifetime_coupon_id = coupon.id if coupon.paid: current_user.paid_lifetime = True - db.session.commit() + Session.commit() # notify admin send_email( diff --git a/app/dashboard/views/mailbox.py b/app/dashboard/views/mailbox.py index 2f6fd867..c69d9a55 100644 --- a/app/dashboard/views/mailbox.py +++ b/app/dashboard/views/mailbox.py @@ -9,6 +9,7 @@ from wtforms.fields.html5 import EmailField from app.config import MAILBOX_SECRET, URL from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import ( email_can_be_used_as_mailbox, mailbox_already_used, @@ -16,7 +17,6 @@ from app.email_utils import ( send_email, is_valid_email, ) -from app.extensions import db from app.log import LOG from app.models import Mailbox @@ -31,7 +31,7 @@ class NewMailboxForm(FlaskForm): @login_required def mailbox_route(): mailboxes = ( - Mailbox.query.filter_by(user_id=current_user.id) + Mailbox.filter_by(user_id=current_user.id) .order_by(Mailbox.created_at.desc()) .all() ) @@ -77,7 +77,7 @@ def mailbox_route(): return redirect(url_for("dashboard.mailbox_route")) current_user.default_mailbox_id = mailbox.id - db.session.commit() + Session.commit() flash(f"Mailbox {mailbox.email} is set as Default Mailbox", "success") return redirect(url_for("dashboard.mailbox_route")) @@ -102,7 +102,7 @@ def mailbox_route(): new_mailbox = Mailbox.create( email=mailbox_email, user_id=current_user.id ) - db.session.commit() + Session.commit() send_verification_email(current_user, new_mailbox) @@ -136,7 +136,7 @@ def delete_mailbox(mailbox_id: int): user = mailbox.user Mailbox.delete(mailbox_id) - db.session.commit() + Session.commit() LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email) send_email( @@ -191,7 +191,7 @@ def mailbox_verify(): return redirect(url_for("dashboard.mailbox_route")) mailbox.verified = True - db.session.commit() + Session.commit() LOG.d("Mailbox %s is verified", mailbox) diff --git a/app/dashboard/views/mailbox_detail.py b/app/dashboard/views/mailbox_detail.py index 03a8a265..b7aa658f 100644 --- a/app/dashboard/views/mailbox_detail.py +++ b/app/dashboard/views/mailbox_detail.py @@ -10,9 +10,9 @@ from wtforms.fields.html5 import EmailField from app.config import ENFORCE_SPF, MAILBOX_SECRET from app.config import URL from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import mailbox_already_used, render, send_email -from app.extensions import db from app.log import LOG from app.models import Alias, AuthorizedAddress from app.models import Mailbox @@ -57,7 +57,7 @@ def mailbox_detail_route(mailbox_id): flash("You cannot use this email address as your mailbox", "error") else: mailbox.new_email = new_email - db.session.commit() + Session.commit() try: verify_mailbox_change(current_user, mailbox, new_email) @@ -82,7 +82,7 @@ def mailbox_detail_route(mailbox_id): mailbox.force_spf = ( True if request.form.get("spf-status") == "on" else False ) - db.session.commit() + Session.commit() flash( "SPF enforcement was " + "enabled" if request.form.get("spf-status") @@ -118,7 +118,7 @@ def mailbox_detail_route(mailbox_id): else: address = authorized_address.email AuthorizedAddress.delete(authorized_address_id) - db.session.commit() + Session.commit() flash(f"{address} has been deleted", "success") return redirect( @@ -140,7 +140,7 @@ def mailbox_detail_route(mailbox_id): except PGPException: flash("Cannot add the public key, please verify it", "error") else: - db.session.commit() + Session.commit() flash("Your PGP public key is saved successfully", "success") return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) @@ -150,7 +150,7 @@ def mailbox_detail_route(mailbox_id): mailbox.pgp_public_key = None mailbox.pgp_finger_print = None mailbox.disable_pgp = False - db.session.commit() + Session.commit() flash("Your PGP public key is removed successfully", "success") return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) @@ -164,7 +164,7 @@ def mailbox_detail_route(mailbox_id): mailbox.disable_pgp = True flash(f"PGP is disabled on {mailbox.email}", "info") - db.session.commit() + Session.commit() return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) ) @@ -180,14 +180,14 @@ def mailbox_detail_route(mailbox_id): ) mailbox.generic_subject = request.form.get("generic-subject") - db.session.commit() + Session.commit() flash("Generic subject for PGP-encrypted email is enabled", "success") return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) ) elif request.form.get("action") == "remove": mailbox.generic_subject = None - db.session.commit() + Session.commit() flash("Generic subject for PGP-encrypted email is disabled", "success") return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) @@ -236,7 +236,7 @@ def cancel_mailbox_change_route(mailbox_id): if mailbox.new_email: mailbox.new_email = None - db.session.commit() + Session.commit() flash("Your mailbox change is cancelled", "success") return redirect( url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) @@ -274,7 +274,7 @@ def mailbox_confirm_change_route(): # mark mailbox as verified if the change request is sent from an unverified mailbox mailbox.verified = True - db.session.commit() + Session.commit() LOG.d("Mailbox change %s is verified", mailbox) flash(f"The {mailbox.email} is updated", "success") diff --git a/app/dashboard/views/mfa_cancel.py b/app/dashboard/views/mfa_cancel.py index bd9da45a..c6c4964e 100644 --- a/app/dashboard/views/mfa_cancel.py +++ b/app/dashboard/views/mfa_cancel.py @@ -3,7 +3,7 @@ from flask_login import login_required, current_user from app.dashboard.base import dashboard_bp from app.dashboard.views.enter_sudo import sudo_required -from app.extensions import db +from app.db import Session from app.models import RecoveryCode @@ -19,7 +19,7 @@ def mfa_cancel(): if request.method == "POST": current_user.enable_otp = False current_user.otp_secret = None - db.session.commit() + Session.commit() # user does not have any 2FA enabled left, delete all recovery codes if not current_user.two_factor_authentication_enabled(): diff --git a/app/dashboard/views/mfa_setup.py b/app/dashboard/views/mfa_setup.py index 0ec18e1c..b3b7a8fc 100644 --- a/app/dashboard/views/mfa_setup.py +++ b/app/dashboard/views/mfa_setup.py @@ -6,7 +6,7 @@ from wtforms import StringField, validators from app.dashboard.base import dashboard_bp from app.dashboard.views.enter_sudo import sudo_required -from app.extensions import db +from app.db import Session from app.log import LOG @@ -27,7 +27,7 @@ def mfa_setup(): if not current_user.otp_secret: LOG.d("Generate otp_secret for user %s", current_user) current_user.otp_secret = pyotp.random_base32() - db.session.commit() + Session.commit() totp = pyotp.TOTP(current_user.otp_secret) @@ -37,7 +37,7 @@ def mfa_setup(): if totp.verify(token) and current_user.last_otp != token: current_user.enable_otp = True current_user.last_otp = token - db.session.commit() + Session.commit() flash("MFA has been activated", "success") return redirect(url_for("dashboard.recovery_code_route")) diff --git a/app/dashboard/views/recovery_code.py b/app/dashboard/views/recovery_code.py index fef7509c..ace874b8 100644 --- a/app/dashboard/views/recovery_code.py +++ b/app/dashboard/views/recovery_code.py @@ -13,12 +13,12 @@ def recovery_code_route(): flash("you need to enable either TOTP or WebAuthn", "warning") return redirect(url_for("dashboard.index")) - recovery_codes = RecoveryCode.query.filter_by(user_id=current_user.id).all() + recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all() if request.method == "GET" and not recovery_codes: # user arrives at this page for the first time LOG.d("%s has no recovery keys, generate", current_user) RecoveryCode.generate(current_user) - recovery_codes = RecoveryCode.query.filter_by(user_id=current_user.id).all() + recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all() if request.method == "POST": RecoveryCode.generate(current_user) diff --git a/app/dashboard/views/referral.py b/app/dashboard/views/referral.py index a037cb02..b190fd13 100644 --- a/app/dashboard/views/referral.py +++ b/app/dashboard/views/referral.py @@ -1,10 +1,9 @@ import re2 as re - from flask import render_template, request, flash, redirect, url_for from flask_login import login_required, current_user from app.dashboard.base import dashboard_bp -from app.extensions import db +from app.db import Session from app.models import Referral, Payout _REFERRAL_PATTERN = r"[0-9a-z-_]{3,}" @@ -30,7 +29,7 @@ def referral_route(): name = request.form.get("name") referral = Referral.create(user_id=current_user.id, code=code, name=name) - db.session.commit() + Session.commit() flash("A new referral code has been created", "success") return redirect( url_for("dashboard.referral_route", highlight_id=referral.id) @@ -40,7 +39,7 @@ def referral_route(): referral = Referral.get(referral_id) if referral and referral.user_id == current_user.id: referral.name = request.form.get("name") - db.session.commit() + Session.commit() flash("Referral name updated", "success") return redirect( url_for("dashboard.referral_route", highlight_id=referral.id) @@ -50,7 +49,7 @@ def referral_route(): referral = Referral.get(referral_id) if referral and referral.user_id == current_user.id: Referral.delete(referral.id) - db.session.commit() + Session.commit() flash("Referral deleted", "success") return redirect(url_for("dashboard.referral_route")) @@ -59,7 +58,7 @@ def referral_route(): if highlight_id: highlight_id = int(highlight_id) - referrals = Referral.query.filter_by(user_id=current_user.id).all() + referrals = Referral.filter_by(user_id=current_user.id).all() # make sure the highlighted referral is the first referral highlight_index = None for ix, referral in enumerate(referrals): @@ -70,6 +69,6 @@ def referral_route(): if highlight_index: referrals.insert(0, referrals.pop(highlight_index)) - payouts = Payout.query.filter_by(user_id=current_user.id).all() + payouts = Payout.filter_by(user_id=current_user.id).all() return render_template("dashboard/referral.html", **locals()) diff --git a/app/dashboard/views/refused_email.py b/app/dashboard/views/refused_email.py index f3a5cab6..32878ef1 100644 --- a/app/dashboard/views/refused_email.py +++ b/app/dashboard/views/refused_email.py @@ -19,7 +19,7 @@ def refused_email_route(): highlight_id = None email_logs: [EmailLog] = ( - EmailLog.query.filter( + EmailLog.filter( EmailLog.user_id == current_user.id, EmailLog.refused_email_id.isnot(None) ) .order_by(EmailLog.id.desc()) diff --git a/app/dashboard/views/setting.py b/app/dashboard/views/setting.py index 873556af..a59a95c1 100644 --- a/app/dashboard/views/setting.py +++ b/app/dashboard/views/setting.py @@ -22,11 +22,11 @@ from app.config import ( ALIAS_RANDOM_SUFFIX_LENGTH, ) from app.dashboard.base import dashboard_bp +from app.db import Session from app.email_utils import ( email_can_be_used_as_mailbox, personal_email_already_used, ) -from app.extensions import db from app.log import LOG from app.models import ( PlanEnum, @@ -116,7 +116,7 @@ def setting(): "delete the expired email change %s", other_email_change ) EmailChange.delete(other_email_change.id) - db.session.commit() + Session.commit() else: flash( "You cannot use this email address as your personal inbox.", @@ -132,7 +132,7 @@ def setting(): ), # todo: make sure the code is unique new_email=new_email, ) - db.session.commit() + Session.commit() send_change_email_confirmation(current_user, email_change) flash( "A confirmation email is on the way, please check your inbox", @@ -145,7 +145,7 @@ def setting(): # update user info if form.name.data != current_user.name: current_user.name = form.name.data - db.session.commit() + Session.commit() profile_updated = True if form.profile_picture.data: @@ -156,11 +156,11 @@ def setting(): file_path, BytesIO(form.profile_picture.data.read()) ) - db.session.flush() + Session.flush() LOG.d("upload file %s to s3", file) current_user.profile_picture_id = file.id - db.session.commit() + Session.commit() profile_updated = True if profile_updated: @@ -181,7 +181,7 @@ def setting(): current_user.notification = True else: current_user.notification = False - db.session.commit() + Session.commit() flash("Your notification preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -212,7 +212,7 @@ def setting(): scheme = int(request.form.get("alias-generator-scheme")) if AliasGeneratorEnum.has_value(scheme): current_user.alias_generator = scheme - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -249,7 +249,7 @@ def setting(): current_user.default_alias_custom_domain_id = None current_user.default_alias_public_domain_id = None - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -257,7 +257,7 @@ def setting(): scheme = int(request.form.get("random-alias-suffix-generator")) if AliasSuffixEnum.has_value(scheme): current_user.random_alias_suffix = scheme - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -266,9 +266,9 @@ def setting(): if SenderFormatEnum.has_value(sender_format): current_user.sender_format = sender_format current_user.sender_format_updated_at = arrow.now() - db.session.commit() + Session.commit() flash("Your sender format preference has been updated", "success") - db.session.commit() + Session.commit() return redirect(url_for("dashboard.setting")) elif request.form.get("form-name") == "replace-ra": @@ -277,7 +277,7 @@ def setting(): current_user.replace_reverse_alias = True else: current_user.replace_reverse_alias = False - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -287,7 +287,7 @@ def setting(): current_user.include_sender_in_reverse_alias = True else: current_user.include_sender_in_reverse_alias = False - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -297,7 +297,7 @@ def setting(): current_user.expand_alias_info = True else: current_user.expand_alias_info = False - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) elif request.form.get("form-name") == "ignore-loop-email": @@ -306,7 +306,7 @@ def setting(): current_user.ignore_loop_email = True else: current_user.ignore_loop_email = False - db.session.commit() + Session.commit() flash("Your preference has been updated", "success") return redirect(url_for("dashboard.setting")) @@ -344,7 +344,7 @@ def send_reset_password_email(user): reset_password_code = ResetPasswordCode.create( user_id=user.id, code=random_string(60) ) - db.session.commit() + Session.commit() reset_password_link = f"{URL}/auth/reset_password?code={reset_password_code.code}" @@ -368,7 +368,7 @@ def resend_email_change(): if email_change: # extend email change expiration email_change.expired = arrow.now().shift(hours=12) - db.session.commit() + Session.commit() send_change_email_confirmation(current_user, email_change) flash("A confirmation email is on the way, please check your inbox", "success") @@ -386,7 +386,7 @@ def cancel_email_change(): email_change = EmailChange.get_by(user_id=current_user.id) if email_change: EmailChange.delete(email_change.id) - db.session.commit() + Session.commit() flash("Your email change is cancelled", "success") return redirect(url_for("dashboard.setting")) else: diff --git a/app/dashboard/views/unsubscribe.py b/app/dashboard/views/unsubscribe.py index ca5d859d..269fe53a 100644 --- a/app/dashboard/views/unsubscribe.py +++ b/app/dashboard/views/unsubscribe.py @@ -1,3 +1,5 @@ +from app.db import Session + """ Allow user to "unsubscribe", aka block an email alias """ @@ -6,7 +8,6 @@ from flask import redirect, url_for, flash, request, render_template from flask_login import login_required, current_user from app.dashboard.base import dashboard_bp -from app.extensions import db from app.models import Alias @@ -29,7 +30,7 @@ def unsubscribe(alias_id): if request.method == "POST": alias.enabled = False flash(f"Alias {alias.email} has been blocked", "success") - db.session.commit() + Session.commit() return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) else: # ask user confirmation diff --git a/app/db.py b/app/db.py new file mode 100644 index 00000000..1c7131c0 --- /dev/null +++ b/app/db.py @@ -0,0 +1,10 @@ +from sqlalchemy import create_engine +from sqlalchemy.orm import scoped_session +from sqlalchemy.orm import sessionmaker + +from app.config import DB_URI + +engine = create_engine(DB_URI) +connection = engine.connect() + +Session = scoped_session(sessionmaker(bind=connection)) diff --git a/app/developer/views/client_detail.py b/app/developer/views/client_detail.py index 4908fd8c..0506f81e 100644 --- a/app/developer/views/client_detail.py +++ b/app/developer/views/client_detail.py @@ -8,9 +8,9 @@ from wtforms import StringField, validators, TextAreaField from app import s3 from app.config import ADMIN_EMAIL +from app.db import Session from app.developer.base import developer_bp from app.email_utils import send_email -from app.extensions import db from app.log import LOG from app.models import Client, RedirectUri, File from app.utils import random_string @@ -55,13 +55,13 @@ def client_detail(client_id): s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read())) - db.session.flush() + Session.flush() LOG.d("upload file %s to s3", file) client.icon_id = file.id - db.session.flush() + Session.flush() - db.session.commit() + Session.commit() flash(f"{client.name} has been updated", "success") @@ -69,7 +69,7 @@ def client_detail(client_id): if action == "submit" and approval_form.validate_on_submit(): client.description = approval_form.description.data - db.session.commit() + Session.commit() send_email( ADMIN_EMAIL, @@ -127,7 +127,7 @@ def client_detail_oauth_setting(client_id): for uri in uris: RedirectUri.create(client_id=client_id, uri=uri) - db.session.commit() + Session.commit() flash(f"{client.name} has been updated", "success") @@ -178,7 +178,7 @@ def client_detail_advanced(client_id): # delete client client_name = client.name Client.delete(client.id) - db.session.commit() + Session.commit() LOG.d("Remove client %s", client) flash(f"{client_name} has been deleted", "success") diff --git a/app/developer/views/new_client.py b/app/developer/views/new_client.py index 53c06036..0bb913f0 100644 --- a/app/developer/views/new_client.py +++ b/app/developer/views/new_client.py @@ -3,8 +3,8 @@ from flask_login import current_user, login_required from flask_wtf import FlaskForm from wtforms import StringField, validators +from app.db import Session from app.developer.base import developer_bp -from app.extensions import db from app.models import Client @@ -19,7 +19,7 @@ def new_client(): if form.validate_on_submit(): client = Client.create_new(form.name.data, current_user.id) - db.session.commit() + Session.commit() flash("Your app has been created", "success") diff --git a/app/email/rate_limit.py b/app/email/rate_limit.py index 0fec2619..7e78022e 100644 --- a/app/email/rate_limit.py +++ b/app/email/rate_limit.py @@ -5,8 +5,8 @@ from app.config import ( MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS, MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX, ) +from app.db import Session from app.email_utils import is_reply_email -from app.extensions import db from app.log import LOG from app.models import Alias, EmailLog, Contact @@ -16,7 +16,7 @@ def rate_limited_for_alias(alias: Alias) -> bool: # get the nb of activity on this alias nb_activity = ( - db.session.query(EmailLog) + Session.query(EmailLog) .join(Contact, EmailLog.contact_id == Contact.id) .filter( Contact.alias_id == alias.id, @@ -42,7 +42,7 @@ def rate_limited_for_mailbox(alias: Alias) -> bool: # get nb of activity on this mailbox nb_activity = ( - db.session.query(EmailLog) + Session.query(EmailLog) .join(Contact, EmailLog.contact_id == Contact.id) .join(Alias, Contact.alias_id == Alias.id) .filter( diff --git a/app/email_utils.py b/app/email_utils.py index 42d87299..c80aabc0 100644 --- a/app/email_utils.py +++ b/app/email_utils.py @@ -53,9 +53,9 @@ from app.config import ( TEMP_DIR, ALIAS_AUTOMATIC_DISABLE, ) +from app.db import Session from app.dns_utils import get_mx_domains from app.email import headers -from app.extensions import db from app.log import LOG from app.models import ( Mailbox, @@ -324,7 +324,7 @@ def send_email_with_rate_control( to_email = sanitize_email(to_email) min_dt = arrow.now().shift(days=-1 * nb_day) nb_alert = ( - SentAlert.query.filter_by(alert_type=alert_type, to_email=to_email) + SentAlert.filter_by(alert_type=alert_type, to_email=to_email) .filter(SentAlert.created_at > min_dt) .count() ) @@ -340,7 +340,7 @@ def send_email_with_rate_control( return False SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email) - db.session.commit() + Session.commit() if ignore_smtp_error: try: @@ -369,9 +369,7 @@ def send_email_at_most_times( Return true if the email is sent, otherwise False """ to_email = sanitize_email(to_email) - nb_alert = SentAlert.query.filter_by( - alert_type=alert_type, to_email=to_email - ).count() + nb_alert = SentAlert.filter_by(alert_type=alert_type, to_email=to_email).count() if nb_alert >= max_times: LOG.w( @@ -383,7 +381,7 @@ def send_email_at_most_times( return False SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email) - db.session.commit() + Session.commit() send_email(to_email, subject, plaintext, html) return True @@ -1036,7 +1034,7 @@ def should_disable(alias: Alias) -> bool: yesterday = arrow.now().shift(days=-1) nb_bounced_last_24h = ( - db.session.query(EmailLog) + Session.query(EmailLog) .filter( EmailLog.bounced.is_(True), EmailLog.is_reply.is_(False), @@ -1054,7 +1052,7 @@ def should_disable(alias: Alias) -> bool: elif nb_bounced_last_24h > 5: one_week_ago = arrow.now().shift(days=-8) nb_bounced_7d_1d = ( - db.session.query(EmailLog) + Session.query(EmailLog) .filter( EmailLog.bounced.is_(True), EmailLog.is_reply.is_(False), @@ -1075,7 +1073,7 @@ def should_disable(alias: Alias) -> bool: # alias level # if bounces at least 9 days in the last 10 days -> disable alias query = ( - db.session.query( + Session.query( func.date(EmailLog.created_at).label("date"), func.count(EmailLog.id).label("count"), ) @@ -1097,7 +1095,7 @@ def should_disable(alias: Alias) -> bool: # account level query = ( - db.session.query( + Session.query( func.date(EmailLog.created_at).label("date"), func.count(EmailLog.id).label("count"), ) diff --git a/app/import_utils.py b/app/import_utils.py index c8526e67..f1ac8bc5 100644 --- a/app/import_utils.py +++ b/app/import_utils.py @@ -3,8 +3,8 @@ import csv import requests from app import s3 +from app.db import Session from app.email_utils import get_email_domain_part -from app.extensions import db from app.models import ( Alias, AliasMailbox, @@ -23,7 +23,7 @@ def handle_batch_import(batch_import: BatchImport): user = batch_import.user batch_import.processed = True - db.session.commit() + Session.commit() LOG.d("Start batch import for %s %s", batch_import, user) file_url = s3.get_url(batch_import.file.path) @@ -97,5 +97,5 @@ def import_from_csv(batch_import: BatchImport, user: User, lines): AliasMailbox.create( alias_id=alias.id, mailbox_id=mailboxes[i], commit=True ) - db.session.commit() + Session.commit() LOG.d("Add %s to mailbox %s", alias, mailboxes[i]) diff --git a/app/models.py b/app/models.py index 27dcd20c..f1e09f0b 100644 --- a/app/models.py +++ b/app/models.py @@ -10,8 +10,10 @@ from arrow import Arrow from flanker.addresslib import address from flask import url_for from flask_login import UserMixin +from sqlalchemy import orm from sqlalchemy import text, desc, CheckConstraint, Index, Column from sqlalchemy.dialects.postgresql import TSVECTOR +from sqlalchemy.ext.declarative import declarative_base from sqlalchemy.orm import deferred from sqlalchemy_utils import ArrowType @@ -29,8 +31,8 @@ from app.config import ( UNSUBSCRIBER, ALIAS_RANDOM_SUFFIX_LENGTH, ) +from app.db import Session from app.errors import AliasInTrashError -from app.extensions import db from app.log import LOG from app.oauth_models import Scope from app.pw_models import PasswordOracle @@ -42,70 +44,88 @@ from app.utils import ( random_word, ) +Base = declarative_base() + class TSVector(sa.types.TypeDecorator): impl = TSVECTOR class ModelMixin(object): - id = db.Column(db.Integer, primary_key=True, autoincrement=True) - created_at = db.Column(ArrowType, default=arrow.utcnow, nullable=False) - updated_at = db.Column(ArrowType, default=None, onupdate=arrow.utcnow) + id = sa.Column(sa.Integer, primary_key=True, autoincrement=True) + created_at = sa.Column(ArrowType, default=arrow.utcnow, nullable=False) + updated_at = sa.Column(ArrowType, default=None, onupdate=arrow.utcnow) _repr_hide = ["created_at", "updated_at"] @classmethod def query(cls): - return db.session.query(cls) + return Session.query(cls) @classmethod def get(cls, id): - return cls.query.get(id) + return Session.query(cls).get(id) @classmethod def get_by(cls, **kw): - return cls.query.filter_by(**kw).first() + return Session.query(cls).filter_by(**kw).first() @classmethod def filter_by(cls, **kw): - return cls.query.filter_by(**kw) + return Session.query(cls).filter_by(**kw) + + @classmethod + def filter(cls, *args, **kw): + return Session.query(cls).filter(*args, **kw) + + @classmethod + def order_by(cls, *args, **kw): + return Session.query(cls).order_by(*args, **kw) + + @classmethod + def all(cls): + return Session.query(cls).all() + + @classmethod + def count(cls): + return Session.query(cls).count() @classmethod def get_or_create(cls, **kw): r = cls.get_by(**kw) if not r: r = cls(**kw) - db.session.add(r) + Session.add(r) return r @classmethod def create(cls, **kw): - # whether should call db.session.commit + # whether should call Session.commit commit = kw.pop("commit", False) flush = kw.pop("flush", False) r = cls(**kw) - db.session.add(r) + Session.add(r) if commit: - db.session.commit() + Session.commit() if flush: - db.session.flush() + Session.flush() return r def save(self): - db.session.add(self) + Session.add(self) @classmethod def delete(cls, obj_id): - cls.query.filter(cls.id == obj_id).delete() + Session.query(cls).filter(cls.id == obj_id).delete() @classmethod def first(cls): - return cls.query.first() + return Session.query(cls).first() def __repr__(self): values = ", ".join( @@ -116,9 +136,10 @@ class ModelMixin(object): return "%s(%s)" % (self.__class__.__name__, values) -class File(db.Model, ModelMixin): - path = db.Column(db.String(128), unique=True, nullable=False) - user_id = db.Column(db.ForeignKey("users.id", ondelete="cascade"), nullable=True) +class File(Base, ModelMixin): + __tablename__ = "file" + path = sa.Column(sa.String(128), unique=True, nullable=False) + user_id = sa.Column(sa.ForeignKey("users.id", ondelete="cascade"), nullable=True) def get_url(self, expires_in=3600): return s3.get_url(self.path, expires_in) @@ -178,102 +199,102 @@ class AliasSuffixEnum(EnumE): random_string = 1 # Completely random string -class Hibp(db.Model, ModelMixin): +class Hibp(Base, ModelMixin): __tablename__ = "hibp" - name = db.Column(db.String(), nullable=False, unique=True, index=True) - breached_aliases = db.relationship("Alias", secondary="alias_hibp") + name = sa.Column(sa.String(), nullable=False, unique=True, index=True) + breached_aliases = orm.relationship("Alias", secondary="alias_hibp") - description = db.Column(db.Text) - date = db.Column(ArrowType, nullable=True) + description = sa.Column(sa.Text) + date = sa.Column(ArrowType, nullable=True) def __repr__(self): return f"" -class HibpNotifiedAlias(db.Model, ModelMixin): +class HibpNotifiedAlias(Base, ModelMixin): """Contain list of aliases that have been notified to users So that we can only notify users of new aliases. """ __tablename__ = "hibp_notified_alias" - alias_id = db.Column(db.ForeignKey("alias.id", ondelete="cascade"), nullable=False) - user_id = db.Column(db.ForeignKey("users.id", ondelete="cascade"), nullable=False) + alias_id = sa.Column(sa.ForeignKey("alias.id", ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey("users.id", ondelete="cascade"), nullable=False) - notified_at = db.Column(ArrowType, default=arrow.utcnow, nullable=False) + notified_at = sa.Column(ArrowType, default=arrow.utcnow, nullable=False) -class Fido(db.Model, ModelMixin): +class Fido(Base, ModelMixin): __tablename__ = "fido" - credential_id = db.Column(db.String(), nullable=False, unique=True, index=True) - uuid = db.Column( - db.ForeignKey("users.fido_uuid", ondelete="cascade"), + credential_id = sa.Column(sa.String(), nullable=False, unique=True, index=True) + uuid = sa.Column( + sa.ForeignKey("users.fido_uuid", ondelete="cascade"), unique=False, nullable=False, ) - public_key = db.Column(db.String(), nullable=False, unique=True) - sign_count = db.Column(db.Integer(), nullable=False) - name = db.Column(db.String(128), nullable=False, unique=False) + public_key = sa.Column(sa.String(), nullable=False, unique=True) + sign_count = sa.Column(sa.Integer(), nullable=False) + name = sa.Column(sa.String(128), nullable=False, unique=False) -class User(db.Model, ModelMixin, UserMixin, PasswordOracle): +class User(Base, ModelMixin, UserMixin, PasswordOracle): __tablename__ = "users" - email = db.Column(db.String(256), unique=True, nullable=False) + email = sa.Column(sa.String(256), unique=True, nullable=False) - name = db.Column(db.String(128), nullable=True) - is_admin = db.Column(db.Boolean, nullable=False, default=False) - alias_generator = db.Column( - db.Integer, + name = sa.Column(sa.String(128), nullable=True) + is_admin = sa.Column(sa.Boolean, nullable=False, default=False) + alias_generator = sa.Column( + sa.Integer, nullable=False, default=AliasGeneratorEnum.word.value, server_default=str(AliasGeneratorEnum.word.value), ) - notification = db.Column( - db.Boolean, default=True, nullable=False, server_default="1" + notification = sa.Column( + sa.Boolean, default=True, nullable=False, server_default="1" ) - activated = db.Column(db.Boolean, default=False, nullable=False) + activated = sa.Column(sa.Boolean, default=False, nullable=False) # an account can be disabled if having harmful behavior - disabled = db.Column(db.Boolean, default=False, nullable=False, server_default="0") + disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") - profile_picture_id = db.Column(db.ForeignKey(File.id), nullable=True) + profile_picture_id = sa.Column(sa.ForeignKey(File.id), nullable=True) - otp_secret = db.Column(db.String(16), nullable=True) - enable_otp = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + otp_secret = sa.Column(sa.String(16), nullable=True) + enable_otp = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) - last_otp = db.Column(db.String(12), nullable=True, default=False) + last_otp = sa.Column(sa.String(12), nullable=True, default=False) # Fields for WebAuthn - fido_uuid = db.Column(db.String(), nullable=True, unique=True) + fido_uuid = sa.Column(sa.String(), nullable=True, unique=True) # the default domain that's used when user creates a new random alias # default_alias_custom_domain_id XOR default_alias_public_domain_id - default_alias_custom_domain_id = db.Column( - db.ForeignKey("custom_domain.id", ondelete="SET NULL"), + default_alias_custom_domain_id = sa.Column( + sa.ForeignKey("custom_domain.id", ondelete="SET NULL"), nullable=True, default=None, ) - default_alias_public_domain_id = db.Column( - db.ForeignKey("public_domain.id", ondelete="SET NULL"), + default_alias_public_domain_id = sa.Column( + sa.ForeignKey("public_domain.id", ondelete="SET NULL"), nullable=True, default=None, ) # some users could have lifetime premium - lifetime = db.Column(db.Boolean, default=False, nullable=False, server_default="0") - paid_lifetime = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + lifetime = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") + paid_lifetime = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) - lifetime_coupon_id = db.Column( - db.ForeignKey("lifetime_coupon.id", ondelete="SET NULL"), + lifetime_coupon_id = sa.Column( + sa.ForeignKey("lifetime_coupon.id", ondelete="SET NULL"), nullable=True, default=None, ) # user can use all premium features until this date - trial_end = db.Column( + trial_end = sa.Column( ArrowType, default=lambda: arrow.now().shift(days=7, hours=1), nullable=True ) @@ -281,78 +302,78 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): # this field is nullable but in practice, it's always set # it cannot be set to non-nullable though # as this will create foreign key cycle between User and Mailbox - default_mailbox_id = db.Column( - db.ForeignKey("mailbox.id"), nullable=True, default=None + default_mailbox_id = sa.Column( + sa.ForeignKey("mailbox.id"), nullable=True, default=None ) - profile_picture = db.relationship(File, foreign_keys=[profile_picture_id]) + profile_picture = orm.relationship(File, foreign_keys=[profile_picture_id]) # Specify the format for sender address # John Wick - john at wick.com -> 0 # john@wick.com via SimpleLogin -> 1 # John Wick - john(a)wick.com -> 2 # John Wick - john@wick.com -> 3 - sender_format = db.Column( - db.Integer, default="0", nullable=False, server_default="0" + sender_format = sa.Column( + sa.Integer, default="0", nullable=False, server_default="0" ) # to know whether user has explicitly chosen a sender format as opposed to those who use the default ones. # users who haven't chosen a sender format and are using 1 or 3 format, their sender format will be set to 0 - sender_format_updated_at = db.Column(ArrowType, default=None) + sender_format_updated_at = sa.Column(ArrowType, default=None) - replace_reverse_alias = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + replace_reverse_alias = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) - referral_id = db.Column( - db.ForeignKey("referral.id", ondelete="SET NULL"), nullable=True, default=None + referral_id = sa.Column( + sa.ForeignKey("referral.id", ondelete="SET NULL"), nullable=True, default=None ) - referral = db.relationship("Referral", foreign_keys=[referral_id]) + referral = orm.relationship("Referral", foreign_keys=[referral_id]) # whether intro has been shown to user - intro_shown = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + intro_shown = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) - default_mailbox = db.relationship("Mailbox", foreign_keys=[default_mailbox_id]) + default_mailbox = orm.relationship("Mailbox", foreign_keys=[default_mailbox_id]) # user can set a more strict max_spam score to block spams more aggressively - max_spam_score = db.Column(db.Integer, nullable=True) + max_spam_score = sa.Column(sa.Integer, nullable=True) # newsletter is sent to this address - newsletter_alias_id = db.Column( - db.ForeignKey("alias.id", ondelete="SET NULL"), nullable=True, default=None + newsletter_alias_id = sa.Column( + sa.ForeignKey("alias.id", ondelete="SET NULL"), nullable=True, default=None ) # whether to include the sender address in reverse-alias - include_sender_in_reverse_alias = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + include_sender_in_reverse_alias = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) # whether to use random string or random word as suffix # Random word from dictionary file -> 0 # Completely random string -> 1 - random_alias_suffix = db.Column( - db.Integer, + random_alias_suffix = sa.Column( + sa.Integer, nullable=False, default=AliasSuffixEnum.random_string.value, server_default=str(AliasSuffixEnum.random_string.value), ) # always expand the alias info, i.e. without needing to press "More" - expand_alias_info = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + expand_alias_info = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) # ignore emails send from a mailbox to its alias. This can happen when replying all to a forwarded email # can automatically re-includes the alias - ignore_loop_email = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + ignore_loop_email = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) # used for flask-login as an "alternative token" # cf https://flask-login.readthedocs.io/en/latest/#alternative-tokens - alternative_id = db.Column(db.String(128), unique=True, nullable=True) + alternative_id = sa.Column(sa.String(128), unique=True, nullable=True) # implement flask-login "alternative token" def get_id(self): @@ -368,10 +389,10 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): if password: user.set_password(password) - db.session.flush() + Session.flush() mb = Mailbox.create(user_id=user.id, email=user.email, verified=True) - db.session.flush() + Session.flush() user.default_mailbox_id = mb.id # create a first alias mail to show user how to use when they login @@ -382,10 +403,10 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): note="This is your first alias. It's used to receive SimpleLogin communications " "like new features announcements, newsletters.", ) - db.session.flush() + Session.flush() user.newsletter_alias_id = alias.id - db.session.flush() + Session.flush() # generate an alternative_id if needed if "alternative_id" not in kwargs: @@ -411,7 +432,7 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): payload={"user_id": user.id}, run_at=arrow.now().shift(days=3), ) - db.session.flush() + Session.flush() return user @@ -666,19 +687,19 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): return sub def verified_custom_domains(self) -> List["CustomDomain"]: - return CustomDomain.query.filter_by(user_id=self.id, verified=True).all() + return CustomDomain.filter_by(user_id=self.id, verified=True).all() def mailboxes(self) -> List["Mailbox"]: """list of mailbox that user own""" mailboxes = [] - for mailbox in Mailbox.query.filter_by(user_id=self.id, verified=True): + for mailbox in Mailbox.filter_by(user_id=self.id, verified=True): mailboxes.append(mailbox) return mailboxes def nb_directory(self): - return Directory.query.filter_by(user_id=self.id).count() + return Directory.filter_by(user_id=self.id).count() def has_custom_domain(self): return CustomDomain.filter_by(user_id=self.id, verified=True).count() > 0 @@ -731,7 +752,7 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): ) self.default_alias_custom_domain_id = None self.default_alias_public_domain_id = None - db.session.commit() + Session.commit() return FIRST_ALIAS_DOMAIN return sl_domain.domain @@ -780,11 +801,9 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): def get_sl_domains(self) -> List["SLDomain"]: if self.is_premium(): - query = SLDomain.query + return SLDomain.all() else: - query = SLDomain.filter_by(premium_only=False) - - return query.all() + return SLDomain.filter_by(premium_only=False).all() def available_alias_domains(self) -> [str]: """return all domains that user can use when creating a new alias, including: @@ -805,9 +824,9 @@ class User(db.Model, ModelMixin, UserMixin, PasswordOracle): """whether to show the app page""" return ( # when user has used the "Sign in with SL" button before - ClientUser.query.filter(ClientUser.user_id == self.id).count() + ClientUser.filter(ClientUser.user_id == self.id).count() # or when user has created an app - + Client.query.filter(Client.user_id == self.id).count() + + Client.filter(Client.user_id == self.id).count() > 0 ) @@ -842,43 +861,49 @@ def _expiration_7d(): return arrow.now().shift(days=7) -class ActivationCode(db.Model, ModelMixin): +class ActivationCode(Base, ModelMixin): """For activate user account""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - code = db.Column(db.String(128), unique=True, nullable=False) + __tablename__ = "activation_code" - user = db.relationship(User) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + code = sa.Column(sa.String(128), unique=True, nullable=False) - expired = db.Column(ArrowType, nullable=False, default=_expiration_1h) + user = orm.relationship(User) + + expired = sa.Column(ArrowType, nullable=False, default=_expiration_1h) def is_expired(self): return self.expired < arrow.now() -class ResetPasswordCode(db.Model, ModelMixin): +class ResetPasswordCode(Base, ModelMixin): """For resetting password""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - code = db.Column(db.String(128), unique=True, nullable=False) + __tablename__ = "reset_password_code" - user = db.relationship(User) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + code = sa.Column(sa.String(128), unique=True, nullable=False) - expired = db.Column(ArrowType, nullable=False, default=_expiration_1h) + user = orm.relationship(User) + + expired = sa.Column(ArrowType, nullable=False, default=_expiration_1h) def is_expired(self): return self.expired < arrow.now() -class SocialAuth(db.Model, ModelMixin): +class SocialAuth(Base, ModelMixin): """Store how user authenticates with social login""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) + __tablename__ = "social_auth" + + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) # name of the social login used, could be facebook, google or github - social = db.Column(db.String(128), nullable=False) + social = sa.Column(sa.String(128), nullable=False) - __table_args__ = (db.UniqueConstraint("user_id", "social", name="uq_social_auth"),) + __table_args__ = (sa.UniqueConstraint("user_id", "social", name="uq_social_auth"),) # <<< OAUTH models >>> @@ -897,12 +922,14 @@ def generate_oauth_client_id(client_name) -> str: return generate_oauth_client_id(client_name) -class MfaBrowser(db.Model, ModelMixin): - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - token = db.Column(db.String(64), default=False, unique=True, nullable=False) - expires = db.Column(ArrowType, default=False, nullable=False) +class MfaBrowser(Base, ModelMixin): + __tablename__ = "mfa_browser" - user = db.relationship(User) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + token = sa.Column(sa.String(64), default=False, unique=True, nullable=False) + expires = sa.Column(ArrowType, default=False, nullable=False) + + user = orm.relationship(User) @classmethod def create_new(cls, user, token_length=64) -> "MfaBrowser": @@ -921,13 +948,13 @@ class MfaBrowser(db.Model, ModelMixin): @classmethod def delete(cls, token): - cls.query.filter(cls.token == token).delete() - db.session.commit() + cls.filter(cls.token == token).delete() + Session.commit() @classmethod def delete_expired(cls): - cls.query.filter(cls.expires < arrow.now()).delete() - db.session.commit() + cls.filter(cls.expires < arrow.now()).delete() + Session.commit() def is_expired(self): return self.expires < arrow.now() @@ -936,23 +963,24 @@ class MfaBrowser(db.Model, ModelMixin): self.expires = arrow.now().shift(days=30) -class Client(db.Model, ModelMixin): - oauth_client_id = db.Column(db.String(128), unique=True, nullable=False) - oauth_client_secret = db.Column(db.String(128), nullable=False) +class Client(Base, ModelMixin): + __tablename__ = "client" + oauth_client_id = sa.Column(sa.String(128), unique=True, nullable=False) + oauth_client_secret = sa.Column(sa.String(128), nullable=False) - name = db.Column(db.String(128), nullable=False) - home_url = db.Column(db.String(1024)) + name = sa.Column(sa.String(128), nullable=False) + home_url = sa.Column(sa.String(1024)) # user who created this client - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - icon_id = db.Column(db.ForeignKey(File.id), nullable=True) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + icon_id = sa.Column(sa.ForeignKey(File.id), nullable=True) # an app needs to be approved by SimpleLogin team - approved = db.Column(db.Boolean, nullable=False, default=False, server_default="0") - description = db.Column(db.Text, nullable=True) + approved = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") + description = sa.Column(sa.Text, nullable=True) - icon = db.relationship(File) - user = db.relationship(User) + icon = orm.relationship(File) + user = orm.relationship(User) def nb_user(self): return ClientUser.filter_by(client_id=self.id).count() @@ -983,7 +1011,7 @@ class Client(db.Model, ModelMixin): def last_user_login(self) -> "ClientUser": client_user = ( - ClientUser.query.filter(ClientUser.client_id == self.id) + ClientUser.filter(ClientUser.client_id == self.id) .order_by(ClientUser.updated_at) .first() ) @@ -992,52 +1020,58 @@ class Client(db.Model, ModelMixin): return None -class RedirectUri(db.Model, ModelMixin): +class RedirectUri(Base, ModelMixin): """Valid redirect uris for a client""" - client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False) - uri = db.Column(db.String(1024), nullable=False) + __tablename__ = "redirect_uri" - client = db.relationship(Client, backref="redirect_uris") + client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False) + uri = sa.Column(sa.String(1024), nullable=False) + + client = orm.relationship(Client, backref="redirect_uris") -class AuthorizationCode(db.Model, ModelMixin): - code = db.Column(db.String(128), unique=True, nullable=False) - client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) +class AuthorizationCode(Base, ModelMixin): + __tablename__ = "authorization_code" - scope = db.Column(db.String(128)) - redirect_uri = db.Column(db.String(1024)) + code = sa.Column(sa.String(128), unique=True, nullable=False) + client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + + scope = sa.Column(sa.String(128)) + redirect_uri = sa.Column(sa.String(1024)) # what is the input response_type, e.g. "code", "code,id_token", ... - response_type = db.Column(db.String(128)) + response_type = sa.Column(sa.String(128)) - nonce = db.Column(db.Text, nullable=True, default=None, server_default=text("NULL")) + nonce = sa.Column(sa.Text, nullable=True, default=None, server_default=text("NULL")) - user = db.relationship(User, lazy=False) - client = db.relationship(Client, lazy=False) + user = orm.relationship(User, lazy=False) + client = orm.relationship(Client, lazy=False) - expired = db.Column(ArrowType, nullable=False, default=_expiration_5m) + expired = sa.Column(ArrowType, nullable=False, default=_expiration_5m) def is_expired(self): return self.expired < arrow.now() -class OauthToken(db.Model, ModelMixin): - access_token = db.Column(db.String(128), unique=True) - client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) +class OauthToken(Base, ModelMixin): + __tablename__ = "oauth_token" - scope = db.Column(db.String(128)) - redirect_uri = db.Column(db.String(1024)) + access_token = sa.Column(sa.String(128), unique=True) + client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + + scope = sa.Column(sa.String(128)) + redirect_uri = sa.Column(sa.String(1024)) # what is the input response_type, e.g. "token", "token,id_token", ... - response_type = db.Column(db.String(128)) + response_type = sa.Column(sa.String(128)) - user = db.relationship(User) - client = db.relationship(Client) + user = orm.relationship(User) + client = orm.relationship(Client) - expired = db.Column(ArrowType, nullable=False, default=_expiration_1h) + expired = sa.Column(ArrowType, nullable=False, default=_expiration_1h) def is_expired(self): return self.expired < arrow.now() @@ -1073,88 +1107,89 @@ def generate_email( return generate_email(scheme=scheme, in_hex=in_hex) -class Alias(db.Model, ModelMixin): - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True +class Alias(Base, ModelMixin): + __tablename__ = "alias" + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True ) - email = db.Column(db.String(128), unique=True, nullable=False) + email = sa.Column(sa.String(128), unique=True, nullable=False) # the name to use when user replies/sends from alias - name = db.Column(db.String(128), nullable=True, default=None) + name = sa.Column(sa.String(128), nullable=True, default=None) - enabled = db.Column(db.Boolean(), default=True, nullable=False) + enabled = sa.Column(sa.Boolean(), default=True, nullable=False) - custom_domain_id = db.Column( - db.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=True + custom_domain_id = sa.Column( + sa.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=True ) - custom_domain = db.relationship("CustomDomain", foreign_keys=[custom_domain_id]) + custom_domain = orm.relationship("CustomDomain", foreign_keys=[custom_domain_id]) # To know whether an alias is created "on the fly", i.e. via the custom domain catch-all feature - automatic_creation = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + automatic_creation = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # to know whether an alias belongs to a directory - directory_id = db.Column( - db.ForeignKey("directory.id", ondelete="cascade"), nullable=True + directory_id = sa.Column( + sa.ForeignKey("directory.id", ondelete="cascade"), nullable=True ) - note = db.Column(db.Text, default=None, nullable=True) + note = sa.Column(sa.Text, default=None, nullable=True) # an alias can be owned by another mailbox - mailbox_id = db.Column( - db.ForeignKey("mailbox.id", ondelete="cascade"), nullable=False, index=True + mailbox_id = sa.Column( + sa.ForeignKey("mailbox.id", ondelete="cascade"), nullable=False, index=True ) # prefix _ to avoid this object being used accidentally. # To have the list of all mailboxes, should use AliasInfo instead - _mailboxes = db.relationship("Mailbox", secondary="alias_mailbox", lazy="joined") + _mailboxes = orm.relationship("Mailbox", secondary="alias_mailbox", lazy="joined") # If the mailbox has PGP-enabled, user can choose disable the PGP on the alias # this is useful when some senders already support PGP - disable_pgp = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + disable_pgp = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # a way to bypass the bounce automatic disable mechanism - cannot_be_disabled = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + cannot_be_disabled = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # when a mailbox wants to send an email on behalf of the alias via the reverse-alias # several checks are performed to avoid email spoofing # this option allow disabling these checks - disable_email_spoofing_check = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + disable_email_spoofing_check = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # to know whether an alias is added using a batch import - batch_import_id = db.Column( - db.ForeignKey("batch_import.id", ondelete="SET NULL"), + batch_import_id = sa.Column( + sa.ForeignKey("batch_import.id", ondelete="SET NULL"), nullable=True, default=None, ) # set in case of alias transfer. - original_owner_id = db.Column( - db.ForeignKey(User.id, ondelete="SET NULL"), nullable=True + original_owner_id = sa.Column( + sa.ForeignKey(User.id, ondelete="SET NULL"), nullable=True ) # alias is pinned on top - pinned = db.Column(db.Boolean, nullable=False, default=False, server_default="0") + pinned = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") # used to transfer an alias to another user - transfer_token = db.Column(db.String(64), default=None, unique=True, nullable=True) + transfer_token = sa.Column(sa.String(64), default=None, unique=True, nullable=True) # have I been pwned - hibp_last_check = db.Column(ArrowType, default=None) - hibp_breaches = db.relationship("Hibp", secondary="alias_hibp") + hibp_last_check = sa.Column(ArrowType, default=None) + hibp_breaches = orm.relationship("Hibp", secondary="alias_hibp") # to use Postgres full text search. Only applied on "note" column for now # this is a generated Postgres column - ts_vector = db.Column( - TSVector(), db.Computed("to_tsvector('english', note)", persisted=True) + ts_vector = sa.Column( + TSVector(), sa.Computed("to_tsvector('english', note)", persisted=True) ) __table_args__ = ( @@ -1168,8 +1203,8 @@ class Alias(db.Model, ModelMixin): ), ) - user = db.relationship(User, foreign_keys=[user_id]) - mailbox = db.relationship("Mailbox", lazy="joined") + user = orm.relationship(User, foreign_keys=[user_id]) + mailbox = orm.relationship("Mailbox", lazy="joined") @property def mailboxes(self): @@ -1196,7 +1231,7 @@ class Alias(db.Model, ModelMixin): @classmethod def create(cls, **kw): - # whether should call db.session.commit + # whether should call Session.commit commit = kw.pop("commit", False) r = cls(**kw) @@ -1212,9 +1247,9 @@ class Alias(db.Model, ModelMixin): if DomainDeletedAlias.get_by(email=email): raise AliasInTrashError - db.session.add(r) + Session.add(r) if commit: - db.session.commit() + Session.commit() return r @classmethod @@ -1304,31 +1339,32 @@ class Alias(db.Model, ModelMixin): return f"" -class ClientUser(db.Model, ModelMixin): +class ClientUser(Base, ModelMixin): + __tablename__ = "client_user" __table_args__ = ( - db.UniqueConstraint("user_id", "client_id", name="uq_client_user"), + sa.UniqueConstraint("user_id", "client_id", name="uq_client_user"), ) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - client_id = db.Column(db.ForeignKey(Client.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + client_id = sa.Column(sa.ForeignKey(Client.id, ondelete="cascade"), nullable=False) # Null means client has access to user original email - alias_id = db.Column(db.ForeignKey(Alias.id, ondelete="cascade"), nullable=True) + alias_id = sa.Column(sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=True) # user can decide to send to client another name - name = db.Column( - db.String(128), nullable=True, default=None, server_default=text("NULL") + name = sa.Column( + sa.String(128), nullable=True, default=None, server_default=text("NULL") ) # user can decide to send to client a default avatar - default_avatar = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + default_avatar = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) - alias = db.relationship(Alias, backref="client_users") + alias = orm.relationship(Alias, backref="client_users") - user = db.relationship(User) - client = db.relationship(Client) + user = orm.relationship(User) + client = orm.relationship(Client) def get_email(self): return self.alias.email if self.alias_id else self.user.email @@ -1390,57 +1426,59 @@ class ClientUser(db.Model, ModelMixin): return res -class Contact(db.Model, ModelMixin): +class Contact(Base, ModelMixin): """ Store configuration of sender (website-email) and alias. """ + __tablename__ = "contact" + __table_args__ = ( - db.UniqueConstraint("alias_id", "website_email", name="uq_contact"), + sa.UniqueConstraint("alias_id", "website_email", name="uq_contact"), ) - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True ) - alias_id = db.Column( - db.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True + alias_id = sa.Column( + sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True ) - name = db.Column( - db.String(512), nullable=True, default=None, server_default=text("NULL") + name = sa.Column( + sa.String(512), nullable=True, default=None, server_default=text("NULL") ) - website_email = db.Column(db.String(512), nullable=False) + website_email = sa.Column(sa.String(512), nullable=False) # the email from header, e.g. AB CD # nullable as this field is added after website_email - website_from = db.Column(db.String(1024), nullable=True) + website_from = sa.Column(sa.String(1024), nullable=True) # when user clicks on "reply", they will reply to this address. # This address allows to hide user personal email # this reply email is created every time a website sends an email to user # it has the prefix "reply+" or "ra+" to distinguish with other email - reply_email = db.Column(db.String(512), nullable=False, index=True) + reply_email = sa.Column(sa.String(512), nullable=False, index=True) # whether a contact is created via CC - is_cc = db.Column(db.Boolean, nullable=False, default=False, server_default="0") + is_cc = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") - pgp_public_key = db.Column(db.Text, nullable=True) - pgp_finger_print = db.Column(db.String(512), nullable=True) + pgp_public_key = sa.Column(sa.Text, nullable=True) + pgp_finger_print = sa.Column(sa.String(512), nullable=True) - alias = db.relationship(Alias, backref="contacts") - user = db.relationship(User) + alias = orm.relationship(Alias, backref="contacts") + user = orm.relationship(User) # the latest reply sent to this contact latest_reply: Optional[Arrow] = None # to investigate why the website_email is sometimes not correctly parsed # the envelope mail_from - mail_from = db.Column(db.Text, nullable=True, default=None) + mail_from = sa.Column(sa.Text, nullable=True, default=None) # a contact can have an empty email address, in this case it can't receive emails - invalid_email = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + invalid_email = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) @property @@ -1523,7 +1561,7 @@ class Contact(db.Model, ModelMixin): def last_reply(self) -> "EmailLog": """return the most recent reply""" return ( - EmailLog.query.filter_by(contact_id=self.id, is_reply=True) + EmailLog.filter_by(contact_id=self.id, is_reply=True) .order_by(desc(EmailLog.created_at)) .first() ) @@ -1532,62 +1570,64 @@ class Contact(db.Model, ModelMixin): return f"" -class EmailLog(db.Model, ModelMixin): - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True +class EmailLog(Base, ModelMixin): + __tablename__ = "email_log" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True ) - contact_id = db.Column( - db.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True + contact_id = sa.Column( + sa.ForeignKey(Contact.id, ondelete="cascade"), nullable=False, index=True ) - alias_id = db.Column( - db.ForeignKey(Alias.id, ondelete="cascade"), nullable=True, index=True + alias_id = sa.Column( + sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=True, index=True ) # whether this is a reply - is_reply = db.Column(db.Boolean, nullable=False, default=False) + is_reply = sa.Column(sa.Boolean, nullable=False, default=False) # for ex if alias is disabled, this forwarding is blocked - blocked = db.Column(db.Boolean, nullable=False, default=False) + blocked = sa.Column(sa.Boolean, nullable=False, default=False) # can happen when user mailbox refuses the forwarded email # usually because the forwarded email is too spammy - bounced = db.Column(db.Boolean, nullable=False, default=False, server_default="0") + bounced = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") # happen when an email with auto (holiday) reply - auto_replied = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + auto_replied = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # SpamAssassin result - is_spam = db.Column(db.Boolean, nullable=False, default=False, server_default="0") - spam_score = db.Column(db.Float, nullable=True) - spam_status = db.Column(db.Text, nullable=True, default=None) + is_spam = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") + spam_score = sa.Column(sa.Float, nullable=True) + spam_status = sa.Column(sa.Text, nullable=True, default=None) # do not load this column - spam_report = deferred(db.Column(db.JSON, nullable=True)) + spam_report = deferred(sa.Column(sa.JSON, nullable=True)) # Point to the email that has been refused - refused_email_id = db.Column( - db.ForeignKey("refused_email.id", ondelete="SET NULL"), nullable=True + refused_email_id = sa.Column( + sa.ForeignKey("refused_email.id", ondelete="SET NULL"), nullable=True ) # in forward phase, this is the mailbox that will receive the email # in reply phase, this is the mailbox (or a mailbox's authorized address) that sends the email - mailbox_id = db.Column( - db.ForeignKey("mailbox.id", ondelete="cascade"), nullable=True + mailbox_id = sa.Column( + sa.ForeignKey("mailbox.id", ondelete="cascade"), nullable=True ) # in case of bounce, record on what mailbox the email has been bounced # useful when an alias has several mailboxes - bounced_mailbox_id = db.Column( - db.ForeignKey("mailbox.id", ondelete="cascade"), nullable=True + bounced_mailbox_id = sa.Column( + sa.ForeignKey("mailbox.id", ondelete="cascade"), nullable=True ) - refused_email = db.relationship("RefusedEmail") - forward = db.relationship(Contact) + refused_email = orm.relationship("RefusedEmail") + forward = orm.relationship(Contact) - contact = db.relationship(Contact, backref="email_logs") - mailbox = db.relationship("Mailbox", lazy="joined", foreign_keys=[mailbox_id]) - user = db.relationship(User) + contact = orm.relationship(Contact, backref="email_logs") + mailbox = orm.relationship("Mailbox", lazy="joined", foreign_keys=[mailbox_id]) + user = orm.relationship(User) def bounced_mailbox(self) -> str: if self.bounced_mailbox_id: @@ -1616,25 +1656,27 @@ class EmailLog(db.Model, ModelMixin): return f"" -class Subscription(db.Model, ModelMixin): +class Subscription(Base, ModelMixin): """Paddle subscription""" + __tablename__ = "subscription" + # Come from Paddle - cancel_url = db.Column(db.String(1024), nullable=False) - update_url = db.Column(db.String(1024), nullable=False) - subscription_id = db.Column(db.String(1024), nullable=False, unique=True) - event_time = db.Column(ArrowType, nullable=False) - next_bill_date = db.Column(db.Date, nullable=False) + cancel_url = sa.Column(sa.String(1024), nullable=False) + update_url = sa.Column(sa.String(1024), nullable=False) + subscription_id = sa.Column(sa.String(1024), nullable=False, unique=True) + event_time = sa.Column(ArrowType, nullable=False) + next_bill_date = sa.Column(sa.Date, nullable=False) - cancelled = db.Column(db.Boolean, nullable=False, default=False) + cancelled = sa.Column(sa.Boolean, nullable=False, default=False) - plan = db.Column(db.Enum(PlanEnum), nullable=False) + plan = sa.Column(sa.Enum(PlanEnum), nullable=False) - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True ) - user = db.relationship(User) + user = orm.relationship(User) def plan_name(self): if self.plan == PlanEnum.monthly: @@ -1646,48 +1688,52 @@ class Subscription(db.Model, ModelMixin): return f"" -class ManualSubscription(db.Model, ModelMixin): +class ManualSubscription(Base, ModelMixin): """ For users who use other forms of payment and therefore not pass by Paddle """ - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True + __tablename__ = "manual_subscription" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True ) # an reminder is sent several days before the subscription ends - end_at = db.Column(ArrowType, nullable=False) + end_at = sa.Column(ArrowType, nullable=False) # for storing note about this subscription - comment = db.Column(db.Text, nullable=True) + comment = sa.Column(sa.Text, nullable=True) # manual subscription are also used for Premium giveaways - is_giveaway = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + is_giveaway = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) - user = db.relationship(User) + user = orm.relationship(User) def is_active(self): return self.end_at > arrow.now() -class CoinbaseSubscription(db.Model, ModelMixin): +class CoinbaseSubscription(Base, ModelMixin): """ For subscriptions using Coinbase Commerce """ - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True + __tablename__ = "coinbase_subscription" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True ) # an reminder is sent several days before the subscription ends - end_at = db.Column(ArrowType, nullable=False) + end_at = sa.Column(ArrowType, nullable=False) # the Coinbase code - code = db.Column(db.String(64), nullable=True) + code = sa.Column(sa.String(64), nullable=True) - user = db.relationship(User) + user = orm.relationship(User) def is_active(self): return self.end_at > arrow.now() @@ -1697,34 +1743,38 @@ class CoinbaseSubscription(db.Model, ModelMixin): _APPLE_GRACE_PERIOD_DAYS = 16 -class AppleSubscription(db.Model, ModelMixin): +class AppleSubscription(Base, ModelMixin): """ For users who have subscribed via Apple in-app payment """ - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True + __tablename__ = "apple_subscription" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True ) - expires_date = db.Column(ArrowType, nullable=False) + expires_date = sa.Column(ArrowType, nullable=False) # to avoid using "Restore Purchase" on another account - original_transaction_id = db.Column(db.String(256), nullable=False, unique=True) - receipt_data = db.Column(db.Text(), nullable=False) + original_transaction_id = sa.Column(sa.String(256), nullable=False, unique=True) + receipt_data = sa.Column(sa.Text(), nullable=False) - plan = db.Column(db.Enum(PlanEnum), nullable=False) + plan = sa.Column(sa.Enum(PlanEnum), nullable=False) - user = db.relationship(User) + user = orm.relationship(User) def is_valid(self): # Todo: take into account grace period? return self.expires_date > arrow.now().shift(days=-_APPLE_GRACE_PERIOD_DAYS) -class DeletedAlias(db.Model, ModelMixin): +class DeletedAlias(Base, ModelMixin): """Store all deleted alias to make sure they are NOT reused""" - email = db.Column(db.String(256), unique=True, nullable=False) + __tablename__ = "deleted_alias" + + email = sa.Column(sa.String(256), unique=True, nullable=False) @classmethod def create(cls, **kw): @@ -1734,20 +1784,22 @@ class DeletedAlias(db.Model, ModelMixin): return f"" -class EmailChange(db.Model, ModelMixin): +class EmailChange(Base, ModelMixin): """Used when user wants to update their email""" - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), + __tablename__ = "email_change" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True, index=True, ) - new_email = db.Column(db.String(256), unique=True, nullable=False) - code = db.Column(db.String(128), unique=True, nullable=False) - expired = db.Column(ArrowType, nullable=False, default=_expiration_12h) + new_email = sa.Column(sa.String(256), unique=True, nullable=False) + code = sa.Column(sa.String(128), unique=True, nullable=False) + expired = sa.Column(ArrowType, nullable=False, default=_expiration_12h) - user = db.relationship(User) + user = orm.relationship(User) def is_expired(self): return self.expired < arrow.now() @@ -1756,31 +1808,35 @@ class EmailChange(db.Model, ModelMixin): return f"" -class AliasUsedOn(db.Model, ModelMixin): +class AliasUsedOn(Base, ModelMixin): """Used to know where an alias is created""" + __tablename__ = "alias_used_on" + __table_args__ = ( - db.UniqueConstraint("alias_id", "hostname", name="uq_alias_used"), + sa.UniqueConstraint("alias_id", "hostname", name="uq_alias_used"), ) - alias_id = db.Column(db.ForeignKey(Alias.id, ondelete="cascade"), nullable=False) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) + alias_id = sa.Column(sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) - alias = db.relationship(Alias) + alias = orm.relationship(Alias) - hostname = db.Column(db.String(1024), nullable=False) + hostname = sa.Column(sa.String(1024), nullable=False) -class ApiKey(db.Model, ModelMixin): +class ApiKey(Base, ModelMixin): """used in browser extension to identify user""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - code = db.Column(db.String(128), unique=True, nullable=False) - name = db.Column(db.String(128), nullable=True) - last_used = db.Column(ArrowType, default=None) - times = db.Column(db.Integer, default=0, nullable=False) + __tablename__ = "api_key" - user = db.relationship(User) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + code = sa.Column(sa.String(128), unique=True, nullable=False) + name = sa.Column(sa.String(128), nullable=True) + last_used = sa.Column(ArrowType, default=None) + times = sa.Column(sa.Integer, default=0, nullable=False) + + user = orm.relationship(User) @classmethod def create(cls, user_id, name=None, **kwargs): @@ -1791,52 +1847,54 @@ class ApiKey(db.Model, ModelMixin): return super().create(user_id=user_id, name=name, code=code, **kwargs) -class CustomDomain(db.Model, ModelMixin): - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - domain = db.Column(db.String(128), unique=True, nullable=False) +class CustomDomain(Base, ModelMixin): + __tablename__ = "custom_domain" + + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + domain = sa.Column(sa.String(128), unique=True, nullable=False) # default name to use when user replies/sends from alias - name = db.Column(db.String(128), nullable=True, default=None) + name = sa.Column(sa.String(128), nullable=True, default=None) # mx verified - verified = db.Column(db.Boolean, nullable=False, default=False) - dkim_verified = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + verified = sa.Column(sa.Boolean, nullable=False, default=False) + dkim_verified = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) - spf_verified = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + spf_verified = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) - dmarc_verified = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + dmarc_verified = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) - _mailboxes = db.relationship("Mailbox", secondary="domain_mailbox", lazy="joined") + _mailboxes = orm.relationship("Mailbox", secondary="domain_mailbox", lazy="joined") # an alias is created automatically the first time it receives an email - catch_all = db.Column(db.Boolean, nullable=False, default=False, server_default="0") + catch_all = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") # option to generate random prefix version automatically - random_prefix_generation = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + random_prefix_generation = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # incremented when a check is failed on the domain # alert when the number exceeds a threshold # used in check_custom_domain() - nb_failed_checks = db.Column( - db.Integer, default=0, server_default="0", nullable=False + nb_failed_checks = sa.Column( + sa.Integer, default=0, server_default="0", nullable=False ) # only domain has the ownership verified can go the next DNS step # MX verified domains before this change don't have to do the TXT check # and therefore have ownership_verified=True - ownership_verified = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + ownership_verified = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) # randomly generated TXT value for verifying domain ownership # the TXT record should be sl-verification=txt_token - ownership_txt_token = db.Column(db.String(128), nullable=True) + ownership_txt_token = sa.Column(sa.String(128), nullable=True) __table_args__ = ( Index( @@ -1847,7 +1905,7 @@ class CustomDomain(db.Model, ModelMixin): ), # The condition ) - user = db.relationship(User, foreign_keys=[user_id]) + user = orm.relationship(User, foreign_keys=[user_id]) @property def mailboxes(self): @@ -1872,7 +1930,7 @@ class CustomDomain(db.Model, ModelMixin): # generate a domain ownership txt token if not domain.ownership_txt_token: domain.ownership_txt_token = random_string(30) - db.session.commit() + Session.commit() return domain @@ -1884,63 +1942,67 @@ class CustomDomain(db.Model, ModelMixin): return f"" -class AutoCreateRule(db.Model, ModelMixin): +class AutoCreateRule(Base, ModelMixin): """Alias auto creation rule for custom domain""" + __tablename__ = "auto_create_rule" + __table_args__ = ( - db.UniqueConstraint( + sa.UniqueConstraint( "custom_domain_id", "order", name="uq_auto_create_rule_order" ), ) - custom_domain_id = db.Column( - db.ForeignKey(CustomDomain.id, ondelete="cascade"), nullable=False + custom_domain_id = sa.Column( + sa.ForeignKey(CustomDomain.id, ondelete="cascade"), nullable=False ) # an alias is auto created if it matches the regex - regex = db.Column(db.String(512), nullable=False) + regex = sa.Column(sa.String(512), nullable=False) # the order in which rules are evaluated in case there are multiple rules - order = db.Column(db.Integer, default=0, nullable=False) + order = sa.Column(sa.Integer, default=0, nullable=False) - custom_domain = db.relationship(CustomDomain, backref="_auto_create_rules") + custom_domain = orm.relationship(CustomDomain, backref="_auto_create_rules") - mailboxes = db.relationship( + mailboxes = orm.relationship( "Mailbox", secondary="auto_create_rule__mailbox", lazy="joined" ) -class AutoCreateRuleMailbox(db.Model, ModelMixin): +class AutoCreateRuleMailbox(Base, ModelMixin): """store auto create rule - mailbox association""" __tablename__ = "auto_create_rule__mailbox" __table_args__ = ( - db.UniqueConstraint( + sa.UniqueConstraint( "auto_create_rule_id", "mailbox_id", name="uq_auto_create_rule_mailbox" ), ) - auto_create_rule_id = db.Column( - db.ForeignKey(AutoCreateRule.id, ondelete="cascade"), nullable=False + auto_create_rule_id = sa.Column( + sa.ForeignKey(AutoCreateRule.id, ondelete="cascade"), nullable=False ) - mailbox_id = db.Column( - db.ForeignKey("mailbox.id", ondelete="cascade"), nullable=False + mailbox_id = sa.Column( + sa.ForeignKey("mailbox.id", ondelete="cascade"), nullable=False ) -class DomainDeletedAlias(db.Model, ModelMixin): +class DomainDeletedAlias(Base, ModelMixin): """Store all deleted alias for a domain""" + __tablename__ = "domain_deleted_alias" + __table_args__ = ( - db.UniqueConstraint("domain_id", "email", name="uq_domain_trash"), + sa.UniqueConstraint("domain_id", "email", name="uq_domain_trash"), ) - email = db.Column(db.String(256), nullable=False) - domain_id = db.Column( - db.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=False + email = sa.Column(sa.String(256), nullable=False) + domain_id = sa.Column( + sa.ForeignKey("custom_domain.id", ondelete="cascade"), nullable=False ) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) - domain = db.relationship(CustomDomain) + domain = orm.relationship(CustomDomain) @classmethod def create(cls, **kw): @@ -1950,42 +2012,47 @@ class DomainDeletedAlias(db.Model, ModelMixin): return f"" -class LifetimeCoupon(db.Model, ModelMixin): - code = db.Column(db.String(128), nullable=False, unique=True) - nb_used = db.Column(db.Integer, nullable=False) - paid = db.Column(db.Boolean, default=False, server_default="0", nullable=False) - comment = db.Column(db.Text, nullable=True) +class LifetimeCoupon(Base, ModelMixin): + __tablename__ = "lifetime_coupon" + + code = sa.Column(sa.String(128), nullable=False, unique=True) + nb_used = sa.Column(sa.Integer, nullable=False) + paid = sa.Column(sa.Boolean, default=False, server_default="0", nullable=False) + comment = sa.Column(sa.Text, nullable=True) -class Coupon(db.Model, ModelMixin): - code = db.Column(db.String(128), nullable=False, unique=True) +class Coupon(Base, ModelMixin): + __tablename__ = "coupon" + + code = sa.Column(sa.String(128), nullable=False, unique=True) # by default a coupon is for 1 year - nb_year = db.Column(db.Integer, nullable=False, server_default="1", default=1) + nb_year = sa.Column(sa.Integer, nullable=False, server_default="1", default=1) # whether the coupon has been used - used = db.Column(db.Boolean, default=False, server_default="0", nullable=False) + used = sa.Column(sa.Boolean, default=False, server_default="0", nullable=False) # the user who uses the code # non-null when the coupon is used - used_by_user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=True + used_by_user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=True ) - is_giveaway = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + is_giveaway = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) -class Directory(db.Model, ModelMixin): - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - name = db.Column(db.String(128), unique=True, nullable=False) +class Directory(Base, ModelMixin): + __tablename__ = "directory" + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + name = sa.Column(sa.String(128), unique=True, nullable=False) # when a directory is disabled, new alias can't be created on the fly - disabled = db.Column(db.Boolean, default=False, nullable=False, server_default="0") + disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") - user = db.relationship(User, backref="directories") + user = orm.relationship(User, backref="directories") - _mailboxes = db.relationship( + _mailboxes = orm.relationship( "Mailbox", secondary="directory_mailbox", lazy="joined" ) @@ -2004,64 +2071,67 @@ class Directory(db.Model, ModelMixin): obj: Directory = cls.get(obj_id) user = obj.user # Put all aliases belonging to this directory to global or domain trash - for alias in Alias.query.filter_by(directory_id=obj_id): + for alias in Alias.filter_by(directory_id=obj_id): from app import alias_utils alias_utils.delete_alias(alias, user) - cls.query.filter(cls.id == obj_id).delete() - db.session.commit() + cls.filter(cls.id == obj_id).delete() + Session.commit() def __repr__(self): return f"" -class Job(db.Model, ModelMixin): +class Job(Base, ModelMixin): """Used to schedule one-time job in the future""" - name = db.Column(db.String(128), nullable=False) - payload = db.Column(db.JSON) + __tablename__ = "job" + + name = sa.Column(sa.String(128), nullable=False) + payload = sa.Column(sa.JSON) # whether the job has been taken by the job runner - taken = db.Column(db.Boolean, default=False, nullable=False) - run_at = db.Column(ArrowType) + taken = sa.Column(sa.Boolean, default=False, nullable=False) + run_at = sa.Column(ArrowType) def __repr__(self): return f"" -class Mailbox(db.Model, ModelMixin): - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True +class Mailbox(Base, ModelMixin): + __tablename__ = "mailbox" + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, index=True ) - email = db.Column(db.String(256), nullable=False, index=True) - verified = db.Column(db.Boolean, default=False, nullable=False) - force_spf = db.Column(db.Boolean, default=True, server_default="1", nullable=False) + email = sa.Column(sa.String(256), nullable=False, index=True) + verified = sa.Column(sa.Boolean, default=False, nullable=False) + force_spf = sa.Column(sa.Boolean, default=True, server_default="1", nullable=False) # used when user wants to update mailbox email - new_email = db.Column(db.String(256), unique=True) + new_email = sa.Column(sa.String(256), unique=True) - pgp_public_key = db.Column(db.Text, nullable=True) - pgp_finger_print = db.Column(db.String(512), nullable=True) - disable_pgp = db.Column( - db.Boolean, default=False, nullable=False, server_default="0" + pgp_public_key = sa.Column(sa.Text, nullable=True) + pgp_finger_print = sa.Column(sa.String(512), nullable=True) + disable_pgp = sa.Column( + sa.Boolean, default=False, nullable=False, server_default="0" ) # incremented when a check is failed on the mailbox # alert when the number exceeds a threshold # used in sanity_check() - nb_failed_checks = db.Column( - db.Integer, default=0, server_default="0", nullable=False + nb_failed_checks = sa.Column( + sa.Integer, default=0, server_default="0", nullable=False ) # a mailbox can be disabled if it can't be reached - disabled = db.Column(db.Boolean, default=False, nullable=False, server_default="0") + disabled = sa.Column(sa.Boolean, default=False, nullable=False, server_default="0") - generic_subject = db.Column(db.String(78), nullable=True) + generic_subject = sa.Column(sa.String(78), nullable=True) - __table_args__ = (db.UniqueConstraint("user_id", "email", name="uq_mailbox_user"),) + __table_args__ = (sa.UniqueConstraint("user_id", "email", name="uq_mailbox_user"),) - user = db.relationship(User, foreign_keys=[user_id]) + user = orm.relationship(User, foreign_keys=[user_id]) def pgp_enabled(self) -> bool: if self.pgp_finger_print and not self.disable_pgp: @@ -2081,7 +2151,7 @@ class Mailbox(db.Model, ModelMixin): user = mailbox.user # Put all aliases belonging to this mailbox to global or domain trash - for alias in Alias.query.filter_by(mailbox_id=obj_id): + for alias in Alias.filter_by(mailbox_id=obj_id): # special handling for alias that has several mailboxes and has mailbox_id=obj_id if len(alias.mailboxes) > 1: # use the first mailbox found in alias._mailboxes @@ -2093,10 +2163,10 @@ class Mailbox(db.Model, ModelMixin): # only put aliases that have mailbox as a single mailbox into trash alias_utils.delete_alias(alias, user) - db.session.commit() + Session.commit() - cls.query.filter(cls.id == obj_id).delete() - db.session.commit() + cls.filter(cls.id == obj_id).delete() + Session.commit() @property def aliases(self) -> [Alias]: @@ -2111,17 +2181,19 @@ class Mailbox(db.Model, ModelMixin): return f"" -class AccountActivation(db.Model, ModelMixin): +class AccountActivation(Base, ModelMixin): """contains code to activate the user account when they sign up on mobile""" - user_id = db.Column( - db.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True + __tablename__ = "account_activation" + + user_id = sa.Column( + sa.ForeignKey(User.id, ondelete="cascade"), nullable=False, unique=True ) # the activation code is usually 6 digits - code = db.Column(db.String(10), nullable=False) + code = sa.Column(sa.String(10), nullable=False) # nb tries decrements each time user enters wrong code - tries = db.Column(db.Integer, default=3, nullable=False) + tries = sa.Column(sa.Integer, default=3, nullable=False) __table_args__ = ( CheckConstraint(tries >= 0, name="account_activation_tries_positive"), @@ -2129,22 +2201,24 @@ class AccountActivation(db.Model, ModelMixin): ) -class RefusedEmail(db.Model, ModelMixin): +class RefusedEmail(Base, ModelMixin): """Store emails that have been refused, i.e. bounced or classified as spams""" + __tablename__ = "refused_email" + # Store the full report, including logs from Sending & Receiving MTA - full_report_path = db.Column(db.String(128), unique=True, nullable=False) + full_report_path = sa.Column(sa.String(128), unique=True, nullable=False) # The original email, to display to user - path = db.Column(db.String(128), unique=True, nullable=True) + path = sa.Column(sa.String(128), unique=True, nullable=True) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) # the email content will be deleted at this date - delete_at = db.Column(ArrowType, nullable=False, default=_expiration_7d) + delete_at = sa.Column(ArrowType, nullable=False, default=_expiration_7d) # toggle this when email content (stored at full_report_path & path are deleted) - deleted = db.Column(db.Boolean, nullable=False, default=False, server_default="0") + deleted = sa.Column(sa.Boolean, nullable=False, default=False, server_default="0") def get_url(self, expires_in=3600): if self.path: @@ -2156,15 +2230,17 @@ class RefusedEmail(db.Model, ModelMixin): return f"" -class Referral(db.Model, ModelMixin): +class Referral(Base, ModelMixin): """Referral code so user can invite others""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - name = db.Column(db.String(512), nullable=True, default=None) + __tablename__ = "referral" - code = db.Column(db.String(128), unique=True, nullable=False) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + name = sa.Column(sa.String(512), nullable=True, default=None) - user = db.relationship(User, foreign_keys=[user_id]) + code = sa.Column(sa.String(128), unique=True, nullable=False) + + user = orm.relationship(User, foreign_keys=[user_id]) @property def nb_user(self) -> int: @@ -2183,7 +2259,7 @@ class Referral(db.Model, ModelMixin): return f"{LANDING_PAGE_URL}?slref={self.code}" -class SentAlert(db.Model, ModelMixin): +class SentAlert(Base, ModelMixin): """keep track of alerts sent to user. User can receive an alert when there's abnormal activity on their aliases such as - reverse-alias not used by the owning mailbox @@ -2196,71 +2272,77 @@ class SentAlert(db.Model, ModelMixin): - max number of sent per 24H: an alert type should not be sent more than X times in 24h """ - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - to_email = db.Column(db.String(256), nullable=False) - alert_type = db.Column(db.String(256), nullable=False) + __tablename__ = "sent_alert" + + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + to_email = sa.Column(sa.String(256), nullable=False) + alert_type = sa.Column(sa.String(256), nullable=False) -class AliasMailbox(db.Model, ModelMixin): +class AliasMailbox(Base, ModelMixin): + __tablename__ = "alias_mailbox" __table_args__ = ( - db.UniqueConstraint("alias_id", "mailbox_id", name="uq_alias_mailbox"), + sa.UniqueConstraint("alias_id", "mailbox_id", name="uq_alias_mailbox"), ) - alias_id = db.Column( - db.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True + alias_id = sa.Column( + sa.ForeignKey(Alias.id, ondelete="cascade"), nullable=False, index=True ) - mailbox_id = db.Column( - db.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True + mailbox_id = sa.Column( + sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False, index=True ) - alias = db.relationship(Alias) + alias = orm.relationship(Alias) -class AliasHibp(db.Model, ModelMixin): +class AliasHibp(Base, ModelMixin): __tablename__ = "alias_hibp" - __table_args__ = (db.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),) + __table_args__ = (sa.UniqueConstraint("alias_id", "hibp_id", name="uq_alias_hibp"),) - alias_id = db.Column( - db.Integer(), db.ForeignKey("alias.id", ondelete="cascade"), index=True + alias_id = sa.Column( + sa.Integer(), sa.ForeignKey("alias.id", ondelete="cascade"), index=True ) - hibp_id = db.Column( - db.Integer(), db.ForeignKey("hibp.id", ondelete="cascade"), index=True + hibp_id = sa.Column( + sa.Integer(), sa.ForeignKey("hibp.id", ondelete="cascade"), index=True ) - alias = db.relationship( - "Alias", backref=db.backref("alias_hibp", cascade="all, delete-orphan") + alias = orm.relationship( + "Alias", backref=orm.backref("alias_hibp", cascade="all, delete-orphan") ) - hibp = db.relationship( - "Hibp", backref=db.backref("alias_hibp", cascade="all, delete-orphan") + hibp = orm.relationship( + "Hibp", backref=orm.backref("alias_hibp", cascade="all, delete-orphan") ) -class DirectoryMailbox(db.Model, ModelMixin): +class DirectoryMailbox(Base, ModelMixin): + __tablename__ = "directory_mailbox" __table_args__ = ( - db.UniqueConstraint("directory_id", "mailbox_id", name="uq_directory_mailbox"), + sa.UniqueConstraint("directory_id", "mailbox_id", name="uq_directory_mailbox"), ) - directory_id = db.Column( - db.ForeignKey(Directory.id, ondelete="cascade"), nullable=False + directory_id = sa.Column( + sa.ForeignKey(Directory.id, ondelete="cascade"), nullable=False ) - mailbox_id = db.Column( - db.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False + mailbox_id = sa.Column( + sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False ) -class DomainMailbox(db.Model, ModelMixin): +class DomainMailbox(Base, ModelMixin): """store the owning mailboxes for a domain""" + __tablename__ = "domain_mailbox" + __table_args__ = ( - db.UniqueConstraint("domain_id", "mailbox_id", name="uq_domain_mailbox"), + sa.UniqueConstraint("domain_id", "mailbox_id", name="uq_domain_mailbox"), ) - domain_id = db.Column( - db.ForeignKey(CustomDomain.id, ondelete="cascade"), nullable=False + domain_id = sa.Column( + sa.ForeignKey(CustomDomain.id, ondelete="cascade"), nullable=False ) - mailbox_id = db.Column( - db.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False + mailbox_id = sa.Column( + sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False ) @@ -2268,24 +2350,25 @@ _NB_RECOVERY_CODE = 8 _RECOVERY_CODE_LENGTH = 8 -class RecoveryCode(db.Model, ModelMixin): +class RecoveryCode(Base, ModelMixin): """allow user to login in case you lose any of your authenticators""" - __table_args__ = (db.UniqueConstraint("user_id", "code", name="uq_recovery_code"),) + __tablename__ = "recovery_code" + __table_args__ = (sa.UniqueConstraint("user_id", "code", name="uq_recovery_code"),) - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - code = db.Column(db.String(16), nullable=False) - used = db.Column(db.Boolean, nullable=False, default=False) - used_at = db.Column(ArrowType, nullable=True, default=None) + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + code = sa.Column(sa.String(16), nullable=False) + used = sa.Column(sa.Boolean, nullable=False, default=False) + used_at = sa.Column(ArrowType, nullable=True, default=None) - user = db.relationship(User) + user = orm.relationship(User) @classmethod def generate(cls, user): """generate recovery codes for user""" # delete all existing codes - cls.query.filter_by(user_id=user.id).delete() - db.session.flush() + cls.filter_by(user_id=user.id).delete() + Session.flush() nb_code = 0 while nb_code < _NB_RECOVERY_CODE: @@ -2295,174 +2378,188 @@ class RecoveryCode(db.Model, ModelMixin): nb_code += 1 LOG.d("Create recovery codes for %s", user) - db.session.commit() + Session.commit() @classmethod def empty(cls, user): """Delete all recovery codes for user""" - cls.query.filter_by(user_id=user.id).delete() - db.session.commit() + cls.filter_by(user_id=user.id).delete() + Session.commit() -class Notification(db.Model, ModelMixin): - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - message = db.Column(db.Text, nullable=False) +class Notification(Base, ModelMixin): + __tablename__ = "notification" + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + message = sa.Column(sa.Text, nullable=False) # whether user has marked the notification as read - read = db.Column(db.Boolean, nullable=False, default=False) + read = sa.Column(sa.Boolean, nullable=False, default=False) -class SLDomain(db.Model, ModelMixin): +class SLDomain(Base, ModelMixin): """SimpleLogin domains""" __tablename__ = "public_domain" - domain = db.Column(db.String(128), unique=True, nullable=False) + domain = sa.Column(sa.String(128), unique=True, nullable=False) # only available for premium accounts - premium_only = db.Column( - db.Boolean, nullable=False, default=False, server_default="0" + premium_only = sa.Column( + sa.Boolean, nullable=False, default=False, server_default="0" ) def __repr__(self): return f"" -class AuthorizedAddress(db.Model, ModelMixin): +class AuthorizedAddress(Base, ModelMixin): """Authorize other addresses to send emails from aliases that are owned by a mailbox""" - user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - mailbox_id = db.Column( - db.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False + __tablename__ = "authorized_address" + + user_id = sa.Column(sa.ForeignKey(User.id, ondelete="cascade"), nullable=False) + mailbox_id = sa.Column( + sa.ForeignKey(Mailbox.id, ondelete="cascade"), nullable=False ) - email = db.Column(db.String(256), nullable=False) + email = sa.Column(sa.String(256), nullable=False) __table_args__ = ( - db.UniqueConstraint("mailbox_id", "email", name="uq_authorize_address"), + sa.UniqueConstraint("mailbox_id", "email", name="uq_authorize_address"), ) - mailbox = db.relationship(Mailbox, backref="authorized_addresses") + mailbox = orm.relationship(Mailbox, backref="authorized_addresses") def __repr__(self): return f"" -class Metric2(db.Model, ModelMixin): +class Metric2(Base, ModelMixin): """ For storing different metrics like number of users, etc Store each metric as a column as opposed to having different rows as in Metric """ - date = db.Column(ArrowType, default=arrow.utcnow, nullable=False) + __tablename__ = "metric2" + date = sa.Column(ArrowType, default=arrow.utcnow, nullable=False) - nb_user = db.Column(db.Float, nullable=True) - nb_activated_user = db.Column(db.Float, nullable=True) + nb_user = sa.Column(sa.Float, nullable=True) + nb_activated_user = sa.Column(sa.Float, nullable=True) - nb_premium = db.Column(db.Float, nullable=True) - nb_apple_premium = db.Column(db.Float, nullable=True) - nb_cancelled_premium = db.Column(db.Float, nullable=True) - nb_manual_premium = db.Column(db.Float, nullable=True) - nb_coinbase_premium = db.Column(db.Float, nullable=True) + nb_premium = sa.Column(sa.Float, nullable=True) + nb_apple_premium = sa.Column(sa.Float, nullable=True) + nb_cancelled_premium = sa.Column(sa.Float, nullable=True) + nb_manual_premium = sa.Column(sa.Float, nullable=True) + nb_coinbase_premium = sa.Column(sa.Float, nullable=True) # nb users who have been referred - nb_referred_user = db.Column(db.Float, nullable=True) - nb_referred_user_paid = db.Column(db.Float, nullable=True) + nb_referred_user = sa.Column(sa.Float, nullable=True) + nb_referred_user_paid = sa.Column(sa.Float, nullable=True) - nb_alias = db.Column(db.Float, nullable=True) + nb_alias = sa.Column(sa.Float, nullable=True) # Obsolete as only for the last 14 days - nb_forward = db.Column(db.Float, nullable=True) - nb_block = db.Column(db.Float, nullable=True) - nb_reply = db.Column(db.Float, nullable=True) - nb_bounced = db.Column(db.Float, nullable=True) - nb_spam = db.Column(db.Float, nullable=True) + nb_forward = sa.Column(sa.Float, nullable=True) + nb_block = sa.Column(sa.Float, nullable=True) + nb_reply = sa.Column(sa.Float, nullable=True) + nb_bounced = sa.Column(sa.Float, nullable=True) + nb_spam = sa.Column(sa.Float, nullable=True) # should be used instead - nb_forward_last_24h = db.Column(db.Float, nullable=True) - nb_block_last_24h = db.Column(db.Float, nullable=True) - nb_reply_last_24h = db.Column(db.Float, nullable=True) - nb_bounced_last_24h = db.Column(db.Float, nullable=True) + nb_forward_last_24h = sa.Column(sa.Float, nullable=True) + nb_block_last_24h = sa.Column(sa.Float, nullable=True) + nb_reply_last_24h = sa.Column(sa.Float, nullable=True) + nb_bounced_last_24h = sa.Column(sa.Float, nullable=True) - nb_verified_custom_domain = db.Column(db.Float, nullable=True) + nb_verified_custom_domain = sa.Column(sa.Float, nullable=True) - nb_app = db.Column(db.Float, nullable=True) + nb_app = sa.Column(sa.Float, nullable=True) -class Bounce(db.Model, ModelMixin): +class Bounce(Base, ModelMixin): """Record all bounces. Deleted after 7 days""" - email = db.Column(db.String(256), nullable=False, index=True) + __tablename__ = "bounce" + email = sa.Column(sa.String(256), nullable=False, index=True) -class TransactionalEmail(db.Model, ModelMixin): +class TransactionalEmail(Base, ModelMixin): """Storing all email addresses that receive transactional emails, including account email and mailboxes. Deleted after 7 days """ - email = db.Column(db.String(256), nullable=False, unique=False) + __tablename__ = "transactional_email" + email = sa.Column(sa.String(256), nullable=False, unique=False) -class Payout(db.Model, ModelMixin): +class Payout(Base, ModelMixin): """Referral payouts""" - user_id = db.Column(db.ForeignKey("users.id", ondelete="cascade"), nullable=False) + __tablename__ = "payout" + user_id = sa.Column(sa.ForeignKey("users.id", ondelete="cascade"), nullable=False) # in USD - amount = db.Column(db.Float, nullable=False) + amount = sa.Column(sa.Float, nullable=False) # BTC, PayPal, etc - payment_method = db.Column(db.String(256), nullable=False) + payment_method = sa.Column(sa.String(256), nullable=False) # number of upgraded user included in this payout - number_upgraded_account = db.Column(db.Integer, nullable=False) + number_upgraded_account = sa.Column(sa.Integer, nullable=False) - comment = db.Column(db.Text) + comment = sa.Column(sa.Text) - user = db.relationship(User) + user = orm.relationship(User) -class IgnoredEmail(db.Model, ModelMixin): +class IgnoredEmail(Base, ModelMixin): """If an email has mail_from and rcpt_to present in this table, discard it by returning 250 status.""" - mail_from = db.Column(db.String(512), nullable=False) - rcpt_to = db.Column(db.String(512), nullable=False) + __tablename__ = "ignored_email" + + mail_from = sa.Column(sa.String(512), nullable=False) + rcpt_to = sa.Column(sa.String(512), nullable=False) -class IgnoreBounceSender(db.Model, ModelMixin): +class IgnoreBounceSender(Base, ModelMixin): """Ignore sender that doesn't correctly handle bounces, for example noreply@github.com""" - mail_from = db.Column(db.String(512), nullable=False, unique=True) + __tablename__ = "ignored_bounce_sender" + + mail_from = sa.Column(sa.String(512), nullable=False, unique=True) def __repr__(self): return f" refused_email.delete_at >= arrow.now(): LOG.d("Delete refused email %s", refused_email) if refused_email.path: @@ -109,14 +109,14 @@ def delete_refused_emails(): # do not set path and full_report_path to null # so we can check later that the files are indeed deleted refused_email.deleted = True - db.session.commit() + Session.commit() LOG.d("Finish delete_refused_emails") def notify_premium_end(): """sent to user who has canceled their subscription and who has their subscription ending soon""" - for sub in Subscription.query.filter_by(cancelled=True).all(): + for sub in Subscription.filter_by(cancelled=True).all(): if ( arrow.now().shift(days=3).date() > sub.next_bill_date @@ -146,7 +146,7 @@ def notify_premium_end(): def notify_manual_sub_end(): - for manual_sub in ManualSubscription.query.all(): + for manual_sub in ManualSubscription.all(): need_reminder = False if arrow.now().shift(days=14) > manual_sub.end_at > arrow.now().shift(days=13): need_reminder = True @@ -172,7 +172,7 @@ def notify_manual_sub_end(): ) extend_subscription_url = URL + "/dashboard/coinbase_checkout" - for coinbase_subscription in CoinbaseSubscription.query.all(): + for coinbase_subscription in CoinbaseSubscription.all(): need_reminder = False if ( arrow.now().shift(days=14) @@ -211,7 +211,7 @@ def notify_manual_sub_end(): def poll_apple_subscription(): """Poll Apple API to update AppleSubscription""" # todo: only near the end of the subscription - for apple_sub in AppleSubscription.query.all(): + for apple_sub in AppleSubscription.all(): user = apple_sub.user verify_receipt(apple_sub.receipt_data, user, APPLE_API_SECRET) verify_receipt(apple_sub.receipt_data, user, MACAPP_APPLE_API_SECRET) @@ -224,49 +224,49 @@ def compute_metric2() -> Metric2: _24h_ago = now.shift(days=-1) nb_referred_user_paid = 0 - for user in User.query.filter(User.referral_id.isnot(None)): + for user in User.filter(User.referral_id.isnot(None)): if user.is_paid(): nb_referred_user_paid += 1 return Metric2.create( date=now, # user stats - nb_user=User.query.count(), - nb_activated_user=User.query.filter_by(activated=True).count(), + nb_user=User.count(), + nb_activated_user=User.filter_by(activated=True).count(), # subscription stats - nb_premium=Subscription.query.filter(Subscription.cancelled.is_(False)).count(), - nb_cancelled_premium=Subscription.query.filter( + nb_premium=Subscription.filter(Subscription.cancelled.is_(False)).count(), + nb_cancelled_premium=Subscription.filter( Subscription.cancelled.is_(True) ).count(), # todo: filter by expires_date > now - nb_apple_premium=AppleSubscription.query.count(), - nb_manual_premium=ManualSubscription.query.filter( + nb_apple_premium=AppleSubscription.count(), + nb_manual_premium=ManualSubscription.filter( ManualSubscription.end_at > now, ManualSubscription.is_giveaway.is_(False), ).count(), - nb_coinbase_premium=CoinbaseSubscription.query.filter( + nb_coinbase_premium=CoinbaseSubscription.filter( CoinbaseSubscription.end_at > now ).count(), # referral stats - nb_referred_user=User.query.filter(User.referral_id.isnot(None)).count(), + nb_referred_user=User.filter(User.referral_id.isnot(None)).count(), nb_referred_user_paid=nb_referred_user_paid, - nb_alias=Alias.query.count(), + nb_alias=Alias.count(), # email log stats - nb_forward_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) + nb_forward_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago) .filter_by(bounced=False, is_spam=False, is_reply=False, blocked=False) .count(), - nb_bounced_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) + nb_bounced_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago) .filter_by(bounced=True) .count(), - nb_reply_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) + nb_reply_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago) .filter_by(is_reply=True) .count(), - nb_block_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) + nb_block_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago) .filter_by(blocked=True) .count(), # other stats - nb_verified_custom_domain=CustomDomain.query.filter_by(verified=True).count(), - nb_app=Client.query.count(), + nb_verified_custom_domain=CustomDomain.filter_by(verified=True).count(), + nb_app=Client.count(), commit=True, ) @@ -309,7 +309,7 @@ def bounce_report() -> List[Tuple[str, int]]: """ min_dt = arrow.now().shift(days=-1) query = ( - db.session.query(User.email, func.count(EmailLog.id).label("count")) + Session.query(User.email, func.count(EmailLog.id).label("count")) .join(EmailLog, EmailLog.user_id == User.id) .filter(EmailLog.bounced, EmailLog.created_at > min_dt) .group_by(User.email) @@ -354,7 +354,7 @@ def alias_creation_report() -> List[Tuple[str, int]]: """ min_dt = arrow.now().shift(days=-7) query = ( - db.session.query( + Session.query( User.email, func.count(Alias.id).label("count"), func.date(Alias.created_at).label("date"), @@ -381,7 +381,7 @@ def stats(): stats_today = compute_metric2() stats_yesterday = ( - Metric2.query.filter(Metric2.date < stats_today.date) + Metric2.filter(Metric2.date < stats_today.date) .order_by(Metric2.date.desc()) .first() ) @@ -442,13 +442,13 @@ nb_referred_user_upgrade: {stats_today.nb_referred_user_paid} - {increase_percen def migrate_domain_trash(): """Move aliases from global trash to domain trash if applicable""" - for deleted_alias in DeletedAlias.query.all(): + for deleted_alias in DeletedAlias.all(): alias_domain = get_email_domain_part(deleted_alias.email) if not SLDomain.get_by(domain=alias_domain): custom_domain = CustomDomain.get_by(domain=alias_domain) if custom_domain: LOG.e("move %s to domain %s trash", deleted_alias, custom_domain) - db.session.add( + Session.add( DomainDeletedAlias( user_id=custom_domain.user_id, email=deleted_alias.email, @@ -458,13 +458,13 @@ def migrate_domain_trash(): ) DeletedAlias.delete(deleted_alias.id) - db.session.commit() + Session.commit() def set_custom_domain_for_alias(): """Go through all aliases and make sure custom_domain is correctly set""" - sl_domains = [sl_domain.domain for sl_domain in SLDomain.query.all()] - for alias in Alias.query.filter(Alias.custom_domain_id.is_(None)): + sl_domains = [sl_domain.domain for sl_domain in SLDomain.all()] + for alias in Alias.filter(Alias.custom_domain_id.is_(None)): if ( not any(alias.email.endswith(f"@{sl_domain}") for sl_domain in sl_domains) and not alias.custom_domain_id @@ -477,7 +477,7 @@ def set_custom_domain_for_alias(): else: # phantom domain LOG.d("phantom domain %s %s %s", alias.user, alias, alias.enabled) - db.session.commit() + Session.commit() def sanity_check(): @@ -487,7 +487,7 @@ def sanity_check(): - detect if there's mailbox that's using a invalid domain """ mailbox_ids = ( - db.session.query(Mailbox.id) + Session.query(Mailbox.id) .filter(Mailbox.verified.is_(True), Mailbox.disabled.is_(False)) .all() ) @@ -544,23 +544,23 @@ def sanity_check(): else: # reset nb check mailbox.nb_failed_checks = 0 - db.session.commit() + Session.commit() for user in User.filter_by(activated=True).all(): if sanitize_email(user.email) != user.email: LOG.e("%s does not have sanitized email", user) - for alias in Alias.query.all(): + for alias in Alias.all(): if sanitize_email(alias.email) != alias.email: LOG.e("Alias %s email not sanitized", alias) if alias.name and "\n" in alias.name: alias.name = alias.name.replace("\n", "") - db.session.commit() + Session.commit() LOG.e("Alias %s name contains linebreak %s", alias, alias.name) contact_email_sanity_date = arrow.get("2021-01-12") - for contact in Contact.query.all(): + for contact in Contact.all(): if sanitize_email(contact.reply_email) != contact.reply_email: LOG.e("Contact %s reply-email not sanitized", contact) @@ -573,13 +573,13 @@ def sanity_check(): if not contact.invalid_email and not is_valid_email(contact.website_email): LOG.e("%s invalid email", contact) contact.invalid_email = True - db.session.commit() + Session.commit() - for mailbox in Mailbox.query.all(): + for mailbox in Mailbox.all(): if sanitize_email(mailbox.email) != mailbox.email: LOG.e("Mailbox %s address not sanitized", mailbox) - for contact in Contact.query.all(): + for contact in Contact.all(): if normalize_reply_email(contact.reply_email) != contact.reply_email: LOG.e( "Contact %s reply email is not normalized %s", @@ -587,7 +587,7 @@ def sanity_check(): contact.reply_email, ) - for domain in CustomDomain.query.all(): + for domain in CustomDomain.all(): if domain.name and "\n" in domain.name: LOG.e("Domain %s name contain linebreak %s", domain, domain.name) @@ -600,9 +600,7 @@ def sanity_check(): def check_custom_domain(): LOG.d("Check verified domain for DNS issues") - for custom_domain in CustomDomain.query.filter_by( - verified=True - ): # type: CustomDomain + for custom_domain in CustomDomain.filter_by(verified=True): # type: CustomDomain mx_domains = get_mx_domains(custom_domain.domain) if sorted(mx_domains) != sorted(EMAIL_SERVERS_WITH_PRIORITY): @@ -644,7 +642,7 @@ def check_custom_domain(): # reset checks custom_domain.nb_failed_checks = 0 - db.session.commit() + Session.commit() def delete_old_monitoring(): @@ -652,8 +650,8 @@ def delete_old_monitoring(): Delete old monitoring records """ max_time = arrow.now().shift(days=-30) - nb_row = Monitoring.query.filter(Monitoring.created_at < max_time).delete() - db.session.commit() + nb_row = Monitoring.filter(Monitoring.created_at < max_time).delete() + Session.commit() LOG.d("delete monitoring records older than %s, nb row %s", max_time, nb_row) @@ -713,8 +711,8 @@ async def _hibp_check(api_key, queue): return alias.hibp_last_check = arrow.utcnow() - db.session.add(alias) - db.session.commit() + Session.add(alias) + Session.commit() LOG.d("Updated breaches info for %s", alias) @@ -738,14 +736,14 @@ async def check_hibp(): hibp_entry.date = arrow.get(entry["BreachDate"]) hibp_entry.description = entry["Description"] - db.session.commit() + Session.commit() LOG.d("Updated list of known breaches") LOG.d("Preparing list of aliases to check") queue = asyncio.Queue() max_date = arrow.now().shift(days=-HIBP_SCAN_INTERVAL_DAYS) for alias in ( - Alias.query.filter( + Alias.filter( or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date) ) .filter(Alias.enabled) @@ -782,19 +780,19 @@ def notify_hibp(): """ # to get a list of users that have at least a breached alias alias_query = ( - db.session.query(Alias) + Session.query(Alias) .options(joinedload(Alias.hibp_breaches)) .filter(Alias.hibp_breaches.any()) - .filter(Alias.id.notin_(db.session.query(HibpNotifiedAlias.alias_id))) + .filter(Alias.id.notin_(Session.query(HibpNotifiedAlias.alias_id))) .distinct(Alias.user_id) .all() ) user_ids = [alias.user_id for alias in alias_query] - for user in User.query.filter(User.id.in_(user_ids)): + for user in User.filter(User.id.in_(user_ids)): breached_aliases = ( - db.session.query(Alias) + Session.query(Alias) .options(joinedload(Alias.hibp_breaches)) .filter(Alias.hibp_breaches.any(), Alias.user_id == user.id) .all() @@ -824,7 +822,7 @@ def notify_hibp(): # add the breached aliases to HibpNotifiedAlias to avoid sending another email for alias in breached_aliases: HibpNotifiedAlias.create(user_id=user.id, alias_id=alias.id) - db.session.commit() + Session.commit() if __name__ == "__main__": diff --git a/email_handler.py b/email_handler.py index 72d75234..5dca07fc 100644 --- a/email_handler.py +++ b/email_handler.py @@ -85,6 +85,7 @@ from app.config import ( ALERT_YAHOO_COMPLAINT, TEMP_DIR, ) +from app.db import Session from app.email import status, headers from app.email.rate_limit import rate_limited from app.email.spam import get_spam_score @@ -123,7 +124,6 @@ from app.email_utils import ( parse_full_address, get_orig_message_from_yahoo_complaint, ) -from app.extensions import db from app.log import LOG, set_message_id from app.models import ( Alias, @@ -141,7 +141,6 @@ from app.utils import sanitize_email from init_app import load_pgp_public_keys from server import create_app, create_light_app - newrelic_app = None if NEWRELIC_CONFIG_PATH: newrelic.agent.initialize(NEWRELIC_CONFIG_PATH) @@ -157,10 +156,7 @@ def new_app(): @app.teardown_appcontext def shutdown_session(response_or_exc): # same as shutdown_session() in flask-sqlalchemy but this is not enough - db.session.remove() - - # dispose the engine too - db.engine.dispose() + Session.remove() return app @@ -210,7 +206,7 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con contact_name, ) contact.name = contact_name - db.session.commit() + Session.commit() # contact created in the past does not have mail_from and from_header field if not contact.mail_from and mail_from: @@ -221,7 +217,7 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con mail_from, ) contact.mail_from = mail_from - db.session.commit() + Session.commit() else: LOG.d( "create contact %s for alias %s", @@ -244,10 +240,10 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con LOG.d("Create a contact with invalid email for %s", alias) contact.invalid_email = True - db.session.commit() + Session.commit() except IntegrityError: LOG.w("Contact %s %s already exist", alias, contact_email) - db.session.rollback() + Session.rollback() contact = Contact.get_by(alias_id=alias.id, website_email=contact_email) return contact @@ -291,10 +287,10 @@ def get_or_create_reply_to_contact( name=contact_name, reply_email=generate_reply_email(contact_address, alias.user), ) - db.session.commit() + Session.commit() except IntegrityError: LOG.w("Contact %s %s already exist", alias, contact_address) - db.session.rollback() + Session.rollback() contact = Contact.get_by(alias_id=alias.id, website_email=contact_address) return contact @@ -341,7 +337,7 @@ def replace_header_when_forward(msg: Message, alias: Alias, header: str): full_address.display_name, ) contact.name = full_address.display_name - db.session.commit() + Session.commit() else: LOG.d( "create contact for alias %s and email %s, header %s", @@ -359,10 +355,10 @@ def replace_header_when_forward(msg: Message, alias: Alias, header: str): reply_email=generate_reply_email(contact_email, alias.user), is_cc=header.lower() == "cc", ) - db.session.commit() + Session.commit() except IntegrityError: LOG.w("Contact %s %s already exist", alias, contact_email) - db.session.rollback() + Session.rollback() contact = Contact.get_by(alias_id=alias.id, website_email=contact_email) new_addrs.append(contact.new_addr()) @@ -501,7 +497,7 @@ def handle_email_sent_to_ourself(alias, mailbox, msg: Message, user): refused_email = RefusedEmail.create( path=None, full_report_path=full_report_path, user_id=alias.user_id ) - db.session.commit() + Session.commit() LOG.d("Create refused email %s", refused_email) # link available for 6 days as it gets deleted in 7 days refused_email_url = refused_email.get_url(expires_in=518400) @@ -588,7 +584,7 @@ def handle_forward(envelope, msg: Message, rcpt_to: str) -> List[Tuple[bool, str alias_id=contact.alias_id, commit=True, ) - db.session.commit() + Session.commit() # do not return 5** to allow user to receive emails later when alias is enabled return [(True, status.E200)] @@ -695,7 +691,7 @@ def forward_email_to_mailbox( spam_report, ) email_log.spam_score = spam_score - db.session.commit() + Session.commit() if (user.max_spam_score and spam_score > user.max_spam_score) or ( not user.max_spam_score and spam_score > MAX_SPAM_SCORE @@ -716,7 +712,7 @@ def forward_email_to_mailbox( ) email_log.is_spam = True email_log.spam_status = spam_status - db.session.commit() + Session.commit() handle_spam(contact, alias, msg, user, mailbox, email_log) return False, status.E519 @@ -846,7 +842,7 @@ def forward_email_to_mailbox( else: return False, status.E521 else: - db.session.commit() + Session.commit() return True, status.E200 @@ -963,7 +959,7 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str): email_log.is_spam = True email_log.spam_status = spam_status - db.session.commit() + Session.commit() handle_spam(contact, alias, msg, user, mailbox, email_log, is_reply=True) return False, status.E506 @@ -1003,11 +999,11 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str): ) # to not save the email_log EmailLog.delete(email_log.id) - db.session.commit() + Session.commit() # return 421 so the client can retry later return False, status.E402 - db.session.commit() + Session.commit() # make the email comes from alias from_header = alias.email @@ -1065,7 +1061,7 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str): ) except Exception: # to not save the email_log - db.session.rollback() + Session.rollback() LOG.w("Cannot send email from %s to %s", alias, contact) send_email( @@ -1218,13 +1214,13 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog): refused_email = RefusedEmail.create( path=file_path, full_report_path=full_report_path, user_id=user.id ) - db.session.flush() + Session.flush() LOG.d("Create refused email %s", refused_email) email_log.bounced = True email_log.refused_email_id = refused_email.id email_log.bounced_mailbox_id = mailbox.id - db.session.commit() + Session.commit() refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}" @@ -1268,7 +1264,7 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog): alias, ) alias.enabled = False - db.session.commit() + Session.commit() send_email_with_rate_control( user, @@ -1411,7 +1407,7 @@ def handle_bounce_reply_phase(envelope, msg: Message, email_log: EmailLog): email_log.bounced_mailbox_id = mailbox.id - db.session.commit() + Session.commit() refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}" @@ -1469,10 +1465,10 @@ def handle_spam( refused_email = RefusedEmail.create( path=file_path, full_report_path=full_report_path, user_id=user.id ) - db.session.flush() + Session.flush() email_log.refused_email_id = refused_email.id - db.session.commit() + Session.commit() LOG.d("Create spam email %s", refused_email) @@ -1574,7 +1570,7 @@ def handle_unsubscribe(envelope: Envelope, msg: Message) -> str: # Sender is owner of this alias alias.enabled = False - db.session.commit() + Session.commit() user = alias.user enable_alias_url = URL + f"/dashboard/?highlight_alias_id={alias.id}" @@ -1611,7 +1607,7 @@ def handle_unsubscribe_user(user_id: int, mail_from: str) -> str: return status.E511 user.notification = False - db.session.commit() + Session.commit() send_email( user.email, @@ -1676,7 +1672,7 @@ def handle_bounce(envelope, email_log: EmailLog, msg: Message) -> str: alias = contact.alias email_log.auto_replied = True - db.session.commit() + Session.commit() # replace the BOUNCE_EMAIL by alias in To field add_or_replace_header(msg, "To", alias.email) @@ -1871,8 +1867,8 @@ def handle(envelope: Envelope) -> str: "total number email log on %s, %s is %s, %s", alias, alias.user, - EmailLog.query.filter(EmailLog.alias_id == alias.id).count(), - EmailLog.query.filter(EmailLog.user_id == alias.user_id).count(), + EmailLog.filter(EmailLog.alias_id == alias.id).count(), + EmailLog.filter(EmailLog.user_id == alias.user_id).count(), ) if should_ignore_bounce(envelope.mail_from): diff --git a/init_app.py b/init_app.py index 4c13e15e..542f7229 100644 --- a/init_app.py +++ b/init_app.py @@ -1,15 +1,14 @@ -"""Initial loading script""" from app.config import ALIAS_DOMAINS, PREMIUM_ALIAS_DOMAINS -from app.models import Mailbox, Contact, SLDomain +from app.db import Session from app.log import LOG -from app.extensions import db +from app.models import Mailbox, Contact, SLDomain from app.pgp_utils import load_public_key from server import create_app def load_pgp_public_keys(): """Load PGP public key to keyring""" - for mailbox in Mailbox.query.filter(Mailbox.pgp_public_key.isnot(None)).all(): + for mailbox in Mailbox.filter(Mailbox.pgp_public_key.isnot(None)).all(): LOG.d("Load PGP key for mailbox %s", mailbox) fingerprint = load_public_key(mailbox.pgp_public_key) @@ -17,9 +16,9 @@ def load_pgp_public_keys(): if fingerprint != mailbox.pgp_finger_print: LOG.e("fingerprint %s different for mailbox %s", fingerprint, mailbox) mailbox.pgp_finger_print = fingerprint - db.session.commit() + Session.commit() - for contact in Contact.query.filter(Contact.pgp_public_key.isnot(None)).all(): + for contact in Contact.filter(Contact.pgp_public_key.isnot(None)).all(): LOG.d("Load PGP key for %s", contact) fingerprint = load_public_key(contact.pgp_public_key) @@ -28,7 +27,7 @@ def load_pgp_public_keys(): LOG.e("fingerprint %s different for contact %s", fingerprint, contact) contact.pgp_finger_print = fingerprint - db.session.commit() + Session.commit() LOG.d("Finish load_pgp_public_keys") @@ -48,7 +47,7 @@ def add_sl_domains(): LOG.i("Add %s to SL domain", premium_domain) SLDomain.create(domain=premium_domain, premium_only=True) - db.session.commit() + Session.commit() if __name__ == "__main__": diff --git a/job_runner.py b/job_runner.py index efab71b4..b2f5daec 100644 --- a/job_runner.py +++ b/job_runner.py @@ -13,11 +13,11 @@ from app.config import ( JOB_BATCH_IMPORT, JOB_DELETE_ACCOUNT, ) +from app.db import Session from app.email_utils import ( send_email, render, ) -from app.extensions import db from app.import_utils import handle_batch_import from app.log import LOG from app.models import User, Job, BatchImport @@ -32,10 +32,7 @@ def new_app(): @app.teardown_appcontext def shutdown_session(response_or_exc): # same as shutdown_session() in flask-sqlalchemy but this is not enough - db.session.remove() - - # dispose the engine too - db.engine.dispose() + Session.remove() return app @@ -109,14 +106,14 @@ if __name__ == "__main__": app = new_app() with app.app_context(): - for job in Job.query.filter( + for job in Job.filter( Job.taken.is_(False), Job.run_at > min_dt, Job.run_at <= max_dt ).all(): LOG.d("Take job %s", job) # mark the job as taken, whether it will be executed successfully or not job.taken = True - db.session.commit() + Session.commit() if job.name == JOB_ONBOARDING_1: user_id = job.payload.get("user_id") @@ -161,7 +158,7 @@ if __name__ == "__main__": user_email = user.email LOG.w("Delete user %s", user) User.delete(user.id) - db.session.commit() + Session.commit() send_email( user_email, diff --git a/monitoring.py b/monitoring.py index ff92f672..631041ad 100644 --- a/monitoring.py +++ b/monitoring.py @@ -2,7 +2,7 @@ import os from time import sleep from app.config import HOST -from app.extensions import db +from app.db import Session from app.log import LOG from app.models import Monitoring from server import create_app @@ -32,7 +32,7 @@ def get_stats(): active_queue=active_queue, deferred_queue=deferred_queue, ) - db.session.commit() + Session.commit() global _nb_failed # alert when too many emails in incoming + active queue diff --git a/server.py b/server.py index 634498c0..beea93f0 100644 --- a/server.py +++ b/server.py @@ -71,10 +71,11 @@ from app.config import ( ROOT_DIR, ) from app.dashboard.base import dashboard_bp +from app.db import Session from app.developer.base import developer_bp from app.discover.base import discover_bp from app.email_utils import send_email, render -from app.extensions import db, login_manager, migrate, limiter +from app.extensions import login_manager, migrate, limiter from app.jose_utils import get_jwk_key from app.log import LOG from app.models import ( @@ -129,8 +130,6 @@ def create_light_app() -> Flask: app.config["SQLALCHEMY_DATABASE_URI"] = DB_URI app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False - db.init_app(app) - return app @@ -200,6 +199,10 @@ def create_app() -> Flask: session.permanent = True app.permanent_session_lifetime = timedelta(days=7) + @app.teardown_appcontext + def cleanup(resp_or_exc): + Session.remove() + return app @@ -219,7 +222,7 @@ def fake_data(): fido_uuid=None, ) user.trial_end = None - db.session.commit() + Session.commit() # add a profile picture file_path = "profile_pic.svg" @@ -230,11 +233,11 @@ def fake_data(): ) file = File.create(user_id=user.id, path=file_path, commit=True) user.profile_picture_id = file.id - db.session.commit() + Session.commit() # create a bounced email alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() bounce_email_file_path = "bounce.eml" s3.upload_email_from_bytesio( @@ -298,7 +301,7 @@ def fake_data(): pgp_public_key=pgp_public_key, ) m1.pgp_finger_print = load_public_key(pgp_public_key) - db.session.commit() + Session.commit() # example@example.com is in a LOT of data breaches Alias.create(email="example@example.com", user_id=user.id, mailbox_id=m1.id) @@ -314,14 +317,14 @@ def fake_data(): user_id=user.id, mailbox_id=user.default_mailbox_id, ) - db.session.commit() + Session.commit() if i % 5 == 0: if i % 2 == 0: AliasMailbox.create(alias_id=a.id, mailbox_id=user.default_mailbox_id) else: AliasMailbox.create(alias_id=a.id, mailbox_id=m1.id) - db.session.commit() + Session.commit() # some aliases don't have any activity # if i % 3 != 0: @@ -331,18 +334,18 @@ def fake_data(): # website_email=f"contact{i}@example.com", # reply_email=f"rep{i}@sl.local", # ) - # db.session.commit() + # Session.commit() # for _ in range(3): # EmailLog.create(user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id) - # db.session.commit() + # Session.commit() # have some disabled alias if i % 5 == 0: a.enabled = False - db.session.commit() + Session.commit() custom_domain1 = CustomDomain.create(user_id=user.id, domain="ab.cd", verified=True) - db.session.commit() + Session.commit() Alias.create( user_id=user.id, @@ -362,13 +365,13 @@ def fake_data(): Directory.create(user_id=user.id, name="abcd") Directory.create(user_id=user.id, name="xyzt") - db.session.commit() + Session.commit() # Create a client client1 = Client.create_new(name="Demo", user_id=user.id) client1.oauth_client_id = "client-id" client1.oauth_client_secret = "client-secret" - db.session.commit() + Session.commit() RedirectUri.create( client_id=client1.id, uri="https://your-website.com/oauth-callback" @@ -377,7 +380,7 @@ def fake_data(): client2 = Client.create_new(name="Demo 2", user_id=user.id) client2.oauth_client_id = "client-id2" client2.oauth_client_secret = "client-secret2" - db.session.commit() + Session.commit() ClientUser.create(user_id=user.id, client_id=client1.id, name="Fake Name") @@ -392,11 +395,11 @@ def fake_data(): number_upgraded_account=200, payment_method="PayPal", ) - db.session.commit() + Session.commit() for i in range(6): Notification.create(user_id=user.id, message=f"""Hey hey {i} """ * 10) - db.session.commit() + Session.commit() user2 = User.create( email="winston@continental.com", @@ -405,7 +408,7 @@ def fake_data(): referral_id=referral.id, ) Mailbox.create(user_id=user2.id, email="winston2@high.table", verified=True) - db.session.commit() + Session.commit() ManualSubscription.create( user_id=user2.id, @@ -695,7 +698,7 @@ def setup_paddle_callback(app: Flask): LOG.d("User %s upgrades!", user) - db.session.commit() + Session.commit() elif request.form.get("alert_name") == "subscription_payment_succeeded": subscription_id = request.form.get("subscription_id") @@ -710,7 +713,7 @@ def setup_paddle_callback(app: Flask): request.form.get("next_bill_date"), "YYYY-MM-DD" ).date() - db.session.commit() + Session.commit() elif request.form.get("alert_name") == "subscription_cancelled": subscription_id = request.form.get("subscription_id") @@ -728,7 +731,7 @@ def setup_paddle_callback(app: Flask): sub.event_time = arrow.now() sub.cancelled = True - db.session.commit() + Session.commit() user = sub.user @@ -774,7 +777,7 @@ def setup_paddle_callback(app: Flask): # make sure to set the new plan as not-cancelled sub.cancelled = False - db.session.commit() + Session.commit() else: return "No such subscription", 400 return "OK" @@ -847,7 +850,7 @@ def handle_coinbase_event(event) -> bool: else: # already expired subscription coinbase_subscription.end_at = arrow.now().shift(years=1) - db.session.commit() + Session.commit() send_email( user.email, @@ -867,7 +870,6 @@ def handle_coinbase_event(event) -> bool: def init_extensions(app: Flask): login_manager.init_app(app) - db.init_app(app) migrate.init_app(app) @@ -875,17 +877,17 @@ def init_admin(app): admin = Admin(name="SimpleLogin", template_mode="bootstrap4") admin.init_app(app, index_view=SLAdminIndexView()) - admin.add_view(UserAdmin(User, db.session)) - admin.add_view(AliasAdmin(Alias, db.session)) - admin.add_view(MailboxAdmin(Mailbox, db.session)) - admin.add_view(EmailLogAdmin(EmailLog, db.session)) - admin.add_view(LifetimeCouponAdmin(LifetimeCoupon, db.session)) - admin.add_view(CouponAdmin(Coupon, db.session)) - admin.add_view(ManualSubscriptionAdmin(ManualSubscription, db.session)) - admin.add_view(ClientAdmin(Client, db.session)) - admin.add_view(CustomDomainAdmin(CustomDomain, db.session)) - admin.add_view(ReferralAdmin(Referral, db.session)) - admin.add_view(PayoutAdmin(Payout, db.session)) + admin.add_view(UserAdmin(User, Session)) + admin.add_view(AliasAdmin(Alias, Session)) + admin.add_view(MailboxAdmin(Mailbox, Session)) + admin.add_view(EmailLogAdmin(EmailLog, Session)) + admin.add_view(LifetimeCouponAdmin(LifetimeCoupon, Session)) + admin.add_view(CouponAdmin(Coupon, Session)) + admin.add_view(ManualSubscriptionAdmin(ManualSubscription, Session)) + admin.add_view(ClientAdmin(Client, Session)) + admin.add_view(CustomDomainAdmin(CustomDomain, Session)) + admin.add_view(ReferralAdmin(Referral, Session)) + admin.add_view(PayoutAdmin(Payout, Session)) def register_custom_commands(app): @@ -900,12 +902,12 @@ def register_custom_commands(app): def fill_up_email_log_alias(): """Fill up email_log.alias_id column""" # split all emails logs into 1000-size trunks - nb_email_log = EmailLog.query.count() + nb_email_log = EmailLog.count() LOG.d("total trunks %s", nb_email_log // 1000 + 2) for trunk in reversed(range(1, nb_email_log // 1000 + 2)): nb_update = 0 for email_log, contact in ( - db.session.query(EmailLog, Contact) + Session.query(EmailLog, Contact) .filter(EmailLog.contact_id == Contact.id) .filter(EmailLog.id <= trunk * 1000) .filter(EmailLog.id > (trunk - 1) * 1000) @@ -915,7 +917,7 @@ def register_custom_commands(app): nb_update += 1 LOG.d("finish trunk %s, update %s email logs", trunk, nb_update) - db.session.commit() + Session.commit() @app.cli.command("dummy-data") def dummy_data(): diff --git a/shell.py b/shell.py index d3df9c1e..2553df68 100644 --- a/shell.py +++ b/shell.py @@ -1,3 +1,5 @@ +from app.db import Session +from app.db import Session from time import sleep import flask_migrate @@ -34,7 +36,7 @@ def create_db(): def change_password(user_id, new_password): user = User.get(user_id) user.set_password(new_password) - db.session.commit() + Session.commit() def reset_db(): @@ -44,7 +46,7 @@ def reset_db(): def send_mailbox_newsletter(): - for user in User.query.order_by(User.id).all(): + for user in User.order_by(User.id).all(): if user.notification and user.activated: try: LOG.d("Send newsletter to %s", user) @@ -60,7 +62,7 @@ def send_mailbox_newsletter(): def send_pgp_newsletter(): - for user in User.query.order_by(User.id).all(): + for user in User.order_by(User.id).all(): if user.notification and user.activated: try: LOG.d("Send PGP newsletter to %s", user) @@ -77,7 +79,7 @@ def send_pgp_newsletter(): def send_mobile_newsletter(): count = 0 - for user in User.query.order_by(User.id).all(): + for user in User.order_by(User.id).all(): if user.notification and user.activated: count += 1 try: @@ -104,7 +106,7 @@ def disable_mailbox(mailbox_id): for alias in mailbox.aliases: alias.enabled = False - db.session.commit() + Session.commit() email_msg = f"""Hi, diff --git a/tests/api/test_alias.py b/tests/api/test_alias.py index 473cc41e..a64022ec 100644 --- a/tests/api/test_alias.py +++ b/tests/api/test_alias.py @@ -1,8 +1,8 @@ from flask import url_for from app.config import PAGE_LIMIT +from app.db import Session from app.email_utils import is_reply_email -from app.extensions import db from app.models import User, ApiKey, Alias, Contact, EmailLog, Mailbox from tests.utils import login @@ -18,7 +18,7 @@ def test_get_aliases_error_without_pagination(flask_client): # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.get_aliases"), headers={"Authentication": api_key.code} @@ -39,12 +39,12 @@ def test_get_aliases_with_pagination(flask_client): # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # create more aliases than PAGE_LIMIT for _ in range(PAGE_LIMIT + 1): Alias.create_new_random(user) - db.session.commit() + Session.commit() # get aliases on the 1st page, should return PAGE_LIMIT aliases r = flask_client.get( @@ -79,16 +79,16 @@ def test_get_aliases_query(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # create more aliases than PAGE_LIMIT Alias.create_new(user, "prefix1") Alias.create_new(user, "prefix2") - db.session.commit() + Session.commit() # get aliases without query, should return 3 aliases as one alias is created when user is created r = flask_client.get( @@ -111,15 +111,15 @@ def test_get_aliases_v2(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() a0 = Alias.create_new(user, "prefix0") a1 = Alias.create_new(user, "prefix1") - db.session.commit() + Session.commit() # << Aliases have no activity >> r = flask_client.get( @@ -154,13 +154,13 @@ def test_get_aliases_v2(flask_client): website_email="c0@example.com", reply_email="re0@SL", ) - db.session.commit() + Session.commit() EmailLog.create( contact_id=c0.id, user_id=user.id, alias_id=c0.alias_id, ) - db.session.commit() + Session.commit() # a1 has more recent activity c1 = Contact.create( @@ -169,13 +169,13 @@ def test_get_aliases_v2(flask_client): website_email="c1@example.com", reply_email="re1@SL", ) - db.session.commit() + Session.commit() EmailLog.create( contact_id=c1.id, user_id=user.id, alias_id=c1.alias_id, ) - db.session.commit() + Session.commit() # get aliases v2 r = flask_client.get( @@ -199,14 +199,14 @@ def test_delete_alias(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.delete( url_for("api.delete_alias", alias_id=alias.id), @@ -221,14 +221,14 @@ def test_toggle_alias(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.toggle_alias", alias_id=alias.id), @@ -243,14 +243,14 @@ def test_alias_activities(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # create some alias log contact = Contact.create( @@ -259,7 +259,7 @@ def test_alias_activities(flask_client): alias_id=alias.id, user_id=alias.user_id, ) - db.session.commit() + Session.commit() for _ in range(int(PAGE_LIMIT / 2)): EmailLog.create( @@ -304,14 +304,14 @@ def test_update_alias(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.put( url_for("api.update_alias", alias_id=alias.id), @@ -326,16 +326,16 @@ def test_update_alias_mailbox(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() mb = Mailbox.create(user_id=user.id, email="ab@cd.com", verified=True) # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.put( url_for("api.update_alias", alias_id=alias.id), @@ -358,14 +358,14 @@ def test_update_alias_name(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.put( url_for("api.update_alias", alias_id=alias.id), @@ -391,17 +391,17 @@ def test_update_alias_mailboxes(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() mb1 = Mailbox.create(user_id=user.id, email="ab1@cd.com", verified=True) mb2 = Mailbox.create(user_id=user.id, email="ab2@cd.com", verified=True) # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.put( url_for("api.update_alias", alias_id=alias.id), @@ -428,14 +428,14 @@ def test_update_disable_pgp(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert not alias.disable_pgp r = flask_client.put( @@ -468,14 +468,14 @@ def test_alias_contacts(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # create some alias log for i in range(PAGE_LIMIT + 1): @@ -485,7 +485,7 @@ def test_alias_contacts(flask_client): alias_id=alias.id, user_id=alias.user_id, ) - db.session.commit() + Session.commit() EmailLog.create( contact_id=contact.id, @@ -493,7 +493,7 @@ def test_alias_contacts(flask_client): user_id=contact.user_id, alias_id=contact.alias_id, ) - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.get_alias_contacts_route", alias_id=alias.id, page_id=0), @@ -523,14 +523,14 @@ def test_create_contact_route(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.create_contact_route", alias_id=alias.id), @@ -560,7 +560,7 @@ def test_create_contact_route(flask_client): def test_create_contact_route_empty_contact_address(flask_client): login(flask_client) - alias = Alias.query.first() + alias = Alias.first() r = flask_client.post( url_for("api.create_contact_route", alias_id=alias.id), @@ -573,7 +573,7 @@ def test_create_contact_route_empty_contact_address(flask_client): def test_create_contact_route_invalid_contact_email(flask_client): login(flask_client) - alias = Alias.query.first() + alias = Alias.first() r = flask_client.post( url_for("api.create_contact_route", alias_id=alias.id), @@ -588,14 +588,14 @@ def test_delete_contact(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() contact = Contact.create( alias_id=alias.id, @@ -603,7 +603,7 @@ def test_delete_contact(flask_client): reply_email="reply+random@sl.io", user_id=alias.user_id, ) - db.session.commit() + Session.commit() r = flask_client.delete( url_for("api.delete_contact", contact_id=contact.id), @@ -618,15 +618,15 @@ def test_get_alias(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # create more aliases than PAGE_LIMIT alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # get aliases on the 1st page, should return PAGE_LIMIT aliases r = flask_client.get( diff --git a/tests/api/test_alias_options.py b/tests/api/test_alias_options.py index 75703f9a..81239c79 100644 --- a/tests/api/test_alias_options.py +++ b/tests/api/test_alias_options.py @@ -2,7 +2,7 @@ import json from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User, ApiKey, AliasUsedOn, Alias @@ -10,11 +10,11 @@ def test_different_scenarios_v4(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # <<< without hostname >>> r = flask_client.get( @@ -37,11 +37,11 @@ def test_different_scenarios_v4(flask_client): # <<< with recommendation >>> alias = Alias.create_new(user, prefix="test") - db.session.commit() + Session.commit() AliasUsedOn.create( alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id ) - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.options_v4", hostname="www.test.com"), @@ -55,11 +55,11 @@ def test_different_scenarios_v4_2(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # <<< without hostname >>> r = flask_client.get( @@ -85,11 +85,11 @@ def test_different_scenarios_v4_2(flask_client): # <<< with recommendation >>> alias = Alias.create_new(user, prefix="test") - db.session.commit() + Session.commit() AliasUsedOn.create( alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id ) - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.options_v4", hostname="www.test.com"), @@ -103,11 +103,11 @@ def test_different_scenarios_v5(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # <<< without hostname >>> r = flask_client.get( @@ -138,11 +138,11 @@ def test_different_scenarios_v5(flask_client): # <<< with recommendation >>> alias = Alias.create_new(user, prefix="test") - db.session.commit() + Session.commit() AliasUsedOn.create( alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id ) - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.options_v4", hostname="www.test.com"), diff --git a/tests/api/test_apple.py b/tests/api/test_apple.py index 9259423c..4351ef50 100644 --- a/tests/api/test_apple.py +++ b/tests/api/test_apple.py @@ -1,6 +1,6 @@ from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User, ApiKey @@ -8,11 +8,11 @@ def test_apple_process_payment(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() receipt_data = """MIIUHgYJKoZIhvcNAQcCoIIUDzCCFAsCAQExCzAJBgUrDgMCGgUAMIIDvwYJKoZIhvcNAQcBoIIDsASCA6wxggOoMAoCAQgCAQEEAhYAMAoCARQCAQEEAgwAMAsCAQECAQEEAwIBADALAgEDAgEBBAMMATIwCwIBCwIBAQQDAgEAMAsCAQ8CAQEEAwIBADALAgEQAgEBBAMCAQAwCwIBGQIBAQQDAgEDMAwCAQoCAQEEBBYCNCswDAIBDgIBAQQEAgIAjjANAgENAgEBBAUCAwH8/TANAgETAgEBBAUMAzEuMDAOAgEJAgEBBAYCBFAyNTMwGAIBBAIBAgQQS28CkyUrKkayzHXyZEQ8/zAbAgEAAgEBBBMMEVByb2R1Y3Rpb25TYW5kYm94MBwCAQUCAQEEFCvruJwvAhV9s7ODIiM3KShyPW3kMB4CAQwCAQEEFhYUMjAyMC0wNC0xOFQxNjoyOToyNlowHgIBEgIBAQQWFhQyMDEzLTA4LTAxVDA3OjAwOjAwWjAgAgECAgEBBBgMFmlvLnNpbXBsZWxvZ2luLmlvcy1hcHAwSAIBBwIBAQRAHWlCA6fQTbOn0QFDAOH79MzMxIwODI0g6I8LZ6OyThRArQ6krRg6M8UPQgF4Jq6lIrz0owFG+xn0IV2Rq8ejFzBRAgEGAgEBBEkx7BUjdVQv+PiguvEl7Wd4pd+3QIrNt+oSRwl05KQdBeoBKU78eBFp48fUNkCFA/xaibj0U4EF/iq0Lgx345M2RSNqqWvRbzsIMIIBoAIBEQIBAQSCAZYxggGSMAsCAgatAgEBBAIMADALAgIGsAIBAQQCFgAwCwICBrICAQEEAgwAMAsCAgazAgEBBAIMADALAgIGtAIBAQQCDAAwCwICBrUCAQEEAgwAMAsCAga2AgEBBAIMADAMAgIGpQIBAQQDAgEBMAwCAgarAgEBBAMCAQMwDAICBq4CAQEEAwIBADAMAgIGsQIBAQQDAgEAMAwCAga3AgEBBAMCAQAwEgICBq8CAQEECQIHA41+p92hIzAbAgIGpwIBAQQSDBAxMDAwMDAwNjUzNTg0NDc0MBsCAgapAgEBBBIMEDEwMDAwMDA2NTM1ODQ0NzQwHwICBqgCAQEEFhYUMjAyMC0wNC0xOFQxNjoyNzo0MlowHwICBqoCAQEEFhYUMjAyMC0wNC0xOFQxNjoyNzo0NFowHwICBqwCAQEEFhYUMjAyMC0wNC0xOFQxNjozMjo0MlowPgICBqYCAQEENQwzaW8uc2ltcGxlbG9naW4uaW9zX2FwcC5zdWJzY3JpcHRpb24ucHJlbWl1bS5tb250aGx5oIIOZTCCBXwwggRkoAMCAQICCA7rV4fnngmNMA0GCSqGSIb3DQEBBQUAMIGWMQswCQYDVQQGEwJVUzETMBEGA1UECgwKQXBwbGUgSW5jLjEsMCoGA1UECwwjQXBwbGUgV29ybGR3aWRlIERldmVsb3BlciBSZWxhdGlvbnMxRDBCBgNVBAMMO0FwcGxlIFdvcmxkd2lkZSBEZXZlbG9wZXIgUmVsYXRpb25zIENlcnRpZmljYXRpb24gQXV0aG9yaXR5MB4XDTE1MTExMzAyMTUwOVoXDTIzMDIwNzIxNDg0N1owgYkxNzA1BgNVBAMMLk1hYyBBcHAgU3RvcmUgYW5kIGlUdW5lcyBTdG9yZSBSZWNlaXB0IFNpZ25pbmcxLDAqBgNVBAsMI0FwcGxlIFdvcmxkd2lkZSBEZXZlbG9wZXIgUmVsYXRpb25zMRMwEQYDVQQKDApBcHBsZSBJbmMuMQswCQYDVQQGEwJVUzCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAKXPgf0looFb1oftI9ozHI7iI8ClxCbLPcaf7EoNVYb/pALXl8o5VG19f7JUGJ3ELFJxjmR7gs6JuknWCOW0iHHPP1tGLsbEHbgDqViiBD4heNXbt9COEo2DTFsqaDeTwvK9HsTSoQxKWFKrEuPt3R+YFZA1LcLMEsqNSIH3WHhUa+iMMTYfSgYMR1TzN5C4spKJfV+khUrhwJzguqS7gpdj9CuTwf0+b8rB9Typj1IawCUKdg7e/pn+/8Jr9VterHNRSQhWicxDkMyOgQLQoJe2XLGhaWmHkBBoJiY5uB0Qc7AKXcVz0N92O9gt2Yge4+wHz+KO0NP6JlWB7+IDSSMCAwEAAaOCAdcwggHTMD8GCCsGAQUFBwEBBDMwMTAvBggrBgEFBQcwAYYjaHR0cDovL29jc3AuYXBwbGUuY29tL29jc3AwMy13d2RyMDQwHQYDVR0OBBYEFJGknPzEdrefoIr0TfWPNl3tKwSFMAwGA1UdEwEB/wQCMAAwHwYDVR0jBBgwFoAUiCcXCam2GGCL7Ou69kdZxVJUo7cwggEeBgNVHSAEggEVMIIBETCCAQ0GCiqGSIb3Y2QFBgEwgf4wgcMGCCsGAQUFBwICMIG2DIGzUmVsaWFuY2Ugb24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRhbmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNvbmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmljYXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wNgYIKwYBBQUHAgEWKmh0dHA6Ly93d3cuYXBwbGUuY29tL2NlcnRpZmljYXRlYXV0aG9yaXR5LzAOBgNVHQ8BAf8EBAMCB4AwEAYKKoZIhvdjZAYLAQQCBQAwDQYJKoZIhvcNAQEFBQADggEBAA2mG9MuPeNbKwduQpZs0+iMQzCCX+Bc0Y2+vQ+9GvwlktuMhcOAWd/j4tcuBRSsDdu2uP78NS58y60Xa45/H+R3ubFnlbQTXqYZhnb4WiCV52OMD3P86O3GH66Z+GVIXKDgKDrAEDctuaAEOR9zucgF/fLefxoqKm4rAfygIFzZ630npjP49ZjgvkTbsUxn/G4KT8niBqjSl/OnjmtRolqEdWXRFgRi48Ff9Qipz2jZkgDJwYyz+I0AZLpYYMB8r491ymm5WyrWHWhumEL1TKc3GZvMOxx6GUPzo22/SGAGDDaSK+zeGLUR2i0j0I78oGmcFxuegHs5R0UwYS/HE6gwggQiMIIDCqADAgECAggB3rzEOW2gEDANBgkqhkiG9w0BAQUFADBiMQswCQYDVQQGEwJVUzETMBEGA1UEChMKQXBwbGUgSW5jLjEmMCQGA1UECxMdQXBwbGUgQ2VydGlmaWNhdGlvbiBBdXRob3JpdHkxFjAUBgNVBAMTDUFwcGxlIFJvb3QgQ0EwHhcNMTMwMjA3MjE0ODQ3WhcNMjMwMjA3MjE0ODQ3WjCBljELMAkGA1UEBhMCVVMxEzARBgNVBAoMCkFwcGxlIEluYy4xLDAqBgNVBAsMI0FwcGxlIFdvcmxkd2lkZSBEZXZlbG9wZXIgUmVsYXRpb25zMUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAMo4VKbLVqrIJDlI6Yzu7F+4fyaRvDRTes58Y4Bhd2RepQcjtjn+UC0VVlhwLX7EbsFKhT4v8N6EGqFXya97GP9q+hUSSRUIGayq2yoy7ZZjaFIVPYyK7L9rGJXgA6wBfZcFZ84OhZU3au0Jtq5nzVFkn8Zc0bxXbmc1gHY2pIeBbjiP2CsVTnsl2Fq/ToPBjdKT1RpxtWCcnTNOVfkSWAyGuBYNweV3RY1QSLorLeSUheHoxJ3GaKWwo/xnfnC6AllLd0KRObn1zeFM78A7SIym5SFd/Wpqu6cWNWDS5q3zRinJ6MOL6XnAamFnFbLw/eVovGJfbs+Z3e8bY/6SZasCAwEAAaOBpjCBozAdBgNVHQ4EFgQUiCcXCam2GGCL7Ou69kdZxVJUo7cwDwYDVR0TAQH/BAUwAwEB/zAfBgNVHSMEGDAWgBQr0GlHlHYJ/vRrjS5ApvdHTX8IXjAuBgNVHR8EJzAlMCOgIaAfhh1odHRwOi8vY3JsLmFwcGxlLmNvbS9yb290LmNybDAOBgNVHQ8BAf8EBAMCAYYwEAYKKoZIhvdjZAYCAQQCBQAwDQYJKoZIhvcNAQEFBQADggEBAE/P71m+LPWybC+P7hOHMugFNahui33JaQy52Re8dyzUZ+L9mm06WVzfgwG9sq4qYXKxr83DRTCPo4MNzh1HtPGTiqN0m6TDmHKHOz6vRQuSVLkyu5AYU2sKThC22R1QbCGAColOV4xrWzw9pv3e9w0jHQtKJoc/upGSTKQZEhltV/V6WId7aIrkhoxK6+JJFKql3VUAqa67SzCu4aCxvCmA5gl35b40ogHKf9ziCuY7uLvsumKV8wVjQYLNDzsdTJWk26v5yZXpT+RN5yaZgem8+bQp0gF6ZuEujPYhisX4eOGBrr/TkJ2prfOv/TgalmcwHFGlXOxxioK0bA8MFR8wggS7MIIDo6ADAgECAgECMA0GCSqGSIb3DQEBBQUAMGIxCzAJBgNVBAYTAlVTMRMwEQYDVQQKEwpBcHBsZSBJbmMuMSYwJAYDVQQLEx1BcHBsZSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTEWMBQGA1UEAxMNQXBwbGUgUm9vdCBDQTAeFw0wNjA0MjUyMTQwMzZaFw0zNTAyMDkyMTQwMzZaMGIxCzAJBgNVBAYTAlVTMRMwEQYDVQQKEwpBcHBsZSBJbmMuMSYwJAYDVQQLEx1BcHBsZSBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eTEWMBQGA1UEAxMNQXBwbGUgUm9vdCBDQTCCASIwDQYJKoZIhvcNAQEBBQADggEPADCCAQoCggEBAOSRqQkfkdseR1DrBe1eeYQt6zaiV0xV7IsZid75S2z1B6siMALoGD74UAnTf0GomPnRymacJGsR0KO75Bsqwx+VnnoMpEeLW9QWNzPLxA9NzhRp0ckZcvVdDtV/X5vyJQO6VY9NXQ3xZDUjFUsVWR2zlPf2nJ7PULrBWFBnjwi0IPfLrCwgb3C2PwEwjLdDzw+dPfMrSSgayP7OtbkO2V4c1ss9tTqt9A8OAJILsSEWLnTVPA3bYharo3GSR1NVwa8vQbP4++NwzeajTEV+H0xrUJZBicR0YgsQg0GHM4qBsTBY7FoEMoxos48d3mVz/2deZbxJ2HafMxRloXeUyS0CAwEAAaOCAXowggF2MA4GA1UdDwEB/wQEAwIBBjAPBgNVHRMBAf8EBTADAQH/MB0GA1UdDgQWBBQr0GlHlHYJ/vRrjS5ApvdHTX8IXjAfBgNVHSMEGDAWgBQr0GlHlHYJ/vRrjS5ApvdHTX8IXjCCAREGA1UdIASCAQgwggEEMIIBAAYJKoZIhvdjZAUBMIHyMCoGCCsGAQUFBwIBFh5odHRwczovL3d3dy5hcHBsZS5jb20vYXBwbGVjYS8wgcMGCCsGAQUFBwICMIG2GoGzUmVsaWFuY2Ugb24gdGhpcyBjZXJ0aWZpY2F0ZSBieSBhbnkgcGFydHkgYXNzdW1lcyBhY2NlcHRhbmNlIG9mIHRoZSB0aGVuIGFwcGxpY2FibGUgc3RhbmRhcmQgdGVybXMgYW5kIGNvbmRpdGlvbnMgb2YgdXNlLCBjZXJ0aWZpY2F0ZSBwb2xpY3kgYW5kIGNlcnRpZmljYXRpb24gcHJhY3RpY2Ugc3RhdGVtZW50cy4wDQYJKoZIhvcNAQEFBQADggEBAFw2mUwteLftjJvc83eb8nbSdzBPwR+Fg4UbmT1HN/Kpm0COLNSxkBLYvvRzm+7SZA/LeU802KI++Xj/a8gH7H05g4tTINM4xLG/mk8Ka/8r/FmnBQl8F0BWER5007eLIztHo9VvJOLr0bdw3w9F4SfK8W147ee1Fxeo3H4iNcol1dkP1mvUoiQjEfehrI9zgWDGG1sJL5Ky+ERI8GA4nhX1PSZnIIozavcNgs/e66Mv+VNqW2TAYzN39zoHLFbr2g8hDtq6cxlPtdk2f8GHVdmnmbkyQvvY1XGefqFStxu9k0IkEirHDx22TZxeY8hLgBdQqorV2uT80AkHN7B1dSExggHLMIIBxwIBATCBozCBljELMAkGA1UEBhMCVVMxEzARBgNVBAoMCkFwcGxlIEluYy4xLDAqBgNVBAsMI0FwcGxlIFdvcmxkd2lkZSBEZXZlbG9wZXIgUmVsYXRpb25zMUQwQgYDVQQDDDtBcHBsZSBXb3JsZHdpZGUgRGV2ZWxvcGVyIFJlbGF0aW9ucyBDZXJ0aWZpY2F0aW9uIEF1dGhvcml0eQIIDutXh+eeCY0wCQYFKw4DAhoFADANBgkqhkiG9w0BAQEFAASCAQCjIWg69JwLxrmuZL7R0isYWjNGR0wvs3YKtWSwHZG/gDaxPWlgZI0oszcMOI07leGl73vQRVFO89ngbDkNp1Mmo9Mmbc/m8EJtvaVkJp0gYICKpWyMMJPNL5CT+MinMj9gBkRrd5rwFlfRkNBSmD6bt/I23B1AKcmmMwklAuF/mxGzOF4PFiPukEtaQAOe7j4w+QLzEeEAi57DIQppp+uRupKQpZRnn/Q9MyGxXA30ei6C1suxPCoRqCKrRXfWp73UsGP5jH6tOLigkVoO4CtJs3fLWpkLi9by6/K6eoGbP5MOklsBJWYGVZbRRDiNROxqPOgWnS1+p+/KGIdIC4+u""" @@ -31,11 +31,11 @@ def test_apple_update_notification(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() payload = { "unified_receipt": { diff --git a/tests/api/test_auth_login.py b/tests/api/test_auth_login.py index 3ebda76c..4ef52bd7 100644 --- a/tests/api/test_auth_login.py +++ b/tests/api/test_auth_login.py @@ -1,9 +1,8 @@ -import unicodedata - import pytest +import unicodedata from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User, AccountActivation PASSWORD_1 = "Aurélie" @@ -20,7 +19,7 @@ def test_auth_login_success(flask_client, mfa: bool): activated=True, enable_otp=mfa, ) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.auth_login"), @@ -49,7 +48,7 @@ def test_auth_login_device_exist(flask_client): User.create( email="abcd@gmail.com", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.auth_login"), @@ -138,7 +137,7 @@ def test_auth_activate_user_already_activated(flask_client): User.create( email="abcd@gmail.com", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.auth_activate"), json={"email": "abcd@gmail.com", "code": "123456"} @@ -214,7 +213,7 @@ def test_auth_activate_too_many_wrong_code(flask_client): def test_auth_reactivate_success(flask_client): User.create(email="abcd@gmail.com", password="password", name="Test User") - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.auth_reactivate"), json={"email": "abcd@gmail.com"} @@ -232,7 +231,7 @@ def test_auth_login_forgot_password(flask_client): User.create( email="abcd@gmail.com", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.forgot_password"), diff --git a/tests/api/test_auth_mfa.py b/tests/api/test_auth_mfa.py index 86ad19a6..b27077c8 100644 --- a/tests/api/test_auth_mfa.py +++ b/tests/api/test_auth_mfa.py @@ -3,7 +3,7 @@ from flask import url_for from itsdangerous import Signer from app.config import FLASK_SECRET -from app.extensions import db +from app.db import Session from app.models import User @@ -16,7 +16,7 @@ def test_auth_mfa_success(flask_client): enable_otp=True, otp_secret="base32secret3232", ) - db.session.commit() + Session.commit() totp = pyotp.TOTP(user.otp_secret) s = Signer(FLASK_SECRET) @@ -42,7 +42,7 @@ def test_auth_wrong_mfa_key(flask_client): enable_otp=True, otp_secret="base32secret3232", ) - db.session.commit() + Session.commit() totp = pyotp.TOTP(user.otp_secret) diff --git a/tests/api/test_import_export.py b/tests/api/test_import_export.py index 8dcdf433..31b2f505 100644 --- a/tests/api/test_import_export.py +++ b/tests/api/test_import_export.py @@ -1,7 +1,8 @@ from flask import url_for from app import alias_utils -from app.extensions import db +from app.db import Session +from app.import_utils import import_from_csv from app.models import ( User, CustomDomain, @@ -11,7 +12,6 @@ from app.models import ( BatchImport, File, ) -from app.import_utils import import_from_csv from tests.utils import login @@ -21,14 +21,14 @@ def test_export(flask_client): user2 = User.create( email="x@y.z", password="password", name="Wrong user", activated=True ) - db.session.commit() + Session.commit() # Remove onboarding aliases for alias in Alias.filter_by(user_id=user1.id).all(): alias_utils.delete_alias(alias, user1) for alias in Alias.filter_by(user_id=user2.id).all(): alias_utils.delete_alias(alias, user2) - db.session.commit() + Session.commit() # Create domains CustomDomain.create( @@ -37,7 +37,7 @@ def test_export(flask_client): CustomDomain.create( user_id=user2.id, domain="bad-destionation-domain.com", verified=True ) - db.session.commit() + Session.commit() # Create mailboxes mailbox1 = Mailbox.create( @@ -51,7 +51,7 @@ def test_export(flask_client): email="baddestination@bad-destination-domain.com", verified=True, ) - db.session.commit() + Session.commit() # Create aliases Alias.create( @@ -72,14 +72,14 @@ def test_export(flask_client): note="Should not appear", mailbox_id=badmailbox1.id, ) - db.session.commit() + Session.commit() # Add second mailbox to an alias AliasMailbox.create( alias_id=alias2.id, mailbox_id=mailbox2.id, ) - db.session.commit() + Session.commit() # Export r = flask_client.get(url_for("api.export_aliases")) @@ -128,7 +128,7 @@ def test_import_no_mailboxes(flask_client): CustomDomain.create( user_id=user.id, domain="my-domain.com", ownership_verified=True ) - db.session.commit() + Session.commit() alias_data = [ "alias,note", @@ -180,7 +180,7 @@ def test_import(flask_client): CustomDomain.create( user_id=user.id, domain="my-destination-domain.com", ownership_verified=True ) - db.session.commit() + Session.commit() # Create mailboxes mailbox1 = Mailbox.create( @@ -189,7 +189,7 @@ def test_import(flask_client): mailbox2 = Mailbox.create( user_id=user.id, email="destination2@my-destination-domain.com", verified=True ) - db.session.commit() + Session.commit() alias_data = [ "alias,note,mailboxes", diff --git a/tests/api/test_mailbox.py b/tests/api/test_mailbox.py index 4da34feb..8a28d1ff 100644 --- a/tests/api/test_mailbox.py +++ b/tests/api/test_mailbox.py @@ -1,6 +1,6 @@ from flask import url_for -from app.extensions import db +from app.db import Session from app.models import Mailbox from tests.utils import login @@ -34,7 +34,7 @@ def test_create_mailbox(flask_client): def test_create_mailbox_fail_for_free_user(flask_client): user = login(flask_client) user.trial_end = None - db.session.commit() + Session.commit() r = flask_client.post( "/api/mailboxes", @@ -50,7 +50,7 @@ def test_delete_mailbox(flask_client): # create a mailbox mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") - db.session.commit() + Session.commit() r = flask_client.delete( f"/api/mailboxes/{mb.id}", @@ -88,7 +88,7 @@ def test_set_mailbox_as_default(flask_client): # <<< Cannot set an unverified mailbox as default >>> mb.verified = False - db.session.commit() + Session.commit() r = flask_client.put( f"/api/mailboxes/{mb.id}", @@ -104,7 +104,7 @@ def test_update_mailbox_email(flask_client): # create a mailbox mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") - db.session.commit() + Session.commit() r = flask_client.put( f"/api/mailboxes/{mb.id}", @@ -122,7 +122,7 @@ def test_cancel_mailbox_email_change(flask_client): # create a mailbox mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") - db.session.commit() + Session.commit() # update mailbox email r = flask_client.put( @@ -150,7 +150,7 @@ def test_get_mailboxes(flask_client): Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) - db.session.commit() + Session.commit() r = flask_client.get( "/api/mailboxes", @@ -173,7 +173,7 @@ def test_get_mailboxes_v2(flask_client): Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) - db.session.commit() + Session.commit() r = flask_client.get( "/api/v2/mailboxes", diff --git a/tests/api/test_new_custom_alias.py b/tests/api/test_new_custom_alias.py index 4d9e3fe8..2bc058b4 100644 --- a/tests/api/test_new_custom_alias.py +++ b/tests/api/test_new_custom_alias.py @@ -3,7 +3,7 @@ from flask import g from app.alias_utils import delete_alias from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN from app.dashboard.views.custom_alias import signer -from app.extensions import db +from app.db import Session from app.models import Alias, CustomDomain, Mailbox, AliasUsedOn from app.utils import random_word from tests.utils import login @@ -86,13 +86,13 @@ def test_full_payload(flask_client): # create another mailbox mb = Mailbox.create(user_id=user.id, email="abcd@gmail.com", verified=True) - db.session.commit() + Session.commit() word = random_word() suffix = f".{word}@{EMAIL_DOMAIN}" signed_suffix = signer.sign(suffix).decode() - assert AliasUsedOn.query.count() == 0 + assert AliasUsedOn.count() == 0 r = flask_client.post( "/api/v3/alias/custom/new?hostname=example.com", @@ -146,7 +146,7 @@ def test_custom_domain_alias(flask_client): def test_out_of_quota(flask_client): user = login(flask_client) user.trial_end = None - db.session.commit() + Session.commit() # create MAX_NB_EMAIL_FREE_PLAN custom alias to run out of quota for _ in range(MAX_NB_EMAIL_FREE_PLAN): diff --git a/tests/api/test_new_random_alias.py b/tests/api/test_new_random_alias.py index dc8c367f..c0894a77 100644 --- a/tests/api/test_new_random_alias.py +++ b/tests/api/test_new_random_alias.py @@ -3,7 +3,7 @@ import uuid from flask import url_for, g from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN -from app.extensions import db +from app.db import Session from app.models import Alias from tests.utils import login @@ -60,7 +60,7 @@ def test_custom_mode(flask_client): def test_out_of_quota(flask_client): user = login(flask_client) user.trial_end = None - db.session.commit() + Session.commit() # create MAX_NB_EMAIL_FREE_PLAN random alias to run out of quota for _ in range(MAX_NB_EMAIL_FREE_PLAN): diff --git a/tests/api/test_notification.py b/tests/api/test_notification.py index 362317b3..26f17195 100644 --- a/tests/api/test_notification.py +++ b/tests/api/test_notification.py @@ -1,6 +1,6 @@ from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User, ApiKey, Notification @@ -8,16 +8,16 @@ def test_get_notifications(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() # create some notifications Notification.create(user_id=user.id, message="Test message 1") Notification.create(user_id=user.id, message="Test message 2") - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.get_notifications", page=0), @@ -46,14 +46,14 @@ def test_mark_notification_as_read(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() Notification.create(id=1, user_id=user.id, message="Test message 1") - db.session.commit() + Session.commit() r = flask_client.post( url_for("api.mark_as_read", notification_id=1), diff --git a/tests/api/test_serializer.py b/tests/api/test_serializer.py index 11bc8a11..0adba539 100644 --- a/tests/api/test_serializer.py +++ b/tests/api/test_serializer.py @@ -1,6 +1,6 @@ from app.api.serializer import get_alias_infos_with_pagination_v3 from app.config import PAGE_LIMIT -from app.extensions import db +from app.db import Session from app.models import User, Alias, Mailbox, Contact from tests.utils import create_user @@ -75,7 +75,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_mailboxes(flask_client): alias = Alias.first() mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") alias._mailboxes.append(mb) - db.session.commit() + Session.commit() alias_infos = get_alias_infos_with_pagination_v3(user, mailbox_id=mb.id) assert len(alias_infos) == 1 @@ -96,7 +96,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_note(flask_client): alias = Alias.first() alias.note = "test note" - db.session.commit() + Session.commit() alias_infos = get_alias_infos_with_pagination_v3(user, query="test note") assert len(alias_infos) == 1 @@ -114,7 +114,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_name(flask_client): alias = Alias.first() alias.name = "Test Name" - db.session.commit() + Session.commit() alias_infos = get_alias_infos_with_pagination_v3(user, query="test name") assert len(alias_infos) == 1 @@ -135,7 +135,7 @@ def test_get_alias_infos_with_pagination_v3_no_duplicate(flask_client): alias = Alias.first() mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") alias._mailboxes.append(mb) - db.session.commit() + Session.commit() alias_infos = get_alias_infos_with_pagination_v3(user) assert len(alias_infos) == 1 @@ -182,7 +182,7 @@ def test_get_alias_infos_pinned_alias(flask_client): for i in range(2 * PAGE_LIMIT): Alias.create_new_random(user) - first_alias = Alias.query.order_by(Alias.id).first() + first_alias = Alias.order_by(Alias.id).first() # should return PAGE_LIMIT alias alias_infos = get_alias_infos_with_pagination_v3(user) @@ -192,7 +192,7 @@ def test_get_alias_infos_pinned_alias(flask_client): # pin the first alias first_alias.pinned = True - db.session.commit() + Session.commit() alias_infos = get_alias_infos_with_pagination_v3(user) # now first_alias is the first result diff --git a/tests/api/test_user_info.py b/tests/api/test_user_info.py index 02687513..ed030c2b 100644 --- a/tests/api/test_user_info.py +++ b/tests/api/test_user_info.py @@ -1,6 +1,6 @@ from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User, ApiKey from tests.utils import login @@ -9,11 +9,11 @@ def test_user_in_trial(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # create api_key api_key = ApiKey.create(user.id, "for test") - db.session.commit() + Session.commit() r = flask_client.get( url_for("api.user_info"), headers={"Authentication": api_key.code} @@ -42,7 +42,7 @@ def test_wrong_api_key(flask_client): def test_create_api_key(flask_client): # create user, user is activated User.create(email="a@b.c", password="password", name="Test User", activated=True) - db.session.commit() + Session.commit() # login user flask_client.post( @@ -61,7 +61,7 @@ def test_create_api_key(flask_client): def test_logout(flask_client): # create user, user is activated User.create(email="a@b.c", password="password", name="Test User", activated=True) - db.session.commit() + Session.commit() # login user flask_client.post( diff --git a/tests/auth/test_login.py b/tests/auth/test_login.py index 553f3b54..5c5ae515 100644 --- a/tests/auth/test_login.py +++ b/tests/auth/test_login.py @@ -1,6 +1,6 @@ from flask import url_for -from app.extensions import db +from app.db import Session from app.models import User @@ -9,7 +9,7 @@ def test_unactivated_user_login(flask_client): # create user, user is not activated User.create(email="a@b.c", password="password", name="Test User") - db.session.commit() + Session.commit() r = flask_client.post( url_for("auth.login"), @@ -29,7 +29,7 @@ def test_activated_user_login(flask_client): # create user, user is activated User.create(email="a@b.c", password="password", name="Test User", activated=True) - db.session.commit() + Session.commit() r = flask_client.post( url_for("auth.login"), diff --git a/tests/conftest.py b/tests/conftest.py index 3966ee7c..9f909c1f 100644 --- a/tests/conftest.py +++ b/tests/conftest.py @@ -2,18 +2,20 @@ import os # use the tests/test.env config fle # flake8: noqa: E402 -import sqlalchemy os.environ["CONFIG"] = os.path.abspath( os.path.join(os.path.dirname(os.path.dirname(__file__)), "tests/test.env") ) +import sqlalchemy + +from app.db import Session, engine, connection +from app.models import Base from psycopg2 import errors from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST import pytest -from app.extensions import db from server import create_app from init_app import add_sl_domains @@ -24,7 +26,7 @@ app.config["SERVER_NAME"] = "sl.test" with app.app_context(): # enable pg_trgm extension - with db.engine.connect() as conn: + with engine.connect() as conn: try: conn.execute("DROP EXTENSION if exists pg_trgm") conn.execute("CREATE EXTENSION pg_trgm") @@ -33,7 +35,7 @@ with app.app_context(): print(">>> pg_trgm can't be dropped, ignore") conn.execute("Rollback") - db.create_all() + Base.metadata.create_all(engine) add_sl_domains() @@ -45,20 +47,14 @@ def flask_app(): @pytest.fixture def flask_client(): - with app.app_context(): - # replace db.session to that we can rollback all commits that can be made during a test - # inspired from http://alexmic.net/flask-sqlalchemy-pytest/ - connection = db.engine.connect() - transaction = connection.begin() - options = dict(bind=connection, binds={}) - session = db.create_scoped_session(options=options) - db.session = session + transaction = connection.begin() + with app.app_context(): try: client = app.test_client() yield client finally: # roll back all commits made during a test transaction.rollback() - connection.close() - session.remove() + Session.rollback() + Session.close() diff --git a/tests/dashboard/test_alias_contact_manager.py b/tests/dashboard/test_alias_contact_manager.py index 8d2e30ea..e79801fb 100644 --- a/tests/dashboard/test_alias_contact_manager.py +++ b/tests/dashboard/test_alias_contact_manager.py @@ -11,7 +11,7 @@ def test_add_contact_success(flask_client): login(flask_client) alias = Alias.first() - assert Contact.query.count() == 0 + assert Contact.count() == 0 # <<< Create a new contact >>> flask_client.post( @@ -23,7 +23,7 @@ def test_add_contact_success(flask_client): follow_redirects=True, ) # a new contact is added - assert Contact.query.count() == 1 + assert Contact.count() == 1 contact = Contact.first() assert contact.website_email == "abcd@gmail.com" @@ -37,8 +37,8 @@ def test_add_contact_success(flask_client): follow_redirects=True, ) # a new contact is added - assert Contact.query.count() == 2 - contact = Contact.query.filter(Contact.id != contact.id).first() + assert Contact.count() == 2 + contact = Contact.filter(Contact.id != contact.id).first() assert contact.website_email == "another@gmail.com" assert contact.name == "First Last" @@ -53,5 +53,5 @@ def test_add_contact_success(flask_client): ) # no new contact is added - assert Contact.query.count() == 2 + assert Contact.count() == 2 assert "Invalid email format. Email must be either email@example.com" in str(r.data) diff --git a/tests/dashboard/test_alias_transfer.py b/tests/dashboard/test_alias_transfer.py index a66bcfbe..604c7074 100644 --- a/tests/dashboard/test_alias_transfer.py +++ b/tests/dashboard/test_alias_transfer.py @@ -1,5 +1,5 @@ from app.dashboard.views import alias_transfer -from app.extensions import db +from app.db import Session from app.models import ( Alias, Mailbox, @@ -14,7 +14,7 @@ def test_alias_transfer(flask_client): mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True) diff --git a/tests/dashboard/test_custom_alias.py b/tests/dashboard/test_custom_alias.py index 1d669905..8099415a 100644 --- a/tests/dashboard/test_custom_alias.py +++ b/tests/dashboard/test_custom_alias.py @@ -8,7 +8,7 @@ from app.dashboard.views.custom_alias import ( get_available_suffixes, AliasSuffix, ) -from app.extensions import db +from app.db import Session from app.models import ( Mailbox, CustomDomain, @@ -46,13 +46,13 @@ def test_add_alias_success(flask_client): assert r.status_code == 200 assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data) - alias = Alias.query.order_by(Alias.created_at.desc()).first() + alias = Alias.order_by(Alias.created_at.desc()).first() assert not alias._mailboxes def test_add_alias_multiple_mailboxes(flask_client): user = login(flask_client) - db.session.commit() + Session.commit() alias_suffix = AliasSuffix( is_custom=False, @@ -64,7 +64,7 @@ def test_add_alias_multiple_mailboxes(flask_client): # create with a multiple mailboxes mb1 = Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) - db.session.commit() + Session.commit() r = flask_client.post( url_for("dashboard.custom_alias"), @@ -78,18 +78,18 @@ def test_add_alias_multiple_mailboxes(flask_client): assert r.status_code == 200 assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data) - alias = Alias.query.order_by(Alias.created_at.desc()).first() + alias = Alias.order_by(Alias.created_at.desc()).first() assert alias._mailboxes def test_not_show_unverified_mailbox(flask_client): """make sure user unverified mailbox is not shown to user""" user = login(flask_client) - db.session.commit() + Session.commit() Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) - db.session.commit() + Session.commit() r = flask_client.get(url_for("dashboard.custom_alias")) @@ -99,7 +99,7 @@ def test_not_show_unverified_mailbox(flask_client): def test_verify_prefix_suffix(flask_client): user = login(flask_client) - db.session.commit() + Session.commit() CustomDomain.create(user_id=user.id, domain="test.com", verified=True) @@ -128,7 +128,7 @@ def test_available_suffixes(flask_client): def test_available_suffixes_default_domain(flask_client): user = login(flask_client) - sl_domain = SLDomain.query.first() + sl_domain = SLDomain.first() CustomDomain.create(user_id=user.id, domain="test.com", verified=True, commit=True) user.default_alias_public_domain_id = sl_domain.id @@ -166,7 +166,7 @@ def test_available_suffixes_random_prefix_generation(flask_client): def test_add_already_existed_alias(flask_client): user = login(flask_client) - db.session.commit() + Session.commit() another_user = User.create( email="a2@b.c", @@ -208,7 +208,7 @@ def test_add_already_existed_alias(flask_client): def test_add_alias_in_global_trash(flask_client): user = login(flask_client) - db.session.commit() + Session.commit() another_user = User.create( email="a2@b.c", @@ -233,9 +233,9 @@ def test_add_alias_in_global_trash(flask_client): commit=True, ) - assert DeletedAlias.query.count() == 0 + assert DeletedAlias.count() == 0 delete_alias(alias, another_user) - assert DeletedAlias.query.count() == 1 + assert DeletedAlias.count() == 1 # create the same alias, should return error r = flask_client.post( @@ -267,9 +267,9 @@ def test_add_alias_in_custom_domain_trash(flask_client): commit=True, ) - assert DomainDeletedAlias.query.count() == 0 + assert DomainDeletedAlias.count() == 0 delete_alias(alias, user) - assert DomainDeletedAlias.query.count() == 1 + assert DomainDeletedAlias.count() == 1 # create the same alias, should return error suffix = "@ab.cd" diff --git a/tests/dashboard/test_custom_domain.py b/tests/dashboard/test_custom_domain.py index 30306b28..d024a8cb 100644 --- a/tests/dashboard/test_custom_domain.py +++ b/tests/dashboard/test_custom_domain.py @@ -1,13 +1,13 @@ from flask import url_for -from app.extensions import db +from app.db import Session from tests.utils import login def test_add_domain_success(flask_client): user = login(flask_client) user.lifetime = True - db.session.commit() + Session.commit() r = flask_client.post( url_for("dashboard.custom_domain"), @@ -23,7 +23,7 @@ def test_add_domain_same_as_user_email(flask_client): """cannot add domain if user personal email uses this domain""" user = login(flask_client) user.lifetime = True - db.session.commit() + Session.commit() r = flask_client.post( url_for("dashboard.custom_domain"), diff --git a/tests/dashboard/test_index.py b/tests/dashboard/test_index.py index ade7a6c4..01b20a22 100644 --- a/tests/dashboard/test_index.py +++ b/tests/dashboard/test_index.py @@ -8,7 +8,7 @@ from tests.utils import login def test_create_random_alias_success(flask_client): login(flask_client) - assert Alias.query.count() == 1 + assert Alias.count() == 1 r = flask_client.post( url_for("dashboard.index"), @@ -16,7 +16,7 @@ def test_create_random_alias_success(flask_client): follow_redirects=True, ) assert r.status_code == 200 - assert Alias.query.count() == 2 + assert Alias.count() == 2 def test_too_many_requests(flask_client): diff --git a/tests/email/test_rate_limit.py b/tests/email/test_rate_limit.py index 851c2d40..30722e25 100644 --- a/tests/email/test_rate_limit.py +++ b/tests/email/test_rate_limit.py @@ -2,7 +2,7 @@ from app.config import ( MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS, MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX, ) -from app.extensions import db +from app.db import Session from app.email.rate_limit import ( rate_limited_forward_phase, rate_limited_for_alias, @@ -16,11 +16,11 @@ def test_rate_limited_forward_phase_for_alias(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() # no rate limiting for a new alias alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert not rate_limited_for_alias(alias) # rate limit when there's a previous activity on alias @@ -30,14 +30,14 @@ def test_rate_limited_forward_phase_for_alias(flask_client): website_email="contact@example.com", reply_email="rep@sl.local", ) - db.session.commit() + Session.commit() for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1): EmailLog.create( user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id, ) - db.session.commit() + Session.commit() assert rate_limited_for_alias(alias) @@ -46,10 +46,10 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() contact = Contact.create( user_id=user.id, @@ -57,14 +57,14 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client): website_email="contact@example.com", reply_email="rep@sl.local", ) - db.session.commit() + Session.commit() for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX + 1): EmailLog.create( user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id, ) - db.session.commit() + Session.commit() EmailLog.create( user_id=user.id, @@ -75,7 +75,7 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client): # Create another alias with the same mailbox # will be rate limited as there's a previous activity on mailbox alias2 = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert rate_limited_for_mailbox(alias2) @@ -91,10 +91,10 @@ def test_rate_limited_reply_phase(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() contact = Contact.create( user_id=user.id, @@ -102,13 +102,13 @@ def test_rate_limited_reply_phase(flask_client): website_email="contact@example.com", reply_email="rep@sl.local", ) - db.session.commit() + Session.commit() for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1): EmailLog.create( user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id, ) - db.session.commit() + Session.commit() assert rate_limited_reply_phase("rep@sl.local") diff --git a/tests/oauth/test_authorize.py b/tests/oauth/test_authorize.py index dcf44bf1..82f1e7e0 100644 --- a/tests/oauth/test_authorize.py +++ b/tests/oauth/test_authorize.py @@ -4,7 +4,7 @@ from urllib.parse import urlparse, parse_qs from flask import url_for -from app.extensions import db +from app.db import Session from app.jose_utils import verify_id_token, decode_id_token from app.models import Client, User, ClientUser from app.oauth.views.authorize import ( @@ -39,10 +39,10 @@ def test_construct_url(): def test_authorize_page_non_login_user(flask_client): """make sure to display login page for non-authenticated user""" user = User.create("test@test.com", "test user") - db.session.commit() + Session.commit() client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() r = flask_client.get( url_for( @@ -63,7 +63,7 @@ def test_authorize_page_login_user_non_supported_flow(flask_client): """return 400 if the flow is not supported""" user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # Not provide any flow r = flask_client.get( @@ -102,7 +102,7 @@ def test_authorize_page_login_user(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() r = flask_client.get( url_for( @@ -128,7 +128,7 @@ def test_authorize_code_flow_no_openid_scope(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -217,7 +217,7 @@ def test_authorize_code_flow_with_openid_scope(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -310,7 +310,7 @@ def test_authorize_token_flow(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -357,7 +357,7 @@ def test_authorize_id_token_flow(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -406,7 +406,7 @@ def test_authorize_token_id_token_flow(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -496,7 +496,7 @@ def test_authorize_code_id_token_flow(flask_client): user = login(flask_client) client = Client.create_new("test client", user.id) - db.session.commit() + Session.commit() # user allows client on the authorization page r = flask_client.post( @@ -629,7 +629,7 @@ def test_authorize_page_invalid_client_id(flask_client): user = login(flask_client) Client.create_new("test client", user.id) - db.session.commit() + Session.commit() r = flask_client.get( url_for( @@ -654,7 +654,7 @@ def test_authorize_page_http_not_allowed(flask_client): client = Client.create_new("test client", user.id) client.approved = True - db.session.commit() + Session.commit() r = flask_client.get( url_for( @@ -676,7 +676,7 @@ def test_authorize_page_unknown_redirect_uri(flask_client): client = Client.create_new("test client", user.id) client.approved = True - db.session.commit() + Session.commit() r = flask_client.get( url_for( diff --git a/tests/test_alias_utils.py b/tests/test_alias_utils.py index e2281539..a9d48aff 100644 --- a/tests/test_alias_utils.py +++ b/tests/test_alias_utils.py @@ -1,5 +1,5 @@ from app.alias_utils import delete_alias, check_alias_prefix -from app.extensions import db +from app.db import Session from app.models import User, Alias, DeletedAlias @@ -41,8 +41,8 @@ def test_delete_alias_already_in_trash(flask_client): ) # add the alias to global trash - db.session.add(DeletedAlias(email=alias.email)) - db.session.commit() + Session.add(DeletedAlias(email=alias.email)) + Session.commit() delete_alias(alias, user) assert Alias.get_by(email="first@d1.test") is None diff --git a/tests/test_config.py b/tests/test_config.py index 21c7ebab..99c31c33 100644 --- a/tests/test_config.py +++ b/tests/test_config.py @@ -1,4 +1,5 @@ import pytest + from app.config import sl_getenv diff --git a/tests/test_email_utils.py b/tests/test_email_utils.py index 3395bcca..66f3e9b0 100644 --- a/tests/test_email_utils.py +++ b/tests/test_email_utils.py @@ -4,6 +4,7 @@ from email.message import EmailMessage import arrow from app.config import MAX_ALERT_24H, EMAIL_DOMAIN, BOUNCE_EMAIL +from app.db import Session from app.email_utils import ( get_email_domain_part, can_create_directory_for_address, @@ -31,7 +32,6 @@ from app.email_utils import ( get_header_unicode, parse_full_address, ) -from app.extensions import db from app.models import User, CustomDomain, Alias, Contact, EmailLog, IgnoreBounceSender # flake8: noqa: E101, W191 @@ -136,7 +136,7 @@ def test_send_email_with_rate_control(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() for _ in range(MAX_ALERT_24H): assert send_email_with_rate_control( @@ -598,7 +598,7 @@ def test_should_disable(flask_client): include_sender_in_reverse_alias=True, ) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert not should_disable(alias) @@ -623,7 +623,7 @@ def test_should_disable(flask_client): # should not affect another alias alias2 = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert not should_disable(alias2) @@ -631,7 +631,7 @@ def test_should_disable_bounces_every_day(flask_client): """if an alias has bounces every day at least 9 days in the last 10 days, disable alias""" user = login(flask_client) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() assert not should_disable(alias) @@ -661,7 +661,7 @@ def test_should_disable_bounces_account(flask_client): user = login(flask_client) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # create a lot of bounces on alias contact = Contact.create( @@ -690,7 +690,7 @@ def test_should_disable_bounces_account(flask_client): def test_should_disable_bounce_consecutive_days(flask_client): user = login(flask_client) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() contact = Contact.create( user_id=user.id, diff --git a/tests/test_jose_utils.py b/tests/test_jose_utils.py index 557beb66..5e4387fb 100644 --- a/tests/test_jose_utils.py +++ b/tests/test_jose_utils.py @@ -1,4 +1,4 @@ -from app.extensions import db +from app.db import Session from app.jose_utils import make_id_token, verify_id_token from app.models import ClientUser, User, Client @@ -7,15 +7,15 @@ def test_encode_decode(flask_client): user = User.create( email="a@b.c", password="password", name="Test User", activated=True ) - db.session.commit() + Session.commit() client1 = Client.create_new(name="Demo", user_id=user.id) client1.oauth_client_id = "client-id" client1.oauth_client_secret = "client-secret" - db.session.commit() + Session.commit() client_user = ClientUser.create(client_id=client1.id, user_id=user.id) - db.session.commit() + Session.commit() jwt_token = make_id_token(client_user) diff --git a/tests/test_models.py b/tests/test_models.py index 6bcf4ca9..81fff9f2 100644 --- a/tests/test_models.py +++ b/tests/test_models.py @@ -3,8 +3,8 @@ from uuid import UUID import pytest from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN +from app.db import Session from app.email_utils import parse_full_address -from app.extensions import db from app.models import ( generate_email, User, @@ -53,7 +53,7 @@ def test_suggested_emails_for_user_who_cannot_create_new_alias(flask_client): # make sure user runs out of quota to create new email for i in range(MAX_NB_EMAIL_FREE_PLAN): Alias.create_new(user=user, prefix="test") - db.session.commit() + Session.commit() suggested_email, other_emails = user.suggested_emails(website_name="test") @@ -88,7 +88,7 @@ def test_website_send_to(flask_client): ) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # non-empty name c1 = Contact.create( @@ -122,7 +122,7 @@ def test_new_addr(flask_client): ) alias = Alias.create_new_random(user) - db.session.commit() + Session.commit() # default sender_format is 'via' c1 = Contact.create( @@ -137,18 +137,18 @@ def test_new_addr(flask_client): # Make sure email isn't duplicated if sender name equals email c1.name = "abcd@example.com" - db.session.commit() + Session.commit() assert c1.new_addr() == '"abcd(a)example.com" ' # set sender_format = AT user.sender_format = SenderFormatEnum.AT.value c1.name = "First Last" - db.session.commit() + Session.commit() assert c1.new_addr() == '"First Last - abcd at example.com" ' # unicode name c1.name = "Nhơn Nguyễn" - db.session.commit() + Session.commit() assert ( c1.new_addr() == "=?utf-8?q?Nh=C6=A1n_Nguy=E1=BB=85n_-_abcd_at_example=2Ecom?= " @@ -182,11 +182,11 @@ def test_mailbox_delete(flask_client): # alias has 2 mailboxes alias = Alias.create_new(user, "prefix", mailbox_id=m1.id) - db.session.commit() + Session.commit() alias._mailboxes.append(m2) alias._mailboxes.append(m3) - db.session.commit() + Session.commit() assert len(alias.mailboxes) == 3 diff --git a/tests/test_server.py b/tests/test_server.py index ed362342..ee1b1301 100644 --- a/tests/test_server.py +++ b/tests/test_server.py @@ -1,6 +1,6 @@ import arrow -from app.extensions import db +from app.db import Session from app.models import User, CoinbaseSubscription from server import handle_coinbase_event @@ -44,7 +44,7 @@ def test_handle_coinbase_event_extend_subscription(flask_client): activated=True, ) user.trial_end = None - db.session.commit() + Session.commit() cb = CoinbaseSubscription.create( user_id=user.id, end_at=arrow.now().shift(days=-400), commit=True