do not use flask-sqlalchemy

- add __tablename__ for all models
- use sa and orm instead of db
- rollback all changes in tests
- remove session in @app.teardown_appcontext
This commit is contained in:
Son 2021-10-12 14:36:47 +02:00
parent 653a03ac11
commit 372466ab06
98 changed files with 1338 additions and 1234 deletions

View File

@ -5,7 +5,7 @@ from flask_admin.actions import action
from flask_admin.contrib import sqla from flask_admin.contrib import sqla
from flask_login import current_user, login_user from flask_login import current_user, login_user
from app.extensions import db from app.db import Session
from app.models import User, ManualSubscription from app.models import User, ManualSubscription
@ -99,7 +99,7 @@ class UserAdmin(SLModelView):
"Extend trial for 1 week more?", "Extend trial for 1 week more?",
) )
def extend_trial_1w(self, ids): def extend_trial_1w(self, ids):
for user in User.query.filter(User.id.in_(ids)): for user in User.filter(User.id.in_(ids)):
if user.trial_end and user.trial_end > arrow.now(): if user.trial_end and user.trial_end > arrow.now():
user.trial_end = user.trial_end.shift(weeks=1) user.trial_end = user.trial_end.shift(weeks=1)
else: else:
@ -107,7 +107,7 @@ class UserAdmin(SLModelView):
flash(f"Extend trial for {user} to {user.trial_end}", "success") flash(f"Extend trial for {user} to {user.trial_end}", "success")
db.session.commit() Session.commit()
@action( @action(
"disable_otp", "disable_otp",
@ -115,12 +115,12 @@ class UserAdmin(SLModelView):
"Disable OTP?", "Disable OTP?",
) )
def disable_otp(self, ids): def disable_otp(self, ids):
for user in User.query.filter(User.id.in_(ids)): for user in User.filter(User.id.in_(ids)):
if user.enable_otp: if user.enable_otp:
user.enable_otp = False user.enable_otp = False
flash(f"Disable OTP for {user}", "info") flash(f"Disable OTP for {user}", "info")
db.session.commit() Session.commit()
@action( @action(
"login_as", "login_as",
@ -132,16 +132,14 @@ class UserAdmin(SLModelView):
flash("only 1 user can be selected", "error") flash("only 1 user can be selected", "error")
return return
for user in User.query.filter(User.id.in_(ids)): for user in User.filter(User.id.in_(ids)):
login_user(user) login_user(user)
flash(f"Login as user {user}", "success") flash(f"Login as user {user}", "success")
return redirect("/") return redirect("/")
def manual_upgrade(way: str, ids: [int], is_giveaway: bool): def manual_upgrade(way: str, ids: [int], is_giveaway: bool):
query = User.query.filter(User.id.in_(ids)) for user in User.filter(User.id.in_(ids)).all():
for user in query.all():
manual_sub: ManualSubscription = ManualSubscription.get_by(user_id=user.id) manual_sub: ManualSubscription = ManualSubscription.get_by(user_id=user.id)
if manual_sub: if manual_sub:
# renew existing subscription # renew existing subscription
@ -149,7 +147,7 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool):
manual_sub.end_at = manual_sub.end_at.shift(years=1) manual_sub.end_at = manual_sub.end_at.shift(years=1)
else: else:
manual_sub.end_at = arrow.now().shift(years=1, days=1) manual_sub.end_at = arrow.now().shift(years=1, days=1)
db.session.commit() Session.commit()
flash(f"Subscription extended to {manual_sub.end_at.humanize()}", "success") flash(f"Subscription extended to {manual_sub.end_at.humanize()}", "success")
continue continue
@ -211,11 +209,11 @@ class ManualSubscriptionAdmin(SLModelView):
"Extend 1 year more?", "Extend 1 year more?",
) )
def extend_1y(self, ids): def extend_1y(self, ids):
for ms in ManualSubscription.query.filter(ManualSubscription.id.in_(ids)): for ms in ManualSubscription.filter(ManualSubscription.id.in_(ids)):
ms.end_at = ms.end_at.shift(years=1) ms.end_at = ms.end_at.shift(years=1)
flash(f"Extend subscription for {ms.user}", "success") flash(f"Extend subscription for {ms.user}", "success")
db.session.commit() Session.commit()
class ClientAdmin(SLModelView): class ClientAdmin(SLModelView):

View File

