Compare commits

...

27 Commits

Author SHA1 Message Date
Carlos Quintana 9d2a35b9c2
fix: monitoring table name (#2120) 2024-05-24 11:09:10 +02:00
Carlos Quintana 5f190d4b46
fix: monitoring table name 2024-05-24 10:52:08 +02:00
Carlos Quintana 6862ed3602
fix: event listener (#2119)
* fix: commit transaction after taking event

* feat: allow to reconnect to postgres for event listener

* chore: log sync events pending to process to metrics

* fix: make dead_letter runner able to process events without needing to have lock on the event

* chore: close Session after reconnect

* refactor: make EventSource emit only events that can be processed
2024-05-24 10:21:19 +02:00
Carlos Quintana 450322fff1
feat: allow to disable event-webhook (#2118) 2024-05-23 16:50:54 +02:00
Carlos Quintana aad6f59e96
Improve error handling on event sink (#2117)
* chore: make event_sink return success

* fix: add return to ConsoleEventSink
2024-05-23 15:05:47 +02:00
Carlos Quintana 8eccb05e33
feat: implement HTTP event sink (#2116)
* feat: implement HTTP event sink

* Update events/event_sink.py

---------

Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>
2024-05-23 11:32:45 +02:00
Carlos Quintana 3e0b7bb369
Add sync events (#2113)
* feat: add protocol buffers for events

* chore: add EventDispatcher

* chore: add WebhookEvent class

* chore: emit events

* feat: initial version of event listener

* chore: emit user plan change with new timestamp

* feat: emit metrics + add alias status to create event

* chore: add newrelic decorator to functions

* fix: event emitter fixes

* fix: take null end_time into account

* fix: avoid double-commits

* chore: move UserDeleted event to User.delete method

* db: add index to sync_event created_at and taken_time columns

* chore: add index to model
2024-05-23 10:27:08 +02:00
Son Nguyen Kim 60ab8c15ec
show app page (#2110)
Co-authored-by: Son NK <son@simplelogin.io>
2024-05-22 15:43:36 +02:00
Son Nguyen Kim b5b167479f
Fix admin loop (#2103)
* mailbox page requires sudo

* fix the loop when non-admin user visits an admin URL

https://github.com/simple-login/app/issues/2101

---------

Co-authored-by: Son NK <son@simplelogin.io>
2024-05-10 18:52:12 +02:00
Adrià Casajús 8f12fabd81
Make hibp rate configurable (#2105) 2024-05-10 18:51:16 +02:00
Daniel Mühlbachler-Pietrzykowski b6004f3336
feat: use oidc well-known url (#2077) 2024-05-02 16:17:10 +02:00
Adrià Casajús 80c8bc820b
Do not double count AlilasMailboxes with Aliases (#2095)
* Do not double count aliasmailboxes with aliases

* Keep Sl-Queue-id
2024-04-30 16:41:47 +02:00
Son Nguyen Kim 037bc9da36
mailbox page requires sudo (#2094)
Co-authored-by: Son NK <son@simplelogin.io>
2024-04-23 22:25:37 +02:00
Son Nguyen Kim ee0be3688f
Data breach (#2093)
* add User.enable_data_breach_check column

* user can turn on/off the data breach check

* only run data breach check for user who enables it

* add tips to run tests using a local DB (without docker)

* refactor True check

* trim trailing space

* fix test

* Apply suggestions from code review

Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>

* format

---------

Co-authored-by: Son NK <son@simplelogin.io>
Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>
2024-04-23 22:16:36 +02:00
Adrià Casajús 015036b499
Prevent proton mailboxes from enabling pgp encryption (#2086) 2024-04-12 15:19:41 +02:00
Son Nguyen Kim d5df91aab6
Premium user can enable data breach monitoring (#2084)
* add User.enable_data_breach_check column

* user can turn on/off the data breach check

* only run data breach check for user who enables it

* add tips to run tests using a local DB (without docker)

* refactor True check

* trim trailing space

* fix test

* Apply suggestions from code review

Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>

* format

---------

Co-authored-by: Son NK <son@simplelogin.io>
Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>
2024-04-12 10:39:23 +02:00
Adrià Casajús 2eb5feaa8f
Small improvements (#2082)
* Update logs with more relevant info for debugging purposes

* Improved logs for alias creation rate-limit

* Reduce sudo time to 120 secs

* log fixes

* Fix missing object to add to the session
2024-04-08 15:05:51 +02:00
Adrià Casajús 3c364da37d
Dmarc fix (#2079)
* Add log to spam check + remove invisible characters on import

* Update log
2024-03-26 11:43:33 +01:00
Adrià Casajús 36cf530ef8
Preserve X-SL-Queue-Id (#2076) 2024-03-22 11:00:06 +01:00
Adrià Casajús 0da1811311
Cleanup old data (#2066)
* Cleanup tasks

* Update

* Added tests

* Create cron job

* Delete old data cron

* Fix import

* import fix

* Added delete + script to disable pgp for proton mboxes
2024-03-18 16:00:21 +01:00
Adrià Casajús f2fcaa6c60
Cleanup also messsage-id headers from linebreaks (#2067) 2024-03-18 14:27:38 +01:00
Adrià Casajús aa2c676b5e
Only check HIBP alias of paid users (#2065) 2024-03-15 10:13:06 +01:00
Adrià Casajús 30ddd4c807
Update oneshot commands (#2064)
* Update oneshot commands

* fix

* Fix test_load command

* Rename to avoid test executing it

* Do HIBP breach check in batches instead of a single load
2024-03-14 16:03:43 +01:00
Son Nguyen Kim f5babd9c81
Move import export back to setting (#2063)
* replace black by ruff

* move alias import/export to settings

* fix html closing tag

* add rate limit for alias import & export

---------

Co-authored-by: Son NK <son@simplelogin.io>
2024-03-14 15:56:35 +01:00
Adrià Casajús 74b811dd35
Update oneshot commands (#2060)
* Update oneshot commands

* fix

* Fix test_load command

* Rename to avoid test executing it
2024-03-14 11:11:50 +01:00
martadams89 e6c51bcf20
ci: remove v7 (#2062) 2024-03-14 09:49:45 +01:00
martadams89 4bfc6b9aca
ci: add arm docker images (#2056)
* ci: add arm docker images

* ci: remove armv6

* ci: update workflow to run only on path

* Update .github/workflows/main.yml

---------

Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>
2024-03-13 15:20:47 +01:00
66 changed files with 1817 additions and 227 deletions

View File

@ -1,7 +1,6 @@
name: Test and lint name: Test and lint
on: on: [push, pull_request]
push:
jobs: jobs:
lint: lint:
@ -139,6 +138,12 @@ jobs:
with: with:
fetch-depth: 0 fetch-depth: 0
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Create Sentry release - name: Create Sentry release
uses: getsentry/action-release@v1 uses: getsentry/action-release@v1
env: env:
@ -158,6 +163,7 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}

View File

@ -68,6 +68,12 @@ For most tests, you will need to have ``redis`` installed and started on your ma
sh scripts/run-test.sh sh scripts/run-test.sh
``` ```
You can also run tests using a local Postgres DB to speed things up. This can be done by
- creating an empty test DB and running the database migration by `dropdb test && createdb test && DB_URI=postgresql://localhost:5432/test alembic upgrade head`
- replacing the `DB_URI` in `test.env` file by `DB_URI=postgresql://localhost:5432/test`
## Run the code locally ## Run the code locally
Install npm packages Install npm packages
@ -151,10 +157,10 @@ Here are the small sum-ups of the directory structures and their roles:
## Pull request ## Pull request
The code is formatted using https://github.com/psf/black, to format the code, simply run The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run
``` ```
poetry run black . poetry run ruff format .
``` ```
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by

View File

@ -46,7 +46,8 @@ class SLModelView(sqla.ModelView):
def inaccessible_callback(self, name, **kwargs): def inaccessible_callback(self, name, **kwargs):
# redirect to login page if user doesn't have access # redirect to login page if user doesn't have access
return redirect(url_for("auth.login", next=request.url)) flash("You don't have access to the admin page", "error")
return redirect(url_for("dashboard.index", next=request.url))
def on_model_change(self, form, model, is_created): def on_model_change(self, form, model, is_created):
changes = {} changes = {}

View File

@ -25,6 +25,8 @@ from app.email_utils import (
render, render,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasDeleted, AliasStatusChange, EventContent
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
@ -308,31 +310,36 @@ def delete_alias(alias: Alias, user: User):
Delete an alias and add it to either global or domain trash Delete an alias and add it to either global or domain trash
Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create
""" """
# save deleted alias to either global or domain trash LOG.i(f"User {user} has deleted alias {alias}")
# save deleted alias to either global or domain tra
if alias.custom_domain_id: if alias.custom_domain_id:
if not DomainDeletedAlias.get_by( if not DomainDeletedAlias.get_by(
email=alias.email, domain_id=alias.custom_domain_id email=alias.email, domain_id=alias.custom_domain_id
): ):
LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id) domain_deleted_alias = DomainDeletedAlias(
Session.add( user_id=user.id,
DomainDeletedAlias( email=alias.email,
user_id=user.id, domain_id=alias.custom_domain_id,
email=alias.email,
domain_id=alias.custom_domain_id,
)
) )
Session.add(domain_deleted_alias)
Session.commit() Session.commit()
LOG.i(
f"Moving {alias} to domain {alias.custom_domain_id} trash {domain_deleted_alias}"
)
else: else:
if not DeletedAlias.get_by(email=alias.email): if not DeletedAlias.get_by(email=alias.email):
LOG.d("add %s to global trash", alias) deleted_alias = DeletedAlias(email=alias.email)
Session.add(DeletedAlias(email=alias.email)) Session.add(deleted_alias)
Session.commit() Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}")
LOG.i("delete alias %s", alias)
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
EventDispatcher.send_event(
user, EventContent(alias_deleted=AliasDeleted(alias_id=alias.id))
)
def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]: def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]:
""" """
@ -458,3 +465,15 @@ def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
alias.pinned = False alias.pinned = False
Session.commit() Session.commit()
def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
alias.enabled = enabled
event = AliasStatusChange(
alias_id=alias.id, alias_email=alias.email, enabled=enabled
)
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))
if commit:
Session.commit()

View File

@ -184,7 +184,7 @@ def toggle_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias.enabled = not alias.enabled alias_utils.change_alias_status(alias, enabled=not alias.enabled)
Session.commit() Session.commit()
return jsonify(enabled=alias.enabled), 200 return jsonify(enabled=alias.enabled), 200

View File

@ -3,11 +3,13 @@ from flask_login import login_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import EmailChange, ResetPasswordCode from app.models import EmailChange, ResetPasswordCode
@auth_bp.route("/change_email", methods=["GET", "POST"]) @auth_bp.route("/change_email", methods=["GET", "POST"])
@limiter.limit("3/hour")
def change_email(): def change_email():
code = request.args.get("code") code = request.args.get("code")

View File

@ -1,14 +1,13 @@
from flask import request, session, redirect, flash, url_for from flask import request, session, redirect, flash, url_for
from requests_oauthlib import OAuth2Session from requests_oauthlib import OAuth2Session
import requests
from app import config from app import config
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
from app.config import ( from app.config import (
URL, URL,
OIDC_AUTHORIZATION_URL,
OIDC_USER_INFO_URL,
OIDC_TOKEN_URL,
OIDC_SCOPES, OIDC_SCOPES,
OIDC_NAME_FIELD, OIDC_NAME_FIELD,
) )
@ -16,14 +15,15 @@ from app.db import Session
from app.email_utils import send_welcome_email from app.email_utils import send_welcome_email
from app.log import LOG from app.log import LOG
from app.models import User, SocialAuth from app.models import User, SocialAuth
from app.utils import encode_url, sanitize_email, sanitize_next_url from app.utils import sanitize_email, sanitize_next_url
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri # need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
# when served behind nginx, the redirect_uri is localhost... and not the real url # when served behind nginx, the redirect_uri is localhost... and not the real url
_redirect_uri = URL + "/auth/oidc/callback" redirect_uri = URL + "/auth/oidc/callback"
SESSION_STATE_KEY = "oauth_state" SESSION_STATE_KEY = "oauth_state"
SESSION_NEXT_KEY = "oauth_redirect_next"
@auth_bp.route("/oidc/login") @auth_bp.route("/oidc/login")
@ -32,18 +32,17 @@ def oidc_login():
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next")) next_url = sanitize_next_url(request.args.get("next"))
if next_url:
redirect_uri = _redirect_uri + "?next=" + encode_url(next_url) auth_url = requests.get(config.OIDC_WELL_KNOWN_URL).json()["authorization_endpoint"]
else:
redirect_uri = _redirect_uri
oidc = OAuth2Session( oidc = OAuth2Session(
config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri
) )
authorization_url, state = oidc.authorization_url(OIDC_AUTHORIZATION_URL) authorization_url, state = oidc.authorization_url(auth_url)
# State is used to prevent CSRF, keep this for later. # State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state session[SESSION_STATE_KEY] = state
session[SESSION_NEXT_KEY] = next_url
return redirect(authorization_url) return redirect(authorization_url)
@ -60,19 +59,23 @@ def oidc_callback():
flash("Please use another sign in method then", "warning") flash("Please use another sign in method then", "warning")
return redirect("/") return redirect("/")
oidc_configuration = requests.get(config.OIDC_WELL_KNOWN_URL).json()
user_info_url = oidc_configuration["userinfo_endpoint"]
token_url = oidc_configuration["token_endpoint"]
oidc = OAuth2Session( oidc = OAuth2Session(
config.OIDC_CLIENT_ID, config.OIDC_CLIENT_ID,
state=session[SESSION_STATE_KEY], state=session[SESSION_STATE_KEY],
scope=[OIDC_SCOPES], scope=[OIDC_SCOPES],
redirect_uri=_redirect_uri, redirect_uri=redirect_uri,
) )
oidc.fetch_token( oidc.fetch_token(
OIDC_TOKEN_URL, token_url,
client_secret=config.OIDC_CLIENT_SECRET, client_secret=config.OIDC_CLIENT_SECRET,
authorization_response=request.url, authorization_response=request.url,
) )
oidc_user_data = oidc.get(OIDC_USER_INFO_URL) oidc_user_data = oidc.get(user_info_url)
if oidc_user_data.status_code != 200: if oidc_user_data.status_code != 200:
LOG.e( LOG.e(
f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}" f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}"
@ -111,7 +114,8 @@ def oidc_callback():
Session.commit() Session.commit()
# The activation link contains the original page, for ex authorize page # The activation link contains the original page, for ex authorize page
next_url = sanitize_next_url(request.args.get("next")) if request.args else None next_url = session[SESSION_NEXT_KEY]
session[SESSION_NEXT_KEY] = None
return after_login(user, next_url) return after_login(user, next_url)

View File

@ -245,9 +245,7 @@ FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON") CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON")
OIDC_AUTHORIZATION_URL = os.environ.get("OIDC_AUTHORIZATION_URL") OIDC_WELL_KNOWN_URL = os.environ.get("OIDC_WELL_KNOWN_URL")
OIDC_USER_INFO_URL = os.environ.get("OIDC_USER_INFO_URL")
OIDC_TOKEN_URL = os.environ.get("OIDC_TOKEN_URL")
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID") OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID")
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET") OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET")
OIDC_SCOPES = os.environ.get("OIDC_SCOPES") OIDC_SCOPES = os.environ.get("OIDC_SCOPES")
@ -431,9 +429,11 @@ except Exception:
HIBP_SCAN_INTERVAL_DAYS = 7 HIBP_SCAN_INTERVAL_DAYS = 7
HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or [] HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or []
HIBP_MAX_ALIAS_CHECK = 10_000 HIBP_MAX_ALIAS_CHECK = 10_000
HIBP_RPM = 100 HIBP_RPM = int(os.environ.get("HIBP_API_RPM", 100))
HIBP_SKIP_PARTNER_ALIAS = os.environ.get("HIBP_SKIP_PARTNER_ALIAS") HIBP_SKIP_PARTNER_ALIAS = os.environ.get("HIBP_SKIP_PARTNER_ALIAS")
KEEP_OLD_DATA_DAYS = 30
POSTMASTER = os.environ.get("POSTMASTER") POSTMASTER = os.environ.get("POSTMASTER")
# store temporary files, especially for debugging # store temporary files, especially for debugging
@ -581,3 +581,9 @@ UPCLOUD_PASSWORD = os.environ.get("UPCLOUD_PASSWORD", None)
UPCLOUD_DB_ID = os.environ.get("UPCLOUD_DB_ID", None) UPCLOUD_DB_ID = os.environ.get("UPCLOUD_DB_ID", None)
STORE_TRANSACTIONAL_EMAILS = "STORE_TRANSACTIONAL_EMAILS" in os.environ STORE_TRANSACTIONAL_EMAILS = "STORE_TRANSACTIONAL_EMAILS" in os.environ
EVENT_WEBHOOK = os.environ.get("EVENT_WEBHOOK", None)
# We want it disabled by default, so only skip if defined
EVENT_WEBHOOK_SKIP_VERIFY_SSL = "EVENT_WEBHOOK_SKIP_VERIFY_SSL" in os.environ
EVENT_WEBHOOK_DISABLE = "EVENT_WEBHOOK_DISABLE" in os.environ

View File

@ -2,10 +2,12 @@ from app.dashboard.base import dashboard_bp
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app.alias_utils import alias_export_csv from app.alias_utils import alias_export_csv
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import limiter
@dashboard_bp.route("/alias_export", methods=["GET"]) @dashboard_bp.route("/alias_export", methods=["GET"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("2/minute")
def alias_export_route(): def alias_export_route():
return alias_export_csv(current_user) return alias_export_csv(current_user)

View File

@ -7,6 +7,7 @@ from app.config import JOB_BATCH_IMPORT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import File, BatchImport, Job from app.models import File, BatchImport, Job
from app.utils import random_string, CSRFValidationForm from app.utils import random_string, CSRFValidationForm
@ -15,6 +16,7 @@ from app.utils import random_string, CSRFValidationForm
@dashboard_bp.route("/batch_import", methods=["GET", "POST"]) @dashboard_bp.route("/batch_import", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("10/minute", methods=["POST"])
def batch_import_route(): def batch_import_route():
# only for users who have custom domains # only for users who have custom domains
if not current_user.verified_custom_domains(): if not current_user.verified_custom_domains():
@ -39,7 +41,7 @@ def batch_import_route():
return redirect(request.url) return redirect(request.url)
if len(batch_imports) > 10: if len(batch_imports) > 10:
flash( flash(
"You have too many imports already. Wait until some get cleaned up", "You have too many imports already. Please wait until some get cleaned up",
"error", "error",
) )
return render_template( return render_template(

View File

@ -14,7 +14,7 @@ from app.models import PartnerUser, SocialAuth
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
_SUDO_GAP = 900 _SUDO_GAP = 120
class LoginForm(FlaskForm): class LoginForm(FlaskForm):

View File

@ -141,12 +141,12 @@ def index():
) )
if request.form.get("form-name") == "delete-alias": if request.form.get("form-name") == "delete-alias":
LOG.d("delete alias %s", alias) LOG.i(f"User {current_user} requested deletion of alias {alias}")
email = alias.email email = alias.email
alias_utils.delete_alias(alias, current_user) alias_utils.delete_alias(alias, current_user)
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias": elif request.form.get("form-name") == "disable-alias":
alias.enabled = False alias_utils.change_alias_status(alias, enabled=False)
Session.commit() Session.commit()
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")

View File

@ -11,9 +11,11 @@ from wtforms.fields.html5 import EmailField
from app.config import ENFORCE_SPF, MAILBOX_SECRET from app.config import ENFORCE_SPF, MAILBOX_SECRET
from app.config import URL from app.config import URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
@ -29,6 +31,8 @@ class ChangeEmailForm(FlaskForm):
@dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"])
@login_required @login_required
@sudo_required
@limiter.limit("20/minute", methods=["POST"])
def mailbox_detail_route(mailbox_id): def mailbox_detail_route(mailbox_id):
mailbox: Mailbox = Mailbox.get(mailbox_id) mailbox: Mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != current_user.id: if not mailbox or mailbox.user_id != current_user.id:
@ -179,8 +183,15 @@ def mailbox_detail_route(mailbox_id):
elif request.form.get("form-name") == "toggle-pgp": elif request.form.get("form-name") == "toggle-pgp":
if request.form.get("pgp-enabled") == "on": if request.form.get("pgp-enabled") == "on":
mailbox.disable_pgp = False if mailbox.is_proton():
flash(f"PGP is enabled on {mailbox.email}", "success") mailbox.disable_pgp = True
flash(
"Enabling PGP for a Proton Mail mailbox is redundant and does not add any security benefit",
"info",
)
else:
mailbox.disable_pgp = False
flash(f"PGP is enabled on {mailbox.email}", "info")
else: else:
mailbox.disable_pgp = True mailbox.disable_pgp = True
flash(f"PGP is disabled on {mailbox.email}", "info") flash(f"PGP is disabled on {mailbox.email}", "info")

View File

@ -227,6 +227,21 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "enable_data_breach_check":
if not current_user.is_premium():
flash("Only premium plan can enable data breach monitoring", "warning")
return redirect(url_for("dashboard.setting"))
choose = request.form.get("enable_data_breach_check")
if choose == "on":
LOG.i("User {current_user} has enabled data breach monitoring")
current_user.enable_data_breach_check = True
flash("Data breach monitoring is enabled", "success")
else:
LOG.i("User {current_user} has disabled data breach monitoring")
current_user.enable_data_breach_check = False
flash("Data breach monitoring is disabled", "info")
Session.commit()
return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "sender-in-ra": elif request.form.get("form-name") == "sender-in-ra":
choose = request.form.get("enable") choose = request.form.get("enable")
if choose == "on": if choose == "on":

View File

@ -8,6 +8,7 @@ from app.db import Session
from flask import redirect, url_for, flash, request, render_template from flask import redirect, url_for, flash, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import alias_utils
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.handler.unsubscribe_encoder import UnsubscribeAction from app.handler.unsubscribe_encoder import UnsubscribeAction
from app.handler.unsubscribe_handler import UnsubscribeHandler from app.handler.unsubscribe_handler import UnsubscribeHandler
@ -31,7 +32,7 @@ def unsubscribe(alias_id):
# automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058 # automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058
if request.method == "POST": if request.method == "POST":
alias.enabled = False alias_utils.change_alias_status(alias, False)
flash(f"Alias {alias.email} has been blocked", "success") flash(f"Alias {alias.email} has been blocked", "success")
Session.commit() Session.commit()

View File

@ -21,6 +21,7 @@ LIST_UNSUBSCRIBE = "List-Unsubscribe"
LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post" LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post"
RETURN_PATH = "Return-Path" RETURN_PATH = "Return-Path"
AUTHENTICATION_RESULTS = "Authentication-Results" AUTHENTICATION_RESULTS = "Authentication-Results"
SL_QUEUE_ID = "X-SL-Queue-Id"
# headers used to DKIM sign in order of preference # headers used to DKIM sign in order of preference
DKIM_HEADERS = [ DKIM_HEADERS = [

View File

@ -494,9 +494,10 @@ def delete_header(msg: Message, header: str):
def sanitize_header(msg: Message, header: str): def sanitize_header(msg: Message, header: str):
"""remove trailing space and remove linebreak from a header""" """remove trailing space and remove linebreak from a header"""
header_lowercase = header.lower()
for i in reversed(range(len(msg._headers))): for i in reversed(range(len(msg._headers))):
header_name = msg._headers[i][0].lower() header_name = msg._headers[i][0].lower()
if header_name == header.lower(): if header_name == header_lowercase:
# msg._headers[i] is a tuple like ('From', 'hey@google.com') # msg._headers[i] is a tuple like ('From', 'hey@google.com')
if msg._headers[i][1]: if msg._headers[i][1]:
msg._headers[i] = ( msg._headers[i] = (

0
app/events/__init__.py Normal file
View File

View File

@ -0,0 +1,66 @@
from abc import ABC, abstractmethod
from app import config
from app.db import Session
from app.errors import ProtonPartnerNotSetUp
from app.events.generated import event_pb2
from app.models import User, PartnerUser, SyncEvent
from app.proton.utils import get_proton_partner
from typing import Optional
NOTIFICATION_CHANNEL = "simplelogin_sync_events"
class Dispatcher(ABC):
@abstractmethod
def send(self, event: bytes):
pass
class PostgresDispatcher(Dispatcher):
def send(self, event: bytes):
instance = SyncEvent.create(content=event, flush=True)
Session.execute(f"NOTIFY {NOTIFICATION_CHANNEL}, '{instance.id}';")
@staticmethod
def get():
return PostgresDispatcher()
class EventDispatcher:
@staticmethod
def send_event(
user: User,
content: event_pb2.EventContent,
dispatcher: Dispatcher = PostgresDispatcher.get(),
skip_if_webhook_missing: bool = True,
):
if config.EVENT_WEBHOOK_DISABLE:
return
if not config.EVENT_WEBHOOK and skip_if_webhook_missing:
return
partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user:
return
event = event_pb2.Event(
user_id=user.id,
external_user_id=partner_user.external_user_id,
partner_id=partner_user.partner_id,
content=content,
)
serialized = event.SerializeToString()
dispatcher.send(serialized)
@staticmethod
def __partner_user(user_id: int) -> Optional[PartnerUser]:
# Check if the current user has a partner_id
try:
proton_partner_id = get_proton_partner().id
except ProtonPartnerNotSetUp:
return None
# It has. Retrieve the information for the PartnerUser
return PartnerUser.get_by(user_id=user_id, partner_id=proton_partner_id)

View File

@ -0,0 +1,38 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# source: event.proto
# Protobuf Python Version: 5.26.1
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"\'\n\x0eUserPlanChange\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"Z\n\x0c\x41liasCreated\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x12\n\nalias_note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\"K\n\x11\x41liasStatusChange\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\"5\n\x0c\x41liasDeleted\x12\x10\n\x08\x61lias_id\x18\x01 \x01(\r\x12\x13\n\x0b\x61lias_email\x18\x02 \x01(\t\"\xce\x02\n\x0c\x45ventContent\x12>\n\x10user_plan_change\x18\x01 \x01(\x0b\x32\".simplelogin_events.UserPlanChangeH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x44\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32%.simplelogin_events.AliasStatusChangeH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'event_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_USERPLANCHANGE']._serialized_start=35
_globals['_USERPLANCHANGE']._serialized_end=74
_globals['_USERDELETED']._serialized_start=76
_globals['_USERDELETED']._serialized_end=89
_globals['_ALIASCREATED']._serialized_start=91
_globals['_ALIASCREATED']._serialized_end=181
_globals['_ALIASSTATUSCHANGE']._serialized_start=183
_globals['_ALIASSTATUSCHANGE']._serialized_end=258
_globals['_ALIASDELETED']._serialized_start=260
_globals['_ALIASDELETED']._serialized_end=313
_globals['_EVENTCONTENT']._serialized_start=316
_globals['_EVENTCONTENT']._serialized_end=650
_globals['_EVENT']._serialized_start=652
_globals['_EVENT']._serialized_end=773
# @@protoc_insertion_point(module_scope)

View File

@ -0,0 +1,71 @@
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class UserPlanChange(_message.Message):
__slots__ = ("plan_end_time",)
PLAN_END_TIME_FIELD_NUMBER: _ClassVar[int]
plan_end_time: int
def __init__(self, plan_end_time: _Optional[int] = ...) -> None: ...
class UserDeleted(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class AliasCreated(_message.Message):
__slots__ = ("alias_id", "alias_email", "alias_note", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
ALIAS_NOTE_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
alias_note: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., alias_note: _Optional[str] = ..., enabled: bool = ...) -> None: ...
class AliasStatusChange(_message.Message):
__slots__ = ("alias_id", "alias_email", "enabled")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
enabled: bool
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ..., enabled: bool = ...) -> None: ...
class AliasDeleted(_message.Message):
__slots__ = ("alias_id", "alias_email")
ALIAS_ID_FIELD_NUMBER: _ClassVar[int]
ALIAS_EMAIL_FIELD_NUMBER: _ClassVar[int]
alias_id: int
alias_email: str
def __init__(self, alias_id: _Optional[int] = ..., alias_email: _Optional[str] = ...) -> None: ...
class EventContent(_message.Message):
__slots__ = ("user_plan_change", "user_deleted", "alias_created", "alias_status_change", "alias_deleted")
USER_PLAN_CHANGE_FIELD_NUMBER: _ClassVar[int]
USER_DELETED_FIELD_NUMBER: _ClassVar[int]
ALIAS_CREATED_FIELD_NUMBER: _ClassVar[int]
ALIAS_STATUS_CHANGE_FIELD_NUMBER: _ClassVar[int]
ALIAS_DELETED_FIELD_NUMBER: _ClassVar[int]
user_plan_change: UserPlanChange
user_deleted: UserDeleted
alias_created: AliasCreated
alias_status_change: AliasStatusChange
alias_deleted: AliasDeleted
def __init__(self, user_plan_change: _Optional[_Union[UserPlanChange, _Mapping]] = ..., user_deleted: _Optional[_Union[UserDeleted, _Mapping]] = ..., alias_created: _Optional[_Union[AliasCreated, _Mapping]] = ..., alias_status_change: _Optional[_Union[AliasStatusChange, _Mapping]] = ..., alias_deleted: _Optional[_Union[AliasDeleted, _Mapping]] = ...) -> None: ...
class Event(_message.Message):
__slots__ = ("user_id", "external_user_id", "partner_id", "content")
USER_ID_FIELD_NUMBER: _ClassVar[int]
EXTERNAL_USER_ID_FIELD_NUMBER: _ClassVar[int]
PARTNER_ID_FIELD_NUMBER: _ClassVar[int]
CONTENT_FIELD_NUMBER: _ClassVar[int]
user_id: int
external_user_id: str
partner_id: int
content: EventContent
def __init__(self, user_id: _Optional[int] = ..., external_user_id: _Optional[str] = ..., partner_id: _Optional[int] = ..., content: _Optional[_Union[EventContent, _Mapping]] = ...) -> None: ...

View File

@ -30,7 +30,9 @@ def apply_dmarc_policy_for_forward_phase(
) -> Tuple[Message, Optional[str]]: ) -> Tuple[Message, Optional[str]]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.forward) spam_result = SpamdResult.extract_from_headers(msg, Phase.forward)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return msg, None return msg, None
LOG.i(f"Spam check result in {spam_result}")
from_header = get_header_unicode(msg[headers.FROM]) from_header = get_header_unicode(msg[headers.FROM])
@ -150,8 +152,10 @@ def apply_dmarc_policy_for_reply_phase(
) -> Optional[str]: ) -> Optional[str]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.reply) spam_result = SpamdResult.extract_from_headers(msg, Phase.reply)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return None return None
LOG.i(f"Spam check result is {spam_result}")
if spam_result.dmarc not in ( if spam_result.dmarc not in (
DmarcCheckResult.quarantine, DmarcCheckResult.quarantine,
DmarcCheckResult.reject, DmarcCheckResult.reject,

View File

@ -5,6 +5,7 @@ from typing import Optional
from aiosmtpd.smtp import Envelope from aiosmtpd.smtp import Envelope
from app import config from app import config
from app import alias_utils
from app.db import Session from app.db import Session
from app.email import headers, status from app.email import headers, status
from app.email_utils import ( from app.email_utils import (
@ -101,7 +102,7 @@ class UnsubscribeHandler:
mailbox.email, alias mailbox.email, alias
): ):
return status.E509 return status.E509
alias.enabled = False alias_utils.change_alias_status(alias, enabled=False)
Session.commit() Session.commit()
enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}" enable_alias_url = config.URL + f"/dashboard/?highlight_alias_id={alias.id}"
for mailbox in alias.mailboxes: for mailbox in alias.mailboxes:

View File

@ -30,7 +30,10 @@ def handle_batch_import(batch_import: BatchImport):
LOG.d("Download file %s from %s", batch_import.file, file_url) LOG.d("Download file %s from %s", batch_import.file, file_url)
r = requests.get(file_url) r = requests.get(file_url)
lines = [line.decode("utf-8") for line in r.iter_lines()] # Replace invisible character
lines = [
line.decode("utf-8").replace("\ufeff", "").strip() for line in r.iter_lines()
]
import_from_csv(batch_import, user, lines) import_from_csv(batch_import, user, lines)

View File

@ -525,6 +525,11 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
sa.Boolean, default=True, nullable=False, server_default="1" sa.Boolean, default=True, nullable=False, server_default="1"
) )
# user opted in for data breach check
enable_data_breach_check = sa.Column(
sa.Boolean, default=False, nullable=False, server_default="0"
)
# bitwise flags. Allow for future expansion # bitwise flags. Allow for future expansion
flags = sa.Column( flags = sa.Column(
sa.BigInteger, sa.BigInteger,
@ -652,6 +657,21 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
return user return user
@classmethod
def delete(cls, obj_id, commit=False):
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserDeleted, EventContent
user: User = cls.get(obj_id)
EventDispatcher.send_event(user, EventContent(user_deleted=UserDeleted()))
res = super(User, cls).delete(obj_id)
if commit:
Session.commit()
return res
def get_active_subscription( def get_active_subscription(
self, include_partner_subscription: bool = True self, include_partner_subscription: bool = True
) -> Optional[ ) -> Optional[
@ -1614,6 +1634,18 @@ class Alias(Base, ModelMixin):
Session.add(new_alias) Session.add(new_alias)
DailyMetric.get_or_create_today_metric().nb_alias += 1 DailyMetric.get_or_create_today_metric().nb_alias += 1
# Internal import to avoid global import cycles
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import AliasCreated, EventContent
event = AliasCreated(
alias_id=new_alias.id,
alias_email=new_alias.email,
alias_note=new_alias.note,
enabled=True,
)
EventDispatcher.send_event(user, EventContent(alias_created=event))
if commit: if commit:
Session.commit() Session.commit()
@ -2644,10 +2676,15 @@ class Mailbox(Base, ModelMixin):
return False return False
def nb_alias(self): def nb_alias(self):
return ( alias_ids = set(
AliasMailbox.filter_by(mailbox_id=self.id).count() am.alias_id
+ Alias.filter_by(mailbox_id=self.id).count() for am in AliasMailbox.filter_by(mailbox_id=self.id).values(
AliasMailbox.alias_id
)
) )
for alias in Alias.filter_by(mailbox_id=self.id).values(Alias.id):
alias_ids.add(alias.id)
return len(alias_ids)
def is_proton(self) -> bool: def is_proton(self) -> bool:
if ( if (
@ -2696,12 +2733,15 @@ class Mailbox(Base, ModelMixin):
@property @property
def aliases(self) -> [Alias]: def aliases(self) -> [Alias]:
ret = Alias.filter_by(mailbox_id=self.id).all() ret = dict(
(alias.id, alias) for alias in Alias.filter_by(mailbox_id=self.id).all()
)
for am in AliasMailbox.filter_by(mailbox_id=self.id): for am in AliasMailbox.filter_by(mailbox_id=self.id):
ret.append(am.alias) if am.alias_id not in ret:
ret[am.alias_id] = am.alias
return ret return list(ret.values())
@classmethod @classmethod
def create(cls, **kw): def create(cls, **kw):
@ -3635,3 +3675,52 @@ class ApiToCookieToken(Base, ModelMixin):
code = secrets.token_urlsafe(32) code = secrets.token_urlsafe(32)
return super().create(code=code, **kwargs) return super().create(code=code, **kwargs)
class SyncEvent(Base, ModelMixin):
"""This model holds the events that need to be sent to the webhook"""
__tablename__ = "sync_event"
content = sa.Column(sa.LargeBinary, unique=False, nullable=False)
taken_time = sa.Column(
ArrowType, default=None, nullable=True, server_default=None, index=True
)
__table_args__ = (
sa.Index("ix_sync_event_created_at", "created_at"),
sa.Index("ix_sync_event_taken_time", "taken_time"),
)
def mark_as_taken(self) -> bool:
sql = """
UPDATE sync_event
SET taken_time = :taken_time
WHERE id = :sync_event_id
AND taken_time IS NULL
"""
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
res = Session.execute(sql, args)
Session.commit()
return res.rowcount > 0
@classmethod
def get_dead_letter(cls, older_than: Arrow) -> [SyncEvent]:
return (
SyncEvent.filter(
(
(
SyncEvent.taken_time.isnot(None)
& (SyncEvent.taken_time < older_than)
)
| (
SyncEvent.taken_time.is_(None)
& (SyncEvent.created_at < older_than)
)
)
)
.order_by(SyncEvent.id)
.limit(100)
.all()
)

View File

@ -30,7 +30,9 @@ def check_bucket_limit(
try: try:
value = lock_redis.incr(bucket_lock_name, bucket_seconds) value = lock_redis.incr(bucket_lock_name, bucket_seconds)
if value > max_hits: if value > max_hits:
LOG.i(f"Rate limit hit for {bucket_lock_name} -> {value}/{max_hits}") LOG.i(
f"Rate limit hit for {lock_name} (bucket id {bucket_id}) -> {value}/{max_hits}"
)
newrelic.agent.record_custom_event( newrelic.agent.record_custom_event(
"BucketRateLimit", "BucketRateLimit",
{"lock_name": lock_name, "bucket_seconds": bucket_seconds}, {"lock_name": lock_name, "bucket_seconds": bucket_seconds},

View File

@ -5,19 +5,9 @@ from typing import Optional
import boto3 import boto3
import requests import requests
from app.config import ( from app import config
AWS_REGION,
BUCKET,
AWS_ACCESS_KEY_ID,
AWS_SECRET_ACCESS_KEY,
LOCAL_FILE_UPLOAD,
UPLOAD_DIR,
URL,
AWS_ENDPOINT_URL,
)
from app.log import LOG from app.log import LOG
_s3_client = None _s3_client = None
@ -25,12 +15,12 @@ def _get_s3client():
global _s3_client global _s3_client
if _s3_client is None: if _s3_client is None:
args = { args = {
"aws_access_key_id": AWS_ACCESS_KEY_ID, "aws_access_key_id": config.AWS_ACCESS_KEY_ID,
"aws_secret_access_key": AWS_SECRET_ACCESS_KEY, "aws_secret_access_key": config.AWS_SECRET_ACCESS_KEY,
"region_name": AWS_REGION, "region_name": config.AWS_REGION,
} }
if AWS_ENDPOINT_URL: if config.AWS_ENDPOINT_URL:
args["endpoint_url"] = AWS_ENDPOINT_URL args["endpoint_url"] = config.AWS_ENDPOINT_URL
_s3_client = boto3.client("s3", **args) _s3_client = boto3.client("s3", **args)
return _s3_client return _s3_client
@ -38,8 +28,8 @@ def _get_s3client():
def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-stream"): def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-stream"):
bs.seek(0) bs.seek(0)
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, key) file_path = os.path.join(config.UPLOAD_DIR, key)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True) os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@ -47,7 +37,7 @@ def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-s
else: else:
_get_s3client().put_object( _get_s3client().put_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=key, Key=key,
Body=bs, Body=bs,
ContentType=content_type, ContentType=content_type,
@ -57,8 +47,8 @@ def upload_from_bytesio(key: str, bs: BytesIO, content_type="application/octet-s
def upload_email_from_bytesio(path: str, bs: BytesIO, filename): def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
bs.seek(0) bs.seek(0)
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, path) file_path = os.path.join(config.UPLOAD_DIR, path)
file_dir = os.path.dirname(file_path) file_dir = os.path.dirname(file_path)
os.makedirs(file_dir, exist_ok=True) os.makedirs(file_dir, exist_ok=True)
with open(file_path, "wb") as f: with open(file_path, "wb") as f:
@ -66,7 +56,7 @@ def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
else: else:
_get_s3client().put_object( _get_s3client().put_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=path, Key=path,
Body=bs, Body=bs,
# Support saving a remote file using Http header # Support saving a remote file using Http header
@ -77,12 +67,12 @@ def upload_email_from_bytesio(path: str, bs: BytesIO, filename):
def download_email(path: str) -> Optional[str]: def download_email(path: str) -> Optional[str]:
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
file_path = os.path.join(UPLOAD_DIR, path) file_path = os.path.join(config.UPLOAD_DIR, path)
with open(file_path, "rb") as f: with open(file_path, "rb") as f:
return f.read() return f.read()
resp = _get_s3client().get_object( resp = _get_s3client().get_object(
Bucket=BUCKET, Bucket=config.BUCKET,
Key=path, Key=path,
) )
if not resp or "Body" not in resp: if not resp or "Body" not in resp:
@ -96,29 +86,30 @@ def upload_from_url(url: str, upload_path):
def get_url(key: str, expires_in=3600) -> str: def get_url(key: str, expires_in=3600) -> str:
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
return URL + "/static/upload/" + key return config.URL + "/static/upload/" + key
else: else:
return _get_s3client().generate_presigned_url( return _get_s3client().generate_presigned_url(
ExpiresIn=expires_in, ExpiresIn=expires_in,
ClientMethod="get_object", ClientMethod="get_object",
Params={"Bucket": BUCKET, "Key": key}, Params={"Bucket": config.BUCKET, "Key": key},
) )
def delete(path: str): def delete(path: str):
if LOCAL_FILE_UPLOAD: if config.LOCAL_FILE_UPLOAD:
os.remove(os.path.join(UPLOAD_DIR, path)) file_path = os.path.join(config.UPLOAD_DIR, path)
os.remove(file_path)
else: else:
_get_s3client().delete_object(Bucket=BUCKET, Key=path) _get_s3client().delete_object(Bucket=config.BUCKET, Key=path)
def create_bucket_if_not_exists(): def create_bucket_if_not_exists():
s3client = _get_s3client() s3client = _get_s3client()
buckets = s3client.list_buckets() buckets = s3client.list_buckets()
for bucket in buckets["Buckets"]: for bucket in buckets["Buckets"]:
if bucket["Name"] == BUCKET: if bucket["Name"] == config.BUCKET:
LOG.i("Bucket already exists") LOG.i("Bucket already exists")
return return
s3client.create_bucket(Bucket=BUCKET) s3client.create_bucket(Bucket=config.BUCKET)
LOG.i(f"Bucket {BUCKET} created") LOG.i(f"Bucket {config.BUCKET} created")

View File

@ -2,6 +2,8 @@ import requests
from requests import RequestException from requests import RequestException
from app import config from app import config
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import EventContent, UserPlanChange
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
@ -31,3 +33,6 @@ def execute_subscription_webhook(user: User):
) )
except RequestException as e: except RequestException as e:
LOG.error(f"Subscription request exception: {e}") LOG.error(f"Subscription request exception: {e}")
event = UserPlanChange(plan_end_time=sl_subscription_end)
EventDispatcher.send_event(user, EventContent(user_plan_change=event))

138
cron.py
View File

@ -5,7 +5,7 @@ from typing import List, Tuple
import arrow import arrow
import requests import requests
from sqlalchemy import func, desc, or_, and_, nullsfirst from sqlalchemy import func, desc, or_, and_
from sqlalchemy.ext.compiler import compiles from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import ObjectDeletedError from sqlalchemy.orm.exc import ObjectDeletedError
@ -61,6 +61,9 @@ from app.pgp_utils import load_public_key_and_check, PGPException
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_email from app.utils import sanitize_email
from server import create_light_app from server import create_light_app
from tasks.cleanup_old_imports import cleanup_old_imports
from tasks.cleanup_old_jobs import cleanup_old_jobs
from tasks.cleanup_old_notifications import cleanup_old_notifications
DELETE_GRACE_DAYS = 30 DELETE_GRACE_DAYS = 30
@ -976,6 +979,9 @@ async def _hibp_check(api_key, queue):
continue continue
user = alias.user user = alias.user
if user.disabled or not user.is_paid(): if user.disabled or not user.is_paid():
# Mark it as hibp done to skip it as if it had been checked
alias.hibp_last_check = arrow.utcnow()
Session.commit()
continue continue
LOG.d("Checking HIBP for %s", alias) LOG.d("Checking HIBP for %s", alias)
@ -1030,6 +1036,61 @@ async def _hibp_check(api_key, queue):
await asyncio.sleep(rate_sleep) await asyncio.sleep(rate_sleep)
def get_alias_to_check_hibp(
oldest_hibp_allowed: arrow.Arrow,
user_ids_to_skip: list[int],
min_alias_id: int,
max_alias_id: int,
):
now = arrow.now()
alias_query = (
Session.query(Alias)
.join(User, User.id == Alias.user_id)
.join(Subscription, User.id == Subscription.user_id, isouter=True)
.join(ManualSubscription, User.id == ManualSubscription.user_id, isouter=True)
.join(AppleSubscription, User.id == AppleSubscription.user_id, isouter=True)
.join(
CoinbaseSubscription,
User.id == CoinbaseSubscription.user_id,
isouter=True,
)
.join(PartnerUser, User.id == PartnerUser.user_id, isouter=True)
.join(
PartnerSubscription,
PartnerSubscription.partner_user_id == PartnerUser.id,
isouter=True,
)
.filter(
or_(
Alias.hibp_last_check.is_(None),
Alias.hibp_last_check < oldest_hibp_allowed,
),
Alias.user_id.notin_(user_ids_to_skip),
Alias.enabled,
Alias.id >= min_alias_id,
Alias.id < max_alias_id,
User.disabled == False, # noqa: E712
User.enable_data_breach_check,
or_(
User.lifetime,
ManualSubscription.end_at > now,
Subscription.next_bill_date > now.date(),
AppleSubscription.expires_date > now,
CoinbaseSubscription.end_at > now,
PartnerSubscription.end_at > now,
),
)
)
if config.HIBP_SKIP_PARTNER_ALIAS:
alias_query = alias_query.filter(
Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0
)
for alias in (
alias_query.order_by(Alias.id.asc()).enable_eagerloads(False).yield_per(500)
):
yield alias
async def check_hibp(): async def check_hibp():
""" """
Check all aliases on the HIBP (Have I Been Pwned) API Check all aliases on the HIBP (Have I Been Pwned) API
@ -1056,43 +1117,43 @@ async def check_hibp():
user_ids = [row[0] for row in rows] user_ids = [row[0] for row in rows]
LOG.d("Got %d users to skip" % len(user_ids)) LOG.d("Got %d users to skip" % len(user_ids))
LOG.d("Preparing list of aliases to check") LOG.d("Checking aliases")
queue = asyncio.Queue() queue = asyncio.Queue()
max_date = arrow.now().shift(days=-config.HIBP_SCAN_INTERVAL_DAYS) min_alias_id = 0
alias_query = Alias.filter( max_alias_id = Session.query(func.max(Alias.id)).scalar()
or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date), step = 10000
Alias.user_id.notin_(user_ids), now = arrow.now()
Alias.enabled, oldest_hibp_allowed = now.shift(days=-config.HIBP_SCAN_INTERVAL_DAYS)
) alias_checked = 0
if config.HIBP_SKIP_PARTNER_ALIAS: for alias_batch_id in range(min_alias_id, max_alias_id, step):
alias_query = alias_query(Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0) for alias in get_alias_to_check_hibp(
for alias in ( oldest_hibp_allowed, user_ids, alias_batch_id, alias_batch_id + step
alias_query.order_by(nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc()) ):
.yield_per(500) await queue.put(alias.id)
.enable_eagerloads(False)
):
await queue.put(alias.id)
LOG.d("Need to check about %s aliases", queue.qsize()) alias_checked += queue.qsize()
LOG.d(
# Start one checking process per API key f"Need to check about {queue.qsize()} aliases in this loop {alias_batch_id}/{max_alias_id}"
# Each checking process will take one alias from the queue, get the info
# and then sleep for 1.5 seconds (due to HIBP API request limits)
checkers = []
for i in range(len(config.HIBP_API_KEYS)):
checker = asyncio.create_task(
_hibp_check(
config.HIBP_API_KEYS[i],
queue,
)
) )
checkers.append(checker)
# Wait until all checking processes are done # Start one checking process per API key
for checker in checkers: # Each checking process will take one alias from the queue, get the info
await checker # and then sleep for 1.5 seconds (due to HIBP API request limits)
checkers = []
for i in range(len(config.HIBP_API_KEYS)):
checker = asyncio.create_task(
_hibp_check(
config.HIBP_API_KEYS[i],
queue,
)
)
checkers.append(checker)
LOG.d("Done checking HIBP API for aliases in breaches") # Wait until all checking processes are done
for checker in checkers:
await checker
LOG.d(f"Done checking {alias_checked} HIBP API for aliases in breaches")
def notify_hibp(): def notify_hibp():
@ -1164,6 +1225,13 @@ def clear_users_scheduled_to_be_deleted(dry_run=False):
Session.commit() Session.commit()
def delete_old_data():
oldest_valid = arrow.now().shift(days=-config.KEEP_OLD_DATA_DAYS)
cleanup_old_imports(oldest_valid)
cleanup_old_jobs(oldest_valid)
cleanup_old_notifications(oldest_valid)
if __name__ == "__main__": if __name__ == "__main__":
LOG.d("Start running cronjob") LOG.d("Start running cronjob")
parser = argparse.ArgumentParser() parser = argparse.ArgumentParser()
@ -1178,6 +1246,7 @@ if __name__ == "__main__":
"notify_manual_subscription_end", "notify_manual_subscription_end",
"notify_premium_end", "notify_premium_end",
"delete_logs", "delete_logs",
"delete_old_data",
"poll_apple_subscription", "poll_apple_subscription",
"sanity_check", "sanity_check",
"delete_old_monitoring", "delete_old_monitoring",
@ -1206,6 +1275,9 @@ if __name__ == "__main__":
elif args.job == "delete_logs": elif args.job == "delete_logs":
LOG.d("Deleted Logs") LOG.d("Deleted Logs")
delete_logs() delete_logs()
elif args.job == "delete_old_data":
LOG.d("Delete old data")
delete_old_data()
elif args.job == "poll_apple_subscription": elif args.job == "poll_apple_subscription":
LOG.d("Poll Apple Subscriptions") LOG.d("Poll Apple Subscriptions")
poll_apple_subscription() poll_apple_subscription()

View File

@ -37,6 +37,12 @@ jobs:
schedule: "15 5 * * *" schedule: "15 5 * * *"
captureStderr: true captureStderr: true
- name: SimpleLogin Delete Old data
command: python /code/cron.py -j delete_old_data
shell: /bin/bash
schedule: "30 5 * * *"
captureStderr: true
- name: SimpleLogin Poll Apple Subscriptions - name: SimpleLogin Poll Apple Subscriptions
command: python /code/cron.py -j poll_apple_subscription command: python /code/cron.py -j poll_apple_subscription
shell: /bin/bash shell: /bin/bash

View File

@ -53,7 +53,7 @@ from flanker.addresslib.address import EmailAddress
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import pgp_utils, s3, config from app import pgp_utils, s3, config
from app.alias_utils import try_auto_create from app.alias_utils import try_auto_create, change_alias_status
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
URL, URL,
@ -875,6 +875,7 @@ def forward_email_to_mailbox(
# References and In-Reply-To are used for keeping the email thread # References and In-Reply-To are used for keeping the email thread
headers.REFERENCES, headers.REFERENCES,
headers.IN_REPLY_TO, headers.IN_REPLY_TO,
headers.SL_QUEUE_ID,
headers.LIST_UNSUBSCRIBE, headers.LIST_UNSUBSCRIBE,
headers.LIST_UNSUBSCRIBE_POST, headers.LIST_UNSUBSCRIBE_POST,
] + headers.MIME_HEADERS ] + headers.MIME_HEADERS
@ -1179,6 +1180,7 @@ def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
# References and In-Reply-To are used for keeping the email thread # References and In-Reply-To are used for keeping the email thread
headers.REFERENCES, headers.REFERENCES,
headers.IN_REPLY_TO, headers.IN_REPLY_TO,
headers.SL_QUEUE_ID,
] ]
+ headers.MIME_HEADERS, + headers.MIME_HEADERS,
) )
@ -1583,7 +1585,7 @@ def handle_bounce_forward_phase(msg: Message, email_log: EmailLog):
LOG.w( LOG.w(
f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}" f"Disable alias {alias} because {reason}. {alias.mailboxes} {alias.user}. Last contact {contact}"
) )
alias.enabled = False change_alias_status(alias, enabled=False)
Notification.create( Notification.create(
user_id=user.id, user_id=user.id,
@ -2040,10 +2042,11 @@ def handle(envelope: Envelope, msg: Message) -> str:
return status.E204 return status.E204
# sanitize email headers # sanitize email headers
sanitize_header(msg, "from") sanitize_header(msg, headers.FROM)
sanitize_header(msg, "to") sanitize_header(msg, headers.TO)
sanitize_header(msg, "cc") sanitize_header(msg, headers.CC)
sanitize_header(msg, "reply-to") sanitize_header(msg, headers.REPLY_TO)
sanitize_header(msg, headers.MESSAGE_ID)
LOG.d( LOG.d(
"==>> Handle mail_from:%s, rcpt_tos:%s, header_from:%s, header_to:%s, " "==>> Handle mail_from:%s, rcpt_tos:%s, header_from:%s, header_to:%s, "

64
event_listener.py Normal file
View File

@ -0,0 +1,64 @@
import argparse
from enum import Enum
from sys import argv, exit
from app.config import DB_URI
from app.log import LOG
from events.runner import Runner
from events.event_source import DeadLetterEventSource, PostgresEventSource
from events.event_sink import ConsoleEventSink, HttpEventSink
class Mode(Enum):
DEAD_LETTER = "dead_letter"
LISTENER = "listener"
@staticmethod
def from_str(value: str):
if value == Mode.DEAD_LETTER.value:
return Mode.DEAD_LETTER
elif value == Mode.LISTENER.value:
return Mode.LISTENER
else:
raise ValueError(f"Invalid mode: {value}")
def main(mode: Mode, dry_run: bool):
if mode == Mode.DEAD_LETTER:
LOG.i("Using DeadLetterEventSource")
source = DeadLetterEventSource()
elif mode == Mode.LISTENER:
LOG.i("Using PostgresEventSource")
source = PostgresEventSource(DB_URI)
else:
raise ValueError(f"Invalid mode: {mode}")
if dry_run:
LOG.i("Starting with ConsoleEventSink")
sink = ConsoleEventSink()
else:
LOG.i("Starting with HttpEventSink")
sink = HttpEventSink()
runner = Runner(source=source, sink=sink)
runner.run()
def args():
parser = argparse.ArgumentParser(description="Run event listener")
parser.add_argument(
"mode",
help="Mode to run",
choices=[Mode.DEAD_LETTER.value, Mode.LISTENER.value],
)
parser.add_argument("--dry-run", help="Dry run mode", action="store_true")
return parser.parse_args()
if __name__ == "__main__":
if len(argv) < 2:
print("Invalid usage. Pass 'listener' or 'dead_letter' as argument")
exit(1)
args = args()
main(Mode.from_str(args.mode), args.dry_run)

0
events/__init__.py Normal file
View File

42
events/event_sink.py Normal file
View File

@ -0,0 +1,42 @@
import requests
from abc import ABC, abstractmethod
from app.config import EVENT_WEBHOOK, EVENT_WEBHOOK_SKIP_VERIFY_SSL
from app.log import LOG
from app.models import SyncEvent
class EventSink(ABC):
@abstractmethod
def process(self, event: SyncEvent) -> bool:
pass
class HttpEventSink(EventSink):
def process(self, event: SyncEvent) -> bool:
if not EVENT_WEBHOOK:
LOG.warning("Skipping sending event because there is no webhook configured")
return False
LOG.info(f"Sending event {event.id} to {EVENT_WEBHOOK}")
res = requests.post(
url=EVENT_WEBHOOK,
data=event.content,
headers={"Content-Type": "application/x-protobuf"},
verify=not EVENT_WEBHOOK_SKIP_VERIFY_SSL,
)
if res.status_code != 200:
LOG.warning(
f"Failed to send event to webhook: {res.status_code} {res.text}"
)
return False
else:
LOG.info(f"Event {event.id} sent successfully to webhook")
return True
class ConsoleEventSink(EventSink):
def process(self, event: SyncEvent) -> bool:
LOG.info(f"Handling event {event.id}")
return True

100
events/event_source.py Normal file
View File

@ -0,0 +1,100 @@
import arrow
import newrelic.agent
import psycopg2
import select
from abc import ABC, abstractmethod
from app.log import LOG
from app.models import SyncEvent
from app.events.event_dispatcher import NOTIFICATION_CHANNEL
from time import sleep
from typing import Callable, NoReturn
_DEAD_LETTER_THRESHOLD_MINUTES = 10
_DEAD_LETTER_INTERVAL_SECONDS = 30
_POSTGRES_RECONNECT_INTERVAL_SECONDS = 5
class EventSource(ABC):
@abstractmethod
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
pass
class PostgresEventSource(EventSource):
def __init__(self, connection_string: str):
self.__connection_string = connection_string
self.__connect()
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True:
try:
self.__listen(on_event)
except Exception as e:
LOG.warn(f"Error listening to events: {e}")
sleep(_POSTGRES_RECONNECT_INTERVAL_SECONDS)
self.__connect()
def __listen(self, on_event: Callable[[SyncEvent], NoReturn]):
self.__connection.set_isolation_level(
psycopg2.extensions.ISOLATION_LEVEL_AUTOCOMMIT
)
cursor = self.__connection.cursor()
cursor.execute(f"LISTEN {NOTIFICATION_CHANNEL};")
while True:
if select.select([self.__connection], [], [], 5) != ([], [], []):
self.__connection.poll()
while self.__connection.notifies:
notify = self.__connection.notifies.pop(0)
LOG.debug(
f"Got NOTIFY: pid={notify.pid} channel={notify.channel} payload={notify.payload}"
)
try:
webhook_id = int(notify.payload)
event = SyncEvent.get_by(id=webhook_id)
if event is not None:
if event.mark_as_taken():
on_event(event)
else:
LOG.info(
f"Event {event.id} was handled by another runner"
)
else:
LOG.info(f"Could not find event with id={notify.payload}")
except Exception as e:
LOG.warn(f"Error getting event: {e}")
def __connect(self):
self.__connection = psycopg2.connect(self.__connection_string)
from app.db import Session
Session.close()
class DeadLetterEventSource(EventSource):
@newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True:
try:
threshold = arrow.utcnow().shift(
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
)
events = SyncEvent.get_dead_letter(older_than=threshold)
if events:
LOG.info(f"Got {len(events)} dead letter events")
if events:
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
on_event(event)
else:
LOG.debug("No dead letter events")
sleep(_DEAD_LETTER_INTERVAL_SECONDS)
except Exception as e:
LOG.warn(f"Error getting dead letter event: {e}")
sleep(_DEAD_LETTER_INTERVAL_SECONDS)

42
events/runner.py Normal file
View File

@ -0,0 +1,42 @@
import arrow
import newrelic.agent
from app.log import LOG
from app.models import SyncEvent
from events.event_sink import EventSink
from events.event_source import EventSource
class Runner:
def __init__(self, source: EventSource, sink: EventSink):
self.__source = source
self.__sink = sink
def run(self):
self.__source.run(self.__on_event)
@newrelic.agent.background_task()
def __on_event(self, event: SyncEvent):
try:
event_created_at = event.created_at
start_time = arrow.now()
success = self.__sink.process(event)
if success:
event_id = event.id
SyncEvent.delete(event.id, commit=True)
LOG.info(f"Marked {event_id} as done")
end_time = arrow.now() - start_time
time_between_taken_and_created = start_time - event_created_at
newrelic.agent.record_custom_metric("Custom/sync_event_processed", 1)
newrelic.agent.record_custom_metric(
"Custom/sync_event_process_time", end_time.total_seconds()
)
newrelic.agent.record_custom_metric(
"Custom/sync_event_elapsed_time",
time_between_taken_and_created.total_seconds(),
)
except Exception as e:
LOG.warn(f"Exception processing event [id={event.id}]: {e}")
newrelic.agent.record_custom_metric("Custom/sync_event_failed", 1)

View File

@ -118,9 +118,7 @@ WORDS_FILE_PATH=local_data/test_words.txt
# Login with OIDC # Login with OIDC
# CONNECT_WITH_OIDC_ICON=fa-github # CONNECT_WITH_OIDC_ICON=fa-github
# OIDC_AUTHORIZATION_URL=to_fill # OIDC_WELL_KNOWN_URL=to_fill
# OIDC_USER_INFO_URL=to_fill
# OIDC_TOKEN_URL=to_fill
# OIDC_SCOPES=openid email profile # OIDC_SCOPES=openid email profile
# OIDC_NAME_FIELD=name # OIDC_NAME_FIELD=name
# OIDC_CLIENT_ID=to_fill # OIDC_CLIENT_ID=to_fill

View File

@ -0,0 +1,29 @@
"""empty message
Revision ID: fa2f19bb4e5a
Revises: 52510a633d6f
Create Date: 2024-04-09 13:12:26.305340
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = 'fa2f19bb4e5a'
down_revision = '52510a633d6f'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.add_column('users', sa.Column('enable_data_breach_check', sa.Boolean(), server_default='0', nullable=False))
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_column('users', 'enable_data_breach_check')
# ### end Alembic commands ###

View File

@ -0,0 +1,38 @@
"""Create sync_event table
Revision ID: 06a9a7133445
Revises: fa2f19bb4e5a
Create Date: 2024-05-17 13:11:20.402259
"""
import sqlalchemy_utils
from alembic import op
import sqlalchemy as sa
# revision identifiers, used by Alembic.
revision = '06a9a7133445'
down_revision = 'fa2f19bb4e5a'
branch_labels = None
depends_on = None
def upgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.create_table('sync_event',
sa.Column('id', sa.Integer(), autoincrement=True, nullable=False),
sa.Column('created_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=False),
sa.Column('updated_at', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.Column('content', sa.LargeBinary(), nullable=False),
sa.Column('taken_time', sqlalchemy_utils.types.arrow.ArrowType(), nullable=True),
sa.PrimaryKeyConstraint('id')
)
op.create_index(op.f('ix_sync_event_created_at'), 'sync_event', ['created_at'], unique=False)
op.create_index(op.f('ix_sync_event_taken_time'), 'sync_event', ['taken_time'], unique=False)
# ### end Alembic commands ###
def downgrade():
# ### commands auto generated by Alembic - please adjust! ###
op.drop_table('sync_event')
# ### end Alembic commands ###

View File

@ -4,6 +4,7 @@ import subprocess
from time import sleep from time import sleep
from typing import List, Dict from typing import List, Dict
import arrow
import newrelic.agent import newrelic.agent
from app.db import Session from app.db import Session
@ -93,11 +94,44 @@ def log_nb_db_connection():
newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection) newrelic.agent.record_custom_metric("Custom/nb_db_connections", nb_connection)
@newrelic.agent.background_task()
def log_pending_to_process_events():
r = Session.execute("select count(*) from sync_event WHERE taken_time IS NULL;")
events_pending = list(r)[0][0]
LOG.d("number of events pending to process %s", events_pending)
newrelic.agent.record_custom_metric(
"Custom/sync_events_pending_to_process", events_pending
)
@newrelic.agent.background_task()
def log_events_pending_dead_letter():
since = arrow.now().shift(minutes=-10).datetime
r = Session.execute(
"""
SELECT COUNT(*)
FROM sync_event
WHERE (taken_time IS NOT NULL AND taken_time < :since)
OR (taken_time IS NULL AND created_at < :since)
""",
{"since": since},
)
events_pending = list(r)[0][0]
LOG.d("number of events pending dead letter %s", events_pending)
newrelic.agent.record_custom_metric(
"Custom/sync_events_pending_dead_letter", events_pending
)
if __name__ == "__main__": if __name__ == "__main__":
exporter = MetricExporter(get_newrelic_license()) exporter = MetricExporter(get_newrelic_license())
while True: while True:
log_postfix_metrics() log_postfix_metrics()
log_nb_db_connection() log_nb_db_connection()
log_pending_to_process_events()
log_events_pending_dead_letter()
Session.close() Session.close()
exporter.run() exporter.run()

View File

@ -0,0 +1,37 @@
#!/usr/bin/env python3
import argparse
import random
import time
from sqlalchemy import func
from app import config
from app.models import Alias, Contact
from app.db import Session
parser = argparse.ArgumentParser(
prog=f"Replace {config.NOREPLY}",
description=f"Replace {config.NOREPLY} from contacts reply email",
)
args = parser.parse_args()
max_alias_id: int = Session.query(func.max(Alias.id)).scalar()
start = time.time()
tests = 1000
for i in range(tests):
alias = (
Alias.filter(Alias.id > int(random.random() * max_alias_id))
.order_by(Alias.id.asc())
.limit(1)
.first()
)
contact = Contact.filter_by(alias_id=alias.id).order_by(Contact.id.asc()).first()
mailboxes = alias.mailboxes
user = alias.user
if i % 10:
print("{i} -> {alias.id}")
end = time.time()
time_taken = end - start
print(f"Took {time_taken} -> {time_taken/tests} per test")

View File

@ -1,29 +1,56 @@
#!/usr/bin/env python3 #!/usr/bin/env python3
import argparse import argparse
import time
from sqlalchemy import func
from app.log import LOG
from app.models import Alias, SLDomain from app.models import Alias, SLDomain
from app.db import Session from app.db import Session
parser = argparse.ArgumentParser( parser = argparse.ArgumentParser(
prog="Mark partner created aliases with the PARTNER_CREATED flag", prog="Mark partner created aliases with the PARTNER_CREATED flag",
) )
parser.add_argument(
"-s", "--start_alias_id", default=0, type=int, help="Initial alias_id"
)
parser.add_argument("-e", "--end_alias_id", default=0, type=int, help="Last alias_id")
args = parser.parse_args() args = parser.parse_args()
alias_id_start = args.start_alias_id
max_alias_id = args.end_alias_id
if max_alias_id == 0:
max_alias_id = Session.query(func.max(Alias.id)).scalar()
print(f"Updating aliases from {alias_id_start} to {max_alias_id}")
domains = SLDomain.filter(SLDomain.partner_id.isnot(None)).all() domains = SLDomain.filter(SLDomain.partner_id.isnot(None)).all()
cond = [f"email like '%{domain.domain}'" for domain in domains]
sql_or_cond = " OR ".join(cond)
sql = f"UPDATE alias set flags = (flags | :flag) WHERE id >= :start and id<:end and flags & :flag = 0 and ({sql_or_cond})"
print(sql)
for domain in domains: step = 1000
LOG.i(f"Checking aliases for domain {domain.domain}") updated = 0
for alias in ( start_time = time.time()
Alias.filter( for batch_start in range(alias_id_start, max_alias_id, step):
Alias.email.like(f"%{domain.domain}"), updated += Session.execute(
Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0, sql,
) {
.enable_eagerloads(False) "start": batch_start,
.yield_per(100) "end": batch_start + step,
.all() "flag": Alias.FLAG_PARTNER_CREATED,
): },
alias.flags = alias.flags | Alias.FLAG_PARTNER_CREATED ).rowcount
LOG.i(f" * Updating {alias.email} to {alias.flags}") elapsed = time.time() - start_time
Session.commit() time_per_alias = elapsed / (batch_start - alias_id_start + step)
last_batch_id = batch_start + step
remaining = max_alias_id - last_batch_id
time_remaining = (max_alias_id - last_batch_id) * time_per_alias
hours_remaining = time_remaining / 3600.0
percent = int(
((batch_start - alias_id_start) * 100) / (max_alias_id - alias_id_start)
)
print(
f"\rAlias {batch_start}/{max_alias_id} {percent}% {updated} updated {hours_remaining:.2f}hrs remaining"
)
print(f"Updated aliases up to {max_alias_id}")

45
proto/event.proto Normal file
View File

@ -0,0 +1,45 @@
syntax = "proto3";
package simplelogin_events;
message UserPlanChange {
uint32 plan_end_time = 1;
}
message UserDeleted {
}
message AliasCreated {
uint32 alias_id = 1;
string alias_email = 2;
string alias_note = 3;
bool enabled = 4;
}
message AliasStatusChange {
uint32 alias_id = 1;
string alias_email = 2;
bool enabled = 3;
}
message AliasDeleted {
uint32 alias_id = 1;
string alias_email = 2;
}
message EventContent {
oneof content {
UserPlanChange user_plan_change = 1;
UserDeleted user_deleted = 2;
AliasCreated alias_created = 3;
AliasStatusChange alias_status_change = 4;
AliasDeleted alias_deleted = 5;
}
}
message Event {
uint32 user_id = 1;
string external_user_id = 2;
uint32 partner_id = 3;
EventContent content = 4;
}

View File

@ -14,13 +14,14 @@ exclude = '''
| build | build
| dist | dist
| migrations # migrations/ is generated by alembic | migrations # migrations/ is generated by alembic
| app/events/generated
)/ )/
) )
''' '''
[tool.ruff] [tool.ruff]
ignore-init-module-imports = true ignore-init-module-imports = true
exclude = [".venv", "migrations"] exclude = [".venv", "migrations", "app/events/generated"]
[tool.djlint] [tool.djlint]
indent = 2 indent = 2

24
scripts/generate-proto-files.sh Executable file
View File

@ -0,0 +1,24 @@
#!/bin/bash
set -euxo pipefail
SCRIPT_DIR="$(cd "$(dirname "$0")" || exit 1; pwd -P)"
REPO_ROOT=$(echo "${SCRIPT_DIR}" | sed 's:scripts::g')
DEST_DIR="${REPO_ROOT}/app/events/generated"
PROTOC=${PROTOC:-"protoc"}
if ! eval "${PROTOC} --version" &> /dev/null ; then
echo "Cannot find $PROTOC"
exit 1
fi
rm -rf "${DEST_DIR}"
mkdir -p "${DEST_DIR}"
pushd $REPO_ROOT || exit 1
eval "${PROTOC} --proto_path=proto --python_out=\"${DEST_DIR}\" --pyi_out=\"${DEST_DIR}\" proto/event.proto"
popd || exit 1

0
tasks/__init__.py Normal file
View File

View File

@ -0,0 +1,19 @@
import arrow
from app import s3
from app.log import LOG
from app.models import BatchImport
def cleanup_old_imports(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting imports older than {oldest_allowed}")
for batch_import in (
BatchImport.filter(BatchImport.created_at < oldest_allowed).yield_per(500).all()
):
LOG.i(
f"Deleting batch import {batch_import} with file {batch_import.file.path}"
)
file = batch_import.file
if file is not None:
s3.delete(file.path)
BatchImport.delete(batch_import.id, commit=True)

24
tasks/cleanup_old_jobs.py Normal file
View File

@ -0,0 +1,24 @@
import arrow
from sqlalchemy import or_, and_
from app import config
from app.db import Session
from app.log import LOG
from app.models import Job, JobState
def cleanup_old_jobs(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting jobs older than {oldest_allowed}")
count = Job.filter(
or_(
Job.state == JobState.done.value,
Job.state == JobState.error.value,
and_(
Job.state == JobState.taken.value,
Job.attempts >= config.JOB_MAX_ATTEMPTS,
),
),
Job.updated_at < oldest_allowed,
).delete()
Session.commit()
LOG.i(f"Deleted {count} jobs")

View File

@ -0,0 +1,12 @@
import arrow
from app.db import Session
from app.log import LOG
from app.models import Notification
def cleanup_old_notifications(oldest_allowed: arrow.Arrow):
LOG.i(f"Deleting notifications older than {oldest_allowed}")
count = Notification.filter(Notification.created_at < oldest_allowed).delete()
Session.commit()
LOG.i(f"Deleted {count} notifications")

View File

@ -120,21 +120,6 @@
</div> </div>
</div> </div>
<!-- END WebAuthn --> <!-- END WebAuthn -->
<!-- Alias import/export -->
<div class="card">
<div class="card-body">
<div class="card-title">Alias import/export</div>
<div class="mb-3">
You can import your aliases created on other platforms into SimpleLogin.
You can also export your aliases to a readable csv format for a future batch import.
</div>
<a href="{{ url_for('dashboard.batch_import_route') }}"
class="btn btn-outline-primary">Batch Import</a>
<a href="{{ url_for('dashboard.alias_export_route') }}"
class="btn btn-outline-secondary">Export Aliases</a>
</div>
</div>
<!-- END Alias import/export -->
<!-- data export --> <!-- data export -->
<div class="card"> <div class="card">
<div class="card-body"> <div class="card-body">

View File

@ -249,6 +249,42 @@
</div> </div>
</div> </div>
<!-- END Random alias --> <!-- END Random alias -->
<!-- Data breach check -->
<div class="card" id="data-breach">
<div class="card-body">
<div class="card-title">Data breach monitoring</div>
<div class="mt-1 mb-3">
{% if not current_user.is_premium() %}
<div class="alert alert-info" role="alert">
This feature is only available on Premium plan.
<a href="{{ url_for('dashboard.pricing') }}"
target="_blank"
rel="noopener noreferrer">
Upgrade<i class="fe fe-external-link"></i>
</a>
</div>
{% endif %}
If enabled, we will inform you via email if one of your aliases appears in a data breach.
<br>
SimpleLogin uses <a href="https://haveibeenpwned.com/">HaveIBeenPwned</a> API for checking for data breaches.
</div>
<form method="post" action="#data-breach">
{{ csrf_form.csrf_token }}
<input type="hidden" name="form-name" value="enable_data_breach_check">
<div class="form-check">
<input type="checkbox"
id="enable_data_breach_check"
name="enable_data_breach_check"
{% if current_user.enable_data_breach_check %} checked{% endif %}
class="form-check-input">
<label for="enable_data_breach_check">Enable data breach monitoring</label>
</div>
<button type="submit" class="btn btn-outline-primary">Update</button>
</form>
</div>
</div>
<!-- END Data breach check -->
<!-- Sender Format --> <!-- Sender Format -->
<div class="card" id="sender-format"> <div class="card" id="sender-format">
<div class="card-body"> <div class="card-body">
@ -285,7 +321,9 @@
No Name (i.e. only reverse-alias) No Name (i.e. only reverse-alias)
</option> </option>
</select> </select>
<button class="btn btn-outline-primary mt-3">Update</button> <button class="btn btn-outline-primary mt-3">
Update
</button>
</form> </form>
</div> </div>
</div> </div>
@ -295,7 +333,9 @@
<div class="card-body"> <div class="card-body">
<div class="card-title"> <div class="card-title">
Reverse Alias Replacement Reverse Alias Replacement
<div class="badge badge-warning">Experimental</div> <div class="badge badge-warning">
Experimental
</div>
</div> </div>
<div class="mb-3"> <div class="mb-3">
When replying to a forwarded email, the <b>reverse-alias</b> can be automatically included When replying to a forwarded email, the <b>reverse-alias</b> can be automatically included
@ -312,9 +352,13 @@
name="replace-ra" name="replace-ra"
{% if current_user.replace_reverse_alias %} checked{% endif %} {% if current_user.replace_reverse_alias %} checked{% endif %}
class="form-check-input"> class="form-check-input">
<label for="replace-ra">Enable replacing reverse alias</label> <label for="replace-ra">
Enable replacing reverse alias
</label>
</div> </div>
<button type="submit" class="btn btn-outline-primary">Update</button> <button type="submit" class="btn btn-outline-primary">
Update
</button>
</form> </form>
</div> </div>
</div> </div>
@ -559,7 +603,7 @@
sender address. sender address.
<br /> <br />
If this option is enabled, the original sender addresses is stored in the email header <b>X-SimpleLogin-Envelope-From</b> If this option is enabled, the original sender addresses is stored in the email header <b>X-SimpleLogin-Envelope-From</b>
and the original From header is stored in <b>X-SimpleLogin-Original-From<b>. and the original From header is stored in <b>X-SimpleLogin-Original-From</b>.
You can choose to display this header in your email client. You can choose to display this header in your email client.
<br /> <br />
As email headers aren't encrypted, your mailbox service can know the sender address via this header. As email headers aren't encrypted, your mailbox service can know the sender address via this header.
@ -583,6 +627,23 @@
</form> </form>
</div> </div>
</div> </div>
<!-- Alias import/export -->
<div class="card">
<div class="card-body">
<div class="card-title">
Alias import/export
</div>
<div class="mb-3">
You can import your aliases created on other platforms into SimpleLogin.
You can also export your aliases to a readable csv format for a future batch import.
</div>
<a href="{{ url_for('dashboard.batch_import_route') }}"
class="btn btn-outline-primary">Batch Import</a>
<a href="{{ url_for('dashboard.alias_export_route') }}"
class="btn btn-outline-secondary">Export Aliases</a>
</div>
</div>
<!-- END Alias import/export -->
</div> </div>
{% endblock %} {% endblock %}
{% block script %} {% block script %}

View File

@ -48,15 +48,16 @@
{# SIWSL#} {# SIWSL#}
{# </a>#} {# </a>#}
{# </li>#} {# </li>#}
{# {% if current_user.should_show_app_page() %}#} {% if current_user.should_show_app_page() %}
{# <li class="nav-item">#}
{# <a href="{{ url_for('dashboard.app_route') }}"#} <li class="nav-item">
{# class="nav-link {{ 'active' if active_page == 'app' }}">#} <a href="{{ url_for('dashboard.app_route') }}"
{# <i class="fe fe-grid"></i>#} class="nav-link {{ 'active' if active_page == 'app' }}">
{# Apps#} <i class="fe fe-grid"></i>
{# </a>#} Apps
{# </li>#} </a>
{# {% endif %}#} </li>
{% endif %}
<li class="nav-item"> <li class="nav-item">
<a href="{{ url_for('dashboard.setting') }}" <a href="{{ url_for('dashboard.setting') }}"
class="nav-link {{ 'active' if active_page == 'setting' }}"> class="nav-link {{ 'active' if active_page == 'setting' }}">

View File

@ -1,5 +1,5 @@
from app import config from app import config
from flask import url_for from flask import url_for, session
from urllib.parse import parse_qs from urllib.parse import parse_qs
from urllib3.util import parse_url from urllib3.util import parse_url
from app.auth.views.oidc import create_user from app.auth.views.oidc import create_user
@ -10,7 +10,21 @@ from app.models import User
from app.config import URL, OIDC_CLIENT_ID from app.config import URL, OIDC_CLIENT_ID
def test_oidc_login(flask_client): mock_well_known_response = {
"authorization_endpoint": "http://localhost:7777/authorization-endpoint",
"userinfo_endpoint": "http://localhost:7777/userinfo-endpoint",
"token_endpoint": "http://localhost:7777/token-endpoint",
}
@patch("requests.get")
def test_oidc_login(mock_get, flask_client):
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_login"), url_for("auth.oidc_login"),
follow_redirects=False, follow_redirects=False,
@ -28,8 +42,41 @@ def test_oidc_login(flask_client):
assert expected_redirect_url == query["redirect_uri"][0] assert expected_redirect_url == query["redirect_uri"][0]
def test_oidc_login_no_client_id(flask_client): @patch("requests.get")
def test_oidc_login_next_url(mock_get, flask_client):
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
mock_get.return_value.json.return_value = mock_well_known_response
with flask_client:
r = flask_client.get(
url_for("auth.oidc_login", next="/dashboard/settings/"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
query = parse_qs(parsed.query)
expected_redirect_url = f"{URL}/auth/oidc/callback"
assert "code" == query["response_type"][0]
assert OIDC_CLIENT_ID == query["client_id"][0]
assert expected_redirect_url == query["redirect_uri"][0]
assert session["oauth_redirect_next"] == "/dashboard/settings/"
@patch("requests.get")
def test_oidc_login_no_client_id(mock_get, flask_client):
config.OIDC_CLIENT_ID = None config.OIDC_CLIENT_ID = None
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_login"), url_for("auth.oidc_login"),
@ -47,8 +94,14 @@ def test_oidc_login_no_client_id(flask_client):
config.OIDC_CLIENT_ID = "to_fill" config.OIDC_CLIENT_ID = "to_fill"
def test_oidc_login_no_client_secret(flask_client): @patch("requests.get")
def test_oidc_login_no_client_secret(mock_get, flask_client):
config.OIDC_CLIENT_SECRET = None config.OIDC_CLIENT_SECRET = None
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_login"), url_for("auth.oidc_login"),
@ -66,9 +119,14 @@ def test_oidc_login_no_client_secret(flask_client):
config.OIDC_CLIENT_SECRET = "to_fill" config.OIDC_CLIENT_SECRET = "to_fill"
def test_oidc_callback_no_oauth_state(flask_client): @patch("requests.get")
with flask_client.session_transaction() as session: def test_oidc_callback_no_oauth_state(mock_get, flask_client):
session["oauth_state"] = None config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
@ -78,11 +136,16 @@ def test_oidc_callback_no_oauth_state(flask_client):
assert location is None assert location is None
def test_oidc_callback_no_client_id(flask_client): @patch("requests.get")
with flask_client.session_transaction() as session: def test_oidc_callback_no_client_id(mock_get, flask_client):
session["oauth_state"] = "state" config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
config.OIDC_CLIENT_ID = None config.OIDC_CLIENT_ID = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
follow_redirects=False, follow_redirects=False,
@ -97,15 +160,20 @@ def test_oidc_callback_no_client_id(flask_client):
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_ID = "to_fill" config.OIDC_CLIENT_ID = "to_fill"
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
def test_oidc_callback_no_client_secret(flask_client): @patch("requests.get")
with flask_client.session_transaction() as session: def test_oidc_callback_no_client_secret(mock_get, flask_client):
session["oauth_state"] = "state" config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
config.OIDC_CLIENT_SECRET = None config.OIDC_CLIENT_SECRET = None
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
follow_redirects=False, follow_redirects=False,
@ -120,16 +188,23 @@ def test_oidc_callback_no_client_secret(flask_client):
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
config.OIDC_CLIENT_SECRET = "to_fill" config.OIDC_CLIENT_SECRET = "to_fill"
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token") @patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get") @patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client): def test_oidc_callback_invalid_user(
mock_get.return_value = MockResponse(400, {}) mock_oauth_get, mock_fetch_token, mock_get, flask_client
with flask_client.session_transaction() as session: ):
session["oauth_state"] = "state" mock_oauth_get.return_value = MockResponse(400, {})
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
@ -143,18 +218,25 @@ def test_oidc_callback_invalid_user(mock_get, mock_fetch_token, flask_client):
expected_redirect_url = "/auth/login" expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
assert mock_get.called assert mock_oauth_get.called
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token") @patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get") @patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client): def test_oidc_callback_no_email(
mock_get.return_value = MockResponse(200, {}) mock_oauth_get, mock_fetch_token, mock_get, flask_client
with flask_client.session_transaction() as session: ):
session["oauth_state"] = "state" mock_oauth_get.return_value = MockResponse(200, {})
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
@ -168,20 +250,27 @@ def test_oidc_callback_no_email(mock_get, mock_fetch_token, flask_client):
expected_redirect_url = "/auth/login" expected_redirect_url = "/auth/login"
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
assert mock_get.called assert mock_oauth_get.called
with flask_client.session_transaction() as session: with flask_client.session_transaction() as session:
session["oauth_state"] = None session["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token") @patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get") @patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_client): def test_oidc_callback_disabled_registration(
mock_oauth_get, mock_fetch_token, mock_get, flask_client
):
config.DISABLE_REGISTRATION = True config.DISABLE_REGISTRATION = True
email = random_string() email = random_string()
mock_get.return_value = MockResponse(200, {"email": email}) mock_oauth_get.return_value = MockResponse(200, {"email": email})
with flask_client.session_transaction() as session: config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
session["oauth_state"] = "state" with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
r = flask_client.get( r = flask_client.get(
url_for("auth.oidc_callback"), url_for("auth.oidc_callback"),
@ -195,26 +284,33 @@ def test_oidc_callback_disabled_registration(mock_get, mock_fetch_token, flask_c
expected_redirect_url = "/auth/register" expected_redirect_url = "/auth/register"
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
assert mock_get.called assert mock_oauth_get.called
config.DISABLE_REGISTRATION = False config.DISABLE_REGISTRATION = False
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token") @patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get") @patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client): def test_oidc_callback_registration(
mock_oauth_get, mock_fetch_token, mock_get, flask_client
):
email = random_string() email = random_string()
mock_get.return_value = MockResponse( mock_oauth_get.return_value = MockResponse(
200, 200,
{ {
"email": email, "email": email,
config.OIDC_NAME_FIELD: "name", config.OIDC_NAME_FIELD: "name",
}, },
) )
with flask_client.session_transaction() as session: config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
session["oauth_state"] = "state" with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
user = User.get_by(email=email) user = User.get_by(email=email)
assert user is None assert user is None
@ -231,28 +327,33 @@ def test_oidc_callback_registration(mock_get, mock_fetch_token, flask_client):
expected_redirect_url = "/dashboard/" expected_redirect_url = "/dashboard/"
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
assert mock_get.called assert mock_oauth_get.called
user = User.get_by(email=email) user = User.get_by(email=email)
assert user is not None assert user is not None
assert user.email == email assert user.email == email
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token") @patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get") @patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client): def test_oidc_callback_login(mock_oauth_get, mock_fetch_token, mock_get, flask_client):
email = random_string() email = random_string()
mock_get.return_value = MockResponse( mock_oauth_get.return_value = MockResponse(
200, 200,
{ {
"email": email, "email": email,
}, },
) )
with flask_client.session_transaction() as session: config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
session["oauth_state"] = "state" with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = None
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
user = User.create( user = User.create(
email=email, email=email,
@ -275,10 +376,57 @@ def test_oidc_callback_login(mock_get, mock_fetch_token, flask_client):
expected_redirect_url = "/dashboard/" expected_redirect_url = "/dashboard/"
assert expected_redirect_url == parsed.path assert expected_redirect_url == parsed.path
assert mock_get.called assert mock_oauth_get.called
with flask_client.session_transaction() as session: with flask_client.session_transaction() as sess:
session["oauth_state"] = None sess["oauth_state"] = None
@patch("requests.get")
@patch("requests_oauthlib.OAuth2Session.fetch_token")
@patch("requests_oauthlib.OAuth2Session.get")
def test_oidc_callback_login_with_next_url(
mock_oauth_get, mock_fetch_token, mock_get, flask_client
):
email = random_string()
mock_oauth_get.return_value = MockResponse(
200,
{
"email": email,
},
)
config.OIDC_WELL_KNOWN_URL = "http://localhost:7777/well-known-url"
with flask_client.session_transaction() as sess:
sess["oauth_redirect_next"] = "/dashboard/settings/"
sess["oauth_state"] = "state"
mock_get.return_value.json.return_value = mock_well_known_response
user = User.create(
email=email,
name="name",
password="",
activated=True,
)
user = User.get_by(email=email)
assert user is not None
r = flask_client.get(
url_for("auth.oidc_callback"),
follow_redirects=False,
)
location = r.headers.get("Location")
assert location is not None
parsed = parse_url(location)
expected_redirect_url = "/dashboard/settings/"
assert expected_redirect_url == parsed.path
assert mock_oauth_get.called
with flask_client.session_transaction() as sess:
sess["oauth_state"] = None
def test_create_user(): def test_create_user():

0
tests/cron/__init__.py Normal file
View File

View File

@ -0,0 +1,168 @@
import arrow
import pytest
import cron
from app.db import Session
from app.models import (
Alias,
AppleSubscription,
PlanEnum,
CoinbaseSubscription,
ManualSubscription,
Subscription,
PartnerUser,
PartnerSubscription,
User,
)
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
def test_get_alias_for_free_user_has_no_alias():
user = create_new_user()
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_get_alias_for_lifetime_with_null_hibp_date():
user = create_new_user()
user.lifetime = True
user.enable_data_breach_check = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def test_get_alias_for_lifetime_with_old_hibp_date():
user = create_new_user()
user.lifetime = True
user.enable_data_breach_check = True
alias = Alias.create_new_random(user)
alias.hibp_last_check = arrow.now().shift(days=-1)
alias_id = alias.id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def create_partner_sub(user: User):
pu = PartnerUser.create(
partner_id=get_proton_partner().id,
partner_email=user.email,
external_user_id=random_token(10),
user_id=user.id,
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(days=15)
)
sub_generator_list = [
lambda u: AppleSubscription.create(
user_id=u.id,
expires_date=arrow.now().shift(days=15),
original_transaction_id=random_token(10),
receipt_data=random_token(10),
plan=PlanEnum.monthly,
),
lambda u: CoinbaseSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: ManualSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: Subscription.create(
user_id=u.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=arrow.now().shift(days=15).date(),
plan=PlanEnum.monthly,
),
create_partner_sub,
]
@pytest.mark.parametrize("sub_generator", sub_generator_list)
def test_get_alias_for_sub(sub_generator):
user = create_new_user()
user.enable_data_breach_check = True
sub_generator(user)
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def test_disabled_user_is_not_checked():
user = create_new_user()
user.lifetime = True
user.disabled = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_skipped_user_is_not_checked():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_already_checked_is_not_checked():
user = create_new_user()
user.lifetime = True
alias = Alias.create_new_random(user)
alias.hibp_last_check = arrow.now().shift(days=1)
alias_id = alias.id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_outed_in_user_is_checked():
user = create_new_user()
user.lifetime = True
user.enable_data_breach_check = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 1
def test_outed_out_user_is_not_checked():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0

0
tests/events/__init__.py Normal file
View File

View File

@ -0,0 +1,56 @@
from app.events.event_dispatcher import EventDispatcher, Dispatcher
from app.events.generated.event_pb2 import EventContent, UserDeleted
from app.models import PartnerUser, User
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
from typing import Tuple
class OnMemoryDispatcher(Dispatcher):
def __init__(self):
self.memory = []
def send(self, event: bytes):
self.memory.append(event)
def _create_unlinked_user() -> User:
return create_new_user()
def _create_linked_user() -> Tuple[User, PartnerUser]:
user = _create_unlinked_user()
partner_user = PartnerUser.create(
partner_id=get_proton_partner().id,
user_id=user.id,
external_user_id=random_token(10),
flush=True,
)
return user, partner_user
def test_event_dispatcher_stores_events():
dispatcher = OnMemoryDispatcher()
(user, partner) = _create_linked_user()
content = EventContent(user_deleted=UserDeleted())
EventDispatcher.send_event(user, content, dispatcher, skip_if_webhook_missing=False)
assert len(dispatcher.memory) == 1
content = EventContent(user_deleted=UserDeleted())
EventDispatcher.send_event(user, content, dispatcher, skip_if_webhook_missing=False)
assert len(dispatcher.memory) == 2
def test_event_dispatcher_does_not_send_event_if_user_not_linked():
dispatcher = OnMemoryDispatcher()
user = _create_unlinked_user()
content = EventContent(user_deleted=UserDeleted())
EventDispatcher.send_event(user, content, dispatcher, skip_if_webhook_missing=False)
assert len(dispatcher.memory) == 0
content = EventContent(user_deleted=UserDeleted())
EventDispatcher.send_event(user, content, dispatcher, skip_if_webhook_missing=False)
assert len(dispatcher.memory) == 0

0
tests/tasks/__init__.py Normal file
View File

View File

@ -0,0 +1,35 @@
import tempfile
from io import BytesIO
import arrow
from app import s3, config
from app.models import File, BatchImport
from tasks.cleanup_old_imports import cleanup_old_imports
from tests.utils import random_token, create_new_user
def test_cleanup_old_imports():
BatchImport.filter().delete()
with tempfile.TemporaryDirectory() as tmpdir:
config.UPLOAD_DIR = tmpdir
user = create_new_user()
path = random_token()
s3.upload_from_bytesio(path, BytesIO("data".encode("utf-8")))
file = File.create(path=path, commit=True) # noqa: F821
now = arrow.now()
delete_batch_import_id = BatchImport.create(
user_id=user.id,
file_id=file.id,
created_at=now.shift(minutes=-1),
flush=True,
).id
keep_batch_import_id = BatchImport.create(
user_id=user.id,
file_id=file.id,
created_at=now.shift(minutes=+1),
commit=True,
).id
cleanup_old_imports(now)
assert BatchImport.get(id=delete_batch_import_id) is None
assert BatchImport.get(id=keep_batch_import_id) is not None

View File

@ -0,0 +1,72 @@
import arrow
from app import config
from app.models import Job, JobState
from tasks.cleanup_old_jobs import cleanup_old_jobs
def test_cleanup_old_jobs():
Job.filter().delete()
now = arrow.now()
delete_ids = [
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.done.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.error.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS,
name="",
payload="",
flush=True,
).id,
]
keep_ids = [
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.done.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.error.value,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=+1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS,
name="",
payload="",
flush=True,
).id,
Job.create(
updated_at=now.shift(minutes=-1),
state=JobState.taken.value,
attempts=config.JOB_MAX_ATTEMPTS - 1,
name="",
payload="",
flush=True,
).id,
]
cleanup_old_jobs(now)
for delete_id in delete_ids:
assert Job.get(id=delete_id) is None
for keep_id in keep_ids:
assert Job.get(id=keep_id) is not None

View File

@ -0,0 +1,26 @@
import arrow
from app.models import Notification
from tasks.cleanup_old_notifications import cleanup_old_notifications
from tests.utils import create_new_user
def test_cleanup_old_notifications():
Notification.filter().delete()
user = create_new_user()
now = arrow.now()
delete_id = Notification.create(
user_id=user.id,
created_at=now.shift(minutes=-1),
message="",
flush=True,
).id
keep_id = Notification.create(
user_id=user.id,
created_at=now.shift(minutes=+1),
message="",
flush=True,
).id
cleanup_old_notifications(now)
assert Notification.get(id=delete_id) is None
assert Notification.get(id=keep_id) is not None

View File

@ -51,9 +51,7 @@ FACEBOOK_CLIENT_SECRET=to_fill
# Login with OIDC # Login with OIDC
CONNECT_WITH_OIDC_ICON=fa-github CONNECT_WITH_OIDC_ICON=fa-github
OIDC_AUTHORIZATION_URL=to_fill OIDC_WELL_KNOWN_URL=to_fill
OIDC_USER_INFO_URL=to_fill
OIDC_TOKEN_URL=to_fill
OIDC_SCOPES=openid email profile OIDC_SCOPES=openid email profile
OIDC_NAME_FIELD=name OIDC_NAME_FIELD=name
OIDC_CLIENT_ID=to_fill OIDC_CLIENT_ID=to_fill

View File

@ -384,3 +384,30 @@ def test_break_loop_alias_as_mailbox(flask_client):
msg[headers.SUBJECT] = random_string() msg[headers.SUBJECT] = random_string()
result = email_handler.handle(envelope, msg) result = email_handler.handle(envelope, msg)
assert result == status.E525 assert result == status.E525
@mail_sender.store_emails_test_decorator
def test_preserve_headers(flask_client):
headers_to_keep = [
headers.SUBJECT,
headers.DATE,
headers.MESSAGE_ID,
headers.REFERENCES,
headers.IN_REPLY_TO,
headers.SL_QUEUE_ID,
] + headers.MIME_HEADERS
user = create_new_user()
alias = Alias.create_new_random(user)
envelope = Envelope()
envelope.mail_from = "somewhere@lo.cal"
envelope.rcpt_tos = [alias.email]
msg = EmailMessage()
for header in headers_to_keep:
msg[header] = header + "keep"
result = email_handler.handle(envelope, msg)
assert result == status.E200
sent_mails = mail_sender.get_stored_emails()
assert len(sent_mails) == 1
msg = sent_mails[0].msg
for header in headers_to_keep:
assert msg[header] == header + "keep"

View File

@ -17,6 +17,7 @@ from app.models import (
Subscription, Subscription,
PlanEnum, PlanEnum,
PADDLE_SUBSCRIPTION_GRACE_DAYS, PADDLE_SUBSCRIPTION_GRACE_DAYS,
SyncEvent,
) )
from tests.utils import login, create_new_user, random_token from tests.utils import login, create_new_user, random_token
@ -325,3 +326,51 @@ def test_user_can_send_receive():
user.disabled = False user.disabled = False
user.delete_on = arrow.now() user.delete_on = arrow.now()
assert not user.can_send_or_receive() assert not user.can_send_or_receive()
def test_sync_event_dead_letter():
# remove all SyncEvents before the test
all_events = SyncEvent.all()
for event in all_events:
SyncEvent.delete(event.id, commit=True)
# create an expired not taken event
e1 = SyncEvent.create(
content=b"content",
created_at=arrow.now().shift(minutes=-15),
taken_time=None,
commit=True,
)
# create an expired taken event (but too long ago)
e2 = SyncEvent.create(
content=b"content",
created_at=arrow.now().shift(minutes=-15),
taken_time=arrow.now().shift(minutes=-14),
commit=True,
)
# create an expired taken event (but recently)
e3 = SyncEvent.create(
content=b"content",
created_at=arrow.now().shift(minutes=-15),
taken_time=arrow.now().shift(minutes=-1),
commit=True,
)
# create a normal event
e4 = SyncEvent.create(
content=b"content",
created_at=arrow.now(),
commit=True,
)
# get dead letter events
dead_letter_events = SyncEvent.get_dead_letter(
older_than=arrow.now().shift(minutes=-10)
)
assert len(dead_letter_events) == 2
assert e1 in dead_letter_events
assert e2 in dead_letter_events
assert e3 not in dead_letter_events
assert e4 not in dead_letter_events