@ -1,10 +1,11 @@
import re2 as re
from typing import Optional from typing import Optional
import re2 as re
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.exc import IntegrityError, DataError
from app.config import BOUNCE_PREFIX_FOR_REPLY_PHASE from app.config import BOUNCE_PREFIX_FOR_REPLY_PHASE
from app.db import Session
from app.email_utils import ( from app.email_utils import (
get_email_domain_part, get_email_domain_part,
send_cannot_create_directory_alias, send_cannot_create_directory_alias,
@ -14,7 +15,6 @@ from app.email_utils import (
get_email_local_part, get_email_local_part,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -97,14 +97,14 @@ def try_auto_create_directory(address: str) -> Optional[Alias]:
mailbox_id=mailboxes[0].id, mailbox_id=mailboxes[0].id,
note=f"Created by directory {directory.name}", note=f"Created by directory {directory.name}",
) )
db.session.flush() Session.flush()
for i in range(1, len(mailboxes)): for i in range(1, len(mailboxes)):
AliasMailbox.create( AliasMailbox.create(
alias_id=alias.id, alias_id=alias.id,
mailbox_id=mailboxes[i].id, mailbox_id=mailboxes[i].id,
) )
db.session.commit() Session.commit()
return alias return alias
except AliasInTrashError: except AliasInTrashError:
LOG.w( LOG.w(
@ -116,7 +116,7 @@ def try_auto_create_directory(address: str) -> Optional[Alias]:
return None return None
except IntegrityError: except IntegrityError:
LOG.w("Alias %s already exists", address) LOG.w("Alias %s already exists", address)
db.session.rollback() Session.rollback()
alias = Alias.get_by(email=address) alias = Alias.get_by(email=address)
return alias return alias
@ -173,13 +173,13 @@ def try_auto_create_via_domain(address: str) -> Optional[Alias]:
mailbox_id=mailboxes[0].id, mailbox_id=mailboxes[0].id,
note=alias_note, note=alias_note,
) )
db.session.flush() Session.flush()
for i in range(1, len(mailboxes)): for i in range(1, len(mailboxes)):
AliasMailbox.create( AliasMailbox.create(
alias_id=alias.id, alias_id=alias.id,
mailbox_id=mailboxes[i].id, mailbox_id=mailboxes[i].id,
) )
db.session.commit() Session.commit()
return alias return alias
except AliasInTrashError: except AliasInTrashError:
LOG.w( LOG.w(
@ -191,12 +191,12 @@ def try_auto_create_via_domain(address: str) -> Optional[Alias]:
return None return None
except IntegrityError: except IntegrityError:
LOG.w("Alias %s already exists", address) LOG.w("Alias %s already exists", address)
db.session.rollback() Session.rollback()
alias = Alias.get_by(email=address) alias = Alias.get_by(email=address)
return alias return alias
except DataError: except DataError:
LOG.w("Cannot create alias %s", address) LOG.w("Cannot create alias %s", address)
db.session.rollback() Session.rollback()
return None return None
@ -211,30 +211,30 @@ def delete_alias(alias: Alias, user: User):
email=alias.email, domain_id=alias.custom_domain_id email=alias.email, domain_id=alias.custom_domain_id
): ):
LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id) LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id)
db.session.add( Session.add(
DomainDeletedAlias( DomainDeletedAlias(
user_id=user.id, email=alias.email, domain_id=alias.custom_domain_id user_id=user.id, email=alias.email, domain_id=alias.custom_domain_id
) )
) )
db.session.commit() Session.commit()
else: else:
if not DeletedAlias.get_by(email=alias.email): if not DeletedAlias.get_by(email=alias.email):
LOG.d("add %s to global trash", alias) LOG.d("add %s to global trash", alias)
db.session.add(DeletedAlias(email=alias.email)) Session.add(DeletedAlias(email=alias.email))
db.session.commit() Session.commit()
Alias.query.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
db.session.commit() Session.commit()
def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]: def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]:
""" """
get list of aliases for a given mailbox get list of aliases for a given mailbox
""" """
ret = set(Alias.query.filter(Alias.mailbox_id == mailbox.id).all()) ret = set(Alias.filter(Alias.mailbox_id == mailbox.id).all())
for alias in ( for alias in (
db.session.query(Alias) Session.query(Alias)
.join(AliasMailbox, Alias.id == AliasMailbox.alias_id) .join(AliasMailbox, Alias.id == AliasMailbox.alias_id)
.filter(AliasMailbox.mailbox_id == mailbox.id) .filter(AliasMailbox.mailbox_id == mailbox.id)
): ):
@ -247,7 +247,7 @@ def nb_email_log_for_mailbox(mailbox: Mailbox):
aliases = aliases_for_mailbox(mailbox) aliases = aliases_for_mailbox(mailbox)
alias_ids = [alias.id for alias in aliases] alias_ids = [alias.id for alias in aliases]
return ( return (
db.session.query(EmailLog) Session.query(EmailLog)
.join(Contact, EmailLog.contact_id == Contact.id) .join(Contact, EmailLog.contact_id == Contact.id)
.filter(Contact.alias_id.in_(alias_ids)) .filter(Contact.alias_id.in_(alias_ids))
.count() .count()

View File

@ -4,7 +4,7 @@ import arrow
from flask import Blueprint, request, jsonify, g from flask import Blueprint, request, jsonify, g
from flask_login import current_user from flask_login import current_user
from app.extensions import db from app.db import Session
from app.models import ApiKey from app.models import ApiKey
api_bp = Blueprint(name="api", import_name=__name__, url_prefix="/api") api_bp = Blueprint(name="api", import_name=__name__, url_prefix="/api")
@ -26,7 +26,7 @@ def require_api_auth(f):
# Update api key stats # Update api key stats
api_key.last_used = arrow.now() api_key.last_used = arrow.now()
api_key.times += 1 api_key.times += 1
db.session.commit() Session.commit()
g.user = api_key.user g.user = api_key.user

View File

@ -6,7 +6,7 @@ from sqlalchemy import or_, func, case, and_
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.extensions import db from app.db import Session
from app.models import ( from app.models import (
Alias, Alias,
Contact, Contact,
@ -117,7 +117,7 @@ def serialize_contact(contact: Contact, existed=False) -> dict:
def get_alias_infos_with_pagination(user, page_id=0, query=None) -> [AliasInfo]: def get_alias_infos_with_pagination(user, page_id=0, query=None) -> [AliasInfo]:
ret = [] ret = []
q = ( q = (
db.session.query(Alias) Session.query(Alias)
.options(joinedload(Alias.mailbox)) .options(joinedload(Alias.mailbox))
.filter(Alias.user_id == user.id) .filter(Alias.user_id == user.id)
.order_by(Alias.created_at.desc()) .order_by(Alias.created_at.desc())
@ -221,7 +221,7 @@ def get_alias_infos_with_pagination_v3(
def get_alias_info(alias: Alias) -> AliasInfo: def get_alias_info(alias: Alias) -> AliasInfo:
q = ( q = (
db.session.query(Contact, EmailLog) Session.query(Contact, EmailLog)
.filter(Contact.alias_id == alias.id) .filter(Contact.alias_id == alias.id)
.filter(EmailLog.contact_id == Contact.id) .filter(EmailLog.contact_id == Contact.id)
) )
@ -251,7 +251,7 @@ def get_alias_info_v2(alias: Alias, mailbox=None) -> AliasInfo:
mailbox = alias.mailbox mailbox = alias.mailbox
q = ( q = (
db.session.query(Contact, EmailLog) Session.query(Contact, EmailLog)
.filter(Contact.alias_id == alias.id) .filter(Contact.alias_id == alias.id)
.filter(EmailLog.contact_id == Contact.id) .filter(EmailLog.contact_id == Contact.id)
) )
@ -297,7 +297,7 @@ def get_alias_info_v2(alias: Alias, mailbox=None) -> AliasInfo:
def get_alias_contacts(alias, page_id: int) -> [dict]: def get_alias_contacts(alias, page_id: int) -> [dict]:
q = ( q = (
Contact.query.filter_by(alias_id=alias.id) Contact.filter_by(alias_id=alias.id)
.order_by(Contact.id.desc()) .order_by(Contact.id.desc())
.limit(PAGE_LIMIT) .limit(PAGE_LIMIT)
.offset(page_id * PAGE_LIMIT) .offset(page_id * PAGE_LIMIT)
@ -332,7 +332,7 @@ def get_alias_info_v3(user: User, alias_id: int) -> AliasInfo:
def construct_alias_query(user: User): def construct_alias_query(user: User):
# subquery on alias annotated with nb_reply, nb_blocked, nb_forward, max_created_at, latest_email_log_created_at # subquery on alias annotated with nb_reply, nb_blocked, nb_forward, max_created_at, latest_email_log_created_at
alias_activity_subquery = ( alias_activity_subquery = (
db.session.query( Session.query(
Alias.id, Alias.id,
func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"), func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"),
func.sum( func.sum(
@ -364,7 +364,7 @@ def construct_alias_query(user: User):
) )
alias_contact_subquery = ( alias_contact_subquery = (
db.session.query(Alias.id, func.max(Contact.id).label("max_contact_id")) Session.query(Alias.id, func.max(Contact.id).label("max_contact_id"))
.join(Contact, Alias.id == Contact.alias_id, isouter=True) .join(Contact, Alias.id == Contact.alias_id, isouter=True)
.filter(Alias.user_id == user.id) .filter(Alias.user_id == user.id)
.group_by(Alias.id) .group_by(Alias.id)
@ -372,7 +372,7 @@ def construct_alias_query(user: User):
) )
return ( return (
db.session.query( Session.query(
Alias, Alias,
Contact, Contact,
EmailLog, EmailLog,

View File

@ -17,10 +17,10 @@ from app.api.serializer import (
get_alias_infos_with_pagination_v3, get_alias_infos_with_pagination_v3,
) )
from app.dashboard.views.alias_log import get_alias_log from app.dashboard.views.alias_log import get_alias_log
from app.db import Session
from app.email_utils import ( from app.email_utils import (
generate_reply_email, generate_reply_email,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, Mailbox, AliasMailbox from app.models import Alias, Contact, Mailbox, AliasMailbox
from app.utils import sanitize_email from app.utils import sanitize_email
@ -164,7 +164,7 @@ def toggle_alias(alias_id):
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias.enabled = not alias.enabled alias.enabled = not alias.enabled
db.session.commit() Session.commit()
return jsonify(enabled=alias.enabled), 200 return jsonify(enabled=alias.enabled), 200
@ -280,8 +280,8 @@ def update_alias(alias_id):
# <<< update alias mailboxes >>> # <<< update alias mailboxes >>>
# first remove all existing alias-mailboxes links # first remove all existing alias-mailboxes links
AliasMailbox.query.filter_by(alias_id=alias.id).delete() AliasMailbox.filter_by(alias_id=alias.id).delete()
db.session.flush() Session.flush()
# then add all new mailboxes # then add all new mailboxes
for i, mailbox in enumerate(mailboxes): for i, mailbox in enumerate(mailboxes):
@ -310,7 +310,7 @@ def update_alias(alias_id):
changed = True changed = True
if changed: if changed:
db.session.commit() Session.commit()
return jsonify(ok=True), 200 return jsonify(ok=True), 200
@ -422,7 +422,7 @@ def create_contact_route(alias_id):
) )
LOG.d("create reverse-alias for %s %s", contact_addr, alias) LOG.d("create reverse-alias for %s %s", contact_addr, alias)
db.session.commit() Session.commit()
return jsonify(**serialize_contact(contact)), 201 return jsonify(**serialize_contact(contact)), 201
@ -444,6 +444,6 @@ def delete_contact(contact_id):
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
Contact.delete(contact_id) Contact.delete(contact_id)
db.session.commit() Session.commit()
return jsonify(deleted=True), 200 return jsonify(deleted=True), 200

View File

@ -5,7 +5,7 @@ from app.api.base import api_bp, require_api_auth
from app.dashboard.views.custom_alias import ( from app.dashboard.views.custom_alias import (
get_available_suffixes, get_available_suffixes,
) )
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import AliasUsedOn, Alias, User from app.models import AliasUsedOn, Alias, User
from app.utils import convert_to_id from app.utils import convert_to_id
@ -43,7 +43,7 @@ def options_v4():
if hostname: if hostname:
# put the latest used alias first # put the latest used alias first
q = ( q = (
db.session.query(AliasUsedOn, Alias, User) Session.query(AliasUsedOn, Alias, User)
.filter( .filter(
AliasUsedOn.alias_id == Alias.id, AliasUsedOn.alias_id == Alias.id,
Alias.user_id == user.id, Alias.user_id == user.id,
@ -114,7 +114,7 @@ def options_v5():
if hostname: if hostname:
# put the latest used alias first # put the latest used alias first
q = ( q = (
db.session.query(AliasUsedOn, Alias, User) Session.query(AliasUsedOn, Alias, User)
.filter( .filter(
AliasUsedOn.alias_id == Alias.id, AliasUsedOn.alias_id == Alias.id,
Alias.user_id == user.id, Alias.user_id == user.id,

View File

@ -9,7 +9,7 @@ from requests import RequestException
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import PlanEnum, AppleSubscription from app.models import PlanEnum, AppleSubscription
@ -279,7 +279,7 @@ def apple_update_notification():
apple_sub.receipt_data = data["unified_receipt"]["latest_receipt"] apple_sub.receipt_data = data["unified_receipt"]["latest_receipt"]
apple_sub.expires_date = expires_date apple_sub.expires_date = expires_date
apple_sub.plan = plan apple_sub.plan = plan
db.session.commit() Session.commit()
return jsonify(ok=True), 200 return jsonify(ok=True), 200
else: else:
LOG.w( LOG.w(
@ -544,6 +544,6 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
plan=plan, plan=plan,
) )
db.session.commit() Session.commit()
return apple_sub return apple_sub

View File

@ -11,13 +11,14 @@ from app import email_utils
from app.api.base import api_bp from app.api.base import api_bp
from app.config import FLASK_SECRET, DISABLE_REGISTRATION from app.config import FLASK_SECRET, DISABLE_REGISTRATION
from app.dashboard.views.setting import send_reset_password_email from app.dashboard.views.setting import send_reset_password_email
from app.db import Session
from app.email_utils import ( from app.email_utils import (
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
personal_email_already_used, personal_email_already_used,
send_email, send_email,
render, render,
) )
from app.extensions import db, limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User, ApiKey, SocialAuth, AccountActivation from app.models import User, ApiKey, SocialAuth, AccountActivation
from app.utils import sanitize_email from app.utils import sanitize_email
@ -98,12 +99,12 @@ def auth_register():
LOG.d("create user %s", email) LOG.d("create user %s", email)
user = User.create(email=email, name="", password=password) user = User.create(email=email, name="", password=password)
db.session.flush() Session.flush()
# create activation code # create activation code
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
AccountActivation.create(user_id=user.id, code=code) AccountActivation.create(user_id=user.id, code=code)
db.session.commit() Session.commit()
send_email( send_email(
email, email,
@ -155,13 +156,13 @@ def auth_activate():
if account_activation.code != code: if account_activation.code != code:
# decrement nb tries # decrement nb tries
account_activation.tries -= 1 account_activation.tries -= 1
db.session.commit() Session.commit()
# Trigger rate limiter # Trigger rate limiter
g.deduct_limit = True g.deduct_limit = True
if account_activation.tries == 0: if account_activation.tries == 0:
AccountActivation.delete(account_activation.id) AccountActivation.delete(account_activation.id)
db.session.commit() Session.commit()
return jsonify(error="Too many wrong tries"), 410 return jsonify(error="Too many wrong tries"), 410
return jsonify(error="Wrong email or code"), 400 return jsonify(error="Wrong email or code"), 400
@ -169,7 +170,7 @@ def auth_activate():
LOG.d("activate user %s", user) LOG.d("activate user %s", user)
user.activated = True user.activated = True
AccountActivation.delete(account_activation.id) AccountActivation.delete(account_activation.id)
db.session.commit() Session.commit()
return jsonify(msg="Account is activated, user can login now"), 200 return jsonify(msg="Account is activated, user can login now"), 200
@ -198,12 +199,12 @@ def auth_reactivate():
account_activation = AccountActivation.get_by(user_id=user.id) account_activation = AccountActivation.get_by(user_id=user.id)
if account_activation: if account_activation:
AccountActivation.delete(account_activation.id) AccountActivation.delete(account_activation.id)
db.session.commit() Session.commit()
# create activation code # create activation code
code = "".join([str(random.randint(0, 9)) for _ in range(6)]) code = "".join([str(random.randint(0, 9)) for _ in range(6)])
AccountActivation.create(user_id=user.id, code=code) AccountActivation.create(user_id=user.id, code=code)
db.session.commit() Session.commit()
send_email( send_email(
email, email,
@ -255,12 +256,12 @@ def auth_facebook():
LOG.d("create facebook user with %s", user_info) LOG.d("create facebook user with %s", user_info)
user = User.create(email=email, name=user_info["name"], activated=True) user = User.create(email=email, name=user_info["name"], activated=True)
db.session.commit() Session.commit()
email_utils.send_welcome_email(user) email_utils.send_welcome_email(user)
if not SocialAuth.get_by(user_id=user.id, social="facebook"): if not SocialAuth.get_by(user_id=user.id, social="facebook"):
SocialAuth.create(user_id=user.id, social="facebook") SocialAuth.create(user_id=user.id, social="facebook")
db.session.commit() Session.commit()
return jsonify(**auth_payload(user, device)), 200 return jsonify(**auth_payload(user, device)), 200
@ -308,12 +309,12 @@ def auth_google():
LOG.d("create Google user with %s", user_info) LOG.d("create Google user with %s", user_info)
user = User.create(email=email, name="", activated=True) user = User.create(email=email, name="", activated=True)
db.session.commit() Session.commit()
email_utils.send_welcome_email(user) email_utils.send_welcome_email(user)
if not SocialAuth.get_by(user_id=user.id, social="google"): if not SocialAuth.get_by(user_id=user.id, social="google"):
SocialAuth.create(user_id=user.id, social="google") SocialAuth.create(user_id=user.id, social="google")
db.session.commit() Session.commit()
return jsonify(**auth_payload(user, device)), 200 return jsonify(**auth_payload(user, device)), 200
@ -331,7 +332,7 @@ def auth_payload(user, device) -> dict:
if not api_key: if not api_key:
LOG.d("create new api key for %s and %s", user, device) LOG.d("create new api key for %s and %s", user, device)
api_key = ApiKey.create(user.id, device) api_key = ApiKey.create(user.id, device)
db.session.commit() Session.commit()
ret["mfa_key"] = None ret["mfa_key"] = None
ret["api_key"] = api_key.code ret["api_key"] = api_key.code

View File

@ -5,7 +5,7 @@ from itsdangerous import Signer
from app.api.base import api_bp from app.api.base import api_bp
from app.config import FLASK_SECRET from app.config import FLASK_SECRET
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, ApiKey from app.models import User, ApiKey
@ -61,7 +61,7 @@ def auth_mfa():
if not api_key: if not api_key:
LOG.d("create new api key for %s and %s", user, device) LOG.d("create new api key for %s and %s", user, device)
api_key = ApiKey.create(user.id, device) api_key = ApiKey.create(user.id, device)
db.session.commit() Session.commit()
ret["api_key"] = api_key.code ret["api_key"] = api_key.code

View File

@ -2,7 +2,7 @@ from flask import g, request
from flask import jsonify from flask import jsonify
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.extensions import db from app.db import Session
from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox from app.models import CustomDomain, DomainDeletedAlias, Mailbox, DomainMailbox
@ -108,8 +108,8 @@ def update_custom_domain(custom_domain_id):
mailboxes.append(mailbox) mailboxes.append(mailbox)
# first remove all existing domain-mailboxes links # first remove all existing domain-mailboxes links
DomainMailbox.query.filter_by(domain_id=custom_domain.id).delete() DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
db.session.flush() Session.flush()
for mailbox in mailboxes: for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
@ -117,6 +117,6 @@ def update_custom_domain(custom_domain_id):
changed = True changed = True
if changed: if changed:
db.session.commit() Session.commit()
return jsonify(ok=True), 200 return jsonify(ok=True), 200

View File

@ -7,12 +7,12 @@ from flask import request
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.dashboard.views.mailbox import send_verification_email from app.dashboard.views.mailbox import send_verification_email
from app.dashboard.views.mailbox_detail import verify_mailbox_change from app.dashboard.views.mailbox_detail import verify_mailbox_change
from app.db import Session
from app.email_utils import ( from app.email_utils import (
mailbox_already_used, mailbox_already_used,
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
is_valid_email, is_valid_email,
) )
from app.extensions import db
from app.models import Mailbox from app.models import Mailbox
from app.utils import sanitize_email from app.utils import sanitize_email
@ -58,7 +58,7 @@ def create_mailbox():
) )
else: else:
new_mailbox = Mailbox.create(email=mailbox_email, user_id=user.id) new_mailbox = Mailbox.create(email=mailbox_email, user_id=user.id)
db.session.commit() Session.commit()
send_verification_email(user, new_mailbox) send_verification_email(user, new_mailbox)
@ -89,7 +89,7 @@ def delete_mailbox(mailbox_id):
return jsonify(error="You cannot delete the default mailbox"), 400 return jsonify(error="You cannot delete the default mailbox"), 400
Mailbox.delete(mailbox_id) Mailbox.delete(mailbox_id)
db.session.commit() Session.commit()
return jsonify(deleted=True), 200 return jsonify(deleted=True), 200
@ -158,7 +158,7 @@ def update_mailbox(mailbox_id):
changed = True changed = True
if changed: if changed:
db.session.commit() Session.commit()
return jsonify(updated=True), 200 return jsonify(updated=True), 200
@ -190,7 +190,7 @@ def get_mailboxes_v2():
user = g.user user = g.user
mailboxes = [] mailboxes = []
for mailbox in Mailbox.query.filter_by(user_id=user.id): for mailbox in Mailbox.filter_by(user_id=user.id):
mailboxes.append(mailbox) mailboxes.append(mailbox)
return ( return (

View File

@ -10,7 +10,8 @@ from app.api.serializer import (
) )
from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT
from app.dashboard.views.custom_alias import verify_prefix_suffix, signer from app.dashboard.views.custom_alias import verify_prefix_suffix, signer
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -108,11 +109,11 @@ def new_custom_alias_v2():
custom_domain_id=custom_domain_id, custom_domain_id=custom_domain_id,
) )
db.session.commit() Session.commit()
if hostname: if hostname:
AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id)
db.session.commit() Session.commit()
return ( return (
jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))), jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))),
@ -217,7 +218,7 @@ def new_custom_alias_v3():
mailbox_id=mailboxes[0].id, mailbox_id=mailboxes[0].id,
custom_domain_id=custom_domain_id, custom_domain_id=custom_domain_id,
) )
db.session.flush() Session.flush()
for i in range(1, len(mailboxes)): for i in range(1, len(mailboxes)):
AliasMailbox.create( AliasMailbox.create(
@ -225,11 +226,11 @@ def new_custom_alias_v3():
mailbox_id=mailboxes[i].id, mailbox_id=mailboxes[i].id,
) )
db.session.commit() Session.commit()
if hostname: if hostname:
AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id)
db.session.commit() Session.commit()
return ( return (
jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))), jsonify(alias=full_alias, **serialize_alias_info_v2(get_alias_info_v2(alias))),

View File

@ -7,7 +7,8 @@ from app.api.serializer import (
serialize_alias_info_v2, serialize_alias_info_v2,
) )
from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Alias, AliasUsedOn, AliasGeneratorEnum from app.models import Alias, AliasUsedOn, AliasGeneratorEnum
@ -51,12 +52,12 @@ def new_random_alias():
return jsonify(error=f"{mode} must be either word or uuid"), 400 return jsonify(error=f"{mode} must be either word or uuid"), 400
alias = Alias.create_new_random(user=user, scheme=scheme, note=note) alias = Alias.create_new_random(user=user, scheme=scheme, note=note)
db.session.commit() Session.commit()
hostname = request.args.get("hostname") hostname = request.args.get("hostname")
if hostname: if hostname:
AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id) AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id)
db.session.commit() Session.commit()
return ( return (
jsonify(alias=alias.email, **serialize_alias_info_v2(get_alias_info_v2(alias))), jsonify(alias=alias.email, **serialize_alias_info_v2(get_alias_info_v2(alias))),

View File

@ -4,7 +4,7 @@ from flask import request
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.extensions import db from app.db import Session
from app.models import Notification from app.models import Notification
@ -32,7 +32,7 @@ def get_notifications():
return jsonify(error="page must be provided in request query"), 400 return jsonify(error="page must be provided in request query"), 400
notifications = ( notifications = (
Notification.query.filter_by(user_id=user.id) Notification.filter_by(user_id=user.id)
.order_by(Notification.read, Notification.created_at.desc()) .order_by(Notification.read, Notification.created_at.desc())
.limit(PAGE_LIMIT + 1) # load a record more to know whether there's more .limit(PAGE_LIMIT + 1) # load a record more to know whether there's more
.offset(page * PAGE_LIMIT) .offset(page * PAGE_LIMIT)
@ -76,6 +76,6 @@ def mark_as_read(notification_id):
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
notification.read = True notification.read = True
db.session.commit() Session.commit()
return jsonify(done=True), 200 return jsonify(done=True), 200

View File

@ -2,7 +2,7 @@ import arrow
from flask import jsonify, g, request from flask import jsonify, g, request
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
User, User,
@ -93,7 +93,7 @@ def update_setting():
user.default_alias_custom_domain_id = custom_domain.id user.default_alias_custom_domain_id = custom_domain.id
user.default_alias_public_domain_id = None user.default_alias_public_domain_id = None
db.session.commit() Session.commit()
return jsonify(setting_to_dict(user)) return jsonify(setting_to_dict(user))

View File

@ -7,7 +7,7 @@ from flask_login import logout_user
from app import s3 from app import s3
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.extensions import db from app.db import Session
from app.models import ApiKey, File, User from app.models import ApiKey, File, User
from app.utils import random_string from app.utils import random_string
@ -56,24 +56,24 @@ def update_user_info():
if user.profile_picture_id: if user.profile_picture_id:
file = user.profile_picture file = user.profile_picture
user.profile_picture_id = None user.profile_picture_id = None
db.session.flush() Session.flush()
if file: if file:
File.delete(file.id) File.delete(file.id)
s3.delete(file.path) s3.delete(file.path)
db.session.flush() Session.flush()
else: else:
raw_data = base64.decodebytes(data["profile_picture"].encode()) raw_data = base64.decodebytes(data["profile_picture"].encode())
file_path = random_string(30) file_path = random_string(30)
file = File.create(user_id=user.id, path=file_path) file = File.create(user_id=user.id, path=file_path)
db.session.flush() Session.flush()
s3.upload_from_bytesio(file_path, BytesIO(raw_data)) s3.upload_from_bytesio(file_path, BytesIO(raw_data))
user.profile_picture_id = file.id user.profile_picture_id = file.id
db.session.flush() Session.flush()
if "name" in data: if "name" in data:
user.name = data["name"] user.name = data["name"]
db.session.commit() Session.commit()
return jsonify(user_to_dict(user)) return jsonify(user_to_dict(user))
@ -95,7 +95,7 @@ def create_api_key():
device = data.get("device") device = data.get("device")
api_key = ApiKey.create(user_id=g.user.id, name=device) api_key = ApiKey.create(user_id=g.user.id, name=device)
db.session.commit() Session.commit()
return jsonify(api_key=api_key.code), 201 return jsonify(api_key=api_key.code), 201

View File

@ -3,7 +3,8 @@ from flask_login import login_user, current_user
from app import email_utils from app import email_utils
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ActivationCode from app.models import ActivationCode
@ -50,7 +51,7 @@ def activate():
# activation code is to be used only once # activation code is to be used only once
ActivationCode.delete(activation_code.id) ActivationCode.delete(activation_code.id)
db.session.commit() Session.commit()
flash("Your account has been activated", "success") flash("Your account has been activated", "success")

View File

@ -2,7 +2,7 @@ from flask import request, flash, render_template, redirect, url_for
from flask_login import login_user from flask_login import login_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.extensions import db from app.db import Session
from app.models import EmailChange from app.models import EmailChange
@ -18,14 +18,14 @@ def change_email():
if email_change.is_expired(): if email_change.is_expired():
# delete the expired email # delete the expired email
EmailChange.delete(email_change.id) EmailChange.delete(email_change.id)
db.session.commit() Session.commit()
return render_template("auth/change_email.html") return render_template("auth/change_email.html")
user = email_change.user user = email_change.user
user.email = email_change.new_email user.email = email_change.new_email
EmailChange.delete(email_change.id) EmailChange.delete(email_change.id)
db.session.commit() Session.commit()
flash("Your new email has been updated", "success") flash("Your new email has been updated", "success")

View File

@ -9,7 +9,7 @@ from app.config import (
FACEBOOK_CLIENT_ID, FACEBOOK_CLIENT_ID,
FACEBOOK_CLIENT_SECRET, FACEBOOK_CLIENT_SECRET,
) )
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, SocialAuth from app.models import User, SocialAuth
from .login_utils import after_login from .login_utils import after_login
@ -102,7 +102,7 @@ def facebook_callback():
LOG.d("set user profile picture to %s", picture_url) LOG.d("set user profile picture to %s", picture_url)
file = create_file_from_url(user, picture_url) file = create_file_from_url(user, picture_url)
user.profile_picture_id = file.id user.profile_picture_id = file.id
db.session.commit() Session.commit()
else: else:
flash( flash(
@ -122,6 +122,6 @@ def facebook_callback():
if not SocialAuth.get_by(user_id=user.id, social="facebook"): if not SocialAuth.get_by(user_id=user.id, social="facebook"):
SocialAuth.create(user_id=user.id, social="facebook") SocialAuth.create(user_id=user.id, social="facebook")
db.session.commit() Session.commit()
return after_login(user, next_url) return after_login(user, next_url)

View File

@ -19,7 +19,8 @@ from wtforms import HiddenField, validators, BooleanField
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import MFA_USER_ID from app.config import MFA_USER_ID
from app.config import RP_ID, URL from app.config import RP_ID, URL
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User, Fido, MfaBrowser from app.models import User, Fido, MfaBrowser
@ -102,7 +103,7 @@ def fido():
auto_activate = False auto_activate = False
else: else:
user.fido_sign_count = new_sign_count user.fido_sign_count = new_sign_count
db.session.commit() Session.commit()
del session[MFA_USER_ID] del session[MFA_USER_ID]
login_user(user) login_user(user)
@ -113,7 +114,7 @@ def fido():
if fido_token_form.remember.data: if fido_token_form.remember.data:
browser = MfaBrowser.create_new(user=user) browser = MfaBrowser.create_new(user=user)
db.session.commit() Session.commit()
response.set_cookie( response.set_cookie(
"mfa", "mfa",
value=browser.token, value=browser.token,

View File

@ -4,7 +4,7 @@ from requests_oauthlib import OAuth2Session
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
from app.config import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET, URL from app.config import GITHUB_CLIENT_ID, GITHUB_CLIENT_SECRET, URL
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, SocialAuth from app.models import User, SocialAuth
from app.utils import encode_url, sanitize_email from app.utils import encode_url, sanitize_email
@ -94,7 +94,7 @@ def github_callback():
if not SocialAuth.get_by(user_id=user.id, social="github"): if not SocialAuth.get_by(user_id=user.id, social="github"):
SocialAuth.create(user_id=user.id, social="github") SocialAuth.create(user_id=user.id, social="github")
db.session.commit() Session.commit()
# The activation link contains the original page, for ex authorize page # The activation link contains the original page, for ex authorize page
next_url = request.args.get("next") if request.args else None next_url = request.args.get("next") if request.args else None

View File

@ -4,7 +4,7 @@ from requests_oauthlib import OAuth2Session
from app import s3 from app import s3
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import URL, GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET from app.config import URL, GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, File, SocialAuth from app.models import User, File, SocialAuth
from app.utils import random_string, sanitize_email from app.utils import random_string, sanitize_email
@ -89,7 +89,7 @@ def google_callback():
LOG.d("set user profile picture to %s", picture_url) LOG.d("set user profile picture to %s", picture_url)
file = create_file_from_url(user, picture_url) file = create_file_from_url(user, picture_url)
user.profile_picture_id = file.id user.profile_picture_id = file.id
db.session.commit() Session.commit()
else: else:
flash( flash(
"Sorry you cannot sign up via Google, please use email/password sign-up instead", "Sorry you cannot sign up via Google, please use email/password sign-up instead",
@ -108,7 +108,7 @@ def google_callback():
if not SocialAuth.get_by(user_id=user.id, social="google"): if not SocialAuth.get_by(user_id=user.id, social="google"):
SocialAuth.create(user_id=user.id, social="google") SocialAuth.create(user_id=user.id, social="google")
db.session.commit() Session.commit()
return after_login(user, next_url) return after_login(user, next_url)
@ -119,7 +119,7 @@ def create_file_from_url(user, url) -> File:
s3.upload_from_url(url, file_path) s3.upload_from_url(url, file_path)
db.session.flush() Session.flush()
LOG.d("upload file %s to s3", file) LOG.d("upload file %s to s3", file)
return file return file

View File

@ -15,7 +15,8 @@ from wtforms import BooleanField, StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import MFA_USER_ID, URL from app.config import MFA_USER_ID, URL
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.models import User, MfaBrowser from app.models import User, MfaBrowser
@ -67,7 +68,7 @@ def mfa():
if totp.verify(token) and user.last_otp != token: if totp.verify(token) and user.last_otp != token:
del session[MFA_USER_ID] del session[MFA_USER_ID]
user.last_otp = token user.last_otp = token
db.session.commit() Session.commit()
login_user(user) login_user(user)
flash(f"Welcome back!", "success") flash(f"Welcome back!", "success")
@ -77,7 +78,7 @@ def mfa():
if otp_token_form.remember.data: if otp_token_form.remember.data:
browser = MfaBrowser.create_new(user=user) browser = MfaBrowser.create_new(user=user)
db.session.commit() Session.commit()
response.set_cookie( response.set_cookie(
"mfa", "mfa",
value=browser.token, value=browser.token,

View File

@ -6,7 +6,8 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import MFA_USER_ID from app.config import MFA_USER_ID
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User, RecoveryCode from app.models import User, RecoveryCode
@ -54,7 +55,7 @@ def recovery_route():
recovery_code.used = True recovery_code.used = True
recovery_code.used_at = arrow.now() recovery_code.used_at = arrow.now()
db.session.commit() Session.commit()
# User comes to login page from another page # User comes to login page from another page
if next_url: if next_url:

View File

@ -8,11 +8,11 @@ from app import email_utils, config
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import get_referral from app.auth.views.login_utils import get_referral
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
from app.db import Session
from app.email_utils import ( from app.email_utils import (
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
personal_email_already_used, personal_email_already_used,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import User, ActivationCode from app.models import User, ActivationCode
from app.utils import random_string, encode_url, sanitize_email from app.utils import random_string, encode_url, sanitize_email
@ -81,7 +81,7 @@ def register():
password=form.password.data, password=form.password.data,
referral=get_referral(), referral=get_referral(),
) )
db.session.commit() Session.commit()
try: try:
send_activation_email(user, next_url) send_activation_email(user, next_url)
@ -102,7 +102,7 @@ def register():
def send_activation_email(user, next_url): def send_activation_email(user, next_url):
# the activation code is valid for 1h # the activation code is valid for 1h
activation = ActivationCode.create(user_id=user.id, code=random_string(30)) activation = ActivationCode.create(user_id=user.id, code=random_string(30))
db.session.commit() Session.commit()
# Send user activation email # Send user activation email
activation_link = f"{URL}/auth/activate?code={activation.code}" activation_link = f"{URL}/auth/activate?code={activation.code}"

View File

@ -6,7 +6,8 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.models import ResetPasswordCode from app.models import ResetPasswordCode
@ -64,7 +65,7 @@ def reset_password():
# change the alternative_id to log user out on other browsers # change the alternative_id to log user out on other browsers
user.alternative_id = str(uuid.uuid4()) user.alternative_id = str(uuid.uuid4())
db.session.commit() Session.commit()
# do not use login_user(user) here # do not use login_user(user) here
# to make sure user needs to go through MFA if enabled # to make sure user needs to go through MFA if enabled

View File

@ -10,12 +10,12 @@ from wtforms import StringField, validators, ValidationError
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import ( from app.email_utils import (
is_valid_email, is_valid_email,
generate_reply_email, generate_reply_email,
parse_full_address, parse_full_address,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog from app.models import Alias, Contact, EmailLog
@ -64,7 +64,7 @@ def get_contact_infos(
) -> [ContactInfo]: ) -> [ContactInfo]:
"""if contact_id is set, only return the contact info for this contact""" """if contact_id is set, only return the contact info for this contact"""
sub = ( sub = (
db.session.query( Session.query(
Contact.id, Contact.id,
func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"), func.sum(case([(EmailLog.is_reply, 1)], else_=0)).label("nb_reply"),
func.sum( func.sum(
@ -94,7 +94,7 @@ def get_contact_infos(
) )
q = ( q = (
db.session.query( Session.query(
Contact, Contact,
EmailLog, EmailLog,
sub.c.nb_reply, sub.c.nb_reply,
@ -221,7 +221,7 @@ def alias_contact_manager(alias_id):
) )
LOG.d("create reverse-alias for %s", contact_addr) LOG.d("create reverse-alias for %s", contact_addr)
db.session.commit() Session.commit()
flash(f"Reverse alias for {contact_addr} is created", "success") flash(f"Reverse alias for {contact_addr} is created", "success")
return redirect( return redirect(
@ -248,7 +248,7 @@ def alias_contact_manager(alias_id):
delete_contact_email = contact.website_email delete_contact_email = contact.website_email
Contact.delete(contact_id) Contact.delete(contact_id)
db.session.commit() Session.commit()
flash( flash(
f"Reverse-alias for {delete_contact_email} has been deleted", "success" f"Reverse-alias for {delete_contact_email} has been deleted", "success"

View File

@ -4,7 +4,7 @@ from flask_login import login_required, current_user
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.models import Alias, EmailLog, Contact from app.models import Alias, EmailLog, Contact
@ -43,7 +43,7 @@ def alias_log(alias_id, page_id):
logs = get_alias_log(alias, page_id) logs = get_alias_log(alias, page_id)
base = ( base = (
db.session.query(Contact, EmailLog) Session.query(Contact, EmailLog)
.filter(Contact.id == EmailLog.contact_id) .filter(Contact.id == EmailLog.contact_id)
.filter(Contact.alias_id == alias.id) .filter(Contact.alias_id == alias.id)
) )
@ -66,7 +66,7 @@ def get_alias_log(alias: Alias, page_id=0) -> [AliasLog]:
logs: [AliasLog] = [] logs: [AliasLog] = []
q = ( q = (
db.session.query(Contact, EmailLog) Session.query(Contact, EmailLog)
.filter(Contact.id == EmailLog.contact_id) .filter(Contact.id == EmailLog.contact_id)
.filter(Contact.alias_id == alias.id) .filter(Contact.alias_id == alias.id)
.order_by(EmailLog.id.desc()) .order_by(EmailLog.id.desc())

View File

@ -5,8 +5,9 @@ from flask_login import login_required, current_user
from app.config import URL from app.config import URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import send_email, render from app.email_utils import send_email, render
from app.extensions import db, limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -25,20 +26,20 @@ def transfer(alias, new_user, new_mailboxes: [Mailbox]):
raise Exception("Cannot transfer alias that's used to receive newsletter") raise Exception("Cannot transfer alias that's used to receive newsletter")
# update user_id # update user_id
db.session.query(Contact).filter(Contact.alias_id == alias.id).update( Session.query(Contact).filter(Contact.alias_id == alias.id).update(
{"user_id": new_user.id} {"user_id": new_user.id}
) )
db.session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update( Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update(
{"user_id": new_user.id} {"user_id": new_user.id}
) )
db.session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update( Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update(
{"user_id": new_user.id} {"user_id": new_user.id}
) )
# remove existing mailboxes from the alias # remove existing mailboxes from the alias
db.session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete() Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete()
# set mailboxes # set mailboxes
alias.mailbox_id = new_mailboxes.pop().id alias.mailbox_id = new_mailboxes.pop().id
@ -71,7 +72,7 @@ def transfer(alias, new_user, new_mailboxes: [Mailbox]):
alias.disable_pgp = False alias.disable_pgp = False
alias.pinned = False alias.pinned = False
db.session.commit() Session.commit()
@dashboard_bp.route("/alias_transfer/send/<int:alias_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/alias_transfer/send/<int:alias_id>/", methods=["GET", "POST"])
@ -100,7 +101,7 @@ def alias_transfer_send_route(alias_id):
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
alias.transfer_token = str(uuid4()) alias.transfer_token = str(uuid4())
db.session.commit() Session.commit()
alias_transfer_url = ( alias_transfer_url = (
URL URL
+ "/dashboard/alias_transfer/receive" + "/dashboard/alias_transfer/receive"
@ -111,7 +112,7 @@ def alias_transfer_send_route(alias_id):
# request.form.get("form-name") == "remove" # request.form.get("form-name") == "remove"
else: else:
alias.transfer_token = None alias.transfer_token = None
db.session.commit() Session.commit()
alias_transfer_url = None alias_transfer_url = None
flash("Share URL deleted", "success") flash("Share URL deleted", "success")
return redirect(request.url) return redirect(request.url)

View File

@ -4,7 +4,7 @@ from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.models import ApiKey from app.models import ApiKey
@ -16,7 +16,7 @@ class NewApiKeyForm(FlaskForm):
@login_required @login_required
def api_key(): def api_key():
api_keys = ( api_keys = (
ApiKey.query.filter(ApiKey.user_id == current_user.id) ApiKey.filter(ApiKey.user_id == current_user.id)
.order_by(ApiKey.created_at.desc()) .order_by(ApiKey.created_at.desc())
.all() .all()
) )
@ -38,7 +38,7 @@ def api_key():
name = api_key.name name = api_key.name
ApiKey.delete(api_key_id) ApiKey.delete(api_key_id)
db.session.commit() Session.commit()
flash(f"API Key {name} has been deleted", "success") flash(f"API Key {name} has been deleted", "success")
return redirect(url_for("dashboard.api_key")) return redirect(url_for("dashboard.api_key"))
@ -48,7 +48,7 @@ def api_key():
new_api_key = ApiKey.create( new_api_key = ApiKey.create(
name=new_api_key_form.name.data, user_id=current_user.id name=new_api_key_form.name.data, user_id=current_user.id
) )
db.session.commit() Session.commit()
flash(f"New API Key {new_api_key.name} has been created", "success") flash(f"New API Key {new_api_key.name} has been created", "success")
return redirect(url_for("dashboard.api_key")) return redirect(url_for("dashboard.api_key"))

View File

@ -1,3 +1,5 @@
from app.db import Session
""" """
List of apps that user has used via the "Sign in with SimpleLogin" List of apps that user has used via the "Sign in with SimpleLogin"
""" """
@ -7,7 +9,6 @@ from flask_login import login_required, current_user
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db
from app.models import ( from app.models import (
ClientUser, ClientUser,
) )
@ -36,7 +37,7 @@ def app_route():
client = client_user.client client = client_user.client
ClientUser.delete(client_user_id) ClientUser.delete(client_user_id)
db.session.commit() Session.commit()
flash(f"Link with {client.name} has been removed", "success") flash(f"Link with {client.name} has been removed", "success")
return redirect(request.url) return redirect(request.url)

View File

@ -5,7 +5,7 @@ from flask_login import login_required, current_user
from app import s3 from app import s3
from app.config import JOB_BATCH_IMPORT from app.config import JOB_BATCH_IMPORT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import File, BatchImport, Job from app.models import File, BatchImport, Job
from app.utils import random_string from app.utils import random_string
@ -18,7 +18,7 @@ def batch_import_route():
if not current_user.verified_custom_domains(): if not current_user.verified_custom_domains():
flash("Alias batch import is only available for custom domains", "warning") flash("Alias batch import is only available for custom domains", "warning")
batch_imports = BatchImport.query.filter_by(user_id=current_user.id).all() batch_imports = BatchImport.filter_by(user_id=current_user.id).all()
if request.method == "POST": if request.method == "POST":
alias_file = request.files["alias-file"] alias_file = request.files["alias-file"]
@ -26,11 +26,11 @@ def batch_import_route():
file_path = random_string(20) + ".csv" file_path = random_string(20) + ".csv"
file = File.create(user_id=current_user.id, path=file_path) file = File.create(user_id=current_user.id, path=file_path)
s3.upload_from_bytesio(file_path, alias_file) s3.upload_from_bytesio(file_path, alias_file)
db.session.flush() Session.flush()
LOG.d("upload file %s to s3 at %s", file, file_path) LOG.d("upload file %s to s3 at %s", file, file_path)
bi = BatchImport.create(user_id=current_user.id, file_id=file.id) bi = BatchImport.create(user_id=current_user.id, file_id=file.id)
db.session.flush() Session.flush()
LOG.d("Add a batch import job %s for %s", bi, current_user) LOG.d("Add a batch import job %s for %s", bi, current_user)
# Schedule batch import job # Schedule batch import job
@ -39,7 +39,7 @@ def batch_import_route():
payload={"batch_import_id": bi.id}, payload={"batch_import_id": bi.id},
run_at=arrow.now(), run_at=arrow.now(),
) )
db.session.commit() Session.commit()
flash( flash(
"The file has been uploaded successfully and the import will start shortly", "The file has been uploaded successfully and the import will start shortly",

View File

@ -3,7 +3,7 @@ from flask_login import login_required, current_user
from app.config import PADDLE_MONTHLY_PRODUCT_ID, PADDLE_YEARLY_PRODUCT_ID from app.config import PADDLE_MONTHLY_PRODUCT_ID, PADDLE_YEARLY_PRODUCT_ID
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Subscription, PlanEnum from app.models import Subscription, PlanEnum
from app.paddle_utils import cancel_subscription, change_plan from app.paddle_utils import cancel_subscription, change_plan
@ -26,7 +26,7 @@ def billing():
if success: if success:
sub.cancelled = True sub.cancelled = True
db.session.commit() Session.commit()
flash("Your subscription has been canceled successfully", "success") flash("Your subscription has been canceled successfully", "success")
else: else:
flash( flash(
@ -44,7 +44,7 @@ def billing():
if success: if success:
sub.plan = PlanEnum.monthly sub.plan = PlanEnum.monthly
db.session.commit() Session.commit()
flash("Your subscription has been updated", "success") flash("Your subscription has been updated", "success")
else: else:
if msg: if msg:
@ -65,7 +65,7 @@ def billing():
if success: if success:
sub.plan = PlanEnum.yearly sub.plan = PlanEnum.yearly
db.session.commit() Session.commit()
flash("Your subscription has been updated", "success") flash("Your subscription has been updated", "success")
else: else:
if msg: if msg:

View File

@ -2,7 +2,7 @@ from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.models import Contact from app.models import Contact
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
@ -34,7 +34,7 @@ def contact_detail_route(contact_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
db.session.commit() Session.commit()
flash( flash(
f"PGP public key for {contact.email} is saved successfully", f"PGP public key for {contact.email} is saved successfully",
"success", "success",
@ -46,7 +46,7 @@ def contact_detail_route(contact_id):
# Free user can decide to remove contact PGP key # Free user can decide to remove contact PGP key
contact.pgp_public_key = None contact.pgp_public_key = None
contact.pgp_finger_print = None contact.pgp_finger_print = None
db.session.commit() Session.commit()
flash(f"PGP public key for {contact.email} is removed", "success") flash(f"PGP public key for {contact.email} is removed", "success")
return redirect( return redirect(
url_for("dashboard.contact_detail_route", contact_id=contact_id) url_for("dashboard.contact_detail_route", contact_id=contact_id)

View File

@ -6,8 +6,8 @@ from wtforms import StringField, validators
from app.config import ADMIN_EMAIL from app.config import ADMIN_EMAIL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import send_email from app.email_utils import send_email
from app.extensions import db
from app.models import ( from app.models import (
ManualSubscription, ManualSubscription,
Coupon, Coupon,
@ -57,7 +57,7 @@ def coupon_route():
if coupon and not coupon.used: if coupon and not coupon.used:
coupon.used_by_user_id = current_user.id coupon.used_by_user_id = current_user.id
coupon.used = True coupon.used = True
db.session.commit() Session.commit()
manual_sub: ManualSubscription = ManualSubscription.get_by( manual_sub: ManualSubscription = ManualSubscription.get_by(
user_id=current_user.id user_id=current_user.id
@ -68,7 +68,7 @@ def coupon_route():
manual_sub.end_at = manual_sub.end_at.shift(years=coupon.nb_year) manual_sub.end_at = manual_sub.end_at.shift(years=coupon.nb_year)
else: else:
manual_sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1) manual_sub.end_at = arrow.now().shift(years=coupon.nb_year, days=1)
db.session.commit() Session.commit()
flash( flash(
f"Your current subscription is extended to {manual_sub.end_at.humanize()}", f"Your current subscription is extended to {manual_sub.end_at.humanize()}",
"success", "success",

View File

@ -13,7 +13,8 @@ from app.config import (
ALIAS_LIMIT, ALIAS_LIMIT,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -307,10 +308,10 @@ def custom_alias():
mailbox_id=mailboxes[0].id, mailbox_id=mailboxes[0].id,
custom_domain_id=custom_domain_id, custom_domain_id=custom_domain_id,
) )
db.session.flush() Session.flush()
except IntegrityError: except IntegrityError:
LOG.w("Alias %s already exists", full_alias) LOG.w("Alias %s already exists", full_alias)
db.session.rollback() Session.rollback()
flash("Unknown error, please retry", "error") flash("Unknown error, please retry", "error")
return redirect(url_for("dashboard.custom_alias")) return redirect(url_for("dashboard.custom_alias"))
@ -320,7 +321,7 @@ def custom_alias():
mailbox_id=mailboxes[i].id, mailbox_id=mailboxes[i].id,
) )
db.session.commit() Session.commit()
flash(f"Alias {full_alias} has been created", "success") flash(f"Alias {full_alias} has been created", "success")
return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) return redirect(url_for("dashboard.index", highlight_alias_id=alias.id))

View File

@ -5,8 +5,8 @@ from wtforms import StringField, validators
from app.config import EMAIL_SERVERS_WITH_PRIORITY from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import get_email_domain_part from app.email_utils import get_email_domain_part
from app.extensions import db
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
@ -19,7 +19,7 @@ class NewCustomDomainForm(FlaskForm):
@dashboard_bp.route("/custom_domain", methods=["GET", "POST"]) @dashboard_bp.route("/custom_domain", methods=["GET", "POST"])
@login_required @login_required
def custom_domain(): def custom_domain():
custom_domains = CustomDomain.query.filter_by(user_id=current_user.id).all() custom_domains = CustomDomain.filter_by(user_id=current_user.id).all()
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
new_custom_domain_form = NewCustomDomainForm() new_custom_domain_form = NewCustomDomainForm()
@ -54,7 +54,7 @@ def custom_domain():
new_custom_domain = CustomDomain.create( new_custom_domain = CustomDomain.create(
domain=new_domain, user_id=current_user.id domain=new_domain, user_id=current_user.id
) )
db.session.commit() Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids") mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids: if mailbox_ids:
@ -76,7 +76,7 @@ def custom_domain():
domain_id=new_custom_domain.id, mailbox_id=mailbox.id domain_id=new_custom_domain.id, mailbox_id=mailbox.id
) )
db.session.commit() Session.commit()
flash( flash(
f"New domain {new_custom_domain.domain} is created", "success" f"New domain {new_custom_domain.domain} is created", "success"

View File

@ -10,7 +10,7 @@ from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE, BOUNCE_PREFIX_FOR_REPLY_PHASE,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.models import Directory, Mailbox, DirectoryMailbox from app.models import Directory, Mailbox, DirectoryMailbox
@ -24,7 +24,7 @@ class NewDirForm(FlaskForm):
@login_required @login_required
def directory(): def directory():
dirs = ( dirs = (
Directory.query.filter_by(user_id=current_user.id) Directory.filter_by(user_id=current_user.id)
.order_by(Directory.created_at.desc()) .order_by(Directory.created_at.desc())
.all() .all()
) )
@ -47,7 +47,7 @@ def directory():
name = dir.name name = dir.name
Directory.delete(dir_id) Directory.delete(dir_id)
db.session.commit() Session.commit()
flash(f"Directory {name} has been deleted", "success") flash(f"Directory {name} has been deleted", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
@ -67,7 +67,7 @@ def directory():
dir.disabled = True dir.disabled = True
flash(f"On-the-fly is disabled for {dir.name}", "warning") flash(f"On-the-fly is disabled for {dir.name}", "warning")
db.session.commit() Session.commit()
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
@ -98,13 +98,13 @@ def directory():
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
# first remove all existing directory-mailboxes links # first remove all existing directory-mailboxes links
DirectoryMailbox.query.filter_by(directory_id=dir.id).delete() DirectoryMailbox.filter_by(directory_id=dir.id).delete()
db.session.flush() Session.flush()
for mailbox in mailboxes: for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir.id, mailbox_id=mailbox.id) DirectoryMailbox.create(directory_id=dir.id, mailbox_id=mailbox.id)
db.session.commit() Session.commit()
flash(f"Directory {dir.name} has been updated", "success") flash(f"Directory {dir.name} has been updated", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
@ -141,7 +141,7 @@ def directory():
new_dir = Directory.create( new_dir = Directory.create(
name=new_dir_name, user_id=current_user.id name=new_dir_name, user_id=current_user.id
) )
db.session.commit() Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids") mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids: if mailbox_ids:
# check if mailbox is not tempered with # check if mailbox is not tempered with
@ -162,7 +162,7 @@ def directory():
directory_id=new_dir.id, mailbox_id=mailbox.id directory_id=new_dir.id, mailbox_id=mailbox.id
) )
db.session.commit() Session.commit()
flash(f"Directory {new_dir.name} is created", "success") flash(f"Directory {new_dir.name} is created", "success")

View File

@ -1,6 +1,6 @@
import re2 as re
from threading import Thread from threading import Thread
import re2 as re
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
@ -8,6 +8,7 @@ from wtforms import StringField, validators, IntegerField
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.dns_utils import ( from app.dns_utils import (
get_mx_domains, get_mx_domains,
get_spf_domain, get_spf_domain,
@ -15,7 +16,6 @@ from app.dns_utils import (
get_cname_record, get_cname_record,
) )
from app.email_utils import send_email from app.email_utils import send_email
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
CustomDomain, CustomDomain,
@ -40,7 +40,7 @@ def domain_detail_dns(custom_domain_id):
# generate a domain ownership txt token if needed # generate a domain ownership txt token if needed
if not custom_domain.ownership_verified and not custom_domain.ownership_txt_token: if not custom_domain.ownership_verified and not custom_domain.ownership_txt_token:
custom_domain.ownership_txt_token = random_string(30) custom_domain.ownership_txt_token = random_string(30)
db.session.commit() Session.commit()
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
@ -62,7 +62,7 @@ def domain_detail_dns(custom_domain_id):
"success", "success",
) )
custom_domain.ownership_verified = True custom_domain.ownership_verified = True
db.session.commit() Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
@ -92,7 +92,7 @@ def domain_detail_dns(custom_domain_id):
"success", "success",
) )
custom_domain.verified = True custom_domain.verified = True
db.session.commit() Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
@ -102,7 +102,7 @@ def domain_detail_dns(custom_domain_id):
spf_domains = get_spf_domain(custom_domain.domain) spf_domains = get_spf_domain(custom_domain.domain)
if EMAIL_DOMAIN in spf_domains: if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True custom_domain.spf_verified = True
db.session.commit() Session.commit()
flash("SPF is setup correctly", "success") flash("SPF is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -111,7 +111,7 @@ def domain_detail_dns(custom_domain_id):
) )
else: else:
custom_domain.spf_verified = False custom_domain.spf_verified = False
db.session.commit() Session.commit()
flash( flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning", "warning",
@ -124,7 +124,7 @@ def domain_detail_dns(custom_domain_id):
if dkim_record == dkim_cname: if dkim_record == dkim_cname:
flash("DKIM is setup correctly.", "success") flash("DKIM is setup correctly.", "success")
custom_domain.dkim_verified = True custom_domain.dkim_verified = True
db.session.commit() Session.commit()
return redirect( return redirect(
url_for( url_for(
@ -133,7 +133,7 @@ def domain_detail_dns(custom_domain_id):
) )
else: else:
custom_domain.dkim_verified = False custom_domain.dkim_verified = False
db.session.commit() Session.commit()
flash("DKIM: the CNAME record is not correctly set", "warning") flash("DKIM: the CNAME record is not correctly set", "warning")
dkim_ok = False dkim_ok = False
dkim_errors = [dkim_record or "[Empty]"] dkim_errors = [dkim_record or "[Empty]"]
@ -142,7 +142,7 @@ def domain_detail_dns(custom_domain_id):
txt_records = get_txt_record("_dmarc." + custom_domain.domain) txt_records = get_txt_record("_dmarc." + custom_domain.domain)
if dmarc_record in txt_records: if dmarc_record in txt_records:
custom_domain.dmarc_verified = True custom_domain.dmarc_verified = True
db.session.commit() Session.commit()
flash("DMARC is setup correctly", "success") flash("DMARC is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -151,7 +151,7 @@ def domain_detail_dns(custom_domain_id):
) )
else: else:
custom_domain.dmarc_verified = False custom_domain.dmarc_verified = False
db.session.commit() Session.commit()
flash( flash(
"DMARC: The TXT record is not correctly set", "DMARC: The TXT record is not correctly set",
"warning", "warning",
@ -179,7 +179,7 @@ def domain_detail(custom_domain_id):
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "switch-catch-all": if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all custom_domain.catch_all = not custom_domain.catch_all
db.session.commit() Session.commit()
if custom_domain.catch_all: if custom_domain.catch_all:
flash( flash(
@ -197,14 +197,14 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "set-name": elif request.form.get("form-name") == "set-name":
if request.form.get("action") == "save": if request.form.get("action") == "save":
custom_domain.name = request.form.get("alias-name").replace("\n", "") custom_domain.name = request.form.get("alias-name").replace("\n", "")
db.session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been set", f"Default alias name for Domain {custom_domain.domain} has been set",
"success", "success",
) )
else: else:
custom_domain.name = None custom_domain.name = None
db.session.commit() Session.commit()
flash( flash(
f"Default alias name for Domain {custom_domain.domain} has been removed", f"Default alias name for Domain {custom_domain.domain} has been removed",
"info", "info",
@ -217,7 +217,7 @@ def domain_detail(custom_domain_id):
custom_domain.random_prefix_generation = ( custom_domain.random_prefix_generation = (
not custom_domain.random_prefix_generation not custom_domain.random_prefix_generation
) )
db.session.commit() Session.commit()
if custom_domain.random_prefix_generation: if custom_domain.random_prefix_generation:
flash( flash(
@ -260,13 +260,13 @@ def domain_detail(custom_domain_id):
) )
# first remove all existing domain-mailboxes links # first remove all existing domain-mailboxes links
DomainMailbox.query.filter_by(domain_id=custom_domain.id).delete() DomainMailbox.filter_by(domain_id=custom_domain.id).delete()
db.session.flush() Session.flush()
for mailbox in mailboxes: for mailbox in mailboxes:
DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id) DomainMailbox.create(domain_id=custom_domain.id, mailbox_id=mailbox.id)
db.session.commit() Session.commit()
flash(f"{custom_domain.domain} mailboxes has been updated", "success") flash(f"{custom_domain.domain} mailboxes has been updated", "success")
return redirect( return redirect(
@ -302,7 +302,7 @@ def delete_domain(custom_domain_id: int):
user = custom_domain.user user = custom_domain.user
CustomDomain.delete(custom_domain.id) CustomDomain.delete(custom_domain.id)
db.session.commit() Session.commit()
LOG.d("Domain %s deleted", domain_name) LOG.d("Domain %s deleted", domain_name)
@ -328,7 +328,7 @@ def domain_detail_trash(custom_domain_id):
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "empty-all": if request.form.get("form-name") == "empty-all":
DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete() DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete()
db.session.commit() Session.commit()
flash("All deleted aliases can now be re-created", "success") flash("All deleted aliases can now be re-created", "success")
return redirect( return redirect(
@ -349,7 +349,7 @@ def domain_detail_trash(custom_domain_id):
) )
DomainDeletedAlias.delete(deleted_alias.id) DomainDeletedAlias.delete(deleted_alias.id)
db.session.commit() Session.commit()
flash( flash(
f"{deleted_alias.email} can now be re-created", f"{deleted_alias.email} can now be re-created",
"success", "success",
@ -477,7 +477,7 @@ def domain_detail_auto_create(custom_domain_id):
auto_create_rule_id=rule.id, mailbox_id=mailbox.id auto_create_rule_id=rule.id, mailbox_id=mailbox.id
) )
db.session.commit() Session.commit()
flash("New auto create rule has been created", "success") flash("New auto create rule has been created", "success")
@ -502,7 +502,7 @@ def domain_detail_auto_create(custom_domain_id):
rule_order = rule.order rule_order = rule.order
AutoCreateRule.delete(rule_id) AutoCreateRule.delete(rule_id)
db.session.commit() Session.commit()
flash(f"Rule #{rule_order} has been deleted", "success") flash(f"Rule #{rule_order} has been deleted", "success")
return redirect( return redirect(
url_for( url_for(

View File

@ -5,7 +5,7 @@ from wtforms import HiddenField, validators
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import RecoveryCode, Fido from app.models import RecoveryCode, Fido
@ -34,7 +34,7 @@ def fido_manage():
return redirect(url_for("dashboard.fido_manage")) return redirect(url_for("dashboard.fido_manage"))
Fido.delete(fido_key.id) Fido.delete(fido_key.id)
db.session.commit() Session.commit()
LOG.d(f"FIDO Key ID={fido_key.id} Removed") LOG.d(f"FIDO Key ID={fido_key.id} Removed")
flash(f"Key {fido_key.name} successfully unlinked", "success") flash(f"Key {fido_key.name} successfully unlinked", "success")
@ -42,7 +42,7 @@ def fido_manage():
# Disable FIDO for the user if all keys have been deleted # Disable FIDO for the user if all keys have been deleted
if not Fido.filter_by(uuid=current_user.fido_uuid).all(): if not Fido.filter_by(uuid=current_user.fido_uuid).all():
current_user.fido_uuid = None current_user.fido_uuid = None
db.session.commit() Session.commit()
# user does not have any 2FA enabled left, delete all recovery codes # user does not have any 2FA enabled left, delete all recovery codes
if not current_user.two_factor_authentication_enabled(): if not current_user.two_factor_authentication_enabled():

View File

@ -11,7 +11,7 @@ from wtforms import StringField, HiddenField, validators
from app.config import RP_ID, URL from app.config import RP_ID, URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Fido, RecoveryCode from app.models import Fido, RecoveryCode
@ -61,7 +61,7 @@ def fido_setup():
if current_user.fido_uuid is None: if current_user.fido_uuid is None:
current_user.fido_uuid = fido_uuid current_user.fido_uuid = fido_uuid
db.session.flush() Session.flush()
Fido.create( Fido.create(
credential_id=str(fido_credential.credential_id, "utf-8"), credential_id=str(fido_credential.credential_id, "utf-8"),
@ -70,14 +70,14 @@ def fido_setup():
sign_count=fido_credential.sign_count, sign_count=fido_credential.sign_count,
name=fido_token_form.key_name.data, name=fido_token_form.key_name.data,
) )
db.session.commit() Session.commit()
LOG.d( LOG.d(
f"credential_id={str(fido_credential.credential_id, 'utf-8')} added for {fido_uuid}" f"credential_id={str(fido_credential.credential_id, 'utf-8')} added for {fido_uuid}"
) )
flash("Security key has been activated", "success") flash("Security key has been activated", "success")
if not RecoveryCode.query.filter_by(user_id=current_user.id).all(): if not RecoveryCode.filter_by(user_id=current_user.id).all():
return redirect(url_for("dashboard.recovery_code_route")) return redirect(url_for("dashboard.recovery_code_route"))
else: else:
return redirect(url_for("dashboard.fido_manage")) return redirect(url_for("dashboard.fido_manage"))

View File

@ -7,7 +7,8 @@ from app import alias_utils
from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3 from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3
from app.config import PAGE_LIMIT, ALIAS_LIMIT from app.config import PAGE_LIMIT, ALIAS_LIMIT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db, limiter from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -26,19 +27,19 @@ class Stats:
def get_stats(user: User) -> Stats: def get_stats(user: User) -> Stats:
nb_alias = Alias.query.filter_by(user_id=user.id).count() nb_alias = Alias.filter_by(user_id=user.id).count()
nb_forward = ( nb_forward = (
db.session.query(EmailLog) Session.query(EmailLog)
.filter_by(user_id=user.id, is_reply=False, blocked=False, bounced=False) .filter_by(user_id=user.id, is_reply=False, blocked=False, bounced=False)
.count() .count()
) )
nb_reply = ( nb_reply = (
db.session.query(EmailLog) Session.query(EmailLog)
.filter_by(user_id=user.id, is_reply=True, blocked=False, bounced=False) .filter_by(user_id=user.id, is_reply=True, blocked=False, bounced=False)
.count() .count()
) )
nb_block = ( nb_block = (
db.session.query(EmailLog) Session.query(EmailLog)
.filter_by(user_id=user.id, is_reply=False, blocked=True, bounced=False) .filter_by(user_id=user.id, is_reply=False, blocked=True, bounced=False)
.count() .count()
) )
@ -92,7 +93,7 @@ def index():
alias.mailbox_id = current_user.default_mailbox_id alias.mailbox_id = current_user.default_mailbox_id
db.session.commit() Session.commit()
LOG.d("create new random alias %s for user %s", alias, current_user) LOG.d("create new random alias %s for user %s", alias, current_user)
flash(f"Alias {alias.email} has been created", "success") flash(f"Alias {alias.email} has been created", "success")
@ -130,7 +131,7 @@ def index():
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias": elif request.form.get("form-name") == "disable-alias":
alias.enabled = False alias.enabled = False
db.session.commit() Session.commit()
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")
return redirect( return redirect(
@ -146,7 +147,7 @@ def index():
# to make sure not showing intro to user again # to make sure not showing intro to user again
current_user.intro_shown = True current_user.intro_shown = True
db.session.commit() Session.commit()
stats = get_stats(current_user) stats = get_stats(current_user)

View File

@ -5,8 +5,8 @@ from wtforms import StringField, validators
from app.config import ADMIN_EMAIL from app.config import ADMIN_EMAIL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import send_email from app.email_utils import send_email
from app.extensions import db
from app.models import LifetimeCoupon from app.models import LifetimeCoupon
@ -40,7 +40,7 @@ def lifetime_licence():
current_user.lifetime_coupon_id = coupon.id current_user.lifetime_coupon_id = coupon.id
if coupon.paid: if coupon.paid:
current_user.paid_lifetime = True current_user.paid_lifetime = True
db.session.commit() Session.commit()
# notify admin # notify admin
send_email( send_email(

View File

@ -9,6 +9,7 @@ from wtforms.fields.html5 import EmailField
from app.config import MAILBOX_SECRET, URL from app.config import MAILBOX_SECRET, URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import ( from app.email_utils import (
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
mailbox_already_used, mailbox_already_used,
@ -16,7 +17,6 @@ from app.email_utils import (
send_email, send_email,
is_valid_email, is_valid_email,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Mailbox from app.models import Mailbox
@ -31,7 +31,7 @@ class NewMailboxForm(FlaskForm):
@login_required @login_required
def mailbox_route(): def mailbox_route():
mailboxes = ( mailboxes = (
Mailbox.query.filter_by(user_id=current_user.id) Mailbox.filter_by(user_id=current_user.id)
.order_by(Mailbox.created_at.desc()) .order_by(Mailbox.created_at.desc())
.all() .all()
) )
@ -77,7 +77,7 @@ def mailbox_route():
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
current_user.default_mailbox_id = mailbox.id current_user.default_mailbox_id = mailbox.id
db.session.commit() Session.commit()
flash(f"Mailbox {mailbox.email} is set as Default Mailbox", "success") flash(f"Mailbox {mailbox.email} is set as Default Mailbox", "success")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
@ -102,7 +102,7 @@ def mailbox_route():
new_mailbox = Mailbox.create( new_mailbox = Mailbox.create(
email=mailbox_email, user_id=current_user.id email=mailbox_email, user_id=current_user.id
) )
db.session.commit() Session.commit()
send_verification_email(current_user, new_mailbox) send_verification_email(current_user, new_mailbox)
@ -136,7 +136,7 @@ def delete_mailbox(mailbox_id: int):
user = mailbox.user user = mailbox.user
Mailbox.delete(mailbox_id) Mailbox.delete(mailbox_id)
db.session.commit() Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email) LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
send_email( send_email(
@ -191,7 +191,7 @@ def mailbox_verify():
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True mailbox.verified = True
db.session.commit() Session.commit()
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)

View File

@ -10,9 +10,9 @@ from wtforms.fields.html5 import EmailField
from app.config import ENFORCE_SPF, MAILBOX_SECRET from app.config import ENFORCE_SPF, MAILBOX_SECRET
from app.config import URL from app.config import URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
@ -57,7 +57,7 @@ def mailbox_detail_route(mailbox_id):
flash("You cannot use this email address as your mailbox", "error") flash("You cannot use this email address as your mailbox", "error")
else: else:
mailbox.new_email = new_email mailbox.new_email = new_email
db.session.commit() Session.commit()
try: try:
verify_mailbox_change(current_user, mailbox, new_email) verify_mailbox_change(current_user, mailbox, new_email)
@ -82,7 +82,7 @@ def mailbox_detail_route(mailbox_id):
mailbox.force_spf = ( mailbox.force_spf = (
True if request.form.get("spf-status") == "on" else False True if request.form.get("spf-status") == "on" else False
) )
db.session.commit() Session.commit()
flash( flash(
"SPF enforcement was " + "enabled" "SPF enforcement was " + "enabled"
if request.form.get("spf-status") if request.form.get("spf-status")
@ -118,7 +118,7 @@ def mailbox_detail_route(mailbox_id):
else: else:
address = authorized_address.email address = authorized_address.email
AuthorizedAddress.delete(authorized_address_id) AuthorizedAddress.delete(authorized_address_id)
db.session.commit() Session.commit()
flash(f"{address} has been deleted", "success") flash(f"{address} has been deleted", "success")
return redirect( return redirect(
@ -140,7 +140,7 @@ def mailbox_detail_route(mailbox_id):
except PGPException: except PGPException:
flash("Cannot add the public key, please verify it", "error") flash("Cannot add the public key, please verify it", "error")
else: else:
db.session.commit() Session.commit()
flash("Your PGP public key is saved successfully", "success") flash("Your PGP public key is saved successfully", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -150,7 +150,7 @@ def mailbox_detail_route(mailbox_id):
mailbox.pgp_public_key = None mailbox.pgp_public_key = None
mailbox.pgp_finger_print = None mailbox.pgp_finger_print = None
mailbox.disable_pgp = False mailbox.disable_pgp = False
db.session.commit() Session.commit()
flash("Your PGP public key is removed successfully", "success") flash("Your PGP public key is removed successfully", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -164,7 +164,7 @@ def mailbox_detail_route(mailbox_id):
mailbox.disable_pgp = True mailbox.disable_pgp = True
flash(f"PGP is disabled on {mailbox.email}", "info") flash(f"PGP is disabled on {mailbox.email}", "info")
db.session.commit() Session.commit()
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
@ -180,14 +180,14 @@ def mailbox_detail_route(mailbox_id):
) )
mailbox.generic_subject = request.form.get("generic-subject") mailbox.generic_subject = request.form.get("generic-subject")
db.session.commit() Session.commit()
flash("Generic subject for PGP-encrypted email is enabled", "success") flash("Generic subject for PGP-encrypted email is enabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
mailbox.generic_subject = None mailbox.generic_subject = None
db.session.commit() Session.commit()
flash("Generic subject for PGP-encrypted email is disabled", "success") flash("Generic subject for PGP-encrypted email is disabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -236,7 +236,7 @@ def cancel_mailbox_change_route(mailbox_id):
if mailbox.new_email: if mailbox.new_email:
mailbox.new_email = None mailbox.new_email = None
db.session.commit() Session.commit()
flash("Your mailbox change is cancelled", "success") flash("Your mailbox change is cancelled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -274,7 +274,7 @@ def mailbox_confirm_change_route():
# mark mailbox as verified if the change request is sent from an unverified mailbox # mark mailbox as verified if the change request is sent from an unverified mailbox
mailbox.verified = True mailbox.verified = True
db.session.commit() Session.commit()
LOG.d("Mailbox change %s is verified", mailbox) LOG.d("Mailbox change %s is verified", mailbox)
flash(f"The {mailbox.email} is updated", "success") flash(f"The {mailbox.email} is updated", "success")

View File

@ -3,7 +3,7 @@ from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import db from app.db import Session
from app.models import RecoveryCode from app.models import RecoveryCode
@ -19,7 +19,7 @@ def mfa_cancel():
if request.method == "POST": if request.method == "POST":
current_user.enable_otp = False current_user.enable_otp = False
current_user.otp_secret = None current_user.otp_secret = None
db.session.commit() Session.commit()
# user does not have any 2FA enabled left, delete all recovery codes # user does not have any 2FA enabled left, delete all recovery codes
if not current_user.two_factor_authentication_enabled(): if not current_user.two_factor_authentication_enabled():

View File

@ -6,7 +6,7 @@ from wtforms import StringField, validators
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
@ -27,7 +27,7 @@ def mfa_setup():
if not current_user.otp_secret: if not current_user.otp_secret:
LOG.d("Generate otp_secret for user %s", current_user) LOG.d("Generate otp_secret for user %s", current_user)
current_user.otp_secret = pyotp.random_base32() current_user.otp_secret = pyotp.random_base32()
db.session.commit() Session.commit()
totp = pyotp.TOTP(current_user.otp_secret) totp = pyotp.TOTP(current_user.otp_secret)
@ -37,7 +37,7 @@ def mfa_setup():
if totp.verify(token) and current_user.last_otp != token: if totp.verify(token) and current_user.last_otp != token:
current_user.enable_otp = True current_user.enable_otp = True
current_user.last_otp = token current_user.last_otp = token
db.session.commit() Session.commit()
flash("MFA has been activated", "success") flash("MFA has been activated", "success")
return redirect(url_for("dashboard.recovery_code_route")) return redirect(url_for("dashboard.recovery_code_route"))

View File

@ -13,12 +13,12 @@ def recovery_code_route():
flash("you need to enable either TOTP or WebAuthn", "warning") flash("you need to enable either TOTP or WebAuthn", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
recovery_codes = RecoveryCode.query.filter_by(user_id=current_user.id).all() recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "GET" and not recovery_codes: if request.method == "GET" and not recovery_codes:
# user arrives at this page for the first time # user arrives at this page for the first time
LOG.d("%s has no recovery keys, generate", current_user) LOG.d("%s has no recovery keys, generate", current_user)
RecoveryCode.generate(current_user) RecoveryCode.generate(current_user)
recovery_codes = RecoveryCode.query.filter_by(user_id=current_user.id).all() recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "POST": if request.method == "POST":
RecoveryCode.generate(current_user) RecoveryCode.generate(current_user)

View File

@ -1,10 +1,9 @@
import re2 as re import re2 as re
from flask import render_template, request, flash, redirect, url_for from flask import render_template, request, flash, redirect, url_for
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db from app.db import Session
from app.models import Referral, Payout from app.models import Referral, Payout
_REFERRAL_PATTERN = r"[0-9a-z-_]{3,}" _REFERRAL_PATTERN = r"[0-9a-z-_]{3,}"
@ -30,7 +29,7 @@ def referral_route():
name = request.form.get("name") name = request.form.get("name")
referral = Referral.create(user_id=current_user.id, code=code, name=name) referral = Referral.create(user_id=current_user.id, code=code, name=name)
db.session.commit() Session.commit()
flash("A new referral code has been created", "success") flash("A new referral code has been created", "success")
return redirect( return redirect(
url_for("dashboard.referral_route", highlight_id=referral.id) url_for("dashboard.referral_route", highlight_id=referral.id)
@ -40,7 +39,7 @@ def referral_route():
referral = Referral.get(referral_id) referral = Referral.get(referral_id)
if referral and referral.user_id == current_user.id: if referral and referral.user_id == current_user.id:
referral.name = request.form.get("name") referral.name = request.form.get("name")
db.session.commit() Session.commit()
flash("Referral name updated", "success") flash("Referral name updated", "success")
return redirect( return redirect(
url_for("dashboard.referral_route", highlight_id=referral.id) url_for("dashboard.referral_route", highlight_id=referral.id)
@ -50,7 +49,7 @@ def referral_route():
referral = Referral.get(referral_id) referral = Referral.get(referral_id)
if referral and referral.user_id == current_user.id: if referral and referral.user_id == current_user.id:
Referral.delete(referral.id) Referral.delete(referral.id)
db.session.commit() Session.commit()
flash("Referral deleted", "success") flash("Referral deleted", "success")
return redirect(url_for("dashboard.referral_route")) return redirect(url_for("dashboard.referral_route"))
@ -59,7 +58,7 @@ def referral_route():
if highlight_id: if highlight_id:
highlight_id = int(highlight_id) highlight_id = int(highlight_id)
referrals = Referral.query.filter_by(user_id=current_user.id).all() referrals = Referral.filter_by(user_id=current_user.id).all()
# make sure the highlighted referral is the first referral # make sure the highlighted referral is the first referral
highlight_index = None highlight_index = None
for ix, referral in enumerate(referrals): for ix, referral in enumerate(referrals):
@ -70,6 +69,6 @@ def referral_route():
if highlight_index: if highlight_index:
referrals.insert(0, referrals.pop(highlight_index)) referrals.insert(0, referrals.pop(highlight_index))
payouts = Payout.query.filter_by(user_id=current_user.id).all() payouts = Payout.filter_by(user_id=current_user.id).all()
return render_template("dashboard/referral.html", **locals()) return render_template("dashboard/referral.html", **locals())

View File

@ -19,7 +19,7 @@ def refused_email_route():
highlight_id = None highlight_id = None
email_logs: [EmailLog] = ( email_logs: [EmailLog] = (
EmailLog.query.filter( EmailLog.filter(
EmailLog.user_id == current_user.id, EmailLog.refused_email_id.isnot(None) EmailLog.user_id == current_user.id, EmailLog.refused_email_id.isnot(None)
) )
.order_by(EmailLog.id.desc()) .order_by(EmailLog.id.desc())

View File

@ -22,11 +22,11 @@ from app.config import (
ALIAS_RANDOM_SUFFIX_LENGTH, ALIAS_RANDOM_SUFFIX_LENGTH,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.email_utils import ( from app.email_utils import (
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
personal_email_already_used, personal_email_already_used,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
PlanEnum, PlanEnum,
@ -116,7 +116,7 @@ def setting():
"delete the expired email change %s", other_email_change "delete the expired email change %s", other_email_change
) )
EmailChange.delete(other_email_change.id) EmailChange.delete(other_email_change.id)
db.session.commit() Session.commit()
else: else:
flash( flash(
"You cannot use this email address as your personal inbox.", "You cannot use this email address as your personal inbox.",
@ -132,7 +132,7 @@ def setting():
), # todo: make sure the code is unique ), # todo: make sure the code is unique
new_email=new_email, new_email=new_email,
) )
db.session.commit() Session.commit()
send_change_email_confirmation(current_user, email_change) send_change_email_confirmation(current_user, email_change)
flash( flash(
"A confirmation email is on the way, please check your inbox", "A confirmation email is on the way, please check your inbox",
@ -145,7 +145,7 @@ def setting():
# update user info # update user info
if form.name.data != current_user.name: if form.name.data != current_user.name:
current_user.name = form.name.data current_user.name = form.name.data
db.session.commit() Session.commit()
profile_updated = True profile_updated = True
if form.profile_picture.data: if form.profile_picture.data:
@ -156,11 +156,11 @@ def setting():
file_path, BytesIO(form.profile_picture.data.read()) file_path, BytesIO(form.profile_picture.data.read())
) )
db.session.flush() Session.flush()
LOG.d("upload file %s to s3", file) LOG.d("upload file %s to s3", file)
current_user.profile_picture_id = file.id current_user.profile_picture_id = file.id
db.session.commit() Session.commit()
profile_updated = True profile_updated = True
if profile_updated: if profile_updated:
@ -181,7 +181,7 @@ def setting():
current_user.notification = True current_user.notification = True
else: else:
current_user.notification = False current_user.notification = False
db.session.commit() Session.commit()
flash("Your notification preference has been updated", "success") flash("Your notification preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -212,7 +212,7 @@ def setting():
scheme = int(request.form.get("alias-generator-scheme")) scheme = int(request.form.get("alias-generator-scheme"))
if AliasGeneratorEnum.has_value(scheme): if AliasGeneratorEnum.has_value(scheme):
current_user.alias_generator = scheme current_user.alias_generator = scheme
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -249,7 +249,7 @@ def setting():
current_user.default_alias_custom_domain_id = None current_user.default_alias_custom_domain_id = None
current_user.default_alias_public_domain_id = None current_user.default_alias_public_domain_id = None
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -257,7 +257,7 @@ def setting():
scheme = int(request.form.get("random-alias-suffix-generator")) scheme = int(request.form.get("random-alias-suffix-generator"))
if AliasSuffixEnum.has_value(scheme): if AliasSuffixEnum.has_value(scheme):
current_user.random_alias_suffix = scheme current_user.random_alias_suffix = scheme
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -266,9 +266,9 @@ def setting():
if SenderFormatEnum.has_value(sender_format): if SenderFormatEnum.has_value(sender_format):
current_user.sender_format = sender_format current_user.sender_format = sender_format
current_user.sender_format_updated_at = arrow.now() current_user.sender_format_updated_at = arrow.now()
db.session.commit() Session.commit()
flash("Your sender format preference has been updated", "success") flash("Your sender format preference has been updated", "success")
db.session.commit() Session.commit()
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "replace-ra": elif request.form.get("form-name") == "replace-ra":
@ -277,7 +277,7 @@ def setting():
current_user.replace_reverse_alias = True current_user.replace_reverse_alias = True
else: else:
current_user.replace_reverse_alias = False current_user.replace_reverse_alias = False
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -287,7 +287,7 @@ def setting():
current_user.include_sender_in_reverse_alias = True current_user.include_sender_in_reverse_alias = True
else: else:
current_user.include_sender_in_reverse_alias = False current_user.include_sender_in_reverse_alias = False
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -297,7 +297,7 @@ def setting():
current_user.expand_alias_info = True current_user.expand_alias_info = True
else: else:
current_user.expand_alias_info = False current_user.expand_alias_info = False
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "ignore-loop-email": elif request.form.get("form-name") == "ignore-loop-email":
@ -306,7 +306,7 @@ def setting():
current_user.ignore_loop_email = True current_user.ignore_loop_email = True
else: else:
current_user.ignore_loop_email = False current_user.ignore_loop_email = False
db.session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -344,7 +344,7 @@ def send_reset_password_email(user):
reset_password_code = ResetPasswordCode.create( reset_password_code = ResetPasswordCode.create(
user_id=user.id, code=random_string(60) user_id=user.id, code=random_string(60)
) )
db.session.commit() Session.commit()
reset_password_link = f"{URL}/auth/reset_password?code={reset_password_code.code}" reset_password_link = f"{URL}/auth/reset_password?code={reset_password_code.code}"
@ -368,7 +368,7 @@ def resend_email_change():
if email_change: if email_change:
# extend email change expiration # extend email change expiration
email_change.expired = arrow.now().shift(hours=12) email_change.expired = arrow.now().shift(hours=12)
db.session.commit() Session.commit()
send_change_email_confirmation(current_user, email_change) send_change_email_confirmation(current_user, email_change)
flash("A confirmation email is on the way, please check your inbox", "success") flash("A confirmation email is on the way, please check your inbox", "success")
@ -386,7 +386,7 @@ def cancel_email_change():
email_change = EmailChange.get_by(user_id=current_user.id) email_change = EmailChange.get_by(user_id=current_user.id)
if email_change: if email_change:
EmailChange.delete(email_change.id) EmailChange.delete(email_change.id)
db.session.commit() Session.commit()
flash("Your email change is cancelled", "success") flash("Your email change is cancelled", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
else: else:

View File

@ -1,3 +1,5 @@
from app.db import Session
""" """
Allow user to "unsubscribe", aka block an email alias Allow user to "unsubscribe", aka block an email alias
""" """
@ -6,7 +8,6 @@ from flask import redirect, url_for, flash, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import db
from app.models import Alias from app.models import Alias
@ -29,7 +30,7 @@ def unsubscribe(alias_id):
if request.method == "POST": if request.method == "POST":
alias.enabled = False alias.enabled = False
flash(f"Alias {alias.email} has been blocked", "success") flash(f"Alias {alias.email} has been blocked", "success")
db.session.commit() Session.commit()
return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) return redirect(url_for("dashboard.index", highlight_alias_id=alias.id))
else: # ask user confirmation else: # ask user confirmation

10
app/db.py Normal file
View File

@ -0,0 +1,10 @@
from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker
from app.config import DB_URI
engine = create_engine(DB_URI)
connection = engine.connect()
Session = scoped_session(sessionmaker(bind=connection))

View File

@ -8,9 +8,9 @@ from wtforms import StringField, validators, TextAreaField
from app import s3 from app import s3
from app.config import ADMIN_EMAIL from app.config import ADMIN_EMAIL
from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.email_utils import send_email from app.email_utils import send_email
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Client, RedirectUri, File from app.models import Client, RedirectUri, File
from app.utils import random_string from app.utils import random_string
@ -55,13 +55,13 @@ def client_detail(client_id):
s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read())) s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read()))
db.session.flush() Session.flush()
LOG.d("upload file %s to s3", file) LOG.d("upload file %s to s3", file)
client.icon_id = file.id client.icon_id = file.id
db.session.flush() Session.flush()
db.session.commit() Session.commit()
flash(f"{client.name} has been updated", "success") flash(f"{client.name} has been updated", "success")
@ -69,7 +69,7 @@ def client_detail(client_id):
if action == "submit" and approval_form.validate_on_submit(): if action == "submit" and approval_form.validate_on_submit():
client.description = approval_form.description.data client.description = approval_form.description.data
db.session.commit() Session.commit()
send_email( send_email(
ADMIN_EMAIL, ADMIN_EMAIL,
@ -127,7 +127,7 @@ def client_detail_oauth_setting(client_id):
for uri in uris: for uri in uris:
RedirectUri.create(client_id=client_id, uri=uri) RedirectUri.create(client_id=client_id, uri=uri)
db.session.commit() Session.commit()
flash(f"{client.name} has been updated", "success") flash(f"{client.name} has been updated", "success")
@ -178,7 +178,7 @@ def client_detail_advanced(client_id):
# delete client # delete client
client_name = client.name client_name = client.name
Client.delete(client.id) Client.delete(client.id)
db.session.commit() Session.commit()
LOG.d("Remove client %s", client) LOG.d("Remove client %s", client)
flash(f"{client_name} has been deleted", "success") flash(f"{client_name} has been deleted", "success")

View File

@ -3,8 +3,8 @@ from flask_login import current_user, login_required
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.extensions import db
from app.models import Client from app.models import Client
@ -19,7 +19,7 @@ def new_client():
if form.validate_on_submit(): if form.validate_on_submit():
client = Client.create_new(form.name.data, current_user.id) client = Client.create_new(form.name.data, current_user.id)
db.session.commit() Session.commit()
flash("Your app has been created", "success") flash("Your app has been created", "success")

View File

@ -5,8 +5,8 @@ from app.config import (
MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS, MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS,
MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX, MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX,
) )
from app.db import Session
from app.email_utils import is_reply_email from app.email_utils import is_reply_email
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import Alias, EmailLog, Contact from app.models import Alias, EmailLog, Contact
@ -16,7 +16,7 @@ def rate_limited_for_alias(alias: Alias) -> bool:
# get the nb of activity on this alias # get the nb of activity on this alias
nb_activity = ( nb_activity = (
db.session.query(EmailLog) Session.query(EmailLog)
.join(Contact, EmailLog.contact_id == Contact.id) .join(Contact, EmailLog.contact_id == Contact.id)
.filter( .filter(
Contact.alias_id == alias.id, Contact.alias_id == alias.id,
@ -42,7 +42,7 @@ def rate_limited_for_mailbox(alias: Alias) -> bool:
# get nb of activity on this mailbox # get nb of activity on this mailbox
nb_activity = ( nb_activity = (
db.session.query(EmailLog) Session.query(EmailLog)
.join(Contact, EmailLog.contact_id == Contact.id) .join(Contact, EmailLog.contact_id == Contact.id)
.join(Alias, Contact.alias_id == Alias.id) .join(Alias, Contact.alias_id == Alias.id)
.filter( .filter(

View File

@ -53,9 +53,9 @@ from app.config import (
TEMP_DIR, TEMP_DIR,
ALIAS_AUTOMATIC_DISABLE, ALIAS_AUTOMATIC_DISABLE,
) )
from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.email import headers from app.email import headers
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Mailbox, Mailbox,
@ -324,7 +324,7 @@ def send_email_with_rate_control(
to_email = sanitize_email(to_email) to_email = sanitize_email(to_email)
min_dt = arrow.now().shift(days=-1 * nb_day) min_dt = arrow.now().shift(days=-1 * nb_day)
nb_alert = ( nb_alert = (
SentAlert.query.filter_by(alert_type=alert_type, to_email=to_email) SentAlert.filter_by(alert_type=alert_type, to_email=to_email)
.filter(SentAlert.created_at > min_dt) .filter(SentAlert.created_at > min_dt)
.count() .count()
) )
@ -340,7 +340,7 @@ def send_email_with_rate_control(
return False return False
SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email) SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email)
db.session.commit() Session.commit()
if ignore_smtp_error: if ignore_smtp_error:
try: try:
@ -369,9 +369,7 @@ def send_email_at_most_times(
Return true if the email is sent, otherwise False Return true if the email is sent, otherwise False
""" """
to_email = sanitize_email(to_email) to_email = sanitize_email(to_email)
nb_alert = SentAlert.query.filter_by( nb_alert = SentAlert.filter_by(alert_type=alert_type, to_email=to_email).count()
alert_type=alert_type, to_email=to_email
).count()
if nb_alert >= max_times: if nb_alert >= max_times:
LOG.w( LOG.w(
@ -383,7 +381,7 @@ def send_email_at_most_times(
return False return False
SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email) SentAlert.create(user_id=user.id, alert_type=alert_type, to_email=to_email)
db.session.commit() Session.commit()
send_email(to_email, subject, plaintext, html) send_email(to_email, subject, plaintext, html)
return True return True
@ -1036,7 +1034,7 @@ def should_disable(alias: Alias) -> bool:
yesterday = arrow.now().shift(days=-1) yesterday = arrow.now().shift(days=-1)
nb_bounced_last_24h = ( nb_bounced_last_24h = (
db.session.query(EmailLog) Session.query(EmailLog)
.filter( .filter(
EmailLog.bounced.is_(True), EmailLog.bounced.is_(True),
EmailLog.is_reply.is_(False), EmailLog.is_reply.is_(False),
@ -1054,7 +1052,7 @@ def should_disable(alias: Alias) -> bool:
elif nb_bounced_last_24h > 5: elif nb_bounced_last_24h > 5:
one_week_ago = arrow.now().shift(days=-8) one_week_ago = arrow.now().shift(days=-8)
nb_bounced_7d_1d = ( nb_bounced_7d_1d = (
db.session.query(EmailLog) Session.query(EmailLog)
.filter( .filter(
EmailLog.bounced.is_(True), EmailLog.bounced.is_(True),
EmailLog.is_reply.is_(False), EmailLog.is_reply.is_(False),
@ -1075,7 +1073,7 @@ def should_disable(alias: Alias) -> bool:
# alias level # alias level
# if bounces at least 9 days in the last 10 days -> disable alias # if bounces at least 9 days in the last 10 days -> disable alias
query = ( query = (
db.session.query( Session.query(
func.date(EmailLog.created_at).label("date"), func.date(EmailLog.created_at).label("date"),
func.count(EmailLog.id).label("count"), func.count(EmailLog.id).label("count"),
) )
@ -1097,7 +1095,7 @@ def should_disable(alias: Alias) -> bool:
# account level # account level
query = ( query = (
db.session.query( Session.query(
func.date(EmailLog.created_at).label("date"), func.date(EmailLog.created_at).label("date"),
func.count(EmailLog.id).label("count"), func.count(EmailLog.id).label("count"),
) )

View File

@ -3,8 +3,8 @@ import csv
import requests import requests
from app import s3 from app import s3
from app.db import Session
from app.email_utils import get_email_domain_part from app.email_utils import get_email_domain_part
from app.extensions import db
from app.models import ( from app.models import (
Alias, Alias,
AliasMailbox, AliasMailbox,
@ -23,7 +23,7 @@ def handle_batch_import(batch_import: BatchImport):
user = batch_import.user user = batch_import.user
batch_import.processed = True batch_import.processed = True
db.session.commit() Session.commit()
LOG.d("Start batch import for %s %s", batch_import, user) LOG.d("Start batch import for %s %s", batch_import, user)
file_url = s3.get_url(batch_import.file.path) file_url = s3.get_url(batch_import.file.path)
@ -97,5 +97,5 @@ def import_from_csv(batch_import: BatchImport, user: User, lines):
AliasMailbox.create( AliasMailbox.create(
alias_id=alias.id, mailbox_id=mailboxes[i], commit=True alias_id=alias.id, mailbox_id=mailboxes[i], commit=True
) )
db.session.commit() Session.commit()
LOG.d("Add %s to mailbox %s", alias, mailboxes[i]) LOG.d("Add %s to mailbox %s", alias, mailboxes[i])

File diff suppressed because it is too large Load Diff

View File

@ -8,7 +8,7 @@ from itsdangerous import SignatureExpired
from app.alias_utils import check_alias_prefix from app.alias_utils import check_alias_prefix
from app.config import EMAIL_DOMAIN from app.config import EMAIL_DOMAIN
from app.dashboard.views.custom_alias import signer, get_available_suffixes from app.dashboard.views.custom_alias import signer, get_available_suffixes
from app.extensions import db from app.db import Session
from app.jose_utils import make_id_token from app.jose_utils import make_id_token
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
@ -206,7 +206,7 @@ def authorize():
if domain: if domain:
alias.custom_domain_id = domain.id alias.custom_domain_id = domain.id
db.session.flush() Session.flush()
flash(f"Alias {full_alias} has been created", "success") flash(f"Alias {full_alias} has been created", "success")
# only happen if the request has been "hacked" # only happen if the request has been "hacked"
else: else:
@ -224,7 +224,7 @@ def authorize():
user_id=current_user.id, user_id=current_user.id,
mailbox_id=current_user.default_mailbox_id, mailbox_id=current_user.default_mailbox_id,
) )
db.session.flush() Session.flush()
suggested_name = request.form.get("suggested-name") suggested_name = request.form.get("suggested-name")
custom_name = request.form.get("custom-name") custom_name = request.form.get("custom-name")
@ -247,7 +247,7 @@ def authorize():
LOG.d("use default avatar for user %s client %s", current_user, client) LOG.d("use default avatar for user %s client %s", current_user, client)
client_user.default_avatar = True client_user.default_avatar = True
db.session.flush() Session.flush()
LOG.d("create client-user for client %s, user %s", client, current_user) LOG.d("create client-user for client %s, user %s", client, current_user)
redirect_args = {} redirect_args = {}
@ -284,7 +284,7 @@ def authorize():
access_token=generate_access_token(), access_token=generate_access_token(),
response_type=response_types_to_str(response_types), response_type=response_types_to_str(response_types),
) )
db.session.add(oauth_token) Session.add(oauth_token)
redirect_args["access_token"] = oauth_token.access_token redirect_args["access_token"] = oauth_token.access_token
if ResponseType.ID_TOKEN in response_types: if ResponseType.ID_TOKEN in response_types:
@ -295,7 +295,7 @@ def authorize():
auth_code.code if auth_code else None, auth_code.code if auth_code else None,
) )
db.session.commit() Session.commit()
# should all params appended the url using fragment (#) or query # should all params appended the url using fragment (#) or query
fragment = False fragment = False

View File

@ -1,7 +1,7 @@
from flask import request, jsonify from flask import request, jsonify
from flask_cors import cross_origin from flask_cors import cross_origin
from app.extensions import db from app.db import Session
from app.jose_utils import make_id_token from app.jose_utils import make_id_token
from app.log import LOG from app.log import LOG
from app.models import Client, AuthorizationCode, OauthToken, ClientUser from app.models import Client, AuthorizationCode, OauthToken, ClientUser
@ -49,7 +49,7 @@ def token():
return jsonify(error=f"no such authorization code {code}"), 400 return jsonify(error=f"no such authorization code {code}"), 400
elif auth_code.is_expired(): elif auth_code.is_expired():
AuthorizationCode.delete(auth_code.id) AuthorizationCode.delete(auth_code.id)
db.session.commit() Session.commit()
LOG.d("delete expired authorization code:%s", auth_code) LOG.d("delete expired authorization code:%s", auth_code)
return jsonify(error=f"{code} already expired"), 400 return jsonify(error=f"{code} already expired"), 400
@ -94,6 +94,6 @@ def token():
# Auth code can be used only once # Auth code can be used only once
AuthorizationCode.delete(auth_code.id) AuthorizationCode.delete(auth_code.id)
db.session.commit() Session.commit()
return jsonify(res) return jsonify(res)

View File

@ -1,7 +1,7 @@
from flask import request, jsonify from flask import request, jsonify
from flask_cors import cross_origin from flask_cors import cross_origin
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import OauthToken, ClientUser from app.models import OauthToken, ClientUser
from app.oauth.base import oauth_bp from app.oauth.base import oauth_bp
@ -27,7 +27,7 @@ def user_info():
elif oauth_token.is_expired(): elif oauth_token.is_expired():
LOG.d("delete oauth token %s", oauth_token) LOG.d("delete oauth token %s", oauth_token)
OauthToken.delete(oauth_token.id) OauthToken.delete(oauth_token.id)
db.session.commit() Session.commit()
return jsonify(error="Expired access token"), 400 return jsonify(error="Expired access token"), 400
client_user = ClientUser.get_or_create( client_user = ClientUser.get_or_create(

View File

@ -1,14 +1,12 @@
import unicodedata
import bcrypt import bcrypt
import sqlalchemy as sa
from app.extensions import db import unicodedata
_NORMALIZATION_FORM = "NFKC" _NORMALIZATION_FORM = "NFKC"
class PasswordOracle: class PasswordOracle:
password = db.Column(db.String(128), nullable=True) password = sa.Column(sa.String(128), nullable=True)
def set_password(self, password): def set_password(self, password):
password = unicodedata.normalize(_NORMALIZATION_FORM, password) password = unicodedata.normalize(_NORMALIZATION_FORM, password)

View File

@ -2,11 +2,12 @@
https://github.com/petermat/spamassassin_client https://github.com/petermat/spamassassin_client
""" """
import logging import logging
import re2 as re
import select
import socket import socket
from io import BytesIO from io import BytesIO
import re2 as re
import select
from app.log import LOG from app.log import LOG
divider_pattern = re.compile(br"^(.*?)\r?\n(.*?)\r?\n\r?\n", re.DOTALL) divider_pattern = re.compile(br"^(.*?)\r?\n(.*?)\r?\n\r?\n", re.DOTALL)

122
cron.py
View File

@ -22,6 +22,7 @@ from app.config import (
HIBP_API_KEYS, HIBP_API_KEYS,
HIBP_SCAN_INTERVAL_DAYS, HIBP_SCAN_INTERVAL_DAYS,
) )
from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.email_utils import ( from app.email_utils import (
send_email, send_email,
@ -33,7 +34,6 @@ from app.email_utils import (
is_valid_email, is_valid_email,
get_email_domain_part, get_email_domain_part,
) )
from app.extensions import db
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Subscription, Subscription,
@ -63,7 +63,7 @@ from server import create_app
def notify_trial_end(): def notify_trial_end():
for user in User.query.filter( for user in User.filter(
User.activated.is_(True), User.trial_end.isnot(None), User.lifetime.is_(False) User.activated.is_(True), User.trial_end.isnot(None), User.lifetime.is_(False)
).all(): ).all():
if user.in_trial() and arrow.now().shift( if user.in_trial() and arrow.now().shift(
@ -78,27 +78,27 @@ def delete_logs():
delete_refused_emails() delete_refused_emails()
delete_old_monitoring() delete_old_monitoring()
for t in TransactionalEmail.query.filter( for t in TransactionalEmail.filter(
TransactionalEmail.created_at < arrow.now().shift(days=-7) TransactionalEmail.created_at < arrow.now().shift(days=-7)
): ):
TransactionalEmail.delete(t.id) TransactionalEmail.delete(t.id)
for b in Bounce.query.filter(Bounce.created_at < arrow.now().shift(days=-7)): for b in Bounce.filter(Bounce.created_at < arrow.now().shift(days=-7)):
Bounce.delete(b.id) Bounce.delete(b.id)
db.session.commit() Session.commit()
LOG.d("Delete EmailLog older than 2 weeks") LOG.d("Delete EmailLog older than 2 weeks")
max_dt = arrow.now().shift(weeks=-2) max_dt = arrow.now().shift(weeks=-2)
nb_deleted = EmailLog.query.filter(EmailLog.created_at < max_dt).delete() nb_deleted = EmailLog.filter(EmailLog.created_at < max_dt).delete()
db.session.commit() Session.commit()
LOG.i("Delete %s email logs", nb_deleted) LOG.i("Delete %s email logs", nb_deleted)
def delete_refused_emails(): def delete_refused_emails():
for refused_email in RefusedEmail.query.filter_by(deleted=False).all(): for refused_email in RefusedEmail.filter_by(deleted=False).all():
if arrow.now().shift(days=1) > refused_email.delete_at >= arrow.now(): if arrow.now().shift(days=1) > refused_email.delete_at >= arrow.now():
LOG.d("Delete refused email %s", refused_email) LOG.d("Delete refused email %s", refused_email)
if refused_email.path: if refused_email.path:
@ -109,14 +109,14 @@ def delete_refused_emails():
# do not set path and full_report_path to null # do not set path and full_report_path to null
# so we can check later that the files are indeed deleted # so we can check later that the files are indeed deleted
refused_email.deleted = True refused_email.deleted = True
db.session.commit() Session.commit()
LOG.d("Finish delete_refused_emails") LOG.d("Finish delete_refused_emails")
def notify_premium_end(): def notify_premium_end():
"""sent to user who has canceled their subscription and who has their subscription ending soon""" """sent to user who has canceled their subscription and who has their subscription ending soon"""
for sub in Subscription.query.filter_by(cancelled=True).all(): for sub in Subscription.filter_by(cancelled=True).all():
if ( if (
arrow.now().shift(days=3).date() arrow.now().shift(days=3).date()
> sub.next_bill_date > sub.next_bill_date
@ -146,7 +146,7 @@ def notify_premium_end():
def notify_manual_sub_end(): def notify_manual_sub_end():
for manual_sub in ManualSubscription.query.all(): for manual_sub in ManualSubscription.all():
need_reminder = False need_reminder = False
if arrow.now().shift(days=14) > manual_sub.end_at > arrow.now().shift(days=13): if arrow.now().shift(days=14) > manual_sub.end_at > arrow.now().shift(days=13):
need_reminder = True need_reminder = True
@ -172,7 +172,7 @@ def notify_manual_sub_end():
) )
extend_subscription_url = URL + "/dashboard/coinbase_checkout" extend_subscription_url = URL + "/dashboard/coinbase_checkout"
for coinbase_subscription in CoinbaseSubscription.query.all(): for coinbase_subscription in CoinbaseSubscription.all():
need_reminder = False need_reminder = False
if ( if (
arrow.now().shift(days=14) arrow.now().shift(days=14)
@ -211,7 +211,7 @@ def notify_manual_sub_end():
def poll_apple_subscription(): def poll_apple_subscription():
"""Poll Apple API to update AppleSubscription""" """Poll Apple API to update AppleSubscription"""
# todo: only near the end of the subscription # todo: only near the end of the subscription
for apple_sub in AppleSubscription.query.all(): for apple_sub in AppleSubscription.all():
user = apple_sub.user user = apple_sub.user
verify_receipt(apple_sub.receipt_data, user, APPLE_API_SECRET) verify_receipt(apple_sub.receipt_data, user, APPLE_API_SECRET)
verify_receipt(apple_sub.receipt_data, user, MACAPP_APPLE_API_SECRET) verify_receipt(apple_sub.receipt_data, user, MACAPP_APPLE_API_SECRET)
@ -224,49 +224,49 @@ def compute_metric2() -> Metric2:
_24h_ago = now.shift(days=-1) _24h_ago = now.shift(days=-1)
nb_referred_user_paid = 0 nb_referred_user_paid = 0
for user in User.query.filter(User.referral_id.isnot(None)): for user in User.filter(User.referral_id.isnot(None)):
if user.is_paid(): if user.is_paid():
nb_referred_user_paid += 1 nb_referred_user_paid += 1
return Metric2.create( return Metric2.create(
date=now, date=now,
# user stats # user stats
nb_user=User.query.count(), nb_user=User.count(),
nb_activated_user=User.query.filter_by(activated=True).count(), nb_activated_user=User.filter_by(activated=True).count(),
# subscription stats # subscription stats
nb_premium=Subscription.query.filter(Subscription.cancelled.is_(False)).count(), nb_premium=Subscription.filter(Subscription.cancelled.is_(False)).count(),
nb_cancelled_premium=Subscription.query.filter( nb_cancelled_premium=Subscription.filter(
Subscription.cancelled.is_(True) Subscription.cancelled.is_(True)
).count(), ).count(),
# todo: filter by expires_date > now # todo: filter by expires_date > now
nb_apple_premium=AppleSubscription.query.count(), nb_apple_premium=AppleSubscription.count(),
nb_manual_premium=ManualSubscription.query.filter( nb_manual_premium=ManualSubscription.filter(
ManualSubscription.end_at > now, ManualSubscription.end_at > now,
ManualSubscription.is_giveaway.is_(False), ManualSubscription.is_giveaway.is_(False),
).count(), ).count(),
nb_coinbase_premium=CoinbaseSubscription.query.filter( nb_coinbase_premium=CoinbaseSubscription.filter(
CoinbaseSubscription.end_at > now CoinbaseSubscription.end_at > now
).count(), ).count(),
# referral stats # referral stats
nb_referred_user=User.query.filter(User.referral_id.isnot(None)).count(), nb_referred_user=User.filter(User.referral_id.isnot(None)).count(),
nb_referred_user_paid=nb_referred_user_paid, nb_referred_user_paid=nb_referred_user_paid,
nb_alias=Alias.query.count(), nb_alias=Alias.count(),
# email log stats # email log stats
nb_forward_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) nb_forward_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago)
.filter_by(bounced=False, is_spam=False, is_reply=False, blocked=False) .filter_by(bounced=False, is_spam=False, is_reply=False, blocked=False)
.count(), .count(),
nb_bounced_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) nb_bounced_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago)
.filter_by(bounced=True) .filter_by(bounced=True)
.count(), .count(),
nb_reply_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) nb_reply_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago)
.filter_by(is_reply=True) .filter_by(is_reply=True)
.count(), .count(),
nb_block_last_24h=EmailLog.query.filter(EmailLog.created_at > _24h_ago) nb_block_last_24h=EmailLog.filter(EmailLog.created_at > _24h_ago)
.filter_by(blocked=True) .filter_by(blocked=True)
.count(), .count(),
# other stats # other stats
nb_verified_custom_domain=CustomDomain.query.filter_by(verified=True).count(), nb_verified_custom_domain=CustomDomain.filter_by(verified=True).count(),
nb_app=Client.query.count(), nb_app=Client.count(),
commit=True, commit=True,
) )
@ -309,7 +309,7 @@ def bounce_report() -> List[Tuple[str, int]]:
""" """
min_dt = arrow.now().shift(days=-1) min_dt = arrow.now().shift(days=-1)
query = ( query = (
db.session.query(User.email, func.count(EmailLog.id).label("count")) Session.query(User.email, func.count(EmailLog.id).label("count"))
.join(EmailLog, EmailLog.user_id == User.id) .join(EmailLog, EmailLog.user_id == User.id)
.filter(EmailLog.bounced, EmailLog.created_at > min_dt) .filter(EmailLog.bounced, EmailLog.created_at > min_dt)
.group_by(User.email) .group_by(User.email)
@ -354,7 +354,7 @@ def alias_creation_report() -> List[Tuple[str, int]]:
""" """
min_dt = arrow.now().shift(days=-7) min_dt = arrow.now().shift(days=-7)
query = ( query = (
db.session.query( Session.query(
User.email, User.email,
func.count(Alias.id).label("count"), func.count(Alias.id).label("count"),
func.date(Alias.created_at).label("date"), func.date(Alias.created_at).label("date"),
@ -381,7 +381,7 @@ def stats():
stats_today = compute_metric2() stats_today = compute_metric2()
stats_yesterday = ( stats_yesterday = (
Metric2.query.filter(Metric2.date < stats_today.date) Metric2.filter(Metric2.date < stats_today.date)
.order_by(Metric2.date.desc()) .order_by(Metric2.date.desc())
.first() .first()
) )
@ -442,13 +442,13 @@ nb_referred_user_upgrade: {stats_today.nb_referred_user_paid} - {increase_percen
def migrate_domain_trash(): def migrate_domain_trash():
"""Move aliases from global trash to domain trash if applicable""" """Move aliases from global trash to domain trash if applicable"""
for deleted_alias in DeletedAlias.query.all(): for deleted_alias in DeletedAlias.all():
alias_domain = get_email_domain_part(deleted_alias.email) alias_domain = get_email_domain_part(deleted_alias.email)
if not SLDomain.get_by(domain=alias_domain): if not SLDomain.get_by(domain=alias_domain):
custom_domain = CustomDomain.get_by(domain=alias_domain) custom_domain = CustomDomain.get_by(domain=alias_domain)
if custom_domain: if custom_domain:
LOG.e("move %s to domain %s trash", deleted_alias, custom_domain) LOG.e("move %s to domain %s trash", deleted_alias, custom_domain)
db.session.add( Session.add(
DomainDeletedAlias( DomainDeletedAlias(
user_id=custom_domain.user_id, user_id=custom_domain.user_id,
email=deleted_alias.email, email=deleted_alias.email,
@ -458,13 +458,13 @@ def migrate_domain_trash():
) )
DeletedAlias.delete(deleted_alias.id) DeletedAlias.delete(deleted_alias.id)
db.session.commit() Session.commit()
def set_custom_domain_for_alias(): def set_custom_domain_for_alias():
"""Go through all aliases and make sure custom_domain is correctly set""" """Go through all aliases and make sure custom_domain is correctly set"""
sl_domains = [sl_domain.domain for sl_domain in SLDomain.query.all()] sl_domains = [sl_domain.domain for sl_domain in SLDomain.all()]
for alias in Alias.query.filter(Alias.custom_domain_id.is_(None)): for alias in Alias.filter(Alias.custom_domain_id.is_(None)):
if ( if (
not any(alias.email.endswith(f"@{sl_domain}") for sl_domain in sl_domains) not any(alias.email.endswith(f"@{sl_domain}") for sl_domain in sl_domains)
and not alias.custom_domain_id and not alias.custom_domain_id
@ -477,7 +477,7 @@ def set_custom_domain_for_alias():
else: # phantom domain else: # phantom domain
LOG.d("phantom domain %s %s %s", alias.user, alias, alias.enabled) LOG.d("phantom domain %s %s %s", alias.user, alias, alias.enabled)
db.session.commit() Session.commit()
def sanity_check(): def sanity_check():
@ -487,7 +487,7 @@ def sanity_check():
- detect if there's mailbox that's using a invalid domain - detect if there's mailbox that's using a invalid domain
""" """
mailbox_ids = ( mailbox_ids = (
db.session.query(Mailbox.id) Session.query(Mailbox.id)
.filter(Mailbox.verified.is_(True), Mailbox.disabled.is_(False)) .filter(Mailbox.verified.is_(True), Mailbox.disabled.is_(False))
.all() .all()
) )
@ -544,23 +544,23 @@ def sanity_check():
else: # reset nb check else: # reset nb check
mailbox.nb_failed_checks = 0 mailbox.nb_failed_checks = 0
db.session.commit() Session.commit()
for user in User.filter_by(activated=True).all(): for user in User.filter_by(activated=True).all():
if sanitize_email(user.email) != user.email: if sanitize_email(user.email) != user.email:
LOG.e("%s does not have sanitized email", user) LOG.e("%s does not have sanitized email", user)
for alias in Alias.query.all(): for alias in Alias.all():
if sanitize_email(alias.email) != alias.email: if sanitize_email(alias.email) != alias.email:
LOG.e("Alias %s email not sanitized", alias) LOG.e("Alias %s email not sanitized", alias)
if alias.name and "\n" in alias.name: if alias.name and "\n" in alias.name:
alias.name = alias.name.replace("\n", "") alias.name = alias.name.replace("\n", "")
db.session.commit() Session.commit()
LOG.e("Alias %s name contains linebreak %s", alias, alias.name) LOG.e("Alias %s name contains linebreak %s", alias, alias.name)
contact_email_sanity_date = arrow.get("2021-01-12") contact_email_sanity_date = arrow.get("2021-01-12")
for contact in Contact.query.all(): for contact in Contact.all():
if sanitize_email(contact.reply_email) != contact.reply_email: if sanitize_email(contact.reply_email) != contact.reply_email:
LOG.e("Contact %s reply-email not sanitized", contact) LOG.e("Contact %s reply-email not sanitized", contact)
@ -573,13 +573,13 @@ def sanity_check():
if not contact.invalid_email and not is_valid_email(contact.website_email): if not contact.invalid_email and not is_valid_email(contact.website_email):
LOG.e("%s invalid email", contact) LOG.e("%s invalid email", contact)
contact.invalid_email = True contact.invalid_email = True
db.session.commit() Session.commit()
for mailbox in Mailbox.query.all(): for mailbox in Mailbox.all():
if sanitize_email(mailbox.email) != mailbox.email: if sanitize_email(mailbox.email) != mailbox.email:
LOG.e("Mailbox %s address not sanitized", mailbox) LOG.e("Mailbox %s address not sanitized", mailbox)
for contact in Contact.query.all(): for contact in Contact.all():
if normalize_reply_email(contact.reply_email) != contact.reply_email: if normalize_reply_email(contact.reply_email) != contact.reply_email:
LOG.e( LOG.e(
"Contact %s reply email is not normalized %s", "Contact %s reply email is not normalized %s",
@ -587,7 +587,7 @@ def sanity_check():
contact.reply_email, contact.reply_email,
) )
for domain in CustomDomain.query.all(): for domain in CustomDomain.all():
if domain.name and "\n" in domain.name: if domain.name and "\n" in domain.name:
LOG.e("Domain %s name contain linebreak %s", domain, domain.name) LOG.e("Domain %s name contain linebreak %s", domain, domain.name)
@ -600,9 +600,7 @@ def sanity_check():
def check_custom_domain(): def check_custom_domain():
LOG.d("Check verified domain for DNS issues") LOG.d("Check verified domain for DNS issues")
for custom_domain in CustomDomain.query.filter_by( for custom_domain in CustomDomain.filter_by(verified=True): # type: CustomDomain
verified=True
): # type: CustomDomain
mx_domains = get_mx_domains(custom_domain.domain) mx_domains = get_mx_domains(custom_domain.domain)
if sorted(mx_domains) != sorted(EMAIL_SERVERS_WITH_PRIORITY): if sorted(mx_domains) != sorted(EMAIL_SERVERS_WITH_PRIORITY):
@ -644,7 +642,7 @@ def check_custom_domain():
# reset checks # reset checks
custom_domain.nb_failed_checks = 0 custom_domain.nb_failed_checks = 0
db.session.commit() Session.commit()
def delete_old_monitoring(): def delete_old_monitoring():
@ -652,8 +650,8 @@ def delete_old_monitoring():
Delete old monitoring records Delete old monitoring records
""" """
max_time = arrow.now().shift(days=-30) max_time = arrow.now().shift(days=-30)
nb_row = Monitoring.query.filter(Monitoring.created_at < max_time).delete() nb_row = Monitoring.filter(Monitoring.created_at < max_time).delete()
db.session.commit() Session.commit()
LOG.d("delete monitoring records older than %s, nb row %s", max_time, nb_row) LOG.d("delete monitoring records older than %s, nb row %s", max_time, nb_row)
@ -713,8 +711,8 @@ async def _hibp_check(api_key, queue):
return return
alias.hibp_last_check = arrow.utcnow() alias.hibp_last_check = arrow.utcnow()
db.session.add(alias) Session.add(alias)
db.session.commit() Session.commit()
LOG.d("Updated breaches info for %s", alias) LOG.d("Updated breaches info for %s", alias)
@ -738,14 +736,14 @@ async def check_hibp():
hibp_entry.date = arrow.get(entry["BreachDate"]) hibp_entry.date = arrow.get(entry["BreachDate"])
hibp_entry.description = entry["Description"] hibp_entry.description = entry["Description"]
db.session.commit() Session.commit()
LOG.d("Updated list of known breaches") LOG.d("Updated list of known breaches")
LOG.d("Preparing list of aliases to check") LOG.d("Preparing list of aliases to check")
queue = asyncio.Queue() queue = asyncio.Queue()
max_date = arrow.now().shift(days=-HIBP_SCAN_INTERVAL_DAYS) max_date = arrow.now().shift(days=-HIBP_SCAN_INTERVAL_DAYS)
for alias in ( for alias in (
Alias.query.filter( Alias.filter(
or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date) or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date)
) )
.filter(Alias.enabled) .filter(Alias.enabled)
@ -782,19 +780,19 @@ def notify_hibp():
""" """
# to get a list of users that have at least a breached alias # to get a list of users that have at least a breached alias
alias_query = ( alias_query = (
db.session.query(Alias) Session.query(Alias)
.options(joinedload(Alias.hibp_breaches)) .options(joinedload(Alias.hibp_breaches))
.filter(Alias.hibp_breaches.any()) .filter(Alias.hibp_breaches.any())
.filter(Alias.id.notin_(db.session.query(HibpNotifiedAlias.alias_id))) .filter(Alias.id.notin_(Session.query(HibpNotifiedAlias.alias_id)))
.distinct(Alias.user_id) .distinct(Alias.user_id)
.all() .all()
) )
user_ids = [alias.user_id for alias in alias_query] user_ids = [alias.user_id for alias in alias_query]
for user in User.query.filter(User.id.in_(user_ids)): for user in User.filter(User.id.in_(user_ids)):
breached_aliases = ( breached_aliases = (
db.session.query(Alias) Session.query(Alias)
.options(joinedload(Alias.hibp_breaches)) .options(joinedload(Alias.hibp_breaches))
.filter(Alias.hibp_breaches.any(), Alias.user_id == user.id) .filter(Alias.hibp_breaches.any(), Alias.user_id == user.id)
.all() .all()
@ -824,7 +822,7 @@ def notify_hibp():
# add the breached aliases to HibpNotifiedAlias to avoid sending another email # add the breached aliases to HibpNotifiedAlias to avoid sending another email
for alias in breached_aliases: for alias in breached_aliases:
HibpNotifiedAlias.create(user_id=user.id, alias_id=alias.id) HibpNotifiedAlias.create(user_id=user.id, alias_id=alias.id)
db.session.commit() Session.commit()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -85,6 +85,7 @@ from app.config import (
ALERT_YAHOO_COMPLAINT, ALERT_YAHOO_COMPLAINT,
TEMP_DIR, TEMP_DIR,
) )
from app.db import Session
from app.email import status, headers from app.email import status, headers
from app.email.rate_limit import rate_limited from app.email.rate_limit import rate_limited
from app.email.spam import get_spam_score from app.email.spam import get_spam_score
@ -123,7 +124,6 @@ from app.email_utils import (
parse_full_address, parse_full_address,
get_orig_message_from_yahoo_complaint, get_orig_message_from_yahoo_complaint,
) )
from app.extensions import db
from app.log import LOG, set_message_id from app.log import LOG, set_message_id
from app.models import ( from app.models import (
Alias, Alias,
@ -141,7 +141,6 @@ from app.utils import sanitize_email
from init_app import load_pgp_public_keys from init_app import load_pgp_public_keys
from server import create_app, create_light_app from server import create_app, create_light_app
newrelic_app = None newrelic_app = None
if NEWRELIC_CONFIG_PATH: if NEWRELIC_CONFIG_PATH:
newrelic.agent.initialize(NEWRELIC_CONFIG_PATH) newrelic.agent.initialize(NEWRELIC_CONFIG_PATH)
@ -157,10 +156,7 @@ def new_app():
@app.teardown_appcontext @app.teardown_appcontext
def shutdown_session(response_or_exc): def shutdown_session(response_or_exc):
# same as shutdown_session() in flask-sqlalchemy but this is not enough # same as shutdown_session() in flask-sqlalchemy but this is not enough
db.session.remove() Session.remove()
# dispose the engine too
db.engine.dispose()
return app return app
@ -210,7 +206,7 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
contact_name, contact_name,
) )
contact.name = contact_name contact.name = contact_name
db.session.commit() Session.commit()
# contact created in the past does not have mail_from and from_header field # contact created in the past does not have mail_from and from_header field
if not contact.mail_from and mail_from: if not contact.mail_from and mail_from:
@ -221,7 +217,7 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
mail_from, mail_from,
) )
contact.mail_from = mail_from contact.mail_from = mail_from
db.session.commit() Session.commit()
else: else:
LOG.d( LOG.d(
"create contact %s for alias %s", "create contact %s for alias %s",
@ -244,10 +240,10 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
LOG.d("Create a contact with invalid email for %s", alias) LOG.d("Create a contact with invalid email for %s", alias)
contact.invalid_email = True contact.invalid_email = True
db.session.commit() Session.commit()
except IntegrityError: except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_email) LOG.w("Contact %s %s already exist", alias, contact_email)
db.session.rollback() Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email) contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
return contact return contact
@ -291,10 +287,10 @@ def get_or_create_reply_to_contact(
name=contact_name, name=contact_name,
reply_email=generate_reply_email(contact_address, alias.user), reply_email=generate_reply_email(contact_address, alias.user),
) )
db.session.commit() Session.commit()
except IntegrityError: except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_address) LOG.w("Contact %s %s already exist", alias, contact_address)
db.session.rollback() Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_address) contact = Contact.get_by(alias_id=alias.id, website_email=contact_address)
return contact return contact
@ -341,7 +337,7 @@ def replace_header_when_forward(msg: Message, alias: Alias, header: str):
full_address.display_name, full_address.display_name,
) )
contact.name = full_address.display_name contact.name = full_address.display_name
db.session.commit() Session.commit()
else: else:
LOG.d( LOG.d(
"create contact for alias %s and email %s, header %s", "create contact for alias %s and email %s, header %s",
@ -359,10 +355,10 @@ def replace_header_when_forward(msg: Message, alias: Alias, header: str):
reply_email=generate_reply_email(contact_email, alias.user), reply_email=generate_reply_email(contact_email, alias.user),
is_cc=header.lower() == "cc", is_cc=header.lower() == "cc",
) )
db.session.commit() Session.commit()
except IntegrityError: except IntegrityError:
LOG.w("Contact %s %s already exist", alias, contact_email) LOG.w("Contact %s %s already exist", alias, contact_email)
db.session.rollback() Session.rollback()
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email) contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
new_addrs.append(contact.new_addr()) new_addrs.append(contact.new_addr())
@ -501,7 +497,7 @@ def handle_email_sent_to_ourself(alias, mailbox, msg: Message, user):
refused_email = RefusedEmail.create( refused_email = RefusedEmail.create(
path=None, full_report_path=full_report_path, user_id=alias.user_id path=None, full_report_path=full_report_path, user_id=alias.user_id
) )
db.session.commit() Session.commit()
LOG.d("Create refused email %s", refused_email) LOG.d("Create refused email %s", refused_email)
# link available for 6 days as it gets deleted in 7 days # link available for 6 days as it gets deleted in 7 days
refused_email_url = refused_email.get_url(expires_in=518400) refused_email_url = refused_email.get_url(expires_in=518400)
@ -588,7 +584,7 @@ def handle_forward(envelope, msg: Message, rcpt_to: str) -> List[Tuple[bool, str
alias_id=contact.alias_id, alias_id=contact.alias_id,
commit=True, commit=True,
) )
db.session.commit() Session.commit()
# do not return 5** to allow user to receive emails later when alias is enabled # do not return 5** to allow user to receive emails later when alias is enabled
return [(True, status.E200)] return [(True, status.E200)]
@ -695,7 +691,7 @@ def forward_email_to_mailbox(
spam_report, spam_report,
) )
email_log.spam_score = spam_score email_log.spam_score = spam_score
db.session.commit() Session.commit()
if (user.max_spam_score and spam_score > user.max_spam_score) or ( if (user.max_spam_score and spam_score > user.max_spam_score) or (
not user.max_spam_score and spam_score > MAX_SPAM_SCORE not user.max_spam_score and spam_score > MAX_SPAM_SCORE
@ -716,7 +712,7 @@ def forward_email_to_mailbox(
) )
email_log.is_spam = True email_log.is_spam = True
email_log.spam_status = spam_status email_log.spam_status = spam_status
db.session.commit() Session.commit()
handle_spam(contact, alias, msg, user, mailbox, email_log) handle_spam(contact, alias, msg, user, mailbox, email_log)
return False, status.E519 return False, status.E519
@ -846,7 +842,7 @@ def forward_email_to_mailbox(
else: else:
return False, status.E521 return False, status.E521
else: else:
db.session.commit() Session.commit()
return True, status.E200 return True, status.E200
@ -963,7 +959,7 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
email_log.is_spam = True email_log.is_spam = True
email_log.spam_status = spam_status email_log.spam_status = spam_status
db.session.commit() Session.commit()
handle_spam(contact, alias, msg, user, mailbox, email_log, is_reply=True) handle_spam(contact, alias, msg, user, mailbox, email_log, is_reply=True)
return False, status.E506 return False, status.E506
@ -1003,11 +999,11 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
) )
# to not save the email_log # to not save the email_log
EmailLog.delete(email_log.id) EmailLog.delete(email_log.id)
db.session.commit() Session.commit()
# return 421 so the client can retry later # return 421 so the client can retry later
return False, status.E402 return False, status.E402
db.session.commit() Session.commit()
# make the email comes from alias # make the email comes from alias
from_header = alias.email from_header = alias.email
@ -1065,7 +1061,7 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
) )
except Exception: except Exception:
# to not save the email_log # to not save the email_log
db.session.rollback() Session.rollback()
LOG.w("Cannot send email from %s to %s", alias, contact) LOG.w("Cannot send email from %s to %s", alias, contact)
send_email( send_email(
@ -1218,13 +1214,13 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
refused_email = RefusedEmail.create( refused_email = RefusedEmail.create(
path=file_path, full_report_path=full_report_path, user_id=user.id path=file_path, full_report_path=full_report_path, user_id=user.id
) )
db.session.flush() Session.flush()
LOG.d("Create refused email %s", refused_email) LOG.d("Create refused email %s", refused_email)
email_log.bounced = True email_log.bounced = True
email_log.refused_email_id = refused_email.id email_log.refused_email_id = refused_email.id
email_log.bounced_mailbox_id = mailbox.id email_log.bounced_mailbox_id = mailbox.id
db.session.commit() Session.commit()
refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}" refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}"
@ -1268,7 +1264,7 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
alias, alias,
) )
alias.enabled = False alias.enabled = False
db.session.commit() Session.commit()
send_email_with_rate_control( send_email_with_rate_control(
user, user,
@ -1411,7 +1407,7 @@ def handle_bounce_reply_phase(envelope, msg: Message, email_log: EmailLog):
email_log.bounced_mailbox_id = mailbox.id email_log.bounced_mailbox_id = mailbox.id
db.session.commit() Session.commit()
refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}" refused_email_url = f"{URL}/dashboard/refused_email?highlight_id={email_log.id}"
@ -1469,10 +1465,10 @@ def handle_spam(
refused_email = RefusedEmail.create( refused_email = RefusedEmail.create(
path=file_path, full_report_path=full_report_path, user_id=user.id path=file_path, full_report_path=full_report_path, user_id=user.id
) )
db.session.flush() Session.flush()
email_log.refused_email_id = refused_email.id email_log.refused_email_id = refused_email.id
db.session.commit() Session.commit()
LOG.d("Create spam email %s", refused_email) LOG.d("Create spam email %s", refused_email)
@ -1574,7 +1570,7 @@ def handle_unsubscribe(envelope: Envelope, msg: Message) -> str:
# Sender is owner of this alias # Sender is owner of this alias
alias.enabled = False alias.enabled = False
db.session.commit() Session.commit()
user = alias.user user = alias.user
enable_alias_url = URL + f"/dashboard/?highlight_alias_id={alias.id}" enable_alias_url = URL + f"/dashboard/?highlight_alias_id={alias.id}"
@ -1611,7 +1607,7 @@ def handle_unsubscribe_user(user_id: int, mail_from: str) -> str:
return status.E511 return status.E511
user.notification = False user.notification = False
db.session.commit() Session.commit()
send_email( send_email(
user.email, user.email,
@ -1676,7 +1672,7 @@ def handle_bounce(envelope, email_log: EmailLog, msg: Message) -> str:
alias = contact.alias alias = contact.alias
email_log.auto_replied = True email_log.auto_replied = True
db.session.commit() Session.commit()
# replace the BOUNCE_EMAIL by alias in To field # replace the BOUNCE_EMAIL by alias in To field
add_or_replace_header(msg, "To", alias.email) add_or_replace_header(msg, "To", alias.email)
@ -1871,8 +1867,8 @@ def handle(envelope: Envelope) -> str:
"total number email log on %s, %s is %s, %s", "total number email log on %s, %s is %s, %s",
alias, alias,
alias.user, alias.user,
EmailLog.query.filter(EmailLog.alias_id == alias.id).count(), EmailLog.filter(EmailLog.alias_id == alias.id).count(),
EmailLog.query.filter(EmailLog.user_id == alias.user_id).count(), EmailLog.filter(EmailLog.user_id == alias.user_id).count(),
) )
if should_ignore_bounce(envelope.mail_from): if should_ignore_bounce(envelope.mail_from):

View File

@ -1,15 +1,14 @@
"""Initial loading script"""
from app.config import ALIAS_DOMAINS, PREMIUM_ALIAS_DOMAINS from app.config import ALIAS_DOMAINS, PREMIUM_ALIAS_DOMAINS
from app.models import Mailbox, Contact, SLDomain from app.db import Session
from app.log import LOG from app.log import LOG
from app.extensions import db from app.models import Mailbox, Contact, SLDomain
from app.pgp_utils import load_public_key from app.pgp_utils import load_public_key
from server import create_app from server import create_app
def load_pgp_public_keys(): def load_pgp_public_keys():
"""Load PGP public key to keyring""" """Load PGP public key to keyring"""
for mailbox in Mailbox.query.filter(Mailbox.pgp_public_key.isnot(None)).all(): for mailbox in Mailbox.filter(Mailbox.pgp_public_key.isnot(None)).all():
LOG.d("Load PGP key for mailbox %s", mailbox) LOG.d("Load PGP key for mailbox %s", mailbox)
fingerprint = load_public_key(mailbox.pgp_public_key) fingerprint = load_public_key(mailbox.pgp_public_key)
@ -17,9 +16,9 @@ def load_pgp_public_keys():
if fingerprint != mailbox.pgp_finger_print: if fingerprint != mailbox.pgp_finger_print:
LOG.e("fingerprint %s different for mailbox %s", fingerprint, mailbox) LOG.e("fingerprint %s different for mailbox %s", fingerprint, mailbox)
mailbox.pgp_finger_print = fingerprint mailbox.pgp_finger_print = fingerprint
db.session.commit() Session.commit()
for contact in Contact.query.filter(Contact.pgp_public_key.isnot(None)).all(): for contact in Contact.filter(Contact.pgp_public_key.isnot(None)).all():
LOG.d("Load PGP key for %s", contact) LOG.d("Load PGP key for %s", contact)
fingerprint = load_public_key(contact.pgp_public_key) fingerprint = load_public_key(contact.pgp_public_key)
@ -28,7 +27,7 @@ def load_pgp_public_keys():
LOG.e("fingerprint %s different for contact %s", fingerprint, contact) LOG.e("fingerprint %s different for contact %s", fingerprint, contact)
contact.pgp_finger_print = fingerprint contact.pgp_finger_print = fingerprint
db.session.commit() Session.commit()
LOG.d("Finish load_pgp_public_keys") LOG.d("Finish load_pgp_public_keys")
@ -48,7 +47,7 @@ def add_sl_domains():
LOG.i("Add %s to SL domain", premium_domain) LOG.i("Add %s to SL domain", premium_domain)
SLDomain.create(domain=premium_domain, premium_only=True) SLDomain.create(domain=premium_domain, premium_only=True)
db.session.commit() Session.commit()
if __name__ == "__main__": if __name__ == "__main__":

View File

@ -13,11 +13,11 @@ from app.config import (
JOB_BATCH_IMPORT, JOB_BATCH_IMPORT,
JOB_DELETE_ACCOUNT, JOB_DELETE_ACCOUNT,
) )
from app.db import Session
from app.email_utils import ( from app.email_utils import (
send_email, send_email,
render, render,
) )
from app.extensions import db
from app.import_utils import handle_batch_import from app.import_utils import handle_batch_import
from app.log import LOG from app.log import LOG
from app.models import User, Job, BatchImport from app.models import User, Job, BatchImport
@ -32,10 +32,7 @@ def new_app():
@app.teardown_appcontext @app.teardown_appcontext
def shutdown_session(response_or_exc): def shutdown_session(response_or_exc):
# same as shutdown_session() in flask-sqlalchemy but this is not enough # same as shutdown_session() in flask-sqlalchemy but this is not enough
db.session.remove() Session.remove()
# dispose the engine too
db.engine.dispose()
return app return app
@ -109,14 +106,14 @@ if __name__ == "__main__":
app = new_app() app = new_app()
with app.app_context(): with app.app_context():
for job in Job.query.filter( for job in Job.filter(
Job.taken.is_(False), Job.run_at > min_dt, Job.run_at <= max_dt Job.taken.is_(False), Job.run_at > min_dt, Job.run_at <= max_dt
).all(): ).all():
LOG.d("Take job %s", job) LOG.d("Take job %s", job)
# mark the job as taken, whether it will be executed successfully or not # mark the job as taken, whether it will be executed successfully or not
job.taken = True job.taken = True
db.session.commit() Session.commit()
if job.name == JOB_ONBOARDING_1: if job.name == JOB_ONBOARDING_1:
user_id = job.payload.get("user_id") user_id = job.payload.get("user_id")
@ -161,7 +158,7 @@ if __name__ == "__main__":
user_email = user.email user_email = user.email
LOG.w("Delete user %s", user) LOG.w("Delete user %s", user)
User.delete(user.id) User.delete(user.id)
db.session.commit() Session.commit()
send_email( send_email(
user_email, user_email,

View File

@ -2,7 +2,7 @@ import os
from time import sleep from time import sleep
from app.config import HOST from app.config import HOST
from app.extensions import db from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import Monitoring from app.models import Monitoring
from server import create_app from server import create_app
@ -32,7 +32,7 @@ def get_stats():
active_queue=active_queue, active_queue=active_queue,
deferred_queue=deferred_queue, deferred_queue=deferred_queue,
) )
db.session.commit() Session.commit()
global _nb_failed global _nb_failed
# alert when too many emails in incoming + active queue # alert when too many emails in incoming + active queue

View File

@ -71,10 +71,11 @@ from app.config import (
ROOT_DIR, ROOT_DIR,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.discover.base import discover_bp from app.discover.base import discover_bp
from app.email_utils import send_email, render from app.email_utils import send_email, render
from app.extensions import db, login_manager, migrate, limiter from app.extensions import login_manager, migrate, limiter
from app.jose_utils import get_jwk_key from app.jose_utils import get_jwk_key
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
@ -129,8 +130,6 @@ def create_light_app() -> Flask:
app.config["SQLALCHEMY_DATABASE_URI"] = DB_URI app.config["SQLALCHEMY_DATABASE_URI"] = DB_URI
app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False app.config["SQLALCHEMY_TRACK_MODIFICATIONS"] = False
db.init_app(app)
return app return app
@ -200,6 +199,10 @@ def create_app() -> Flask:
session.permanent = True session.permanent = True
app.permanent_session_lifetime = timedelta(days=7) app.permanent_session_lifetime = timedelta(days=7)
@app.teardown_appcontext
def cleanup(resp_or_exc):
Session.remove()
return app return app
@ -219,7 +222,7 @@ def fake_data():
fido_uuid=None, fido_uuid=None,
) )
user.trial_end = None user.trial_end = None
db.session.commit() Session.commit()
# add a profile picture # add a profile picture
file_path = "profile_pic.svg" file_path = "profile_pic.svg"
@ -230,11 +233,11 @@ def fake_data():
) )
file = File.create(user_id=user.id, path=file_path, commit=True) file = File.create(user_id=user.id, path=file_path, commit=True)
user.profile_picture_id = file.id user.profile_picture_id = file.id
db.session.commit() Session.commit()
# create a bounced email # create a bounced email
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
bounce_email_file_path = "bounce.eml" bounce_email_file_path = "bounce.eml"
s3.upload_email_from_bytesio( s3.upload_email_from_bytesio(
@ -298,7 +301,7 @@ def fake_data():
pgp_public_key=pgp_public_key, pgp_public_key=pgp_public_key,
) )
m1.pgp_finger_print = load_public_key(pgp_public_key) m1.pgp_finger_print = load_public_key(pgp_public_key)
db.session.commit() Session.commit()
# example@example.com is in a LOT of data breaches # example@example.com is in a LOT of data breaches
Alias.create(email="example@example.com", user_id=user.id, mailbox_id=m1.id) Alias.create(email="example@example.com", user_id=user.id, mailbox_id=m1.id)
@ -314,14 +317,14 @@ def fake_data():
user_id=user.id, user_id=user.id,
mailbox_id=user.default_mailbox_id, mailbox_id=user.default_mailbox_id,
) )
db.session.commit() Session.commit()
if i % 5 == 0: if i % 5 == 0:
if i % 2 == 0: if i % 2 == 0:
AliasMailbox.create(alias_id=a.id, mailbox_id=user.default_mailbox_id) AliasMailbox.create(alias_id=a.id, mailbox_id=user.default_mailbox_id)
else: else:
AliasMailbox.create(alias_id=a.id, mailbox_id=m1.id) AliasMailbox.create(alias_id=a.id, mailbox_id=m1.id)
db.session.commit() Session.commit()
# some aliases don't have any activity # some aliases don't have any activity
# if i % 3 != 0: # if i % 3 != 0:
@ -331,18 +334,18 @@ def fake_data():
# website_email=f"contact{i}@example.com", # website_email=f"contact{i}@example.com",
# reply_email=f"rep{i}@sl.local", # reply_email=f"rep{i}@sl.local",
# ) # )
# db.session.commit() # Session.commit()
# for _ in range(3): # for _ in range(3):
# EmailLog.create(user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id) # EmailLog.create(user_id=user.id, contact_id=contact.id, alias_id=contact.alias_id)
# db.session.commit() # Session.commit()
# have some disabled alias # have some disabled alias
if i % 5 == 0: if i % 5 == 0:
a.enabled = False a.enabled = False
db.session.commit() Session.commit()
custom_domain1 = CustomDomain.create(user_id=user.id, domain="ab.cd", verified=True) custom_domain1 = CustomDomain.create(user_id=user.id, domain="ab.cd", verified=True)
db.session.commit() Session.commit()
Alias.create( Alias.create(
user_id=user.id, user_id=user.id,
@ -362,13 +365,13 @@ def fake_data():
Directory.create(user_id=user.id, name="abcd") Directory.create(user_id=user.id, name="abcd")
Directory.create(user_id=user.id, name="xyzt") Directory.create(user_id=user.id, name="xyzt")
db.session.commit() Session.commit()
# Create a client # Create a client
client1 = Client.create_new(name="Demo", user_id=user.id) client1 = Client.create_new(name="Demo", user_id=user.id)
client1.oauth_client_id = "client-id" client1.oauth_client_id = "client-id"
client1.oauth_client_secret = "client-secret" client1.oauth_client_secret = "client-secret"
db.session.commit() Session.commit()
RedirectUri.create( RedirectUri.create(
client_id=client1.id, uri="https://your-website.com/oauth-callback" client_id=client1.id, uri="https://your-website.com/oauth-callback"
@ -377,7 +380,7 @@ def fake_data():
client2 = Client.create_new(name="Demo 2", user_id=user.id) client2 = Client.create_new(name="Demo 2", user_id=user.id)
client2.oauth_client_id = "client-id2" client2.oauth_client_id = "client-id2"
client2.oauth_client_secret = "client-secret2" client2.oauth_client_secret = "client-secret2"
db.session.commit() Session.commit()
ClientUser.create(user_id=user.id, client_id=client1.id, name="Fake Name") ClientUser.create(user_id=user.id, client_id=client1.id, name="Fake Name")
@ -392,11 +395,11 @@ def fake_data():
number_upgraded_account=200, number_upgraded_account=200,
payment_method="PayPal", payment_method="PayPal",
) )
db.session.commit() Session.commit()
for i in range(6): for i in range(6):
Notification.create(user_id=user.id, message=f"""Hey hey <b>{i}</b> """ * 10) Notification.create(user_id=user.id, message=f"""Hey hey <b>{i}</b> """ * 10)
db.session.commit() Session.commit()
user2 = User.create( user2 = User.create(
email="winston@continental.com", email="winston@continental.com",
@ -405,7 +408,7 @@ def fake_data():
referral_id=referral.id, referral_id=referral.id,
) )
Mailbox.create(user_id=user2.id, email="winston2@high.table", verified=True) Mailbox.create(user_id=user2.id, email="winston2@high.table", verified=True)
db.session.commit() Session.commit()
ManualSubscription.create( ManualSubscription.create(
user_id=user2.id, user_id=user2.id,
@ -695,7 +698,7 @@ def setup_paddle_callback(app: Flask):
LOG.d("User %s upgrades!", user) LOG.d("User %s upgrades!", user)
db.session.commit() Session.commit()
elif request.form.get("alert_name") == "subscription_payment_succeeded": elif request.form.get("alert_name") == "subscription_payment_succeeded":
subscription_id = request.form.get("subscription_id") subscription_id = request.form.get("subscription_id")
@ -710,7 +713,7 @@ def setup_paddle_callback(app: Flask):
request.form.get("next_bill_date"), "YYYY-MM-DD" request.form.get("next_bill_date"), "YYYY-MM-DD"
).date() ).date()
db.session.commit() Session.commit()
elif request.form.get("alert_name") == "subscription_cancelled": elif request.form.get("alert_name") == "subscription_cancelled":
subscription_id = request.form.get("subscription_id") subscription_id = request.form.get("subscription_id")
@ -728,7 +731,7 @@ def setup_paddle_callback(app: Flask):
sub.event_time = arrow.now() sub.event_time = arrow.now()
sub.cancelled = True sub.cancelled = True
db.session.commit() Session.commit()
user = sub.user user = sub.user
@ -774,7 +777,7 @@ def setup_paddle_callback(app: Flask):
# make sure to set the new plan as not-cancelled # make sure to set the new plan as not-cancelled
sub.cancelled = False sub.cancelled = False
db.session.commit() Session.commit()
else: else:
return "No such subscription", 400 return "No such subscription", 400
return "OK" return "OK"
@ -847,7 +850,7 @@ def handle_coinbase_event(event) -> bool:
else: # already expired subscription else: # already expired subscription
coinbase_subscription.end_at = arrow.now().shift(years=1) coinbase_subscription.end_at = arrow.now().shift(years=1)
db.session.commit() Session.commit()
send_email( send_email(
user.email, user.email,
@ -867,7 +870,6 @@ def handle_coinbase_event(event) -> bool:
def init_extensions(app: Flask): def init_extensions(app: Flask):
login_manager.init_app(app) login_manager.init_app(app)
db.init_app(app)
migrate.init_app(app) migrate.init_app(app)
@ -875,17 +877,17 @@ def init_admin(app):
admin = Admin(name="SimpleLogin", template_mode="bootstrap4") admin = Admin(name="SimpleLogin", template_mode="bootstrap4")
admin.init_app(app, index_view=SLAdminIndexView()) admin.init_app(app, index_view=SLAdminIndexView())
admin.add_view(UserAdmin(User, db.session)) admin.add_view(UserAdmin(User, Session))
admin.add_view(AliasAdmin(Alias, db.session)) admin.add_view(AliasAdmin(Alias, Session))
admin.add_view(MailboxAdmin(Mailbox, db.session)) admin.add_view(MailboxAdmin(Mailbox, Session))
admin.add_view(EmailLogAdmin(EmailLog, db.session)) admin.add_view(EmailLogAdmin(EmailLog, Session))
admin.add_view(LifetimeCouponAdmin(LifetimeCoupon, db.session)) admin.add_view(LifetimeCouponAdmin(LifetimeCoupon, Session))
admin.add_view(CouponAdmin(Coupon, db.session)) admin.add_view(CouponAdmin(Coupon, Session))
admin.add_view(ManualSubscriptionAdmin(ManualSubscription, db.session)) admin.add_view(ManualSubscriptionAdmin(ManualSubscription, Session))
admin.add_view(ClientAdmin(Client, db.session)) admin.add_view(ClientAdmin(Client, Session))
admin.add_view(CustomDomainAdmin(CustomDomain, db.session)) admin.add_view(CustomDomainAdmin(CustomDomain, Session))
admin.add_view(ReferralAdmin(Referral, db.session)) admin.add_view(ReferralAdmin(Referral, Session))
admin.add_view(PayoutAdmin(Payout, db.session)) admin.add_view(PayoutAdmin(Payout, Session))
def register_custom_commands(app): def register_custom_commands(app):
@ -900,12 +902,12 @@ def register_custom_commands(app):
def fill_up_email_log_alias(): def fill_up_email_log_alias():
"""Fill up email_log.alias_id column""" """Fill up email_log.alias_id column"""
# split all emails logs into 1000-size trunks # split all emails logs into 1000-size trunks
nb_email_log = EmailLog.query.count() nb_email_log = EmailLog.count()
LOG.d("total trunks %s", nb_email_log // 1000 + 2) LOG.d("total trunks %s", nb_email_log // 1000 + 2)
for trunk in reversed(range(1, nb_email_log // 1000 + 2)): for trunk in reversed(range(1, nb_email_log // 1000 + 2)):
nb_update = 0 nb_update = 0
for email_log, contact in ( for email_log, contact in (
db.session.query(EmailLog, Contact) Session.query(EmailLog, Contact)
.filter(EmailLog.contact_id == Contact.id) .filter(EmailLog.contact_id == Contact.id)
.filter(EmailLog.id <= trunk * 1000) .filter(EmailLog.id <= trunk * 1000)
.filter(EmailLog.id > (trunk - 1) * 1000) .filter(EmailLog.id > (trunk - 1) * 1000)
@ -915,7 +917,7 @@ def register_custom_commands(app):
nb_update += 1 nb_update += 1
LOG.d("finish trunk %s, update %s email logs", trunk, nb_update) LOG.d("finish trunk %s, update %s email logs", trunk, nb_update)
db.session.commit() Session.commit()
@app.cli.command("dummy-data") @app.cli.command("dummy-data")
def dummy_data(): def dummy_data():

View File

@ -1,3 +1,5 @@
from app.db import Session
from app.db import Session
from time import sleep from time import sleep
import flask_migrate import flask_migrate
@ -34,7 +36,7 @@ def create_db():
def change_password(user_id, new_password): def change_password(user_id, new_password):
user = User.get(user_id) user = User.get(user_id)
user.set_password(new_password) user.set_password(new_password)
db.session.commit() Session.commit()
def reset_db(): def reset_db():
@ -44,7 +46,7 @@ def reset_db():
def send_mailbox_newsletter(): def send_mailbox_newsletter():
for user in User.query.order_by(User.id).all(): for user in User.order_by(User.id).all():
if user.notification and user.activated: if user.notification and user.activated:
try: try:
LOG.d("Send newsletter to %s", user) LOG.d("Send newsletter to %s", user)
@ -60,7 +62,7 @@ def send_mailbox_newsletter():
def send_pgp_newsletter(): def send_pgp_newsletter():
for user in User.query.order_by(User.id).all(): for user in User.order_by(User.id).all():
if user.notification and user.activated: if user.notification and user.activated:
try: try:
LOG.d("Send PGP newsletter to %s", user) LOG.d("Send PGP newsletter to %s", user)
@ -77,7 +79,7 @@ def send_pgp_newsletter():
def send_mobile_newsletter(): def send_mobile_newsletter():
count = 0 count = 0
for user in User.query.order_by(User.id).all(): for user in User.order_by(User.id).all():
if user.notification and user.activated: if user.notification and user.activated:
count += 1 count += 1
try: try:
@ -104,7 +106,7 @@ def disable_mailbox(mailbox_id):
for alias in mailbox.aliases: for alias in mailbox.aliases:
alias.enabled = False alias.enabled = False
db.session.commit() Session.commit()
email_msg = f"""Hi, email_msg = f"""Hi,

View File

@ -1,8 +1,8 @@
from flask import url_for from flask import url_for
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.db import Session
from app.email_utils import is_reply_email from app.email_utils import is_reply_email
from app.extensions import db
from app.models import User, ApiKey, Alias, Contact, EmailLog, Mailbox from app.models import User, ApiKey, Alias, Contact, EmailLog, Mailbox
from tests.utils import login from tests.utils import login
@ -18,7 +18,7 @@ def test_get_aliases_error_without_pagination(flask_client):
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.get_aliases"), headers={"Authentication": api_key.code} url_for("api.get_aliases"), headers={"Authentication": api_key.code}
@ -39,12 +39,12 @@ def test_get_aliases_with_pagination(flask_client):
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# create more aliases than PAGE_LIMIT # create more aliases than PAGE_LIMIT
for _ in range(PAGE_LIMIT + 1): for _ in range(PAGE_LIMIT + 1):
Alias.create_new_random(user) Alias.create_new_random(user)
db.session.commit() Session.commit()
# get aliases on the 1st page, should return PAGE_LIMIT aliases # get aliases on the 1st page, should return PAGE_LIMIT aliases
r = flask_client.get( r = flask_client.get(
@ -79,16 +79,16 @@ def test_get_aliases_query(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# create more aliases than PAGE_LIMIT # create more aliases than PAGE_LIMIT
Alias.create_new(user, "prefix1") Alias.create_new(user, "prefix1")
Alias.create_new(user, "prefix2") Alias.create_new(user, "prefix2")
db.session.commit() Session.commit()
# get aliases without query, should return 3 aliases as one alias is created when user is created # get aliases without query, should return 3 aliases as one alias is created when user is created
r = flask_client.get( r = flask_client.get(
@ -111,15 +111,15 @@ def test_get_aliases_v2(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
a0 = Alias.create_new(user, "prefix0") a0 = Alias.create_new(user, "prefix0")
a1 = Alias.create_new(user, "prefix1") a1 = Alias.create_new(user, "prefix1")
db.session.commit() Session.commit()
# << Aliases have no activity >> # << Aliases have no activity >>
r = flask_client.get( r = flask_client.get(
@ -154,13 +154,13 @@ def test_get_aliases_v2(flask_client):
website_email="c0@example.com", website_email="c0@example.com",
reply_email="re0@SL", reply_email="re0@SL",
) )
db.session.commit() Session.commit()
EmailLog.create( EmailLog.create(
contact_id=c0.id, contact_id=c0.id,
user_id=user.id, user_id=user.id,
alias_id=c0.alias_id, alias_id=c0.alias_id,
) )
db.session.commit() Session.commit()
# a1 has more recent activity # a1 has more recent activity
c1 = Contact.create( c1 = Contact.create(
@ -169,13 +169,13 @@ def test_get_aliases_v2(flask_client):
website_email="c1@example.com", website_email="c1@example.com",
reply_email="re1@SL", reply_email="re1@SL",
) )
db.session.commit() Session.commit()
EmailLog.create( EmailLog.create(
contact_id=c1.id, contact_id=c1.id,
user_id=user.id, user_id=user.id,
alias_id=c1.alias_id, alias_id=c1.alias_id,
) )
db.session.commit() Session.commit()
# get aliases v2 # get aliases v2
r = flask_client.get( r = flask_client.get(
@ -199,14 +199,14 @@ def test_delete_alias(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.delete( r = flask_client.delete(
url_for("api.delete_alias", alias_id=alias.id), url_for("api.delete_alias", alias_id=alias.id),
@ -221,14 +221,14 @@ def test_toggle_alias(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.toggle_alias", alias_id=alias.id), url_for("api.toggle_alias", alias_id=alias.id),
@ -243,14 +243,14 @@ def test_alias_activities(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# create some alias log # create some alias log
contact = Contact.create( contact = Contact.create(
@ -259,7 +259,7 @@ def test_alias_activities(flask_client):
alias_id=alias.id, alias_id=alias.id,
user_id=alias.user_id, user_id=alias.user_id,
) )
db.session.commit() Session.commit()
for _ in range(int(PAGE_LIMIT / 2)): for _ in range(int(PAGE_LIMIT / 2)):
EmailLog.create( EmailLog.create(
@ -304,14 +304,14 @@ def test_update_alias(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
url_for("api.update_alias", alias_id=alias.id), url_for("api.update_alias", alias_id=alias.id),
@ -326,16 +326,16 @@ def test_update_alias_mailbox(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
mb = Mailbox.create(user_id=user.id, email="ab@cd.com", verified=True) mb = Mailbox.create(user_id=user.id, email="ab@cd.com", verified=True)
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
url_for("api.update_alias", alias_id=alias.id), url_for("api.update_alias", alias_id=alias.id),
@ -358,14 +358,14 @@ def test_update_alias_name(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
url_for("api.update_alias", alias_id=alias.id), url_for("api.update_alias", alias_id=alias.id),
@ -391,17 +391,17 @@ def test_update_alias_mailboxes(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
mb1 = Mailbox.create(user_id=user.id, email="ab1@cd.com", verified=True) mb1 = Mailbox.create(user_id=user.id, email="ab1@cd.com", verified=True)
mb2 = Mailbox.create(user_id=user.id, email="ab2@cd.com", verified=True) mb2 = Mailbox.create(user_id=user.id, email="ab2@cd.com", verified=True)
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
url_for("api.update_alias", alias_id=alias.id), url_for("api.update_alias", alias_id=alias.id),
@ -428,14 +428,14 @@ def test_update_disable_pgp(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert not alias.disable_pgp assert not alias.disable_pgp
r = flask_client.put( r = flask_client.put(
@ -468,14 +468,14 @@ def test_alias_contacts(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# create some alias log # create some alias log
for i in range(PAGE_LIMIT + 1): for i in range(PAGE_LIMIT + 1):
@ -485,7 +485,7 @@ def test_alias_contacts(flask_client):
alias_id=alias.id, alias_id=alias.id,
user_id=alias.user_id, user_id=alias.user_id,
) )
db.session.commit() Session.commit()
EmailLog.create( EmailLog.create(
contact_id=contact.id, contact_id=contact.id,
@ -493,7 +493,7 @@ def test_alias_contacts(flask_client):
user_id=contact.user_id, user_id=contact.user_id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
) )
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.get_alias_contacts_route", alias_id=alias.id, page_id=0), url_for("api.get_alias_contacts_route", alias_id=alias.id, page_id=0),
@ -523,14 +523,14 @@ def test_create_contact_route(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),
@ -560,7 +560,7 @@ def test_create_contact_route(flask_client):
def test_create_contact_route_empty_contact_address(flask_client): def test_create_contact_route_empty_contact_address(flask_client):
login(flask_client) login(flask_client)
alias = Alias.query.first() alias = Alias.first()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),
@ -573,7 +573,7 @@ def test_create_contact_route_empty_contact_address(flask_client):
def test_create_contact_route_invalid_contact_email(flask_client): def test_create_contact_route_invalid_contact_email(flask_client):
login(flask_client) login(flask_client)
alias = Alias.query.first() alias = Alias.first()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),
@ -588,14 +588,14 @@ def test_delete_contact(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
contact = Contact.create( contact = Contact.create(
alias_id=alias.id, alias_id=alias.id,
@ -603,7 +603,7 @@ def test_delete_contact(flask_client):
reply_email="reply+random@sl.io", reply_email="reply+random@sl.io",
user_id=alias.user_id, user_id=alias.user_id,
) )
db.session.commit() Session.commit()
r = flask_client.delete( r = flask_client.delete(
url_for("api.delete_contact", contact_id=contact.id), url_for("api.delete_contact", contact_id=contact.id),
@ -618,15 +618,15 @@ def test_get_alias(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# create more aliases than PAGE_LIMIT # create more aliases than PAGE_LIMIT
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# get aliases on the 1st page, should return PAGE_LIMIT aliases # get aliases on the 1st page, should return PAGE_LIMIT aliases
r = flask_client.get( r = flask_client.get(

View File

@ -2,7 +2,7 @@ import json
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import User, ApiKey, AliasUsedOn, Alias from app.models import User, ApiKey, AliasUsedOn, Alias
@ -10,11 +10,11 @@ def test_different_scenarios_v4(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# <<< without hostname >>> # <<< without hostname >>>
r = flask_client.get( r = flask_client.get(
@ -37,11 +37,11 @@ def test_different_scenarios_v4(flask_client):
# <<< with recommendation >>> # <<< with recommendation >>>
alias = Alias.create_new(user, prefix="test") alias = Alias.create_new(user, prefix="test")
db.session.commit() Session.commit()
AliasUsedOn.create( AliasUsedOn.create(
alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id
) )
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.options_v4", hostname="www.test.com"), url_for("api.options_v4", hostname="www.test.com"),
@ -55,11 +55,11 @@ def test_different_scenarios_v4_2(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# <<< without hostname >>> # <<< without hostname >>>
r = flask_client.get( r = flask_client.get(
@ -85,11 +85,11 @@ def test_different_scenarios_v4_2(flask_client):
# <<< with recommendation >>> # <<< with recommendation >>>
alias = Alias.create_new(user, prefix="test") alias = Alias.create_new(user, prefix="test")
db.session.commit() Session.commit()
AliasUsedOn.create( AliasUsedOn.create(
alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id
) )
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.options_v4", hostname="www.test.com"), url_for("api.options_v4", hostname="www.test.com"),
@ -103,11 +103,11 @@ def test_different_scenarios_v5(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# <<< without hostname >>> # <<< without hostname >>>
r = flask_client.get( r = flask_client.get(
@ -138,11 +138,11 @@ def test_different_scenarios_v5(flask_client):
# <<< with recommendation >>> # <<< with recommendation >>>
alias = Alias.create_new(user, prefix="test") alias = Alias.create_new(user, prefix="test")
db.session.commit() Session.commit()
AliasUsedOn.create( AliasUsedOn.create(
alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id alias_id=alias.id, hostname="www.test.com", user_id=alias.user_id
) )
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.options_v4", hostname="www.test.com"), url_for("api.options_v4", hostname="www.test.com"),

File diff suppressed because one or more lines are too long

View File

@ -1,9 +1,8 @@
import unicodedata
import pytest import pytest
import unicodedata
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import User, AccountActivation from app.models import User, AccountActivation
PASSWORD_1 = "Aurélie" PASSWORD_1 = "Aurélie"
@ -20,7 +19,7 @@ def test_auth_login_success(flask_client, mfa: bool):
activated=True, activated=True,
enable_otp=mfa, enable_otp=mfa,
) )
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.auth_login"), url_for("api.auth_login"),
@ -49,7 +48,7 @@ def test_auth_login_device_exist(flask_client):
User.create( User.create(
email="abcd@gmail.com", password="password", name="Test User", activated=True email="abcd@gmail.com", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.auth_login"), url_for("api.auth_login"),
@ -138,7 +137,7 @@ def test_auth_activate_user_already_activated(flask_client):
User.create( User.create(
email="abcd@gmail.com", password="password", name="Test User", activated=True email="abcd@gmail.com", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.auth_activate"), json={"email": "abcd@gmail.com", "code": "123456"} url_for("api.auth_activate"), json={"email": "abcd@gmail.com", "code": "123456"}
@ -214,7 +213,7 @@ def test_auth_activate_too_many_wrong_code(flask_client):
def test_auth_reactivate_success(flask_client): def test_auth_reactivate_success(flask_client):
User.create(email="abcd@gmail.com", password="password", name="Test User") User.create(email="abcd@gmail.com", password="password", name="Test User")
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.auth_reactivate"), json={"email": "abcd@gmail.com"} url_for("api.auth_reactivate"), json={"email": "abcd@gmail.com"}
@ -232,7 +231,7 @@ def test_auth_login_forgot_password(flask_client):
User.create( User.create(
email="abcd@gmail.com", password="password", name="Test User", activated=True email="abcd@gmail.com", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.forgot_password"), url_for("api.forgot_password"),

View File

@ -3,7 +3,7 @@ from flask import url_for
from itsdangerous import Signer from itsdangerous import Signer
from app.config import FLASK_SECRET from app.config import FLASK_SECRET
from app.extensions import db from app.db import Session
from app.models import User from app.models import User
@ -16,7 +16,7 @@ def test_auth_mfa_success(flask_client):
enable_otp=True, enable_otp=True,
otp_secret="base32secret3232", otp_secret="base32secret3232",
) )
db.session.commit() Session.commit()
totp = pyotp.TOTP(user.otp_secret) totp = pyotp.TOTP(user.otp_secret)
s = Signer(FLASK_SECRET) s = Signer(FLASK_SECRET)
@ -42,7 +42,7 @@ def test_auth_wrong_mfa_key(flask_client):
enable_otp=True, enable_otp=True,
otp_secret="base32secret3232", otp_secret="base32secret3232",
) )
db.session.commit() Session.commit()
totp = pyotp.TOTP(user.otp_secret) totp = pyotp.TOTP(user.otp_secret)

View File

@ -1,7 +1,8 @@
from flask import url_for from flask import url_for
from app import alias_utils from app import alias_utils
from app.extensions import db from app.db import Session
from app.import_utils import import_from_csv
from app.models import ( from app.models import (
User, User,
CustomDomain, CustomDomain,
@ -11,7 +12,6 @@ from app.models import (
BatchImport, BatchImport,
File, File,
) )
from app.import_utils import import_from_csv
from tests.utils import login from tests.utils import login
@ -21,14 +21,14 @@ def test_export(flask_client):
user2 = User.create( user2 = User.create(
email="x@y.z", password="password", name="Wrong user", activated=True email="x@y.z", password="password", name="Wrong user", activated=True
) )
db.session.commit() Session.commit()
# Remove onboarding aliases # Remove onboarding aliases
for alias in Alias.filter_by(user_id=user1.id).all(): for alias in Alias.filter_by(user_id=user1.id).all():
alias_utils.delete_alias(alias, user1) alias_utils.delete_alias(alias, user1)
for alias in Alias.filter_by(user_id=user2.id).all(): for alias in Alias.filter_by(user_id=user2.id).all():
alias_utils.delete_alias(alias, user2) alias_utils.delete_alias(alias, user2)
db.session.commit() Session.commit()
# Create domains # Create domains
CustomDomain.create( CustomDomain.create(
@ -37,7 +37,7 @@ def test_export(flask_client):
CustomDomain.create( CustomDomain.create(
user_id=user2.id, domain="bad-destionation-domain.com", verified=True user_id=user2.id, domain="bad-destionation-domain.com", verified=True
) )
db.session.commit() Session.commit()
# Create mailboxes # Create mailboxes
mailbox1 = Mailbox.create( mailbox1 = Mailbox.create(
@ -51,7 +51,7 @@ def test_export(flask_client):
email="baddestination@bad-destination-domain.com", email="baddestination@bad-destination-domain.com",
verified=True, verified=True,
) )
db.session.commit() Session.commit()
# Create aliases # Create aliases
Alias.create( Alias.create(
@ -72,14 +72,14 @@ def test_export(flask_client):
note="Should not appear", note="Should not appear",
mailbox_id=badmailbox1.id, mailbox_id=badmailbox1.id,
) )
db.session.commit() Session.commit()
# Add second mailbox to an alias # Add second mailbox to an alias
AliasMailbox.create( AliasMailbox.create(
alias_id=alias2.id, alias_id=alias2.id,
mailbox_id=mailbox2.id, mailbox_id=mailbox2.id,
) )
db.session.commit() Session.commit()
# Export # Export
r = flask_client.get(url_for("api.export_aliases")) r = flask_client.get(url_for("api.export_aliases"))
@ -128,7 +128,7 @@ def test_import_no_mailboxes(flask_client):
CustomDomain.create( CustomDomain.create(
user_id=user.id, domain="my-domain.com", ownership_verified=True user_id=user.id, domain="my-domain.com", ownership_verified=True
) )
db.session.commit() Session.commit()
alias_data = [ alias_data = [
"alias,note", "alias,note",
@ -180,7 +180,7 @@ def test_import(flask_client):
CustomDomain.create( CustomDomain.create(
user_id=user.id, domain="my-destination-domain.com", ownership_verified=True user_id=user.id, domain="my-destination-domain.com", ownership_verified=True
) )
db.session.commit() Session.commit()
# Create mailboxes # Create mailboxes
mailbox1 = Mailbox.create( mailbox1 = Mailbox.create(
@ -189,7 +189,7 @@ def test_import(flask_client):
mailbox2 = Mailbox.create( mailbox2 = Mailbox.create(
user_id=user.id, email="destination2@my-destination-domain.com", verified=True user_id=user.id, email="destination2@my-destination-domain.com", verified=True
) )
db.session.commit() Session.commit()
alias_data = [ alias_data = [
"alias,note,mailboxes", "alias,note,mailboxes",

View File

@ -1,6 +1,6 @@
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import Mailbox from app.models import Mailbox
from tests.utils import login from tests.utils import login
@ -34,7 +34,7 @@ def test_create_mailbox(flask_client):
def test_create_mailbox_fail_for_free_user(flask_client): def test_create_mailbox_fail_for_free_user(flask_client):
user = login(flask_client) user = login(flask_client)
user.trial_end = None user.trial_end = None
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
"/api/mailboxes", "/api/mailboxes",
@ -50,7 +50,7 @@ def test_delete_mailbox(flask_client):
# create a mailbox # create a mailbox
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") mb = Mailbox.create(user_id=user.id, email="mb@gmail.com")
db.session.commit() Session.commit()
r = flask_client.delete( r = flask_client.delete(
f"/api/mailboxes/{mb.id}", f"/api/mailboxes/{mb.id}",
@ -88,7 +88,7 @@ def test_set_mailbox_as_default(flask_client):
# <<< Cannot set an unverified mailbox as default >>> # <<< Cannot set an unverified mailbox as default >>>
mb.verified = False mb.verified = False
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
f"/api/mailboxes/{mb.id}", f"/api/mailboxes/{mb.id}",
@ -104,7 +104,7 @@ def test_update_mailbox_email(flask_client):
# create a mailbox # create a mailbox
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") mb = Mailbox.create(user_id=user.id, email="mb@gmail.com")
db.session.commit() Session.commit()
r = flask_client.put( r = flask_client.put(
f"/api/mailboxes/{mb.id}", f"/api/mailboxes/{mb.id}",
@ -122,7 +122,7 @@ def test_cancel_mailbox_email_change(flask_client):
# create a mailbox # create a mailbox
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") mb = Mailbox.create(user_id=user.id, email="mb@gmail.com")
db.session.commit() Session.commit()
# update mailbox email # update mailbox email
r = flask_client.put( r = flask_client.put(
@ -150,7 +150,7 @@ def test_get_mailboxes(flask_client):
Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m1@example.com", verified=True)
Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False)
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
"/api/mailboxes", "/api/mailboxes",
@ -173,7 +173,7 @@ def test_get_mailboxes_v2(flask_client):
Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m1@example.com", verified=True)
Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False)
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
"/api/v2/mailboxes", "/api/v2/mailboxes",

View File

@ -3,7 +3,7 @@ from flask import g
from app.alias_utils import delete_alias from app.alias_utils import delete_alias
from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN
from app.dashboard.views.custom_alias import signer from app.dashboard.views.custom_alias import signer
from app.extensions import db from app.db import Session
from app.models import Alias, CustomDomain, Mailbox, AliasUsedOn from app.models import Alias, CustomDomain, Mailbox, AliasUsedOn
from app.utils import random_word from app.utils import random_word
from tests.utils import login from tests.utils import login
@ -86,13 +86,13 @@ def test_full_payload(flask_client):
# create another mailbox # create another mailbox
mb = Mailbox.create(user_id=user.id, email="abcd@gmail.com", verified=True) mb = Mailbox.create(user_id=user.id, email="abcd@gmail.com", verified=True)
db.session.commit() Session.commit()
word = random_word() word = random_word()
suffix = f".{word}@{EMAIL_DOMAIN}" suffix = f".{word}@{EMAIL_DOMAIN}"
signed_suffix = signer.sign(suffix).decode() signed_suffix = signer.sign(suffix).decode()
assert AliasUsedOn.query.count() == 0 assert AliasUsedOn.count() == 0
r = flask_client.post( r = flask_client.post(
"/api/v3/alias/custom/new?hostname=example.com", "/api/v3/alias/custom/new?hostname=example.com",
@ -146,7 +146,7 @@ def test_custom_domain_alias(flask_client):
def test_out_of_quota(flask_client): def test_out_of_quota(flask_client):
user = login(flask_client) user = login(flask_client)
user.trial_end = None user.trial_end = None
db.session.commit() Session.commit()
# create MAX_NB_EMAIL_FREE_PLAN custom alias to run out of quota # create MAX_NB_EMAIL_FREE_PLAN custom alias to run out of quota
for _ in range(MAX_NB_EMAIL_FREE_PLAN): for _ in range(MAX_NB_EMAIL_FREE_PLAN):

View File

@ -3,7 +3,7 @@ import uuid
from flask import url_for, g from flask import url_for, g
from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN
from app.extensions import db from app.db import Session
from app.models import Alias from app.models import Alias
from tests.utils import login from tests.utils import login
@ -60,7 +60,7 @@ def test_custom_mode(flask_client):
def test_out_of_quota(flask_client): def test_out_of_quota(flask_client):
user = login(flask_client) user = login(flask_client)
user.trial_end = None user.trial_end = None
db.session.commit() Session.commit()
# create MAX_NB_EMAIL_FREE_PLAN random alias to run out of quota # create MAX_NB_EMAIL_FREE_PLAN random alias to run out of quota
for _ in range(MAX_NB_EMAIL_FREE_PLAN): for _ in range(MAX_NB_EMAIL_FREE_PLAN):

View File

@ -1,6 +1,6 @@
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import User, ApiKey, Notification from app.models import User, ApiKey, Notification
@ -8,16 +8,16 @@ def test_get_notifications(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
# create some notifications # create some notifications
Notification.create(user_id=user.id, message="Test message 1") Notification.create(user_id=user.id, message="Test message 1")
Notification.create(user_id=user.id, message="Test message 2") Notification.create(user_id=user.id, message="Test message 2")
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.get_notifications", page=0), url_for("api.get_notifications", page=0),
@ -46,14 +46,14 @@ def test_mark_notification_as_read(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
Notification.create(id=1, user_id=user.id, message="Test message 1") Notification.create(id=1, user_id=user.id, message="Test message 1")
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.mark_as_read", notification_id=1), url_for("api.mark_as_read", notification_id=1),

View File

@ -1,6 +1,6 @@
from app.api.serializer import get_alias_infos_with_pagination_v3 from app.api.serializer import get_alias_infos_with_pagination_v3
from app.config import PAGE_LIMIT from app.config import PAGE_LIMIT
from app.extensions import db from app.db import Session
from app.models import User, Alias, Mailbox, Contact from app.models import User, Alias, Mailbox, Contact
from tests.utils import create_user from tests.utils import create_user
@ -75,7 +75,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_mailboxes(flask_client):
alias = Alias.first() alias = Alias.first()
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") mb = Mailbox.create(user_id=user.id, email="mb@gmail.com")
alias._mailboxes.append(mb) alias._mailboxes.append(mb)
db.session.commit() Session.commit()
alias_infos = get_alias_infos_with_pagination_v3(user, mailbox_id=mb.id) alias_infos = get_alias_infos_with_pagination_v3(user, mailbox_id=mb.id)
assert len(alias_infos) == 1 assert len(alias_infos) == 1
@ -96,7 +96,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_note(flask_client):
alias = Alias.first() alias = Alias.first()
alias.note = "test note" alias.note = "test note"
db.session.commit() Session.commit()
alias_infos = get_alias_infos_with_pagination_v3(user, query="test note") alias_infos = get_alias_infos_with_pagination_v3(user, query="test note")
assert len(alias_infos) == 1 assert len(alias_infos) == 1
@ -114,7 +114,7 @@ def test_get_alias_infos_with_pagination_v3_query_alias_name(flask_client):
alias = Alias.first() alias = Alias.first()
alias.name = "Test Name" alias.name = "Test Name"
db.session.commit() Session.commit()
alias_infos = get_alias_infos_with_pagination_v3(user, query="test name") alias_infos = get_alias_infos_with_pagination_v3(user, query="test name")
assert len(alias_infos) == 1 assert len(alias_infos) == 1
@ -135,7 +135,7 @@ def test_get_alias_infos_with_pagination_v3_no_duplicate(flask_client):
alias = Alias.first() alias = Alias.first()
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com") mb = Mailbox.create(user_id=user.id, email="mb@gmail.com")
alias._mailboxes.append(mb) alias._mailboxes.append(mb)
db.session.commit() Session.commit()
alias_infos = get_alias_infos_with_pagination_v3(user) alias_infos = get_alias_infos_with_pagination_v3(user)
assert len(alias_infos) == 1 assert len(alias_infos) == 1
@ -182,7 +182,7 @@ def test_get_alias_infos_pinned_alias(flask_client):
for i in range(2 * PAGE_LIMIT): for i in range(2 * PAGE_LIMIT):
Alias.create_new_random(user) Alias.create_new_random(user)
first_alias = Alias.query.order_by(Alias.id).first() first_alias = Alias.order_by(Alias.id).first()
# should return PAGE_LIMIT alias # should return PAGE_LIMIT alias
alias_infos = get_alias_infos_with_pagination_v3(user) alias_infos = get_alias_infos_with_pagination_v3(user)
@ -192,7 +192,7 @@ def test_get_alias_infos_pinned_alias(flask_client):
# pin the first alias # pin the first alias
first_alias.pinned = True first_alias.pinned = True
db.session.commit() Session.commit()
alias_infos = get_alias_infos_with_pagination_v3(user) alias_infos = get_alias_infos_with_pagination_v3(user)
# now first_alias is the first result # now first_alias is the first result

View File

@ -1,6 +1,6 @@
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import User, ApiKey from app.models import User, ApiKey
from tests.utils import login from tests.utils import login
@ -9,11 +9,11 @@ def test_user_in_trial(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# create api_key # create api_key
api_key = ApiKey.create(user.id, "for test") api_key = ApiKey.create(user.id, "for test")
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for("api.user_info"), headers={"Authentication": api_key.code} url_for("api.user_info"), headers={"Authentication": api_key.code}
@ -42,7 +42,7 @@ def test_wrong_api_key(flask_client):
def test_create_api_key(flask_client): def test_create_api_key(flask_client):
# create user, user is activated # create user, user is activated
User.create(email="a@b.c", password="password", name="Test User", activated=True) User.create(email="a@b.c", password="password", name="Test User", activated=True)
db.session.commit() Session.commit()
# login user # login user
flask_client.post( flask_client.post(
@ -61,7 +61,7 @@ def test_create_api_key(flask_client):
def test_logout(flask_client): def test_logout(flask_client):
# create user, user is activated # create user, user is activated
User.create(email="a@b.c", password="password", name="Test User", activated=True) User.create(email="a@b.c", password="password", name="Test User", activated=True)
db.session.commit() Session.commit()
# login user # login user
flask_client.post( flask_client.post(

View File

@ -1,6 +1,6 @@
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.models import User from app.models import User
@ -9,7 +9,7 @@ def test_unactivated_user_login(flask_client):
# create user, user is not activated # create user, user is not activated
User.create(email="a@b.c", password="password", name="Test User") User.create(email="a@b.c", password="password", name="Test User")
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("auth.login"), url_for("auth.login"),
@ -29,7 +29,7 @@ def test_activated_user_login(flask_client):
# create user, user is activated # create user, user is activated
User.create(email="a@b.c", password="password", name="Test User", activated=True) User.create(email="a@b.c", password="password", name="Test User", activated=True)
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("auth.login"), url_for("auth.login"),

View File

@ -2,18 +2,20 @@ import os
# use the tests/test.env config fle # use the tests/test.env config fle
# flake8: noqa: E402 # flake8: noqa: E402
import sqlalchemy
os.environ["CONFIG"] = os.path.abspath( os.environ["CONFIG"] = os.path.abspath(
os.path.join(os.path.dirname(os.path.dirname(__file__)), "tests/test.env") os.path.join(os.path.dirname(os.path.dirname(__file__)), "tests/test.env")
) )
import sqlalchemy
from app.db import Session, engine, connection
from app.models import Base
from psycopg2 import errors from psycopg2 import errors
from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST from psycopg2.errorcodes import DEPENDENT_OBJECTS_STILL_EXIST
import pytest import pytest
from app.extensions import db
from server import create_app from server import create_app
from init_app import add_sl_domains from init_app import add_sl_domains
@ -24,7 +26,7 @@ app.config["SERVER_NAME"] = "sl.test"
with app.app_context(): with app.app_context():
# enable pg_trgm extension # enable pg_trgm extension
with db.engine.connect() as conn: with engine.connect() as conn:
try: try:
conn.execute("DROP EXTENSION if exists pg_trgm") conn.execute("DROP EXTENSION if exists pg_trgm")
conn.execute("CREATE EXTENSION pg_trgm") conn.execute("CREATE EXTENSION pg_trgm")
@ -33,7 +35,7 @@ with app.app_context():
print(">>> pg_trgm can't be dropped, ignore") print(">>> pg_trgm can't be dropped, ignore")
conn.execute("Rollback") conn.execute("Rollback")
db.create_all() Base.metadata.create_all(engine)
add_sl_domains() add_sl_domains()
@ -45,20 +47,14 @@ def flask_app():
@pytest.fixture @pytest.fixture
def flask_client(): def flask_client():
with app.app_context(): transaction = connection.begin()
# replace db.session to that we can rollback all commits that can be made during a test
# inspired from http://alexmic.net/flask-sqlalchemy-pytest/
connection = db.engine.connect()
transaction = connection.begin()
options = dict(bind=connection, binds={})
session = db.create_scoped_session(options=options)
db.session = session
with app.app_context():
try: try:
client = app.test_client() client = app.test_client()
yield client yield client
finally: finally:
# roll back all commits made during a test # roll back all commits made during a test
transaction.rollback() transaction.rollback()
connection.close() Session.rollback()
session.remove() Session.close()

View File

@ -11,7 +11,7 @@ def test_add_contact_success(flask_client):
login(flask_client) login(flask_client)
alias = Alias.first() alias = Alias.first()
assert Contact.query.count() == 0 assert Contact.count() == 0
# <<< Create a new contact >>> # <<< Create a new contact >>>
flask_client.post( flask_client.post(
@ -23,7 +23,7 @@ def test_add_contact_success(flask_client):
follow_redirects=True, follow_redirects=True,
) )
# a new contact is added # a new contact is added
assert Contact.query.count() == 1 assert Contact.count() == 1
contact = Contact.first() contact = Contact.first()
assert contact.website_email == "abcd@gmail.com" assert contact.website_email == "abcd@gmail.com"
@ -37,8 +37,8 @@ def test_add_contact_success(flask_client):
follow_redirects=True, follow_redirects=True,
) )
# a new contact is added # a new contact is added
assert Contact.query.count() == 2 assert Contact.count() == 2
contact = Contact.query.filter(Contact.id != contact.id).first() contact = Contact.filter(Contact.id != contact.id).first()
assert contact.website_email == "another@gmail.com" assert contact.website_email == "another@gmail.com"
assert contact.name == "First Last" assert contact.name == "First Last"
@ -53,5 +53,5 @@ def test_add_contact_success(flask_client):
) )
# no new contact is added # no new contact is added
assert Contact.query.count() == 2 assert Contact.count() == 2
assert "Invalid email format. Email must be either email@example.com" in str(r.data) assert "Invalid email format. Email must be either email@example.com" in str(r.data)

View File

@ -1,5 +1,5 @@
from app.dashboard.views import alias_transfer from app.dashboard.views import alias_transfer
from app.extensions import db from app.db import Session
from app.models import ( from app.models import (
Alias, Alias,
Mailbox, Mailbox,
@ -14,7 +14,7 @@ def test_alias_transfer(flask_client):
mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True) mb = Mailbox.create(user_id=user.id, email="mb@gmail.com", commit=True)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True) AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id, commit=True)

View File

@ -8,7 +8,7 @@ from app.dashboard.views.custom_alias import (
get_available_suffixes, get_available_suffixes,
AliasSuffix, AliasSuffix,
) )
from app.extensions import db from app.db import Session
from app.models import ( from app.models import (
Mailbox, Mailbox,
CustomDomain, CustomDomain,
@ -46,13 +46,13 @@ def test_add_alias_success(flask_client):
assert r.status_code == 200 assert r.status_code == 200
assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data) assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data)
alias = Alias.query.order_by(Alias.created_at.desc()).first() alias = Alias.order_by(Alias.created_at.desc()).first()
assert not alias._mailboxes assert not alias._mailboxes
def test_add_alias_multiple_mailboxes(flask_client): def test_add_alias_multiple_mailboxes(flask_client):
user = login(flask_client) user = login(flask_client)
db.session.commit() Session.commit()
alias_suffix = AliasSuffix( alias_suffix = AliasSuffix(
is_custom=False, is_custom=False,
@ -64,7 +64,7 @@ def test_add_alias_multiple_mailboxes(flask_client):
# create with a multiple mailboxes # create with a multiple mailboxes
mb1 = Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) mb1 = Mailbox.create(user_id=user.id, email="m1@example.com", verified=True)
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("dashboard.custom_alias"), url_for("dashboard.custom_alias"),
@ -78,18 +78,18 @@ def test_add_alias_multiple_mailboxes(flask_client):
assert r.status_code == 200 assert r.status_code == 200
assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data) assert f"Alias prefix.12345@{EMAIL_DOMAIN} has been created" in str(r.data)
alias = Alias.query.order_by(Alias.created_at.desc()).first() alias = Alias.order_by(Alias.created_at.desc()).first()
assert alias._mailboxes assert alias._mailboxes
def test_not_show_unverified_mailbox(flask_client): def test_not_show_unverified_mailbox(flask_client):
"""make sure user unverified mailbox is not shown to user""" """make sure user unverified mailbox is not shown to user"""
user = login(flask_client) user = login(flask_client)
db.session.commit() Session.commit()
Mailbox.create(user_id=user.id, email="m1@example.com", verified=True) Mailbox.create(user_id=user.id, email="m1@example.com", verified=True)
Mailbox.create(user_id=user.id, email="m2@example.com", verified=False) Mailbox.create(user_id=user.id, email="m2@example.com", verified=False)
db.session.commit() Session.commit()
r = flask_client.get(url_for("dashboard.custom_alias")) r = flask_client.get(url_for("dashboard.custom_alias"))
@ -99,7 +99,7 @@ def test_not_show_unverified_mailbox(flask_client):
def test_verify_prefix_suffix(flask_client): def test_verify_prefix_suffix(flask_client):
user = login(flask_client) user = login(flask_client)
db.session.commit() Session.commit()
CustomDomain.create(user_id=user.id, domain="test.com", verified=True) CustomDomain.create(user_id=user.id, domain="test.com", verified=True)
@ -128,7 +128,7 @@ def test_available_suffixes(flask_client):
def test_available_suffixes_default_domain(flask_client): def test_available_suffixes_default_domain(flask_client):
user = login(flask_client) user = login(flask_client)
sl_domain = SLDomain.query.first() sl_domain = SLDomain.first()
CustomDomain.create(user_id=user.id, domain="test.com", verified=True, commit=True) CustomDomain.create(user_id=user.id, domain="test.com", verified=True, commit=True)
user.default_alias_public_domain_id = sl_domain.id user.default_alias_public_domain_id = sl_domain.id
@ -166,7 +166,7 @@ def test_available_suffixes_random_prefix_generation(flask_client):
def test_add_already_existed_alias(flask_client): def test_add_already_existed_alias(flask_client):
user = login(flask_client) user = login(flask_client)
db.session.commit() Session.commit()
another_user = User.create( another_user = User.create(
email="a2@b.c", email="a2@b.c",
@ -208,7 +208,7 @@ def test_add_already_existed_alias(flask_client):
def test_add_alias_in_global_trash(flask_client): def test_add_alias_in_global_trash(flask_client):
user = login(flask_client) user = login(flask_client)
db.session.commit() Session.commit()
another_user = User.create( another_user = User.create(
email="a2@b.c", email="a2@b.c",
@ -233,9 +233,9 @@ def test_add_alias_in_global_trash(flask_client):
commit=True, commit=True,
) )
assert DeletedAlias.query.count() == 0 assert DeletedAlias.count() == 0
delete_alias(alias, another_user) delete_alias(alias, another_user)
assert DeletedAlias.query.count() == 1 assert DeletedAlias.count() == 1
# create the same alias, should return error # create the same alias, should return error
r = flask_client.post( r = flask_client.post(
@ -267,9 +267,9 @@ def test_add_alias_in_custom_domain_trash(flask_client):
commit=True, commit=True,
) )
assert DomainDeletedAlias.query.count() == 0 assert DomainDeletedAlias.count() == 0
delete_alias(alias, user) delete_alias(alias, user)
assert DomainDeletedAlias.query.count() == 1 assert DomainDeletedAlias.count() == 1
# create the same alias, should return error # create the same alias, should return error
suffix = "@ab.cd" suffix = "@ab.cd"

View File

@ -1,13 +1,13 @@
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from tests.utils import login from tests.utils import login
def test_add_domain_success(flask_client): def test_add_domain_success(flask_client):
user = login(flask_client) user = login(flask_client)
user.lifetime = True user.lifetime = True
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("dashboard.custom_domain"), url_for("dashboard.custom_domain"),
@ -23,7 +23,7 @@ def test_add_domain_same_as_user_email(flask_client):
"""cannot add domain if user personal email uses this domain""" """cannot add domain if user personal email uses this domain"""
user = login(flask_client) user = login(flask_client)
user.lifetime = True user.lifetime = True
db.session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("dashboard.custom_domain"), url_for("dashboard.custom_domain"),

View File

@ -8,7 +8,7 @@ from tests.utils import login
def test_create_random_alias_success(flask_client): def test_create_random_alias_success(flask_client):
login(flask_client) login(flask_client)
assert Alias.query.count() == 1 assert Alias.count() == 1
r = flask_client.post( r = flask_client.post(
url_for("dashboard.index"), url_for("dashboard.index"),
@ -16,7 +16,7 @@ def test_create_random_alias_success(flask_client):
follow_redirects=True, follow_redirects=True,
) )
assert r.status_code == 200 assert r.status_code == 200
assert Alias.query.count() == 2 assert Alias.count() == 2
def test_too_many_requests(flask_client): def test_too_many_requests(flask_client):

View File

@ -2,7 +2,7 @@ from app.config import (
MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS, MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS,
MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX, MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX,
) )
from app.extensions import db from app.db import Session
from app.email.rate_limit import ( from app.email.rate_limit import (
rate_limited_forward_phase, rate_limited_forward_phase,
rate_limited_for_alias, rate_limited_for_alias,
@ -16,11 +16,11 @@ def test_rate_limited_forward_phase_for_alias(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
# no rate limiting for a new alias # no rate limiting for a new alias
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert not rate_limited_for_alias(alias) assert not rate_limited_for_alias(alias)
# rate limit when there's a previous activity on alias # rate limit when there's a previous activity on alias
@ -30,14 +30,14 @@ def test_rate_limited_forward_phase_for_alias(flask_client):
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.local", reply_email="rep@sl.local",
) )
db.session.commit() Session.commit()
for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1): for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1):
EmailLog.create( EmailLog.create(
user_id=user.id, user_id=user.id,
contact_id=contact.id, contact_id=contact.id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
) )
db.session.commit() Session.commit()
assert rate_limited_for_alias(alias) assert rate_limited_for_alias(alias)
@ -46,10 +46,10 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
contact = Contact.create( contact = Contact.create(
user_id=user.id, user_id=user.id,
@ -57,14 +57,14 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client):
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.local", reply_email="rep@sl.local",
) )
db.session.commit() Session.commit()
for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX + 1): for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_MAILBOX + 1):
EmailLog.create( EmailLog.create(
user_id=user.id, user_id=user.id,
contact_id=contact.id, contact_id=contact.id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
) )
db.session.commit() Session.commit()
EmailLog.create( EmailLog.create(
user_id=user.id, user_id=user.id,
@ -75,7 +75,7 @@ def test_rate_limited_forward_phase_for_mailbox(flask_client):
# Create another alias with the same mailbox # Create another alias with the same mailbox
# will be rate limited as there's a previous activity on mailbox # will be rate limited as there's a previous activity on mailbox
alias2 = Alias.create_new_random(user) alias2 = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert rate_limited_for_mailbox(alias2) assert rate_limited_for_mailbox(alias2)
@ -91,10 +91,10 @@ def test_rate_limited_reply_phase(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
contact = Contact.create( contact = Contact.create(
user_id=user.id, user_id=user.id,
@ -102,13 +102,13 @@ def test_rate_limited_reply_phase(flask_client):
website_email="contact@example.com", website_email="contact@example.com",
reply_email="rep@sl.local", reply_email="rep@sl.local",
) )
db.session.commit() Session.commit()
for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1): for _ in range(MAX_ACTIVITY_DURING_MINUTE_PER_ALIAS + 1):
EmailLog.create( EmailLog.create(
user_id=user.id, user_id=user.id,
contact_id=contact.id, contact_id=contact.id,
alias_id=contact.alias_id, alias_id=contact.alias_id,
) )
db.session.commit() Session.commit()
assert rate_limited_reply_phase("rep@sl.local") assert rate_limited_reply_phase("rep@sl.local")

View File

@ -4,7 +4,7 @@ from urllib.parse import urlparse, parse_qs
from flask import url_for from flask import url_for
from app.extensions import db from app.db import Session
from app.jose_utils import verify_id_token, decode_id_token from app.jose_utils import verify_id_token, decode_id_token
from app.models import Client, User, ClientUser from app.models import Client, User, ClientUser
from app.oauth.views.authorize import ( from app.oauth.views.authorize import (
@ -39,10 +39,10 @@ def test_construct_url():
def test_authorize_page_non_login_user(flask_client): def test_authorize_page_non_login_user(flask_client):
"""make sure to display login page for non-authenticated user""" """make sure to display login page for non-authenticated user"""
user = User.create("test@test.com", "test user") user = User.create("test@test.com", "test user")
db.session.commit() Session.commit()
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for( url_for(
@ -63,7 +63,7 @@ def test_authorize_page_login_user_non_supported_flow(flask_client):
"""return 400 if the flow is not supported""" """return 400 if the flow is not supported"""
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# Not provide any flow # Not provide any flow
r = flask_client.get( r = flask_client.get(
@ -102,7 +102,7 @@ def test_authorize_page_login_user(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for( url_for(
@ -128,7 +128,7 @@ def test_authorize_code_flow_no_openid_scope(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -217,7 +217,7 @@ def test_authorize_code_flow_with_openid_scope(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -310,7 +310,7 @@ def test_authorize_token_flow(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -357,7 +357,7 @@ def test_authorize_id_token_flow(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -406,7 +406,7 @@ def test_authorize_token_id_token_flow(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -496,7 +496,7 @@ def test_authorize_code_id_token_flow(flask_client):
user = login(flask_client) user = login(flask_client)
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
db.session.commit() Session.commit()
# user allows client on the authorization page # user allows client on the authorization page
r = flask_client.post( r = flask_client.post(
@ -629,7 +629,7 @@ def test_authorize_page_invalid_client_id(flask_client):
user = login(flask_client) user = login(flask_client)
Client.create_new("test client", user.id) Client.create_new("test client", user.id)
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for( url_for(
@ -654,7 +654,7 @@ def test_authorize_page_http_not_allowed(flask_client):
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
client.approved = True client.approved = True
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for( url_for(
@ -676,7 +676,7 @@ def test_authorize_page_unknown_redirect_uri(flask_client):
client = Client.create_new("test client", user.id) client = Client.create_new("test client", user.id)
client.approved = True client.approved = True
db.session.commit() Session.commit()
r = flask_client.get( r = flask_client.get(
url_for( url_for(

View File

@ -1,5 +1,5 @@
from app.alias_utils import delete_alias, check_alias_prefix from app.alias_utils import delete_alias, check_alias_prefix
from app.extensions import db from app.db import Session
from app.models import User, Alias, DeletedAlias from app.models import User, Alias, DeletedAlias
@ -41,8 +41,8 @@ def test_delete_alias_already_in_trash(flask_client):
) )
# add the alias to global trash # add the alias to global trash
db.session.add(DeletedAlias(email=alias.email)) Session.add(DeletedAlias(email=alias.email))
db.session.commit() Session.commit()
delete_alias(alias, user) delete_alias(alias, user)
assert Alias.get_by(email="first@d1.test") is None assert Alias.get_by(email="first@d1.test") is None

View File

@ -1,4 +1,5 @@
import pytest import pytest
from app.config import sl_getenv from app.config import sl_getenv

View File

@ -4,6 +4,7 @@ from email.message import EmailMessage
import arrow import arrow
from app.config import MAX_ALERT_24H, EMAIL_DOMAIN, BOUNCE_EMAIL from app.config import MAX_ALERT_24H, EMAIL_DOMAIN, BOUNCE_EMAIL
from app.db import Session
from app.email_utils import ( from app.email_utils import (
get_email_domain_part, get_email_domain_part,
can_create_directory_for_address, can_create_directory_for_address,
@ -31,7 +32,6 @@ from app.email_utils import (
get_header_unicode, get_header_unicode,
parse_full_address, parse_full_address,
) )
from app.extensions import db
from app.models import User, CustomDomain, Alias, Contact, EmailLog, IgnoreBounceSender from app.models import User, CustomDomain, Alias, Contact, EmailLog, IgnoreBounceSender
# flake8: noqa: E101, W191 # flake8: noqa: E101, W191
@ -136,7 +136,7 @@ def test_send_email_with_rate_control(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
for _ in range(MAX_ALERT_24H): for _ in range(MAX_ALERT_24H):
assert send_email_with_rate_control( assert send_email_with_rate_control(
@ -598,7 +598,7 @@ def test_should_disable(flask_client):
include_sender_in_reverse_alias=True, include_sender_in_reverse_alias=True,
) )
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert not should_disable(alias) assert not should_disable(alias)
@ -623,7 +623,7 @@ def test_should_disable(flask_client):
# should not affect another alias # should not affect another alias
alias2 = Alias.create_new_random(user) alias2 = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert not should_disable(alias2) assert not should_disable(alias2)
@ -631,7 +631,7 @@ def test_should_disable_bounces_every_day(flask_client):
"""if an alias has bounces every day at least 9 days in the last 10 days, disable alias""" """if an alias has bounces every day at least 9 days in the last 10 days, disable alias"""
user = login(flask_client) user = login(flask_client)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
assert not should_disable(alias) assert not should_disable(alias)
@ -661,7 +661,7 @@ def test_should_disable_bounces_account(flask_client):
user = login(flask_client) user = login(flask_client)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# create a lot of bounces on alias # create a lot of bounces on alias
contact = Contact.create( contact = Contact.create(
@ -690,7 +690,7 @@ def test_should_disable_bounces_account(flask_client):
def test_should_disable_bounce_consecutive_days(flask_client): def test_should_disable_bounce_consecutive_days(flask_client):
user = login(flask_client) user = login(flask_client)
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
contact = Contact.create( contact = Contact.create(
user_id=user.id, user_id=user.id,

View File

@ -1,4 +1,4 @@
from app.extensions import db from app.db import Session
from app.jose_utils import make_id_token, verify_id_token from app.jose_utils import make_id_token, verify_id_token
from app.models import ClientUser, User, Client from app.models import ClientUser, User, Client
@ -7,15 +7,15 @@ def test_encode_decode(flask_client):
user = User.create( user = User.create(
email="a@b.c", password="password", name="Test User", activated=True email="a@b.c", password="password", name="Test User", activated=True
) )
db.session.commit() Session.commit()
client1 = Client.create_new(name="Demo", user_id=user.id) client1 = Client.create_new(name="Demo", user_id=user.id)
client1.oauth_client_id = "client-id" client1.oauth_client_id = "client-id"
client1.oauth_client_secret = "client-secret" client1.oauth_client_secret = "client-secret"
db.session.commit() Session.commit()
client_user = ClientUser.create(client_id=client1.id, user_id=user.id) client_user = ClientUser.create(client_id=client1.id, user_id=user.id)
db.session.commit() Session.commit()
jwt_token = make_id_token(client_user) jwt_token = make_id_token(client_user)

View File

@ -3,8 +3,8 @@ from uuid import UUID
import pytest import pytest
from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN from app.config import EMAIL_DOMAIN, MAX_NB_EMAIL_FREE_PLAN
from app.db import Session
from app.email_utils import parse_full_address from app.email_utils import parse_full_address
from app.extensions import db
from app.models import ( from app.models import (
generate_email, generate_email,
User, User,
@ -53,7 +53,7 @@ def test_suggested_emails_for_user_who_cannot_create_new_alias(flask_client):
# make sure user runs out of quota to create new email # make sure user runs out of quota to create new email
for i in range(MAX_NB_EMAIL_FREE_PLAN): for i in range(MAX_NB_EMAIL_FREE_PLAN):
Alias.create_new(user=user, prefix="test") Alias.create_new(user=user, prefix="test")
db.session.commit() Session.commit()
suggested_email, other_emails = user.suggested_emails(website_name="test") suggested_email, other_emails = user.suggested_emails(website_name="test")
@ -88,7 +88,7 @@ def test_website_send_to(flask_client):
) )
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# non-empty name # non-empty name
c1 = Contact.create( c1 = Contact.create(
@ -122,7 +122,7 @@ def test_new_addr(flask_client):
) )
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
db.session.commit() Session.commit()
# default sender_format is 'via' # default sender_format is 'via'
c1 = Contact.create( c1 = Contact.create(
@ -137,18 +137,18 @@ def test_new_addr(flask_client):
# Make sure email isn't duplicated if sender name equals email # Make sure email isn't duplicated if sender name equals email
c1.name = "abcd@example.com" c1.name = "abcd@example.com"
db.session.commit() Session.commit()
assert c1.new_addr() == '"abcd(a)example.com" <rep@SL>' assert c1.new_addr() == '"abcd(a)example.com" <rep@SL>'
# set sender_format = AT # set sender_format = AT
user.sender_format = SenderFormatEnum.AT.value user.sender_format = SenderFormatEnum.AT.value
c1.name = "First Last" c1.name = "First Last"
db.session.commit() Session.commit()
assert c1.new_addr() == '"First Last - abcd at example.com" <rep@SL>' assert c1.new_addr() == '"First Last - abcd at example.com" <rep@SL>'
# unicode name # unicode name
c1.name = "Nhơn Nguyễn" c1.name = "Nhơn Nguyễn"
db.session.commit() Session.commit()
assert ( assert (
c1.new_addr() c1.new_addr()
== "=?utf-8?q?Nh=C6=A1n_Nguy=E1=BB=85n_-_abcd_at_example=2Ecom?= <rep@SL>" == "=?utf-8?q?Nh=C6=A1n_Nguy=E1=BB=85n_-_abcd_at_example=2Ecom?= <rep@SL>"
@ -182,11 +182,11 @@ def test_mailbox_delete(flask_client):
# alias has 2 mailboxes # alias has 2 mailboxes
alias = Alias.create_new(user, "prefix", mailbox_id=m1.id) alias = Alias.create_new(user, "prefix", mailbox_id=m1.id)
db.session.commit() Session.commit()
alias._mailboxes.append(m2) alias._mailboxes.append(m2)
alias._mailboxes.append(m3) alias._mailboxes.append(m3)
db.session.commit() Session.commit()
assert len(alias.mailboxes) == 3 assert len(alias.mailboxes) == 3

View File

@ -1,6 +1,6 @@
import arrow import arrow
from app.extensions import db from app.db import Session
from app.models import User, CoinbaseSubscription from app.models import User, CoinbaseSubscription
from server import handle_coinbase_event from server import handle_coinbase_event
@ -44,7 +44,7 @@ def test_handle_coinbase_event_extend_subscription(flask_client):
activated=True, activated=True,
) )
user.trial_end = None user.trial_end = None
db.session.commit() Session.commit()
cb = CoinbaseSubscription.create( cb = CoinbaseSubscription.create(
user_id=user.id, end_at=arrow.now().shift(days=-400), commit=True user_id=user.id, end_at=arrow.now().shift(days=-400), commit=True