Compare commits

..

No commits in common. "master" and "v4.7.3" have entirely different histories.

470 changed files with 704493 additions and 31491 deletions

View File

@ -1,5 +1,5 @@
{ {
"template": "${{CHANGELOG}}\n\n<details>\n<summary>Uncategorized</summary>\n\n${{UNCATEGORIZED}}\n</details>", "template": "${{CHANGELOG}}",
"pr_template": "- ${{TITLE}} #${{NUMBER}}", "pr_template": "- ${{TITLE}} #${{NUMBER}}",
"empty_template": "- no changes", "empty_template": "- no changes",
"categories": [ "categories": [
@ -18,6 +18,6 @@
], ],
"ignore_labels": ["ignore"], "ignore_labels": ["ignore"],
"tag_resolver": { "tag_resolver": {
"method": "semver" "method": "sort"
} }
} }

View File

@ -1,43 +1,15 @@
name: Test and lint name: Run tests & Publish to Docker Registry
on: [push, pull_request] on:
push:
jobs: jobs:
lint:
runs-on: ubuntu-latest
steps:
- name: Check out repo
uses: actions/checkout@v3
- name: Install poetry
run: pipx install poetry
- uses: actions/setup-python@v4
with:
python-version: '3.10'
cache: 'poetry'
- name: Install OS dependencies
if: ${{ matrix.python-version }} == '3.10'
run: |
sudo apt update
sudo apt install -y libre2-dev libpq-dev
- name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction
- name: Check formatting & linting
run: |
poetry run pre-commit run --all-files
test: test:
runs-on: ubuntu-latest runs-on: ubuntu-latest
strategy: strategy:
max-parallel: 4 max-parallel: 4
matrix: matrix:
python-version: ["3.10"] python-version: [3.7, "3.10"]
# service containers to run with `postgres-job` # service containers to run with `postgres-job`
services: services:
@ -66,16 +38,27 @@ jobs:
--health-retries 5 --health-retries 5
steps: steps:
- name: Check out repo - name: Check out repository
uses: actions/checkout@v3 uses: actions/checkout@v2
- name: Install poetry - name: Set up Python ${{ matrix.python-version }}
run: pipx install poetry uses: actions/setup-python@v2
- uses: actions/setup-python@v4
with: with:
python-version: ${{ matrix.python-version }} python-version: ${{ matrix.python-version }}
cache: 'poetry'
- name: Install poetry
uses: snok/install-poetry@v1
with:
virtualenvs-create: true
virtualenvs-in-project: true
installer-parallel: true
- name: Run caching
id: cached-poetry-dependencies
uses: actions/cache@v2
with:
path: .venv
key: venv-${{ runner.os }}-${{ matrix.python-version }}-${{ hashFiles('**/poetry.lock') }}
- name: Install OS dependencies - name: Install OS dependencies
if: ${{ matrix.python-version }} == '3.10' if: ${{ matrix.python-version }} == '3.10'
@ -85,13 +68,16 @@ jobs:
- name: Install dependencies - name: Install dependencies
if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true' if: steps.cached-poetry-dependencies.outputs.cache-hit != 'true'
run: poetry install --no-interaction --no-root
- name: Install library
run: poetry install --no-interaction run: poetry install --no-interaction
- name: Check formatting & linting
- name: Start Redis v6 run: |
uses: superchargejs/redis-github-action@1.1.0 poetry run black --check .
with: poetry run flake8
redis-version: 6 poetry run djlint --check templates
- name: Run db migration - name: Run db migration
run: | run: |
@ -109,14 +95,14 @@ jobs:
GITHUB_ACTIONS_TEST: true GITHUB_ACTIONS_TEST: true
- name: Archive code coverage results - name: Archive code coverage results
uses: actions/upload-artifact@v4 uses: actions/upload-artifact@v2
with: with:
name: code-coverage-report name: code-coverage-report
path: htmlcov path: htmlcov
build: build:
runs-on: ubuntu-latest runs-on: ubuntu-latest
needs: ['test', 'lint'] needs: ['test']
if: github.event_name == 'push' && (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/v')) if: github.event_name == 'push' && (github.ref == 'refs/heads/master' || startsWith(github.ref, 'refs/tags/v'))
steps: steps:
@ -134,25 +120,7 @@ jobs:
# We need to checkout the repository in order for the "Create Sentry release" to work # We need to checkout the repository in order for the "Create Sentry release" to work
- name: Checkout repository - name: Checkout repository
uses: actions/checkout@v3 uses: actions/checkout@v2
with:
fetch-depth: 0
- name: Set up QEMU
uses: docker/setup-qemu-action@v2
- name: Set up Docker Buildx
uses: docker/setup-buildx-action@v2
- name: Create Sentry release
uses: getsentry/action-release@v1
env:
SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }}
SENTRY_ORG: ${{ secrets.SENTRY_ORG }}
SENTRY_PROJECT: ${{ secrets.SENTRY_PROJECT }}
with:
ignore_missing: true
ignore_empty: true
- name: Prepare version file - name: Prepare version file
run: | run: |
@ -163,17 +131,22 @@ jobs:
uses: docker/build-push-action@v3 uses: docker/build-push-action@v3
with: with:
context: . context: .
platforms: linux/amd64,linux/arm64
push: true push: true
tags: ${{ steps.meta.outputs.tags }} tags: ${{ steps.meta.outputs.tags }}
- name: Create Sentry release
uses: getsentry/action-release@v1
env:
SENTRY_AUTH_TOKEN: ${{ secrets.SENTRY_AUTH_TOKEN }}
SENTRY_ORG: ${{ secrets.SENTRY_ORG }}
SENTRY_PROJECT: ${{ secrets.SENTRY_PROJECT }}
#- name: Send Telegram message - name: Send Telegram message
# uses: appleboy/telegram-action@master uses: appleboy/telegram-action@master
# with: with:
# to: ${{ secrets.TELEGRAM_TO }} to: ${{ secrets.TELEGRAM_TO }}
# token: ${{ secrets.TELEGRAM_TOKEN }} token: ${{ secrets.TELEGRAM_TOKEN }}
# args: Docker image pushed on ${{ github.ref }} args: Docker image pushed on ${{ github.ref }}
# If we have generated a tag, generate the changelog, send a notification to slack and create the GitHub release # If we have generated a tag, generate the changelog, send a notification to slack and create the GitHub release
- name: Build Changelog - name: Build Changelog

2
.gitignore vendored
View File

@ -11,7 +11,7 @@ db.sqlite-journal
static/upload static/upload
venv/ venv/
.venv .venv
.python-version
.coverage .coverage
htmlcov htmlcov
adhoc adhoc
.env.*

View File

@ -1,24 +1,10 @@
exclude: "(migrations|static/node_modules|static/assets|static/vendor)"
default_language_version:
python: python3
repos: repos:
- repo: https://github.com/pre-commit/pre-commit-hooks - repo: https://github.com/psf/black
rev: v4.2.0 rev: 22.1.0
hooks: hooks:
- id: check-yaml - id: black
- id: trailing-whitespace language_version: python3.7
- repo: https://github.com/Riverside-Healthcare/djLint - repo: https://github.com/pycqa/flake8
rev: v1.34.1 rev: 4.0.1
hooks: hooks:
- id: djlint-jinja - id: flake8
files: '.*\.html'
entry: djlint --reformat
- repo: https://github.com/astral-sh/ruff-pre-commit
# Ruff version.
rev: v0.1.5
hooks:
# Run the linter.
- id: ruff
args: [ --fix ]
# Run the formatter.
- id: ruff-format

227
.pylintrc
View File

@ -1,227 +0,0 @@
[MASTER]
extension-pkg-allow-list=re2
fail-under=7.0
ignore=CVS
ignore-paths=migrations
ignore-patterns=^\.#
jobs=0
[MESSAGES CONTROL]
disable=missing-function-docstring,
missing-module-docstring,
duplicate-code,
#import-error,
missing-class-docstring,
useless-object-inheritance,
use-dict-literal,
logging-format-interpolation,
consider-using-f-string,
unnecessary-comprehension,
inconsistent-return-statements,
wrong-import-order,
line-too-long,
invalid-name,
global-statement,
no-else-return,
unspecified-encoding,
logging-fstring-interpolation,
too-few-public-methods,
bare-except,
fixme,
unnecessary-pass,
f-string-without-interpolation,
super-init-not-called,
unused-argument,
ungrouped-imports,
too-many-locals,
consider-using-with,
too-many-statements,
consider-using-set-comprehension,
unidiomatic-typecheck,
useless-else-on-loop,
too-many-return-statements,
broad-except,
protected-access,
consider-using-enumerate,
too-many-nested-blocks,
too-many-branches,
simplifiable-if-expression,
possibly-unused-variable,
pointless-string-statement,
wrong-import-position,
redefined-outer-name,
raise-missing-from,
logging-too-few-args,
redefined-builtin,
too-many-arguments,
import-outside-toplevel,
redefined-argument-from-local,
logging-too-many-args,
too-many-instance-attributes,
unreachable,
no-name-in-module,
no-member,
consider-using-ternary,
too-many-lines,
arguments-differ,
too-many-public-methods,
unused-variable,
consider-using-dict-items,
consider-using-in,
reimported,
too-many-boolean-expressions,
cyclic-import,
not-callable, # (paddle_utils.py) verifier.verify cannot be called (although it can)
abstract-method, # (models.py)
[BASIC]
# Naming style matching correct argument names.
argument-naming-style=snake_case
# Regular expression matching correct argument names. Overrides argument-
# naming-style. If left empty, argument names will be checked with the set
# naming style.
#argument-rgx=
# Naming style matching correct attribute names.
attr-naming-style=snake_case
# Regular expression matching correct attribute names. Overrides attr-naming-
# style. If left empty, attribute names will be checked with the set naming
# style.
#attr-rgx=
# Bad variable names which should always be refused, separated by a comma.
bad-names=foo,
bar,
baz,
toto,
tutu,
tata
# Bad variable names regexes, separated by a comma. If names match any regex,
# they will always be refused
bad-names-rgxs=
# Naming style matching correct class attribute names.
class-attribute-naming-style=any
# Regular expression matching correct class attribute names. Overrides class-
# attribute-naming-style. If left empty, class attribute names will be checked
# with the set naming style.
#class-attribute-rgx=
# Naming style matching correct class constant names.
class-const-naming-style=UPPER_CASE
# Regular expression matching correct class constant names. Overrides class-
# const-naming-style. If left empty, class constant names will be checked with
# the set naming style.
#class-const-rgx=
# Naming style matching correct class names.
class-naming-style=PascalCase
# Regular expression matching correct class names. Overrides class-naming-
# style. If left empty, class names will be checked with the set naming style.
#class-rgx=
# Naming style matching correct constant names.
const-naming-style=UPPER_CASE
# Regular expression matching correct constant names. Overrides const-naming-
# style. If left empty, constant names will be checked with the set naming
# style.
#const-rgx=
# Minimum line length for functions/classes that require docstrings, shorter
# ones are exempt.
docstring-min-length=-1
# Naming style matching correct function names.
function-naming-style=snake_case
# Regular expression matching correct function names. Overrides function-
# naming-style. If left empty, function names will be checked with the set
# naming style.
#function-rgx=
# Good variable names which should always be accepted, separated by a comma.
good-names=i,
j,
k,
ex,
Run,
_
# Good variable names regexes, separated by a comma. If names match any regex,
# they will always be accepted
good-names-rgxs=
# Include a hint for the correct naming format with invalid-name.
include-naming-hint=no
# Naming style matching correct inline iteration names.
inlinevar-naming-style=any
# Regular expression matching correct inline iteration names. Overrides
# inlinevar-naming-style. If left empty, inline iteration names will be checked
# with the set naming style.
#inlinevar-rgx=
# Naming style matching correct method names.
method-naming-style=snake_case
# Regular expression matching correct method names. Overrides method-naming-
# style. If left empty, method names will be checked with the set naming style.
#method-rgx=
# Naming style matching correct module names.
module-naming-style=snake_case
# Regular expression matching correct module names. Overrides module-naming-
# style. If left empty, module names will be checked with the set naming style.
#module-rgx=
# Colon-delimited sets of names that determine each other's naming style when
# the name regexes allow several styles.
name-group=
# Regular expression which should only match function or class names that do
# not require a docstring.
no-docstring-rgx=^_
# List of decorators that produce properties, such as abc.abstractproperty. Add
# to this list to register other decorators that produce valid properties.
# These decorators are taken in consideration only for invalid-name.
property-classes=abc.abstractproperty
# Regular expression matching correct type variable names. If left empty, type
# variable names will be checked with the set naming style.
#typevar-rgx=
# Naming style matching correct variable names.
variable-naming-style=snake_case
# Regular expression matching correct variable names. Overrides variable-
# naming-style. If left empty, variable names will be checked with the set
# naming style.
#variable-rgx=
[STRING]
# This flag controls whether inconsistent-quotes generates a warning when the
# character used as a quote delimiter is used inconsistently within a module.
check-quote-consistency=no
# This flag controls whether the implicit-str-concat should generate a warning
# on implicit string concatenation in sequences defined over several lines.
check-str-concat-over-line-jumps=no
[FORMAT]
max-line-length=88
single-line-if-stmt=yes

View File

@ -20,21 +20,21 @@ SimpleLogin backend consists of 2 main components:
## Install dependencies ## Install dependencies
The project requires: The project requires:
- Python 3.10 and poetry to manage dependencies - Python 3.7+ and [poetry](https://python-poetry.org/) to manage dependencies
- Node v10 for front-end. - Node v10 for front-end.
- Postgres 13+ - Postgres 12+
First, install all dependencies by running the following command. First, install all dependencies by running the following command.
Feel free to use `virtualenv` or similar tools to isolate development environment. Feel free to use `virtualenv` or similar tools to isolate development environment.
```bash ```bash
poetry sync poetry install
``` ```
On Mac, sometimes you might need to install some other packages via `brew`: On Mac, sometimes you might need to install some other packages via `brew`:
```bash ```bash
brew install pkg-config libffi openssl postgresql@13 brew install pkg-config libffi openssl postgresql
``` ```
You also need to install `gpg` tool, on Mac it can be done with: You also need to install `gpg` tool, on Mac it can be done with:
@ -50,30 +50,12 @@ More info on https://github.com/andreasvc/pyre2
brew install -s re2 pybind11 brew install -s re2 pybind11
``` ```
## Linting and static analysis
We use pre-commit to run all our linting and static analysis checks. Please run
```bash
poetry run pre-commit install
```
To install it in your development environment.
## Run tests ## Run tests
For most tests, you will need to have ``redis`` installed and started on your machine (listening on port 6379).
```bash ```bash
sh scripts/run-test.sh sh scripts/run-test.sh
``` ```
You can also run tests using a local Postgres DB to speed things up. This can be done by
- creating an empty test DB and running the database migration by `dropdb test && createdb test && DB_URI=postgresql://localhost:5432/test alembic upgrade head`
- replacing the `DB_URI` in `test.env` file by `DB_URI=postgresql://localhost:5432/test`
## Run the code locally ## Run the code locally
Install npm packages Install npm packages
@ -88,16 +70,10 @@ To run the code locally, please create a local setting file based on `example.en
cp example.env .env cp example.env .env
``` ```
You need to edit your .env to reflect the postgres exposed port, edit the `DB_URI` to:
```
DB_URI=postgresql://myuser:mypassword@localhost:35432/simplelogin
```
Run the postgres database: Run the postgres database:
```bash ```bash
docker run -e POSTGRES_PASSWORD=mypassword -e POSTGRES_USER=myuser -e POSTGRES_DB=simplelogin -p 15432:5432 postgres:13 docker run -e POSTGRES_PASSWORD=mypassword -e POSTGRES_USER=myuser -e POSTGRES_DB=simplelogin -p 35432:5432 postgres:13
``` ```
To run the server: To run the server:
@ -157,10 +133,10 @@ Here are the small sum-ups of the directory structures and their roles:
## Pull request ## Pull request
The code is formatted using [ruff](https://github.com/astral-sh/ruff), to format the code, simply run The code is formatted using https://github.com/psf/black, to format the code, simply run
``` ```
poetry run ruff format . poetry run black .
``` ```
The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by The code is also checked with `flake8`, make sure to run `flake8` before creating the pull request by
@ -175,12 +151,6 @@ For HTML templates, we use `djlint`. Before creating a pull request, please run
poetry run djlint --check templates poetry run djlint --check templates
``` ```
If some files aren't properly formatted, you can format all files with
```bash
poetry run djlint --reformat .
```
## Test sending email ## Test sending email
[swaks](http://www.jetmore.org/john/code/swaks/) is used for sending test emails to the `email_handler`. [swaks](http://www.jetmore.org/john/code/swaks/) is used for sending test emails to the `email_handler`.
@ -219,35 +189,3 @@ swaks --to e1@sl.local --from hey@google.com --server 127.0.0.1:20381
``` ```
Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email. Now open http://localhost:1080/ (or http://localhost:1080/ for MailHog), you should see the forwarded email.
## Job runner
Some features require a job handler (such as GDPR data export). To test such feature you need to run the job_runner
```bash
python job_runner.py
```
# Setup for Mac
There are several ways to setup Python and manage the project dependencies on Mac. For info we have successfully used this setup on a Mac silicon:
```bash
# we haven't managed to make python 3.12 work
brew install python3.10
# make sure to update the PATH so python, pip point to Python3
# for us it can be done by adding "export PATH=/opt/homebrew/opt/python@3.10/libexec/bin:$PATH" to .zprofile
# Although pipx is the recommended way to install poetry,
# install pipx via brew will automatically install python 3.12
# and poetry will then use python 3.12
# so we recommend using poetry this way instead
curl -sSL https://install.python-poetry.org | python3 -
poetry install
# activate the virtualenv and you should be good to go!
source .venv/bin/activate
```

View File

@ -2,10 +2,10 @@
FROM node:10.17.0-alpine AS npm FROM node:10.17.0-alpine AS npm
WORKDIR /code WORKDIR /code
COPY ./static/package*.json /code/static/ COPY ./static/package*.json /code/static/
RUN cd /code/static && npm ci RUN cd /code/static && npm install
# Main image # Main image
FROM python:3.10 FROM python:3.7
# Keeps Python from generating .pyc files in the container # Keeps Python from generating .pyc files in the container
ENV PYTHONDONTWRITEBYTECODE 1 ENV PYTHONDONTWRITEBYTECODE 1
@ -13,7 +13,7 @@ ENV PYTHONDONTWRITEBYTECODE 1
ENV PYTHONUNBUFFERED 1 ENV PYTHONUNBUFFERED 1
# Add poetry to PATH # Add poetry to PATH
ENV PATH="${PATH}:/root/.local/bin" ENV PATH="${PATH}:/root/.poetry/bin"
WORKDIR /code WORKDIR /code
@ -23,15 +23,15 @@ COPY poetry.lock pyproject.toml ./
# Install and setup poetry # Install and setup poetry
RUN pip install -U pip \ RUN pip install -U pip \
&& apt-get update \ && apt-get update \
&& apt install -y curl netcat-traditional gcc python3-dev gnupg git libre2-dev cmake ninja-build\ && apt install -y curl netcat gcc python3-dev gnupg git libre2-dev \
&& curl -sSL https://install.python-poetry.org | python3 - \ && curl -sSL https://raw.githubusercontent.com/python-poetry/poetry/master/get-poetry.py | python - \
# Remove curl and netcat from the image # Remove curl and netcat from the image
&& apt-get purge -y curl netcat-traditional \ && apt-get purge -y curl netcat \
# Run poetry # Run poetry
&& poetry config virtualenvs.create false \ && poetry config virtualenvs.create false \
&& poetry install --no-interaction --no-ansi --no-root \ && poetry install --no-interaction --no-ansi --no-root \
# Clear apt cache \ # Clear apt cache \
&& apt-get purge -y libre2-dev cmake ninja-build\ && apt-get purge -y libre2-dev \
&& apt-get clean \ && apt-get clean \
&& rm -rf /var/lib/apt/lists/* && rm -rf /var/lib/apt/lists/*

View File

@ -15,8 +15,8 @@
<img src="https://img.shields.io/github/license/simple-login/app"> <img src="https://img.shields.io/github/license/simple-login/app">
</a> </a>
<a href="https://twitter.com/simplelogin"> <a href="https://twitter.com/simple_login">
<img src="https://img.shields.io/twitter/follow/simplelogin?style=social"> <img src="https://img.shields.io/twitter/follow/simple_login?style=social">
</a> </a>
</p> </p>
@ -74,7 +74,7 @@ Setting up DKIM is highly recommended to reduce the chance your emails ending up
First you need to generate a private and public key for DKIM: First you need to generate a private and public key for DKIM:
```bash ```bash
openssl genrsa -out dkim.key -traditional 1024 openssl genrsa -out dkim.key 1024
openssl rsa -in dkim.key -pubout -out dkim.pub.key openssl rsa -in dkim.key -pubout -out dkim.pub.key
``` ```
@ -334,12 +334,6 @@ smtpd_recipient_restrictions =
permit permit
``` ```
Check that the ssl certificates `/etc/ssl/certs/ssl-cert-snakeoil.pem` and `/etc/ssl/private/ssl-cert-snakeoil.key` exist. Depending on the linux distribution you are using they may or may not be present. If they are not, you will need to generate them with this command:
```bash
openssl req -x509 -nodes -days 3650 -newkey rsa:2048 -keyout /etc/ssl/private/ssl-cert-snakeoil.key -out /etc/ssl/certs/ssl-cert-snakeoil.pem
```
Create the `/etc/postfix/pgsql-relay-domains.cf` file with the following content. Create the `/etc/postfix/pgsql-relay-domains.cf` file with the following content.
Make sure that the database config is correctly set, replace `mydomain.com` with your domain, update 'myuser' and 'mypassword' with your postgres credentials. Make sure that the database config is correctly set, replace `mydomain.com` with your domain, update 'myuser' and 'mypassword' with your postgres credentials.
@ -510,14 +504,11 @@ server {
server_name app.mydomain.com; server_name app.mydomain.com;
location / { location / {
proxy_pass http://localhost:7777; proxy_pass http://localhost:7777;
proxy_set_header Host $host;
} }
} }
``` ```
Note: If `/etc/nginx/sites-enabled/default` exists, delete it or certbot will fail due to the conflict. The `simplelogin` file should be the only file in `sites-enabled`.
Reload Nginx with the command below Reload Nginx with the command below
```bash ```bash
@ -541,7 +532,7 @@ exit
Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations: Once you've created all your desired login accounts, add these lines to `/simplelogin.env` to disable further registrations:
```.env ```
DISABLE_REGISTRATION=1 DISABLE_REGISTRATION=1
DISABLE_ONBOARDING=true DISABLE_ONBOARDING=true
``` ```

View File

@ -5,23 +5,16 @@ from typing import Optional
from arrow import Arrow from arrow import Arrow
from newrelic import agent from newrelic import agent
from sqlalchemy import or_
from app.db import Session from app.db import Session
from app.email_utils import send_welcome_email from app.email_utils import send_welcome_email
from app.utils import sanitize_email, canonicalize_email from app.errors import AccountAlreadyLinkedToAnotherPartnerException
from app.errors import (
AccountAlreadyLinkedToAnotherPartnerException,
AccountIsUsingAliasAsEmail,
AccountAlreadyLinkedToAnotherUserException,
)
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
PartnerSubscription, PartnerSubscription,
Partner, Partner,
PartnerUser, PartnerUser,
User, User,
Alias,
) )
from app.utils import random_string from app.utils import random_string
@ -132,9 +125,8 @@ class ClientMergeStrategy(ABC):
class NewUserStrategy(ClientMergeStrategy): class NewUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
# Will create a new SL User with a random password # Will create a new SL User with a random password
canonical_email = canonicalize_email(self.link_request.email)
new_user = User.create( new_user = User.create(
email=canonical_email, email=self.link_request.email,
name=self.link_request.name, name=self.link_request.name,
password=random_string(20), password=random_string(20),
activated=True, activated=True,
@ -168,8 +160,7 @@ class NewUserStrategy(ClientMergeStrategy):
class ExistingUnlinkedUserStrategy(ClientMergeStrategy): class ExistingUnlinkedUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
# IF it was scheduled to be deleted. Unschedule it.
self.user.delete_on = None
partner_user = ensure_partner_user_exists_for_user( partner_user = ensure_partner_user_exists_for_user(
self.link_request, self.user, self.partner self.link_request, self.user, self.partner
) )
@ -183,7 +174,7 @@ class ExistingUnlinkedUserStrategy(ClientMergeStrategy):
class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy): class LinkedWithAnotherPartnerUserStrategy(ClientMergeStrategy):
def process(self) -> LinkResult: def process(self) -> LinkResult:
raise AccountAlreadyLinkedToAnotherUserException() raise AccountAlreadyLinkedToAnotherPartnerException()
def get_login_strategy( def get_login_strategy(
@ -200,37 +191,17 @@ def get_login_strategy(
return ExistingUnlinkedUserStrategy(link_request, user, partner) return ExistingUnlinkedUserStrategy(link_request, user, partner)
def check_alias(email: str) -> bool:
alias = Alias.get_by(email=email)
if alias is not None:
raise AccountIsUsingAliasAsEmail()
def process_login_case( def process_login_case(
link_request: PartnerLinkRequest, partner: Partner link_request: PartnerLinkRequest, partner: Partner
) -> LinkResult: ) -> LinkResult:
# Sanitize email just in case
link_request.email = sanitize_email(link_request.email)
# Try to find a SimpleLogin user registered with that partner user id # Try to find a SimpleLogin user registered with that partner user id
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=partner.id, external_user_id=link_request.external_user_id partner_id=partner.id, external_user_id=link_request.external_user_id
) )
if partner_user is None: if partner_user is None:
canonical_email = canonicalize_email(link_request.email)
# We didn't find any SimpleLogin user registered with that partner user id # We didn't find any SimpleLogin user registered with that partner user id
# Make sure they aren't using an alias as their link email
check_alias(link_request.email)
check_alias(canonical_email)
# Try to find it using the partner's e-mail address # Try to find it using the partner's e-mail address
users = User.filter( user = User.get_by(email=link_request.email)
or_(User.email == link_request.email, User.email == canonical_email)
).all()
if len(users) > 1:
user = [user for user in users if user.email == canonical_email][0]
elif len(users) == 1:
user = users[0]
else:
user = None
return get_login_strategy(link_request, user, partner).process() return get_login_strategy(link_request, user, partner).process()
else: else:
# We found the SL user registered with that partner user id # We found the SL user registered with that partner user id
@ -246,10 +217,6 @@ def process_login_case(
def link_user( def link_user(
link_request: PartnerLinkRequest, current_user: User, partner: Partner link_request: PartnerLinkRequest, current_user: User, partner: Partner
) -> LinkResult: ) -> LinkResult:
# Sanitize email just in case
link_request.email = sanitize_email(link_request.email)
# If it was scheduled to be deleted. Unschedule it.
current_user.delete_on = None
partner_user = ensure_partner_user_exists_for_user( partner_user = ensure_partner_user_exists_for_user(
link_request, current_user, partner link_request, current_user, partner
) )
@ -293,8 +260,6 @@ def process_link_case(
current_user: User, current_user: User,
partner: Partner, partner: Partner,
) -> LinkResult: ) -> LinkResult:
# Sanitize email just in case
link_request.email = sanitize_email(link_request.email)
# Try to find a SimpleLogin user linked with this Partner account # Try to find a SimpleLogin user linked with this Partner account
partner_user = PartnerUser.get_by( partner_user = PartnerUser.get_by(
partner_id=partner.id, external_user_id=link_request.external_user_id partner_id=partner.id, external_user_id=link_request.external_user_id
@ -304,7 +269,7 @@ def process_link_case(
return link_user(link_request, current_user, partner) return link_user(link_request, current_user, partner)
# There is a SL user registered with the partner. Check if is the current one # There is a SL user registered with the partner. Check if is the current one
if partner_user.user_id == current_user.id: if partner_user.id == current_user.id:
# Update plan # Update plan
set_plan_for_partner_user(partner_user, link_request.plan) set_plan_for_partner_user(partner_user, link_request.plan)
# It's the same user. No need to do anything # It's the same user. No need to do anything
@ -313,4 +278,5 @@ def process_link_case(
strategy="Link", strategy="Link",
) )
else: else:
return switch_already_linked_user(link_request, partner_user, current_user) return switch_already_linked_user(link_request, partner_user, current_user)

View File

@ -1,10 +1,7 @@
from __future__ import annotations
from typing import Optional from typing import Optional
import arrow import arrow
import sqlalchemy import sqlalchemy
from flask_admin import BaseView
from flask_admin.form import SecureForm
from flask_admin.model.template import EndpointLinkRowAction from flask_admin.model.template import EndpointLinkRowAction
from markupsafe import Markup from markupsafe import Markup
@ -27,34 +24,12 @@ from app.models import (
ProviderComplaintState, ProviderComplaintState,
Phase, Phase,
ProviderComplaint, ProviderComplaint,
Alias,
Newsletter,
PADDLE_SUBSCRIPTION_GRACE_DAYS,
Mailbox,
DeletedAlias,
DomainDeletedAlias,
PartnerUser,
) )
from app.newsletter_utils import send_newsletter_to_user, send_newsletter_to_address
def _admin_action_formatter(view, context, model, name):
action_name = AuditLogActionEnum.get_name(model.action)
return "{} ({})".format(action_name, model.action)
def _admin_date_formatter(view, context, model, name):
return model.created_at.format()
def _user_upgrade_channel_formatter(view, context, model, name):
return Markup(model.upgrade_channel)
class SLModelView(sqla.ModelView): class SLModelView(sqla.ModelView):
column_default_sort = ("id", True) column_default_sort = ("id", True)
column_display_pk = True column_display_pk = True
page_size = 100
can_edit = False can_edit = False
can_create = False can_create = False
@ -66,8 +41,7 @@ class SLModelView(sqla.ModelView):
def inaccessible_callback(self, name, **kwargs): def inaccessible_callback(self, name, **kwargs):
# redirect to login page if user doesn't have access # redirect to login page if user doesn't have access
flash("You don't have access to the admin page", "error") return redirect(url_for("auth.login", next=request.url))
return redirect(url_for("dashboard.index", next=request.url))
def on_model_change(self, form, model, is_created): def on_model_change(self, form, model, is_created):
changes = {} changes = {}
@ -116,7 +90,6 @@ class SLAdminIndexView(AdminIndexView):
class UserAdmin(SLModelView): class UserAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["email", "id"] column_searchable_list = ["email", "id"]
column_exclude_list = [ column_exclude_list = [
"salt", "salt",
@ -133,40 +106,6 @@ class UserAdmin(SLModelView):
ret.insert(0, "upgrade_channel") ret.insert(0, "upgrade_channel")
return ret return ret
column_formatters = {
"upgrade_channel": _user_upgrade_channel_formatter,
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
@action(
"disable_user",
"Disable user",
"Are you sure you want to disable the selected users?",
)
def action_disable_user(self, ids):
for user in User.filter(User.id.in_(ids)):
user.disabled = True
flash(f"Disabled user {user.id}")
AdminAuditLog.disable_user(current_user.id, user.id)
Session.commit()
@action(
"enable_user",
"Enable user",
"Are you sure you want to enable the selected users?",
)
def action_enable_user(self, ids):
for user in User.filter(User.id.in_(ids)):
user.disabled = False
flash(f"Enabled user {user.id}")
AdminAuditLog.enable_user(current_user.id, user.id)
Session.commit()
@action( @action(
"education_upgrade", "education_upgrade",
"Education upgrade", "Education upgrade",
@ -234,20 +173,6 @@ class UserAdmin(SLModelView):
Session.commit() Session.commit()
@action(
"remove trial",
"Stop trial period",
"Remove trial for this user?",
)
def stop_trial(self, ids):
for user in User.filter(User.id.in_(ids)):
user.trial_end = None
flash(f"Stopped trial for {user}", "success")
AdminAuditLog.stop_trial(current_user.id, user.id)
Session.commit()
@action( @action(
"disable_otp_fido", "disable_otp_fido",
"Disable OTP & FIDO", "Disable OTP & FIDO",
@ -271,36 +196,6 @@ class UserAdmin(SLModelView):
Session.commit() Session.commit()
@action(
"stop_paddle_sub",
"Stop user Paddle subscription",
"This will stop the current user Paddle subscription so if user doesn't have Proton sub, they will lose all SL benefits immediately",
)
def stop_paddle_sub(self, ids):
for user in User.filter(User.id.in_(ids)):
sub: Subscription = user.get_paddle_subscription()
if not sub:
flash(f"No Paddle sub for {user}", "warning")
continue
flash(f"{user} sub will end now, instead of {sub.next_bill_date}", "info")
sub.next_bill_date = (
arrow.now().shift(days=-PADDLE_SUBSCRIPTION_GRACE_DAYS).date()
)
Session.commit()
@action(
"clear_delete_on",
"Remove scheduled deletion of user",
"This will remove the scheduled deletion for this users",
)
def clean_delete_on(self, ids):
for user in User.filter(User.id.in_(ids)):
user.delete_on = None
Session.commit()
# @action( # @action(
# "login_as", # "login_as",
# "Login as this user", # "Login as this user",
@ -363,60 +258,22 @@ def manual_upgrade(way: str, ids: [int], is_giveaway: bool):
class EmailLogAdmin(SLModelView): class EmailLogAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id"] column_searchable_list = ["id"]
column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"] column_filters = ["id", "user.email", "mailbox.email", "contact.website_email"]
can_edit = False can_edit = False
can_create = False can_create = False
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
class AliasAdmin(SLModelView): class AliasAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "email", "mailbox.email"] column_searchable_list = ["id", "user.email", "email", "mailbox.email"]
column_filters = ["id", "user.email", "email", "mailbox.email"] column_filters = ["id", "user.email", "email", "mailbox.email"]
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
@action(
"disable_email_spoofing_check",
"Disable email spoofing protection",
"Disable email spoofing protection?",
)
def disable_email_spoofing_check_for(self, ids):
for alias in Alias.filter(Alias.id.in_(ids)):
if alias.disable_email_spoofing_check:
flash(
f"Email spoofing protection is already disabled on {alias.email}",
"warning",
)
else:
alias.disable_email_spoofing_check = True
flash(
f"Email spoofing protection is disabled on {alias.email}", "success"
)
Session.commit()
class MailboxAdmin(SLModelView): class MailboxAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "email"] column_searchable_list = ["id", "user.email", "email"]
column_filters = ["id", "user.email", "email"] column_filters = ["id", "user.email", "email"]
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
# class LifetimeCouponAdmin(SLModelView): # class LifetimeCouponAdmin(SLModelView):
# can_edit = True # can_edit = True
@ -424,26 +281,14 @@ class MailboxAdmin(SLModelView):
class CouponAdmin(SLModelView): class CouponAdmin(SLModelView):
form_base_class = SecureForm
can_edit = False can_edit = False
can_create = True can_create = True
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
class ManualSubscriptionAdmin(SLModelView): class ManualSubscriptionAdmin(SLModelView):
form_base_class = SecureForm
can_edit = True can_edit = True
column_searchable_list = ["id", "user.email"] column_searchable_list = ["id", "user.email"]
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
@action( @action(
"extend_1y", "extend_1y",
"Extend for 1 year", "Extend for 1 year",
@ -482,27 +327,15 @@ class ManualSubscriptionAdmin(SLModelView):
class CustomDomainAdmin(SLModelView): class CustomDomainAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["domain", "user.email", "user.id"] column_searchable_list = ["domain", "user.email", "user.id"]
column_exclude_list = ["ownership_txt_token"] column_exclude_list = ["ownership_txt_token"]
can_edit = False can_edit = False
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
class ReferralAdmin(SLModelView): class ReferralAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.email", "code", "name"] column_searchable_list = ["id", "user.email", "code", "name"]
column_filters = ["id", "user.email", "code", "name"] column_filters = ["id", "user.email", "code", "name"]
column_formatters = {
"created_at": _admin_date_formatter,
"updated_at": _admin_date_formatter,
}
def scaffold_list_columns(self): def scaffold_list_columns(self):
ret = super().scaffold_list_columns() ret = super().scaffold_list_columns()
ret.insert(0, "nb_user") ret.insert(0, "nb_user")
@ -518,8 +351,16 @@ class ReferralAdmin(SLModelView):
# can_delete = True # can_delete = True
def _admin_action_formatter(view, context, model, name):
action_name = AuditLogActionEnum.get_name(model.action)
return "{} ({})".format(action_name, model.action)
def _admin_created_at_formatter(view, context, model, name):
return model.created_at.format()
class AdminAuditLogAdmin(SLModelView): class AdminAuditLogAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"] column_searchable_list = ["admin.id", "admin.email", "model_id", "created_at"]
column_filters = ["admin.id", "admin.email", "model_id", "created_at"] column_filters = ["admin.id", "admin.email", "model_id", "created_at"]
column_exclude_list = ["id"] column_exclude_list = ["id"]
@ -530,8 +371,7 @@ class AdminAuditLogAdmin(SLModelView):
column_formatters = { column_formatters = {
"action": _admin_action_formatter, "action": _admin_action_formatter,
"created_at": _admin_date_formatter, "created_at": _admin_created_at_formatter,
"updated_at": _admin_date_formatter,
} }
@ -551,7 +391,6 @@ def _transactionalcomplaint_refused_email_id_formatter(view, context, model, nam
class ProviderComplaintAdmin(SLModelView): class ProviderComplaintAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id", "user.id", "created_at"] column_searchable_list = ["id", "user.id", "created_at"]
column_filters = ["user.id", "state"] column_filters = ["user.id", "state"]
column_hide_backrefs = False column_hide_backrefs = False
@ -560,8 +399,8 @@ class ProviderComplaintAdmin(SLModelView):
can_delete = False can_delete = False
column_formatters = { column_formatters = {
"created_at": _admin_date_formatter, "created_at": _admin_created_at_formatter,
"updated_at": _admin_date_formatter, "updated_at": _admin_created_at_formatter,
"state": _transactionalcomplaint_state_formatter, "state": _transactionalcomplaint_state_formatter,
"phase": _transactionalcomplaint_phase_formatter, "phase": _transactionalcomplaint_phase_formatter,
"refused_email": _transactionalcomplaint_refused_email_id_formatter, "refused_email": _transactionalcomplaint_refused_email_id_formatter,
@ -609,217 +448,3 @@ class ProviderComplaintAdmin(SLModelView):
) )
}, },
) )
def _newsletter_plain_text_formatter(view, context, model: Newsletter, name):
# to display newsletter plain_text with linebreaks in the list view
return Markup(model.plain_text.replace("\n", "<br>"))
def _newsletter_html_formatter(view, context, model: Newsletter, name):
# to display newsletter html with linebreaks in the list view
return Markup(model.html.replace("\n", "<br>"))
class NewsletterAdmin(SLModelView):
form_base_class = SecureForm
list_template = "admin/model/newsletter-list.html"
edit_template = "admin/model/newsletter-edit.html"
edit_modal = False
can_edit = True
can_create = True
column_formatters = {
"plain_text": _newsletter_plain_text_formatter,
"html": _newsletter_html_formatter,
}
@action(
"send_newsletter_to_user",
"Send this newsletter to myself or the specified userID",
)
def send_newsletter_to_user(self, newsletter_ids):
user_id = request.form["user_id"]
if user_id:
user = User.get(user_id)
if not user:
flash(f"No such user with ID {user_id}", "error")
return
else:
flash("use the current user", "info")
user = current_user
for newsletter_id in newsletter_ids:
newsletter = Newsletter.get(newsletter_id)
sent, error_msg = send_newsletter_to_user(newsletter, user)
if sent:
flash(f"{newsletter} sent to {user}", "success")
else:
flash(error_msg, "error")
@action(
"send_newsletter_to_address",
"Send this newsletter to a specific address",
)
def send_newsletter_to_address(self, newsletter_ids):
to_address = request.form["to_address"]
if not to_address:
flash("to_address missing", "error")
return
for newsletter_id in newsletter_ids:
newsletter = Newsletter.get(newsletter_id)
# use the current_user for rendering email
sent, error_msg = send_newsletter_to_address(
newsletter, current_user, to_address
)
if sent:
flash(
f"{newsletter} sent to {to_address} with {current_user} context",
"success",
)
else:
flash(error_msg, "error")
@action(
"clone_newsletter",
"Clone this newsletter",
)
def clone_newsletter(self, newsletter_ids):
if len(newsletter_ids) != 1:
flash("you can only select 1 newsletter", "error")
return
newsletter_id = newsletter_ids[0]
newsletter: Newsletter = Newsletter.get(newsletter_id)
new_newsletter = Newsletter.create(
subject=newsletter.subject,
html=newsletter.html,
plain_text=newsletter.plain_text,
commit=True,
)
flash(f"Newsletter {new_newsletter.subject} has been cloned", "success")
class NewsletterUserAdmin(SLModelView):
form_base_class = SecureForm
column_searchable_list = ["id"]
column_filters = ["id", "user.email", "newsletter.subject"]
column_exclude_list = ["created_at", "updated_at", "id"]
can_edit = False
can_create = False
class DailyMetricAdmin(SLModelView):
form_base_class = SecureForm
column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True
class MetricAdmin(SLModelView):
form_base_class = SecureForm
column_exclude_list = ["created_at", "updated_at", "id"]
can_export = True
class InvalidMailboxDomainAdmin(SLModelView):
form_base_class = SecureForm
can_create = True
can_delete = True
class EmailSearchResult:
no_match: bool = True
alias: Optional[Alias] = None
mailbox: list[Mailbox] = []
mailbox_count: int = 0
deleted_alias: Optional[DeletedAlias] = None
deleted_custom_alias: Optional[DomainDeletedAlias] = None
user: Optional[User] = None
@staticmethod
def from_email(email: str) -> EmailSearchResult:
output = EmailSearchResult()
alias = Alias.get_by(email=email)
if alias:
output.alias = alias
output.no_match = False
user = User.get_by(email=email)
if user:
output.user = user
output.no_match = False
mailboxes = (
Mailbox.filter_by(email=email).order_by(Mailbox.id.desc()).limit(10).all()
)
if mailboxes:
output.mailbox = mailboxes
output.mailbox_count = Mailbox.filter_by(email=email).count()
output.no_match = False
deleted_alias = DeletedAlias.get_by(email=email)
if deleted_alias:
output.deleted_alias = deleted_alias
output.no_match = False
domain_deleted_alias = DomainDeletedAlias.get_by(email=email)
if domain_deleted_alias:
output.domain_deleted_alias = domain_deleted_alias
output.no_match = False
return output
class EmailSearchHelpers:
@staticmethod
def mailbox_list(user: User) -> list[Mailbox]:
return (
Mailbox.filter_by(user_id=user.id)
.order_by(Mailbox.id.asc())
.limit(10)
.all()
)
@staticmethod
def mailbox_count(user: User) -> int:
return Mailbox.filter_by(user_id=user.id).order_by(Mailbox.id.desc()).count()
@staticmethod
def alias_list(user: User) -> list[Alias]:
return (
Alias.filter_by(user_id=user.id).order_by(Alias.id.desc()).limit(10).all()
)
@staticmethod
def alias_count(user: User) -> int:
return Alias.filter_by(user_id=user.id).count()
@staticmethod
def partner_user(user: User) -> Optional[PartnerUser]:
return PartnerUser.get_by(user_id=user.id)
class EmailSearchAdmin(BaseView):
def is_accessible(self):
return current_user.is_authenticated and current_user.is_admin
def inaccessible_callback(self, name, **kwargs):
# redirect to login page if user doesn't have access
flash("You don't have access to the admin page", "error")
return redirect(url_for("dashboard.index", next=request.url))
@expose("/", methods=["GET", "POST"])
def index(self):
search = EmailSearchResult()
email = request.args.get("email")
if email is not None and len(email) > 0:
email = email.strip()
search = EmailSearchResult.from_email(email)
return self.render(
"admin/email_search.html",
email=email,
data=search,
helper=EmailSearchHelpers,
)

View File

@ -1,192 +0,0 @@
from __future__ import annotations
import json
from dataclasses import asdict, dataclass
from typing import Optional
import itsdangerous
from app import config
from app.log import LOG
from app.models import User, AliasOptions, SLDomain
signer = itsdangerous.TimestampSigner(config.CUSTOM_ALIAS_SECRET)
@dataclass
class AliasSuffix:
# whether this is a custom domain
is_custom: bool
# Suffix
suffix: str
# Suffix signature
signed_suffix: str
# whether this is a premium SL domain. Not apply to custom domain
is_premium: bool
# can be either Custom or SL domain
domain: str
# if custom domain, whether the custom domain has MX verified, i.e. can receive emails
mx_verified: bool = True
def serialize(self):
return json.dumps(asdict(self))
@classmethod
def deserialize(cls, data: str) -> AliasSuffix:
return AliasSuffix(**json.loads(data))
def check_suffix_signature(signed_suffix: str) -> Optional[str]:
# hypothesis: user will click on the button in the 600 secs
try:
return signer.unsign(signed_suffix, max_age=600).decode()
except itsdangerous.BadSignature:
return None
def verify_prefix_suffix(
user: User, alias_prefix, alias_suffix, alias_options: Optional[AliasOptions] = None
) -> bool:
"""verify if user could create an alias with the given prefix and suffix"""
if not alias_prefix or not alias_suffix: # should be caught on frontend
return False
user_custom_domains = [cd.domain for cd in user.verified_custom_domains()]
# make sure alias_suffix is either .random_word@simplelogin.co or @my-domain.com
alias_suffix = alias_suffix.strip()
# alias_domain_prefix is either a .random_word or ""
alias_domain_prefix, alias_domain = alias_suffix.split("@", 1)
# alias_domain must be either one of user custom domains or built-in domains
if alias_domain not in user.available_alias_domains(alias_options=alias_options):
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
# SimpleLogin domain case:
# 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list
available_sl_domains = [
sl_domain.domain
for sl_domain in user.get_sl_domains(alias_options=alias_options)
]
if (
alias_domain in available_sl_domains
and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not config.DISABLE_ALIAS_SUFFIX
):
if not alias_domain_prefix.startswith("."):
LOG.e("User %s submits a wrong alias suffix %s", user, alias_suffix)
return False
else:
if alias_domain not in user_custom_domains:
if not config.DISABLE_ALIAS_SUFFIX:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
if alias_domain not in available_sl_domains:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
return True
def get_alias_suffixes(
user: User, alias_options: Optional[AliasOptions] = None
) -> [AliasSuffix]:
"""
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
"""
user_custom_domains = user.verified_custom_domains()
alias_suffixes: [AliasSuffix] = []
# put custom domain first
# for each user domain, generate both the domain and a random suffix version
for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation:
suffix = (
f".{user.get_random_alias_suffix(custom_domain)}@{custom_domain.domain}"
)
alias_suffix = AliasSuffix(
is_custom=True,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=False,
domain=custom_domain.domain,
mx_verified=custom_domain.verified,
)
if user.default_alias_custom_domain_id == custom_domain.id:
alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
suffix = f"@{custom_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=True,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=False,
domain=custom_domain.domain,
mx_verified=custom_domain.verified,
)
# put the default domain to top
# only if random_prefix_generation isn't enabled
if (
user.default_alias_custom_domain_id == custom_domain.id
and not custom_domain.random_prefix_generation
):
alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
# then SimpleLogin domain
sl_domains = user.get_sl_domains(alias_options=alias_options)
default_domain_found = False
for sl_domain in sl_domains:
prefix = (
"" if config.DISABLE_ALIAS_SUFFIX else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
# No default or this is not the default
if (
user.default_alias_public_domain_id is None
or user.default_alias_public_domain_id != sl_domain.id
):
alias_suffixes.append(alias_suffix)
else:
default_domain_found = True
alias_suffixes.insert(0, alias_suffix)
if not default_domain_found:
domain_conditions = {"id": user.default_alias_public_domain_id, "hidden": False}
if not user.is_premium():
domain_conditions["premium_only"] = False
sl_domain = SLDomain.get_by(**domain_conditions)
if sl_domain:
prefix = (
""
if config.DISABLE_ALIAS_SUFFIX
else f".{user.get_random_alias_suffix()}"
)
suffix = f"{prefix}@{sl_domain.domain}"
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
signed_suffix=signer.sign(suffix).decode(),
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
alias_suffixes.insert(0, alias_suffix)
return alias_suffixes

View File

@ -1,11 +1,8 @@
import csv
from io import StringIO
import re import re
from typing import Optional, Tuple from typing import Optional, Tuple
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
from sqlalchemy.exc import IntegrityError, DataError from sqlalchemy.exc import IntegrityError, DataError
from flask import make_response
from app.config import ( from app.config import (
BOUNCE_PREFIX_FOR_REPLY_PHASE, BOUNCE_PREFIX_FOR_REPLY_PHASE,
@ -21,20 +18,11 @@ from app.email_utils import (
send_cannot_create_directory_alias_disabled, send_cannot_create_directory_alias_disabled,
get_email_local_part, get_email_local_part,
send_cannot_create_domain_alias, send_cannot_create_domain_alias,
send_email,
render,
) )
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import (
AliasDeleted,
AliasStatusChanged,
EventContent,
)
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
AliasDeleteReason,
CustomDomain, CustomDomain,
Directory, Directory,
User, User,
@ -45,8 +33,6 @@ from app.models import (
EmailLog, EmailLog,
Contact, Contact,
AutoCreateRule, AutoCreateRule,
AliasUsedOn,
ClientUser,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
@ -63,17 +49,11 @@ def get_user_if_alias_would_auto_create(
# Prevent addresses with unicode characters (🤯) in them for now. # Prevent addresses with unicode characters (🤯) in them for now.
validate_email(address, check_deliverability=False, allow_smtputf8=False) validate_email(address, check_deliverability=False, allow_smtputf8=False)
except EmailNotValidError: except EmailNotValidError:
LOG.i(f"Not creating alias for {address} because email is invalid")
return None return None
domain_and_rule = check_if_alias_can_be_auto_created_for_custom_domain( domain_and_rule = check_if_alias_can_be_auto_created_for_custom_domain(
address, notify_user=notify_user address, notify_user=notify_user
) )
if DomainDeletedAlias.get_by(email=address):
LOG.i(
f"Not creating alias for {address} because it was previously deleted for this domain"
)
return None
if domain_and_rule: if domain_and_rule:
return domain_and_rule[0].user return domain_and_rule[0].user
directory = check_if_alias_can_be_auto_created_for_a_directory( directory = check_if_alias_can_be_auto_created_for_a_directory(
@ -97,9 +77,6 @@ def check_if_alias_can_be_auto_created_for_custom_domain(
custom_domain: CustomDomain = CustomDomain.get_by(domain=alias_domain) custom_domain: CustomDomain = CustomDomain.get_by(domain=alias_domain)
if not custom_domain: if not custom_domain:
LOG.i(
f"Cannot auto-create custom domain alias for {address} because there's no custom domain for {alias_domain}"
)
return None return None
user: User = custom_domain.user user: User = custom_domain.user
@ -108,16 +85,12 @@ def check_if_alias_can_be_auto_created_for_custom_domain(
return None return None
if not user.can_create_new_alias(): if not user.can_create_new_alias():
LOG.d(f"{user} can't create new custom-domain alias {address}")
if notify_user: if notify_user:
send_cannot_create_domain_alias(custom_domain.user, address, alias_domain) send_cannot_create_domain_alias(custom_domain.user, address, alias_domain)
return None return None
if not custom_domain.catch_all: if not custom_domain.catch_all:
if len(custom_domain.auto_create_rules) == 0: if len(custom_domain.auto_create_rules) == 0:
LOG.i(
f"Cannot create alias {address} for domain {custom_domain} because it has no catch-all and no rules"
)
return None return None
local = get_email_local_part(address) local = get_email_local_part(address)
@ -131,7 +104,7 @@ def check_if_alias_can_be_auto_created_for_custom_domain(
) )
return custom_domain, rule return custom_domain, rule
else: # no rule passes else: # no rule passes
LOG.d(f"No rule matches auto-create {address} for domain {custom_domain}") LOG.d("no rule passed to create %s", local)
return None return None
LOG.d("Create alias via catchall") LOG.d("Create alias via catchall")
@ -158,7 +131,6 @@ def check_if_alias_can_be_auto_created_for_a_directory(
sep = "#" sep = "#"
else: else:
# if there's no directory separator in the alias, no way to auto-create it # if there's no directory separator in the alias, no way to auto-create it
LOG.info(f"Cannot auto-create {address} since it has no directory separator")
return None return None
directory_name = address[: address.find(sep)] directory_name = address[: address.find(sep)]
@ -166,9 +138,6 @@ def check_if_alias_can_be_auto_created_for_a_directory(
directory = Directory.get_by(name=directory_name) directory = Directory.get_by(name=directory_name)
if not directory: if not directory:
LOG.info(
f"Cannot auto-create {address} because there is no directory for {directory_name}"
)
return None return None
user: User = directory.user user: User = directory.user
@ -177,17 +146,11 @@ def check_if_alias_can_be_auto_created_for_a_directory(
return None return None
if not user.can_create_new_alias(): if not user.can_create_new_alias():
LOG.d(
f"{user} can't create new directory alias {address} because user cannot create aliases"
)
if notify_user: if notify_user:
send_cannot_create_directory_alias(user, address, directory_name) send_cannot_create_directory_alias(user, address, directory_name)
return None return None
if directory.disabled: if directory.disabled:
LOG.d(
f"{user} can't create new directory alias {address} bcause directory is disabled"
)
if notify_user: if notify_user:
send_cannot_create_directory_alias_disabled(user, address, directory_name) send_cannot_create_directory_alias_disabled(user, address, directory_name)
return None return None
@ -329,52 +292,36 @@ def try_auto_create_via_domain(address: str) -> Optional[Alias]:
return None return None
def delete_alias( def delete_alias(alias: Alias, user: User):
alias: Alias,
user: User,
reason: AliasDeleteReason = AliasDeleteReason.Unspecified,
commit: bool = False,
):
""" """
Delete an alias and add it to either global or domain trash Delete an alias and add it to either global or domain trash
Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create Should be used instead of Alias.delete, DomainDeletedAlias.create, DeletedAlias.create
""" """
LOG.i(f"User {user} has deleted alias {alias}") # save deleted alias to either global or domain trash
# save deleted alias to either global or domain tra
if alias.custom_domain_id: if alias.custom_domain_id:
if not DomainDeletedAlias.get_by( if not DomainDeletedAlias.get_by(
email=alias.email, domain_id=alias.custom_domain_id email=alias.email, domain_id=alias.custom_domain_id
): ):
domain_deleted_alias = DomainDeletedAlias( LOG.d("add %s to domain %s trash", alias, alias.custom_domain_id)
user_id=user.id, Session.add(
email=alias.email, DomainDeletedAlias(
domain_id=alias.custom_domain_id, user_id=user.id,
reason=reason, email=alias.email,
domain_id=alias.custom_domain_id,
)
) )
Session.add(domain_deleted_alias)
Session.commit() Session.commit()
LOG.i(
f"Moving {alias} to domain {alias.custom_domain_id} trash {domain_deleted_alias}"
)
else: else:
if not DeletedAlias.get_by(email=alias.email): if not DeletedAlias.get_by(email=alias.email):
deleted_alias = DeletedAlias(email=alias.email, reason=reason) LOG.d("add %s to global trash", alias)
Session.add(deleted_alias) Session.add(DeletedAlias(email=alias.email))
Session.commit() Session.commit()
LOG.i(f"Moving {alias} to global trash {deleted_alias}")
alias_id = alias.id LOG.i("delete alias %s", alias)
alias_email = alias.email
Alias.filter(Alias.id == alias.id).delete() Alias.filter(Alias.id == alias.id).delete()
Session.commit() Session.commit()
EventDispatcher.send_event(
user,
EventContent(alias_deleted=AliasDeleted(id=alias_id, email=alias_email)),
)
if commit:
Session.commit()
def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]: def aliases_for_mailbox(mailbox: Mailbox) -> [Alias]:
""" """
@ -415,106 +362,3 @@ def check_alias_prefix(alias_prefix) -> bool:
return False return False
return True return True
def alias_export_csv(user, csv_direct_export=False):
"""
Get user aliases as importable CSV file
Output:
Importable CSV file
"""
data = [["alias", "note", "enabled", "mailboxes"]]
for alias in Alias.filter_by(user_id=user.id).all(): # type: Alias
# Always put the main mailbox first
# It is seen a primary while importing
alias_mailboxes = alias.mailboxes
alias_mailboxes.insert(
0, alias_mailboxes.pop(alias_mailboxes.index(alias.mailbox))
)
mailboxes = " ".join([mailbox.email for mailbox in alias_mailboxes])
data.append([alias.email, alias.note, alias.enabled, mailboxes])
si = StringIO()
cw = csv.writer(si)
cw.writerows(data)
if csv_direct_export:
return si.getvalue()
output = make_response(si.getvalue())
output.headers["Content-Disposition"] = "attachment; filename=aliases.csv"
output.headers["Content-type"] = "text/csv"
return output
def transfer_alias(alias, new_user, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter")
# update user_id
Session.query(Contact).filter(Contact.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update(
{"user_id": new_user.id}
)
# remove existing mailboxes from the alias
Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete()
# set mailboxes
alias.mailbox_id = new_mailboxes.pop().id
for mb in new_mailboxes:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id)
# alias has never been transferred before
if not alias.original_owner_id:
alias.original_owner_id = alias.user_id
# inform previous owner
old_user = alias.user
send_email(
old_user.email,
f"Alias {alias.email} has been received",
render(
"transactional/alias-transferred.txt",
user=old_user,
alias=alias,
),
render(
"transactional/alias-transferred.html",
user=old_user,
alias=alias,
),
)
# now the alias belongs to the new user
alias.user_id = new_user.id
# set some fields back to default
alias.disable_pgp = False
alias.pinned = False
Session.commit()
def change_alias_status(alias: Alias, enabled: bool, commit: bool = False):
LOG.i(f"Changing alias {alias} enabled to {enabled}")
alias.enabled = enabled
event = AliasStatusChanged(
id=alias.id,
email=alias.email,
enabled=enabled,
created_at=int(alias.created_at.timestamp),
)
EventDispatcher.send_event(alias.user, EventContent(alias_status_change=event))
if commit:
Session.commit()

View File

@ -14,24 +14,4 @@ from .views import (
export, export,
phone, phone,
sudo, sudo,
user,
) )
__all__ = [
"alias_options",
"new_custom_alias",
"custom_domain",
"new_random_alias",
"user_info",
"auth",
"auth_mfa",
"alias",
"apple",
"mailbox",
"notification",
"setting",
"export",
"phone",
"sudo",
"user",
]

View File

@ -19,9 +19,6 @@ def authorize_request() -> Optional[Tuple[str, int]]:
if not api_key: if not api_key:
if current_user.is_authenticated: if current_user.is_authenticated:
# if current_user.is_authenticated and request.headers.get(
# constants.HEADER_ALLOW_API_COOKIES
# ):
g.user = current_user g.user = current_user
else: else:
return jsonify(error="Wrong api key"), 401 return jsonify(error="Wrong api key"), 401
@ -36,9 +33,6 @@ def authorize_request() -> Optional[Tuple[str, int]]:
if g.user.disabled: if g.user.disabled:
return jsonify(error="Disabled account"), 403 return jsonify(error="Disabled account"), 403
if not g.user.is_active():
return jsonify(error="Account does not exist"), 401
g.api_key = api_key g.api_key = api_key
return None return None

View File

@ -201,10 +201,10 @@ def get_alias_infos_with_pagination_v3(
q = q.order_by(Alias.pinned.desc()) q = q.order_by(Alias.pinned.desc())
q = q.order_by(latest_activity.desc()) q = q.order_by(latest_activity.desc())
q = q.limit(page_limit).offset(page_id * page_size) q = list(q.limit(page_limit).offset(page_id * page_size))
ret = [] ret = []
for alias, contact, email_log, nb_reply, nb_blocked, nb_forward in list(q): for alias, contact, email_log, nb_reply, nb_blocked, nb_forward in q:
ret.append( ret.append(
AliasInfo( AliasInfo(
alias=alias, alias=alias,
@ -358,6 +358,7 @@ def construct_alias_query(user: User):
else_=0, else_=0,
) )
).label("nb_forward"), ).label("nb_forward"),
func.max(EmailLog.created_at).label("latest_email_log_created_at"),
) )
.join(EmailLog, Alias.id == EmailLog.alias_id, isouter=True) .join(EmailLog, Alias.id == EmailLog.alias_id, isouter=True)
.filter(Alias.user_id == user.id) .filter(Alias.user_id == user.id)
@ -365,6 +366,14 @@ def construct_alias_query(user: User):
.subquery() .subquery()
) )
alias_contact_subquery = (
Session.query(Alias.id, func.max(Contact.id).label("max_contact_id"))
.join(Contact, Alias.id == Contact.alias_id, isouter=True)
.filter(Alias.user_id == user.id)
.group_by(Alias.id)
.subquery()
)
return ( return (
Session.query( Session.query(
Alias, Alias,
@ -376,7 +385,23 @@ def construct_alias_query(user: User):
) )
.options(joinedload(Alias.hibp_breaches)) .options(joinedload(Alias.hibp_breaches))
.options(joinedload(Alias.custom_domain)) .options(joinedload(Alias.custom_domain))
.join(EmailLog, Alias.last_email_log_id == EmailLog.id, isouter=True) .join(Contact, Alias.id == Contact.alias_id, isouter=True)
.join(Contact, EmailLog.contact_id == Contact.id, isouter=True) .join(EmailLog, Contact.id == EmailLog.contact_id, isouter=True)
.filter(Alias.id == alias_activity_subquery.c.id) .filter(Alias.id == alias_activity_subquery.c.id)
.filter(Alias.id == alias_contact_subquery.c.id)
.filter(
or_(
EmailLog.created_at
== alias_activity_subquery.c.latest_email_log_created_at,
and_(
# no email log yet for this alias
alias_activity_subquery.c.latest_email_log_created_at.is_(None),
# to make sure only 1 contact is returned in this case
or_(
Contact.id == alias_contact_subquery.c.max_contact_id,
alias_contact_subquery.c.max_contact_id.is_(None),
),
),
)
)
) )

View File

@ -24,15 +24,12 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
ErrAddressInvalid, ErrAddressInvalid,
) )
from app.extensions import limiter from app.models import Alias, Contact, Mailbox, AliasMailbox
from app.log import LOG
from app.models import Alias, Contact, Mailbox, AliasMailbox, AliasDeleteReason
@deprecated @deprecated
@api_bp.route("/aliases", methods=["GET", "POST"]) @api_bp.route("/aliases", methods=["GET", "POST"])
@require_api_auth @require_api_auth
@limiter.limit("10/minute", key_func=lambda: g.user.id)
def get_aliases(): def get_aliases():
""" """
Get aliases Get aliases
@ -75,7 +72,6 @@ def get_aliases():
@api_bp.route("/v2/aliases", methods=["GET", "POST"]) @api_bp.route("/v2/aliases", methods=["GET", "POST"])
@require_api_auth @require_api_auth
@limiter.limit("50/minute", key_func=lambda: g.user.id)
def get_aliases_v2(): def get_aliases_v2():
""" """
Get aliases Get aliases
@ -161,7 +157,7 @@ def delete_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias_utils.delete_alias(alias, user, AliasDeleteReason.ManualAction) alias_utils.delete_alias(alias, user)
return jsonify(deleted=True), 200 return jsonify(deleted=True), 200
@ -185,8 +181,7 @@ def toggle_alias(alias_id):
if not alias or alias.user_id != user.id: if not alias or alias.user_id != user.id:
return jsonify(error="Forbidden"), 403 return jsonify(error="Forbidden"), 403
alias_utils.change_alias_status(alias, enabled=not alias.enabled) alias.enabled = not alias.enabled
LOG.i(f"User {user} changed alias {alias} enabled status to {alias.enabled}")
Session.commit() Session.commit()
return jsonify(enabled=alias.enabled), 200 return jsonify(enabled=alias.enabled), 200
@ -424,7 +419,7 @@ def create_contact_route(alias_id):
contact_address = data.get("contact") contact_address = data.get("contact")
try: try:
contact = create_contact(alias, contact_address) contact = create_contact(g.user, alias, contact_address)
except ErrContactErrorUpgradeNeeded as err: except ErrContactErrorUpgradeNeeded as err:
return jsonify(error=err.error_for_user()), 403 return jsonify(error=err.error_for_user()), 403
except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err: except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err:

View File

@ -2,8 +2,10 @@ import tldextract
from flask import jsonify, request, g from flask import jsonify, request, g
from sqlalchemy import desc from sqlalchemy import desc
from app.alias_suffix import get_alias_suffixes
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.dashboard.views.custom_alias import (
get_available_suffixes,
)
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import AliasUsedOn, Alias, User from app.models import AliasUsedOn, Alias, User
@ -66,7 +68,7 @@ def options_v4():
prefix_suggestion = convert_to_id(prefix_suggestion) prefix_suggestion = convert_to_id(prefix_suggestion)
ret["prefix_suggestion"] = prefix_suggestion ret["prefix_suggestion"] = prefix_suggestion
suffixes = get_alias_suffixes(user) suffixes = get_available_suffixes(user)
# custom domain should be put first # custom domain should be put first
ret["suffixes"] = list([suffix.suffix, suffix.signed_suffix] for suffix in suffixes) ret["suffixes"] = list([suffix.suffix, suffix.signed_suffix] for suffix in suffixes)
@ -137,7 +139,7 @@ def options_v5():
prefix_suggestion = convert_to_id(prefix_suggestion) prefix_suggestion = convert_to_id(prefix_suggestion)
ret["prefix_suggestion"] = prefix_suggestion ret["prefix_suggestion"] = prefix_suggestion
suffixes = get_alias_suffixes(user) suffixes = get_available_suffixes(user)
# custom domain should be put first # custom domain should be put first
ret["suffixes"] = [ ret["suffixes"] = [

View File

@ -9,7 +9,6 @@ from requests import RequestException
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET from app.config import APPLE_API_SECRET, MACAPP_APPLE_API_SECRET
from app.subscription_webhook import execute_subscription_webhook
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import PlanEnum, AppleSubscription from app.models import PlanEnum, AppleSubscription
@ -17,14 +16,9 @@ from app.models import PlanEnum, AppleSubscription
_MONTHLY_PRODUCT_ID = "io.simplelogin.ios_app.subscription.premium.monthly" _MONTHLY_PRODUCT_ID = "io.simplelogin.ios_app.subscription.premium.monthly"
_YEARLY_PRODUCT_ID = "io.simplelogin.ios_app.subscription.premium.yearly" _YEARLY_PRODUCT_ID = "io.simplelogin.ios_app.subscription.premium.yearly"
# SL Mac app used to be in SL account
_MACAPP_MONTHLY_PRODUCT_ID = "io.simplelogin.macapp.subscription.premium.monthly" _MACAPP_MONTHLY_PRODUCT_ID = "io.simplelogin.macapp.subscription.premium.monthly"
_MACAPP_YEARLY_PRODUCT_ID = "io.simplelogin.macapp.subscription.premium.yearly" _MACAPP_YEARLY_PRODUCT_ID = "io.simplelogin.macapp.subscription.premium.yearly"
# SL Mac app is moved to Proton account
_MACAPP_MONTHLY_PRODUCT_ID_NEW = "me.proton.simplelogin.macos.premium.monthly"
_MACAPP_YEARLY_PRODUCT_ID_NEW = "me.proton.simplelogin.macos.premium.yearly"
# Apple API URL # Apple API URL
_SANDBOX_URL = "https://sandbox.itunes.apple.com/verifyReceipt" _SANDBOX_URL = "https://sandbox.itunes.apple.com/verifyReceipt"
_PROD_URL = "https://buy.itunes.apple.com/verifyReceipt" _PROD_URL = "https://buy.itunes.apple.com/verifyReceipt"
@ -46,17 +40,15 @@ def apple_process_payment():
LOG.d("request for /apple/process_payment from %s", user) LOG.d("request for /apple/process_payment from %s", user)
data = request.get_json() data = request.get_json()
receipt_data = data.get("receipt_data") receipt_data = data.get("receipt_data")
is_macapp = "is_macapp" in data and data["is_macapp"] is True is_macapp = "is_macapp" in data
if is_macapp: if is_macapp:
LOG.d("Use Macapp secret")
password = MACAPP_APPLE_API_SECRET password = MACAPP_APPLE_API_SECRET
else: else:
password = APPLE_API_SECRET password = APPLE_API_SECRET
apple_sub = verify_receipt(receipt_data, user, password) apple_sub = verify_receipt(receipt_data, user, password)
if apple_sub: if apple_sub:
execute_subscription_webhook(user)
return jsonify(ok=True), 200 return jsonify(ok=True), 200
return jsonify(error="Processing failed"), 400 return jsonify(error="Processing failed"), 400
@ -268,11 +260,7 @@ def apple_update_notification():
plan = ( plan = (
PlanEnum.monthly PlanEnum.monthly
if transaction["product_id"] if transaction["product_id"]
in ( in (_MONTHLY_PRODUCT_ID, _MACAPP_MONTHLY_PRODUCT_ID)
_MONTHLY_PRODUCT_ID,
_MACAPP_MONTHLY_PRODUCT_ID,
_MACAPP_MONTHLY_PRODUCT_ID_NEW,
)
else PlanEnum.yearly else PlanEnum.yearly
) )
@ -293,7 +281,6 @@ def apple_update_notification():
apple_sub.plan = plan apple_sub.plan = plan
apple_sub.product_id = transaction["product_id"] apple_sub.product_id = transaction["product_id"]
Session.commit() Session.commit()
execute_subscription_webhook(user)
return jsonify(ok=True), 200 return jsonify(ok=True), 200
else: else:
LOG.w( LOG.w(
@ -487,7 +474,7 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
# } # }
if data["status"] != 0: if data["status"] != 0:
LOG.e( LOG.w(
"verifyReceipt status !=0, probably invalid receipt. User %s, data %s", "verifyReceipt status !=0, probably invalid receipt. User %s, data %s",
user, user,
data, data,
@ -526,11 +513,7 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
plan = ( plan = (
PlanEnum.monthly PlanEnum.monthly
if latest_transaction["product_id"] if latest_transaction["product_id"]
in ( in (_MONTHLY_PRODUCT_ID, _MACAPP_MONTHLY_PRODUCT_ID)
_MONTHLY_PRODUCT_ID,
_MACAPP_MONTHLY_PRODUCT_ID,
_MACAPP_MONTHLY_PRODUCT_ID_NEW,
)
else PlanEnum.yearly else PlanEnum.yearly
) )
@ -538,10 +521,9 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
if apple_sub: if apple_sub:
LOG.d( LOG.d(
"Update AppleSubscription for user %s, expired at %s (%s), plan %s", "Update AppleSubscription for user %s, expired at %s, plan %s",
user, user,
expires_date, expires_date,
expires_date.humanize(),
plan, plan,
) )
apple_sub.receipt_data = receipt_data apple_sub.receipt_data = receipt_data
@ -570,7 +552,6 @@ def verify_receipt(receipt_data, user, password) -> Optional[AppleSubscription]:
product_id=latest_transaction["product_id"], product_id=latest_transaction["product_id"],
) )
execute_subscription_webhook(user)
Session.commit() Session.commit()
return apple_sub return apple_sub

View File

@ -11,7 +11,7 @@ from itsdangerous import Signer
from app import email_utils from app import email_utils
from app.api.base import api_bp from app.api.base import api_bp
from app.config import FLASK_SECRET, DISABLE_REGISTRATION from app.config import FLASK_SECRET, DISABLE_REGISTRATION
from app.dashboard.views.account_setting import send_reset_password_email from app.dashboard.views.setting import send_reset_password_email
from app.db import Session from app.db import Session
from app.email_utils import ( from app.email_utils import (
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
@ -23,7 +23,7 @@ from app.events.auth_event import LoginEvent, RegisterEvent
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User, ApiKey, SocialAuth, AccountActivation from app.models import User, ApiKey, SocialAuth, AccountActivation
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
@api_bp.route("/auth/login", methods=["POST"]) @api_bp.route("/auth/login", methods=["POST"])
@ -49,13 +49,11 @@ def auth_login():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email"))
password = data.get("password") password = data.get("password")
device = data.get("device") device = data.get("device")
email = sanitize_email(data.get("email")) user = User.filter_by(email=email).first()
canonical_email = canonicalize_email(data.get("email"))
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user or not user.check_password(password): if not user or not user.check_password(password):
LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.failed, LoginEvent.Source.api).send()
@ -63,11 +61,6 @@ def auth_login():
elif user.disabled: elif user.disabled:
LoginEvent(LoginEvent.ActionType.disabled_login, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.disabled_login, LoginEvent.Source.api).send()
return jsonify(error="Account disabled"), 400 return jsonify(error="Account disabled"), 400
elif user.delete_on is not None:
LoginEvent(
LoginEvent.ActionType.scheduled_to_be_deleted, LoginEvent.Source.api
).send()
return jsonify(error="Account scheduled for deletion"), 400
elif not user.activated: elif not user.activated:
LoginEvent(LoginEvent.ActionType.not_activated, LoginEvent.Source.api).send() LoginEvent(LoginEvent.ActionType.not_activated, LoginEvent.Source.api).send()
return jsonify(error="Account not activated"), 422 return jsonify(error="Account not activated"), 422
@ -96,8 +89,7 @@ def auth_register():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
dirty_email = data.get("email") email = sanitize_email(data.get("email"))
email = canonicalize_email(dirty_email)
password = data.get("password") password = data.get("password")
if DISABLE_REGISTRATION: if DISABLE_REGISTRATION:
@ -118,7 +110,7 @@ def auth_register():
return jsonify(error="password too long"), 400 return jsonify(error="password too long"), 400
LOG.d("create user %s", email) LOG.d("create user %s", email)
user = User.create(email=email, name=dirty_email, password=password) user = User.create(email=email, name="", password=password)
Session.flush() Session.flush()
# create activation code # create activation code
@ -129,8 +121,8 @@ def auth_register():
send_email( send_email(
email, email,
"Just one more step to join SimpleLogin", "Just one more step to join SimpleLogin",
render("transactional/code-activation.txt.jinja2", user=user, code=code), render("transactional/code-activation.txt.jinja2", code=code),
render("transactional/code-activation.html", user=user, code=code), render("transactional/code-activation.html", code=code),
) )
RegisterEvent(RegisterEvent.ActionType.success, RegisterEvent.Source.api).send() RegisterEvent(RegisterEvent.ActionType.success, RegisterEvent.Source.api).send()
@ -156,10 +148,9 @@ def auth_activate():
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email"))
code = data.get("code") code = data.get("code")
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email)
# do not use a different message to avoid exposing existing email # do not use a different message to avoid exposing existing email
if not user or user.activated: if not user or user.activated:
@ -205,9 +196,7 @@ def auth_reactivate():
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email")) user = User.get_by(email=email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
# do not use a different message to avoid exposing existing email # do not use a different message to avoid exposing existing email
if not user or user.activated: if not user or user.activated:
@ -226,8 +215,8 @@ def auth_reactivate():
send_email( send_email(
email, email,
"Just one more step to join SimpleLogin", "Just one more step to join SimpleLogin",
render("transactional/code-activation.txt.jinja2", user=user, code=code), render("transactional/code-activation.txt.jinja2", code=code),
render("transactional/code-activation.html", user=user, code=code), render("transactional/code-activation.html", code=code),
) )
return jsonify(msg="User needs to confirm their account"), 200 return jsonify(msg="User needs to confirm their account"), 200
@ -362,7 +351,7 @@ def auth_payload(user, device) -> dict:
@api_bp.route("/auth/forgot_password", methods=["POST"]) @api_bp.route("/auth/forgot_password", methods=["POST"])
@limiter.limit("2/minute") @limiter.limit("10/minute")
def forgot_password(): def forgot_password():
""" """
User forgot password User forgot password
@ -378,9 +367,8 @@ def forgot_password():
return jsonify(error="request body must contain email"), 400 return jsonify(error="request body must contain email"), 400
email = sanitize_email(data.get("email")) email = sanitize_email(data.get("email"))
canonical_email = canonicalize_email(data.get("email"))
user = User.get_by(email=email) or User.get_by(email=canonical_email) user = User.get_by(email=email)
if user: if user:
send_reset_password_email(user) send_reset_password_email(user)

View File

@ -55,7 +55,7 @@ def auth_mfa():
) )
totp = pyotp.TOTP(user.otp_secret) totp = pyotp.TOTP(user.otp_secret)
if not totp.verify(mfa_token, valid_window=2): if not totp.verify(mfa_token):
send_invalid_totp_login_email(user, "TOTP") send_invalid_totp_login_email(user, "TOTP")
return jsonify(error="Wrong TOTP Token"), 400 return jsonify(error="Wrong TOTP Token"), 400

View File

@ -1,9 +1,12 @@
import csv
from io import StringIO
from flask import g from flask import g
from flask import jsonify from flask import jsonify
from flask import make_response
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.models import Alias, Client, CustomDomain from app.models import Alias, Client, CustomDomain
from app.alias_utils import alias_export_csv
@api_bp.route("/export/data", methods=["GET"]) @api_bp.route("/export/data", methods=["GET"])
@ -46,4 +49,24 @@ def export_aliases():
Importable CSV file Importable CSV file
""" """
return alias_export_csv(g.user) user = g.user
data = [["alias", "note", "enabled", "mailboxes"]]
for alias in Alias.filter_by(user_id=user.id).all(): # type: Alias
# Always put the main mailbox first
# It is seen a primary while importing
alias_mailboxes = alias.mailboxes
alias_mailboxes.insert(
0, alias_mailboxes.pop(alias_mailboxes.index(alias.mailbox))
)
mailboxes = " ".join([mailbox.email for mailbox in alias_mailboxes])
data.append([alias.email, alias.note, alias.enabled, mailboxes])
si = StringIO()
cw = csv.writer(si)
cw.writerows(data)
output = make_response(si.getvalue())
output.headers["Content-Disposition"] = "attachment; filename=aliases.csv"
output.headers["Content-type"] = "text/csv"
return output

View File

@ -1,18 +1,22 @@
from smtplib import SMTPRecipientsRefused from smtplib import SMTPRecipientsRefused
import arrow
from flask import g from flask import g
from flask import jsonify from flask import jsonify
from flask import request from flask import request
from app import mailbox_utils
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import JOB_DELETE_MAILBOX
from app.dashboard.views.mailbox import send_verification_email
from app.dashboard.views.mailbox_detail import verify_mailbox_change from app.dashboard.views.mailbox_detail import verify_mailbox_change
from app.db import Session from app.db import Session
from app.email_utils import ( from app.email_utils import (
mailbox_already_used, mailbox_already_used,
email_can_be_used_as_mailbox, email_can_be_used_as_mailbox,
is_valid_email,
) )
from app.models import Mailbox from app.log import LOG
from app.models import Mailbox, Job
from app.utils import sanitize_email from app.utils import sanitize_email
@ -40,48 +44,66 @@ def create_mailbox():
user = g.user user = g.user
mailbox_email = sanitize_email(request.get_json().get("email")) mailbox_email = sanitize_email(request.get_json().get("email"))
try: if not user.is_premium():
new_mailbox = mailbox_utils.create_mailbox(user, mailbox_email).mailbox return jsonify(error=f"Only premium plan can add additional mailbox"), 400
except mailbox_utils.MailboxError as e:
return jsonify(error=e.msg), 400
return ( if not is_valid_email(mailbox_email):
jsonify(mailbox_to_dict(new_mailbox)), return jsonify(error=f"{mailbox_email} invalid"), 400
201, elif mailbox_already_used(mailbox_email, user):
) return jsonify(error=f"{mailbox_email} already used"), 400
elif not email_can_be_used_as_mailbox(mailbox_email):
return (
jsonify(
error=f"{mailbox_email} cannot be used. Please note a mailbox cannot "
f"be a disposable email address"
),
400,
)
else:
new_mailbox = Mailbox.create(email=mailbox_email, user_id=user.id)
Session.commit()
send_verification_email(user, new_mailbox)
return (
jsonify(mailbox_to_dict(new_mailbox)),
201,
)
@api_bp.route("/mailboxes/<int:mailbox_id>", methods=["DELETE"]) @api_bp.route("/mailboxes/<mailbox_id>", methods=["DELETE"])
@require_api_auth @require_api_auth
def delete_mailbox(mailbox_id): def delete_mailbox(mailbox_id):
""" """
Delete mailbox Delete mailbox
Input: Input:
mailbox_id: in url mailbox_id: in url
(optional) transfer_aliases_to: in body. Id of the new mailbox for the aliases.
If omitted or the value is set to -1,
the aliases of the mailbox will be deleted too.
Output: Output:
200 if deleted successfully 200 if deleted successfully
""" """
user = g.user user = g.user
data = request.get_json() or {} mailbox = Mailbox.get(mailbox_id)
transfer_mailbox_id = data.get("transfer_aliases_to")
if transfer_mailbox_id and int(transfer_mailbox_id) >= 0:
transfer_mailbox_id = int(transfer_mailbox_id)
else:
transfer_mailbox_id = None
try: if not mailbox or mailbox.user_id != user.id:
mailbox_utils.delete_mailbox(user, mailbox_id, transfer_mailbox_id) return jsonify(error="Forbidden"), 403
except mailbox_utils.MailboxError as e:
return jsonify(error=e.msg), 400 if mailbox.id == user.default_mailbox_id:
return jsonify(error="You cannot delete the default mailbox"), 400
# Schedule delete account job
LOG.w("schedule delete mailbox job for %s", mailbox)
Job.create(
name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": mailbox.id},
run_at=arrow.now(),
commit=True,
)
return jsonify(deleted=True), 200 return jsonify(deleted=True), 200
@api_bp.route("/mailboxes/<int:mailbox_id>", methods=["PUT"]) @api_bp.route("/mailboxes/<mailbox_id>", methods=["PUT"])
@require_api_auth @require_api_auth
def update_mailbox(mailbox_id): def update_mailbox(mailbox_id):
""" """

View File

@ -1,8 +1,7 @@
from flask import g from flask import g
from flask import jsonify, request from flask import jsonify, request
from itsdangerous import SignatureExpired
from app import parallel_limiter
from app.alias_suffix import check_suffix_signature, verify_prefix_suffix
from app.alias_utils import check_alias_prefix from app.alias_utils import check_alias_prefix
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.api.serializer import ( from app.api.serializer import (
@ -10,6 +9,7 @@ from app.api.serializer import (
get_alias_info_v2, get_alias_info_v2,
) )
from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT
from app.dashboard.views.custom_alias import verify_prefix_suffix, signer
from app.db import Session from app.db import Session
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
@ -28,7 +28,6 @@ from app.utils import convert_to_id
@api_bp.route("/v2/alias/custom/new", methods=["POST"]) @api_bp.route("/v2/alias/custom/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_custom_alias_v2(): def new_custom_alias_v2():
""" """
Create a new custom alias Create a new custom alias
@ -66,11 +65,12 @@ def new_custom_alias_v2():
note = data.get("note") note = data.get("note")
alias_prefix = convert_to_id(alias_prefix) alias_prefix = convert_to_id(alias_prefix)
# hypothesis: user will click on the button in the 600 secs
try: try:
alias_suffix = check_suffix_signature(signed_suffix) alias_suffix = signer.unsign(signed_suffix, max_age=600).decode()
if not alias_suffix: except SignatureExpired:
LOG.w("Alias creation time expired for %s", user) LOG.w("Alias creation time expired for %s", user)
return jsonify(error="Alias creation time is expired, please retry"), 412 return jsonify(error="Alias creation time is expired, please retry"), 412
except Exception: except Exception:
LOG.w("Alias suffix is tampered, user %s", user) LOG.w("Alias suffix is tampered, user %s", user)
return jsonify(error="Tampered suffix"), 400 return jsonify(error="Tampered suffix"), 400
@ -115,7 +115,6 @@ def new_custom_alias_v2():
@api_bp.route("/v3/alias/custom/new", methods=["POST"]) @api_bp.route("/v3/alias/custom/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_custom_alias_v3(): def new_custom_alias_v3():
""" """
Create a new custom alias Create a new custom alias
@ -150,7 +149,7 @@ def new_custom_alias_v3():
if not data: if not data:
return jsonify(error="request body cannot be empty"), 400 return jsonify(error="request body cannot be empty"), 400
if not isinstance(data, dict): if type(data) is not dict:
return jsonify(error="request body does not follow the required format"), 400 return jsonify(error="request body does not follow the required format"), 400
alias_prefix = data.get("alias_prefix", "").strip().lower().replace(" ", "") alias_prefix = data.get("alias_prefix", "").strip().lower().replace(" ", "")
@ -168,7 +167,7 @@ def new_custom_alias_v3():
return jsonify(error="alias prefix invalid format or too long"), 400 return jsonify(error="alias prefix invalid format or too long"), 400
# check if mailbox is not tempered with # check if mailbox is not tempered with
if not isinstance(mailbox_ids, list): if type(mailbox_ids) is not list:
return jsonify(error="mailbox_ids must be an array of id"), 400 return jsonify(error="mailbox_ids must be an array of id"), 400
mailboxes = [] mailboxes = []
for mailbox_id in mailbox_ids: for mailbox_id in mailbox_ids:
@ -182,10 +181,10 @@ def new_custom_alias_v3():
# hypothesis: user will click on the button in the 600 secs # hypothesis: user will click on the button in the 600 secs
try: try:
alias_suffix = check_suffix_signature(signed_suffix) alias_suffix = signer.unsign(signed_suffix, max_age=600).decode()
if not alias_suffix: except SignatureExpired:
LOG.w("Alias creation time expired for %s", user) LOG.w("Alias creation time expired for %s", user)
return jsonify(error="Alias creation time is expired, please retry"), 412 return jsonify(error="Alias creation time is expired, please retry"), 412
except Exception: except Exception:
LOG.w("Alias suffix is tampered, user %s", user) LOG.w("Alias suffix is tampered, user %s", user)
return jsonify(error="Tampered suffix"), 400 return jsonify(error="Tampered suffix"), 400

View File

@ -2,14 +2,13 @@ import tldextract
from flask import g from flask import g
from flask import jsonify, request from flask import jsonify, request
from app import parallel_limiter
from app.alias_suffix import get_alias_suffixes
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.api.serializer import ( from app.api.serializer import (
get_alias_info_v2, get_alias_info_v2,
serialize_alias_info_v2, serialize_alias_info_v2,
) )
from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT from app.config import MAX_NB_EMAIL_FREE_PLAN, ALIAS_LIMIT
from app.dashboard.views.custom_alias import get_available_suffixes
from app.db import Session from app.db import Session
from app.errors import AliasInTrashError from app.errors import AliasInTrashError
from app.extensions import limiter from app.extensions import limiter
@ -21,7 +20,6 @@ from app.utils import convert_to_id
@api_bp.route("/alias/random/new", methods=["POST"]) @api_bp.route("/alias/random/new", methods=["POST"])
@limiter.limit(ALIAS_LIMIT) @limiter.limit(ALIAS_LIMIT)
@require_api_auth @require_api_auth
@parallel_limiter.lock(name="alias_creation")
def new_random_alias(): def new_random_alias():
""" """
Create a new random alias Create a new random alias
@ -59,7 +57,7 @@ def new_random_alias():
prefix_suggestion = ext.domain prefix_suggestion = ext.domain
prefix_suggestion = convert_to_id(prefix_suggestion) prefix_suggestion = convert_to_id(prefix_suggestion)
suffixes = get_alias_suffixes(user) suffixes = get_available_suffixes(user)
# use the first suffix # use the first suffix
suggested_alias = prefix_suggestion + suffixes[0].suffix suggested_alias = prefix_suggestion + suffixes[0].suffix
@ -107,9 +105,8 @@ def new_random_alias():
Session.commit() Session.commit()
if hostname and not AliasUsedOn.get_by(alias_id=alias.id, hostname=hostname): if hostname and not AliasUsedOn.get_by(alias_id=alias.id, hostname=hostname):
AliasUsedOn.create( AliasUsedOn.create(alias_id=alias.id, hostname=hostname, user_id=alias.user_id)
alias_id=alias.id, hostname=hostname, user_id=alias.user_id, commit=True Session.commit()
)
return ( return (
jsonify(alias=alias.email, **serialize_alias_info_v2(get_alias_info_v2(alias))), jsonify(alias=alias.email, **serialize_alias_info_v2(get_alias_info_v2(alias))),

View File

@ -60,7 +60,7 @@ def get_notifications():
) )
@api_bp.route("/notifications/<int:notification_id>/read", methods=["POST"]) @api_bp.route("/notifications/<notification_id>/read", methods=["POST"])
@require_api_auth @require_api_auth
def mark_as_read(notification_id): def mark_as_read(notification_id):
""" """

View File

@ -9,7 +9,7 @@ from app.models import (
) )
@api_bp.route("/phone/reservations/<int:reservation_id>", methods=["GET", "POST"]) @api_bp.route("/phone/reservations/<reservation_id>", methods=["GET", "POST"])
@require_api_auth @require_api_auth
def phone_messages(reservation_id): def phone_messages(reservation_id):
""" """

View File

@ -12,7 +12,6 @@ from app.models import (
SenderFormatEnum, SenderFormatEnum,
AliasSuffixEnum, AliasSuffixEnum,
) )
from app.proton.utils import perform_proton_account_unlink
def setting_to_dict(user: User): def setting_to_dict(user: User):
@ -138,11 +137,3 @@ def get_available_domains_for_random_alias_v2():
] ]
return jsonify(ret) return jsonify(ret)
@api_bp.route("/setting/unlink_proton_account", methods=["DELETE"])
@require_api_auth
def unlink_proton_account():
user = g.user
perform_proton_account_unlink(user)
return jsonify({"ok": True})

View File

@ -1,46 +0,0 @@
from flask import jsonify, g
from sqlalchemy_utils.types.arrow import arrow
from app.api.base import api_bp, require_api_sudo, require_api_auth
from app import config
from app.extensions import limiter
from app.log import LOG
from app.models import Job, ApiToCookieToken
@api_bp.route("/user", methods=["DELETE"])
@require_api_sudo
def delete_user():
"""
Delete the user. Requires sudo mode.
"""
# Schedule delete account job
LOG.w("schedule delete account job for %s", g.user)
Job.create(
name=config.JOB_DELETE_ACCOUNT,
payload={"user_id": g.user.id},
run_at=arrow.now(),
commit=True,
)
return jsonify(ok=True)
@api_bp.route("/user/cookie_token", methods=["GET"])
@require_api_auth
@limiter.limit("5/minute")
def get_api_session_token():
"""
Get a temporary token to exchange it for a cookie based session
Output:
200 and a temporary random token
{
token: "asdli3ldq39h9hd3",
}
"""
token = ApiToCookieToken.create(
user=g.user,
api_key_id=g.api_key.id,
commit=True,
)
return jsonify({"token": token.code})

View File

@ -1,44 +1,25 @@
import base64 import base64
import dataclasses
from io import BytesIO from io import BytesIO
from typing import Optional
from flask import jsonify, g, request, make_response from flask import jsonify, g, request, make_response
from flask_login import logout_user
from app import s3, config from app import s3
from app.api.base import api_bp, require_api_auth from app.api.base import api_bp, require_api_auth
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.dashboard.views.index import get_stats
from app.db import Session from app.db import Session
from app.image_validation import detect_image_format, ImageFormat from app.models import ApiKey, File, User
from app.models import ApiKey, File, PartnerUser, User
from app.proton.utils import get_proton_partner
from app.session import logout_session
from app.utils import random_string from app.utils import random_string
def get_connected_proton_address(user: User) -> Optional[str]:
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(user_id=user.id, partner_id=proton_partner.id)
if partner_user is None:
return None
return partner_user.partner_email
def user_to_dict(user: User) -> dict: def user_to_dict(user: User) -> dict:
ret = { ret = {
"name": user.name or "", "name": user.name or "",
"is_premium": user.is_premium(), "is_premium": user.is_premium(),
"email": user.email, "email": user.email,
"in_trial": user.in_trial(), "in_trial": user.in_trial(),
"max_alias_free_plan": user.max_alias_for_free_account(),
"connected_proton_address": None,
"can_create_reverse_alias": user.can_create_contacts(),
} }
if config.CONNECT_WITH_PROTON:
ret["connected_proton_address"] = get_connected_proton_address(user)
if user.profile_picture_id: if user.profile_picture_id:
ret["profile_picture_url"] = user.profile_picture.get_url() ret["profile_picture_url"] = user.profile_picture.get_url()
else: else:
@ -52,15 +33,6 @@ def user_to_dict(user: User) -> dict:
def user_info(): def user_info():
""" """
Return user info given the api-key Return user info given the api-key
Output as json
- name
- is_premium
- email
- in_trial
- max_alias_free
- is_connected_with_proton
- can_create_reverse_alias
""" """
user = g.user user = g.user
@ -74,23 +46,23 @@ def update_user_info():
Input Input
- profile_picture (optional): base64 of the profile picture. Set to null to remove the profile picture - profile_picture (optional): base64 of the profile picture. Set to null to remove the profile picture
- name (optional) - name (optional)
""" """
user = g.user user = g.user
data = request.get_json() or {} data = request.get_json() or {}
if "profile_picture" in data: if "profile_picture" in data:
if user.profile_picture_id: if data["profile_picture"] is None:
file = user.profile_picture if user.profile_picture_id:
user.profile_picture_id = None file = user.profile_picture
Session.flush() user.profile_picture_id = None
if file:
File.delete(file.id)
s3.delete(file.path)
Session.flush() Session.flush()
if file:
File.delete(file.id)
s3.delete(file.path)
Session.flush()
else: else:
raw_data = base64.decodebytes(data["profile_picture"].encode()) raw_data = base64.decodebytes(data["profile_picture"].encode())
if detect_image_format(raw_data) == ImageFormat.Unknown:
return jsonify(error="Unsupported image format"), 400
file_path = random_string(30) file_path = random_string(30)
file = File.create(user_id=user.id, path=file_path) file = File.create(user_id=user.id, path=file_path)
Session.flush() Session.flush()
@ -137,27 +109,8 @@ def logout():
Output: Output:
- 200 - 200
""" """
logout_session() logout_user()
response = make_response(jsonify(msg="User is logged out"), 200) response = make_response(jsonify(msg="User is logged out"), 200)
response.delete_cookie(SESSION_COOKIE_NAME) response.delete_cookie(SESSION_COOKIE_NAME)
return response return response
@api_bp.route("/stats")
@require_api_auth
def user_stats():
"""
Return stats
Output as json
- nb_alias
- nb_forward
- nb_reply
- nb_block
"""
user = g.user
stats = get_stats(user)
return jsonify(dataclasses.asdict(stats))

View File

@ -15,27 +15,4 @@ from .views import (
fido, fido,
social, social,
recovery, recovery,
api_to_cookie,
oidc,
) )
__all__ = [
"login",
"logout",
"register",
"activate",
"resend_activation",
"reset_password",
"forgot_password",
"github",
"google",
"facebook",
"proton",
"change_email",
"mfa",
"fido",
"social",
"recovery",
"api_to_cookie",
"oidc",
]

View File

@ -1,30 +0,0 @@
import arrow
from flask import redirect, url_for, request, flash
from flask_login import login_user
from app.auth.base import auth_bp
from app.models import ApiToCookieToken
from app.utils import sanitize_next_url
@auth_bp.route("/api_to_cookie", methods=["GET"])
def api_to_cookie():
code = request.args.get("token")
if not code:
flash("Missing token", "error")
return redirect(url_for("auth.login"))
token = ApiToCookieToken.get_by(code=code)
if not token or token.created_at < arrow.now().shift(minutes=-5):
flash("Missing token", "error")
return redirect(url_for("auth.login"))
user = token.user
ApiToCookieToken.delete(token.id, commit=True)
login_user(user)
next_url = sanitize_next_url(request.args.get("next"))
if next_url:
return redirect(next_url)
else:
return redirect(url_for("dashboard.index"))

View File

@ -3,13 +3,10 @@ from flask_login import login_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.db import Session from app.db import Session
from app.extensions import limiter from app.models import EmailChange
from app.log import LOG
from app.models import EmailChange, ResetPasswordCode
@auth_bp.route("/change_email", methods=["GET", "POST"]) @auth_bp.route("/change_email", methods=["GET", "POST"])
@limiter.limit("3/hour")
def change_email(): def change_email():
code = request.args.get("code") code = request.args.get("code")
@ -25,14 +22,11 @@ def change_email():
return render_template("auth/change_email.html") return render_template("auth/change_email.html")
user = email_change.user user = email_change.user
old_email = user.email
user.email = email_change.new_email user.email = email_change.new_email
EmailChange.delete(email_change.id) EmailChange.delete(email_change.id)
ResetPasswordCode.filter_by(user_id=user.id).delete()
Session.commit() Session.commit()
LOG.i(f"User {user} has changed their email from {old_email} to {user.email}")
flash("Your new email has been updated", "success") flash("Your new email has been updated", "success")
login_user(user) login_user(user)

View File

@ -1,6 +1,5 @@
import json import json
import secrets import secrets
from time import time
import webauthn import webauthn
from flask import ( from flask import (
@ -62,7 +61,7 @@ def fido():
browser = MfaBrowser.get_by(token=request.cookies.get("mfa")) browser = MfaBrowser.get_by(token=request.cookies.get("mfa"))
if browser and not browser.is_expired() and browser.user_id == user.id: if browser and not browser.is_expired() and browser.user_id == user.id:
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
return redirect(next_url or url_for("dashboard.index")) return redirect(next_url or url_for("dashboard.index"))
else: else:
@ -108,9 +107,8 @@ def fido():
Session.commit() Session.commit()
del session[MFA_USER_ID] del session[MFA_USER_ID]
session["sudo_time"] = int(time())
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
response = make_response(redirect(next_url or url_for("dashboard.index"))) response = make_response(redirect(next_url or url_for("dashboard.index")))

View File

@ -1,13 +1,13 @@
from flask import request, render_template, flash, g from flask import request, render_template, redirect, url_for, flash, g
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.dashboard.views.account_setting import send_reset_password_email from app.dashboard.views.setting import send_reset_password_email
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
class ForgotPasswordForm(FlaskForm): class ForgotPasswordForm(FlaskForm):
@ -16,7 +16,7 @@ class ForgotPasswordForm(FlaskForm):
@auth_bp.route("/forgot_password", methods=["GET", "POST"]) @auth_bp.route("/forgot_password", methods=["GET", "POST"])
@limiter.limit( @limiter.limit(
"10/hour", deduct_when=lambda r: hasattr(g, "deduct_limit") and g.deduct_limit "10/minute", deduct_when=lambda r: hasattr(g, "deduct_limit") and g.deduct_limit
) )
def forgot_password(): def forgot_password():
form = ForgotPasswordForm(request.form) form = ForgotPasswordForm(request.form)
@ -25,17 +25,16 @@ def forgot_password():
# Trigger rate limiter # Trigger rate limiter
g.deduct_limit = True g.deduct_limit = True
email = sanitize_email(form.email.data)
flash( flash(
"If your email is correct, you are going to receive an email to reset your password", "If your email is correct, you are going to receive an email to reset your password",
"success", "success",
) )
user = User.get_by(email=email)
email = sanitize_email(form.email.data)
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if user: if user:
LOG.d("Send forgot password email to %s", user) LOG.d("Send forgot password email to %s", user)
send_reset_password_email(user) send_reset_password_email(user)
return redirect(url_for("auth.forgot_password"))
return render_template("auth/forgot_password.html", form=form) return render_template("auth/forgot_password.html", form=form)

View File

@ -7,7 +7,7 @@ from app.config import URL, GOOGLE_CLIENT_ID, GOOGLE_CLIENT_SECRET
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import User, File, SocialAuth from app.models import User, File, SocialAuth
from app.utils import random_string, sanitize_email, sanitize_next_url from app.utils import random_string, sanitize_email
from .login_utils import after_login from .login_utils import after_login
_authorization_base_url = "https://accounts.google.com/o/oauth2/v2/auth" _authorization_base_url = "https://accounts.google.com/o/oauth2/v2/auth"
@ -29,7 +29,7 @@ def google_login():
# to avoid flask-login displaying the login error message # to avoid flask-login displaying the login error message
session.pop("_flashes", None) session.pop("_flashes", None)
next_url = sanitize_next_url(request.args.get("next")) next_url = request.args.get("next")
# Google does not allow to append param to redirect_url # Google does not allow to append param to redirect_url
# we need to pass the next url by session # we need to pass the next url by session

View File

@ -5,12 +5,12 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON, OIDC_CLIENT_ID
from app.events.auth_event import LoginEvent from app.events.auth_event import LoginEvent
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, sanitize_next_url, canonicalize_email from app.proton.utils import is_connect_with_proton_enabled
from app.utils import sanitize_email, sanitize_next_url
class LoginForm(FlaskForm): class LoginForm(FlaskForm):
@ -38,9 +38,7 @@ def login():
show_resend_activation = False show_resend_activation = False
if form.validate_on_submit(): if form.validate_on_submit():
email = sanitize_email(form.email.data) user = User.filter_by(email=sanitize_email(form.email.data)).first()
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user or not user.check_password(form.password.data): if not user or not user.check_password(form.password.data):
# Trigger rate limiter # Trigger rate limiter
@ -54,12 +52,6 @@ def login():
"error", "error",
) )
LoginEvent(LoginEvent.ActionType.disabled_login).send() LoginEvent(LoginEvent.ActionType.disabled_login).send()
elif user.delete_on is not None:
flash(
f"Your account is scheduled to be deleted on {user.delete_on}",
"error",
)
LoginEvent(LoginEvent.ActionType.scheduled_to_be_deleted).send()
elif not user.activated: elif not user.activated:
show_resend_activation = True show_resend_activation = True
flash( flash(
@ -76,7 +68,5 @@ def login():
form=form, form=form,
next_url=next_url, next_url=next_url,
show_resend_activation=show_resend_activation, show_resend_activation=show_resend_activation,
connect_with_proton=CONNECT_WITH_PROTON, connect_with_proton=is_connect_with_proton_enabled(),
connect_with_oidc=OIDC_CLIENT_ID is not None,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )

View File

@ -1,4 +1,3 @@
from time import time
from typing import Optional from typing import Optional
from flask import session, redirect, url_for, request from flask import session, redirect, url_for, request
@ -9,40 +8,37 @@ from app.log import LOG
from app.models import Referral from app.models import Referral
def after_login(user, next_url, login_from_proton: bool = False): def after_login(user, next_url):
""" """
Redirect to the correct page after login. Redirect to the correct page after login.
If the user is logged in with Proton, do not look at fido nor otp
If user enables MFA: redirect user to MFA page If user enables MFA: redirect user to MFA page
Otherwise redirect to dashboard page if no next_url Otherwise redirect to dashboard page if no next_url
""" """
if not login_from_proton: if user.fido_enabled():
if user.fido_enabled(): # Use the same session for FIDO so that we can easily
# Use the same session for FIDO so that we can easily # switch between these two 2FA option
# switch between these two 2FA option session[MFA_USER_ID] = user.id
session[MFA_USER_ID] = user.id if next_url:
if next_url: return redirect(url_for("auth.fido", next=next_url))
return redirect(url_for("auth.fido", next=next_url)) else:
else: return redirect(url_for("auth.fido"))
return redirect(url_for("auth.fido")) elif user.enable_otp:
elif user.enable_otp: session[MFA_USER_ID] = user.id
session[MFA_USER_ID] = user.id if next_url:
if next_url: return redirect(url_for("auth.mfa", next=next_url))
return redirect(url_for("auth.mfa", next=next_url)) else:
else: return redirect(url_for("auth.mfa"))
return redirect(url_for("auth.mfa"))
LOG.d("log user %s in", user)
login_user(user)
session["sudo_time"] = int(time())
# User comes to login page from another page
if next_url:
LOG.d("redirect user to %s", next_url)
return redirect(next_url)
else: else:
LOG.d("redirect user to dashboard") LOG.d("log user %s in", user)
return redirect(url_for("dashboard.index")) login_user(user)
# User comes to login page from another page
if next_url:
LOG.d("redirect user to %s", next_url)
return redirect(next_url)
else:
LOG.d("redirect user to dashboard")
return redirect(url_for("dashboard.index"))
# name of the cookie that stores the referral code # name of the cookie that stores the referral code

View File

@ -1,13 +1,13 @@
from flask import redirect, url_for, flash, make_response from flask import redirect, url_for, flash, make_response
from flask_login import logout_user
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import SESSION_COOKIE_NAME from app.config import SESSION_COOKIE_NAME
from app.session import logout_session
@auth_bp.route("/logout") @auth_bp.route("/logout")
def logout(): def logout():
logout_session() logout_user()
flash("You are logged out", "success") flash("You are logged out", "success")
response = make_response(redirect(url_for("auth.login"))) response = make_response(redirect(url_for("auth.login")))
response.delete_cookie(SESSION_COOKIE_NAME) response.delete_cookie(SESSION_COOKIE_NAME)

View File

@ -55,7 +55,7 @@ def mfa():
browser = MfaBrowser.get_by(token=request.cookies.get("mfa")) browser = MfaBrowser.get_by(token=request.cookies.get("mfa"))
if browser and not browser.is_expired() and browser.user_id == user.id: if browser and not browser.is_expired() and browser.user_id == user.id:
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
return redirect(next_url or url_for("dashboard.index")) return redirect(next_url or url_for("dashboard.index"))
else: else:
@ -67,13 +67,13 @@ def mfa():
token = otp_token_form.token.data.replace(" ", "") token = otp_token_form.token.data.replace(" ", "")
if totp.verify(token, valid_window=2) and user.last_otp != token: if totp.verify(token) and user.last_otp != token:
del session[MFA_USER_ID] del session[MFA_USER_ID]
user.last_otp = token user.last_otp = token
Session.commit() Session.commit()
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
# Redirect user to correct page # Redirect user to correct page
response = make_response(redirect(next_url or url_for("dashboard.index"))) response = make_response(redirect(next_url or url_for("dashboard.index")))

View File

@ -1,135 +0,0 @@
from flask import request, session, redirect, flash, url_for
from requests_oauthlib import OAuth2Session
import requests
from app import config
from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login
from app.config import (
URL,
OIDC_SCOPES,
OIDC_NAME_FIELD,
)
from app.db import Session
from app.email_utils import send_welcome_email
from app.log import LOG
from app.models import User, SocialAuth
from app.utils import sanitize_email, sanitize_next_url
# need to set explicitly redirect_uri instead of leaving the lib to pre-fill redirect_uri
# when served behind nginx, the redirect_uri is localhost... and not the real url
redirect_uri = URL + "/auth/oidc/callback"
SESSION_STATE_KEY = "oauth_state"
SESSION_NEXT_KEY = "oauth_redirect_next"
@auth_bp.route("/oidc/login")
def oidc_login():
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next"))
auth_url = requests.get(config.OIDC_WELL_KNOWN_URL).json()["authorization_endpoint"]
oidc = OAuth2Session(
config.OIDC_CLIENT_ID, scope=[OIDC_SCOPES], redirect_uri=redirect_uri
)
authorization_url, state = oidc.authorization_url(auth_url)
# State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state
session[SESSION_NEXT_KEY] = next_url
return redirect(authorization_url)
@auth_bp.route("/oidc/callback")
def oidc_callback():
if SESSION_STATE_KEY not in session:
flash("Invalid state, please retry", "error")
return redirect(url_for("auth.login"))
if config.OIDC_CLIENT_ID is None or config.OIDC_CLIENT_SECRET is None:
return redirect(url_for("auth.login"))
# user clicks on cancel
if "error" in request.args:
flash("Please use another sign in method then", "warning")
return redirect("/")
oidc_configuration = requests.get(config.OIDC_WELL_KNOWN_URL).json()
user_info_url = oidc_configuration["userinfo_endpoint"]
token_url = oidc_configuration["token_endpoint"]
oidc = OAuth2Session(
config.OIDC_CLIENT_ID,
state=session[SESSION_STATE_KEY],
scope=[OIDC_SCOPES],
redirect_uri=redirect_uri,
)
oidc.fetch_token(
token_url,
client_secret=config.OIDC_CLIENT_SECRET,
authorization_response=request.url,
)
oidc_user_data = oidc.get(user_info_url)
if oidc_user_data.status_code != 200:
LOG.e(
f"cannot get oidc user data {oidc_user_data.status_code} {oidc_user_data.text}"
)
flash(
"Cannot get user data from OIDC, please use another way to login/sign up",
"error",
)
return redirect(url_for("auth.login"))
oidc_user_data = oidc_user_data.json()
email = oidc_user_data.get("email")
if not email:
LOG.e(f"cannot get email for OIDC user {oidc_user_data} {email}")
flash(
"Cannot get a valid email from OIDC, please another way to login/sign up",
"error",
)
return redirect(url_for("auth.login"))
email = sanitize_email(email)
user = User.get_by(email=email)
if not user and config.DISABLE_REGISTRATION:
flash(
"Sorry you cannot sign up via the OIDC provider. Please sign-up first with your email.",
"error",
)
return redirect(url_for("auth.register"))
elif not user:
user = create_user(email, oidc_user_data)
if not SocialAuth.get_by(user_id=user.id, social="oidc"):
SocialAuth.create(user_id=user.id, social="oidc")
Session.commit()
# The activation link contains the original page, for ex authorize page
next_url = session[SESSION_NEXT_KEY]
session[SESSION_NEXT_KEY] = None
return after_login(user, next_url)
def create_user(email, oidc_user_data):
new_user = User.create(
email=email,
name=oidc_user_data.get(OIDC_NAME_FIELD),
password="",
activated=True,
)
LOG.i(f"Created new user for login request from OIDC. New user {new_user.id}")
Session.commit()
send_welcome_email(new_user)
return new_user

View File

@ -3,7 +3,6 @@ from flask import request, session, redirect, flash, url_for
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from flask_login import current_user from flask_login import current_user
from requests_oauthlib import OAuth2Session from requests_oauthlib import OAuth2Session
from typing import Optional
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.login_utils import after_login from app.auth.views.login_utils import after_login
@ -16,15 +15,13 @@ from app.config import (
PROTON_VALIDATE_CERTS, PROTON_VALIDATE_CERTS,
URL, URL,
) )
from app.log import LOG
from app.models import ApiKey, User
from app.proton.proton_client import HttpProtonClient, convert_access_token from app.proton.proton_client import HttpProtonClient, convert_access_token
from app.proton.proton_callback_handler import ( from app.proton.proton_callback_handler import (
ProtonCallbackHandler, ProtonCallbackHandler,
Action, Action,
) )
from app.proton.utils import get_proton_partner from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url, sanitize_scheme from app.utils import sanitize_next_url
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize" _authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
_token_url = PROTON_BASE_URL + "/oauth/token" _token_url = PROTON_BASE_URL + "/oauth/token"
@ -33,35 +30,19 @@ _token_url = PROTON_BASE_URL + "/oauth/token"
# when served behind nginx, the redirect_uri is localhost... and not the real url # when served behind nginx, the redirect_uri is localhost... and not the real url
_redirect_uri = URL + "/auth/proton/callback" _redirect_uri = URL + "/auth/proton/callback"
SESSION_ACTION_KEY = "oauth_action"
SESSION_STATE_KEY = "oauth_state"
DEFAULT_SCHEME = "auth.simplelogin"
def extract_action() -> Action:
def get_api_key_for_user(user: User) -> str:
ak = ApiKey.create(
user_id=user.id,
name="Created via Login with Proton on mobile app",
commit=True,
)
return ak.code
def extract_action() -> Optional[Action]:
action = request.args.get("action") action = request.args.get("action")
if action is not None: if action is not None:
if action == "link": if action == "link":
return Action.Link return Action.Link
elif action == "login":
return Action.Login
else: else:
LOG.w(f"Unknown action received: {action}") raise Exception(f"Unknown action: {action}")
return None
return Action.Login return Action.Login
def get_action_from_state() -> Action: def get_action_from_state() -> Action:
oauth_action = session[SESSION_ACTION_KEY] oauth_action = session["oauth_action"]
if oauth_action == Action.Login.value: if oauth_action == Action.Login.value:
return Action.Login return Action.Login
elif oauth_action == Action.Link.value: elif oauth_action == Action.Link.value:
@ -74,44 +55,22 @@ def proton_login():
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
action = extract_action()
if action is None:
return redirect(url_for("auth.login"))
if action == Action.Link and not current_user.is_authenticated:
return redirect(url_for("auth.login"))
next_url = sanitize_next_url(request.args.get("next")) next_url = sanitize_next_url(request.args.get("next"))
if next_url: if next_url:
session["oauth_next"] = next_url session["oauth_next"] = next_url
elif "oauth_next" in session: elif "oauth_next" in session:
del session["oauth_next"] del session["oauth_next"]
scheme = sanitize_scheme(request.args.get("scheme"))
if scheme:
session["oauth_scheme"] = scheme
elif "oauth_scheme" in session:
del session["oauth_scheme"]
mode = request.args.get("mode", "session")
if mode == "apikey":
session["oauth_mode"] = "apikey"
else:
session["oauth_mode"] = "session"
proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=_redirect_uri) proton = OAuth2Session(PROTON_CLIENT_ID, redirect_uri=_redirect_uri)
authorization_url, state = proton.authorization_url(_authorization_base_url) authorization_url, state = proton.authorization_url(_authorization_base_url)
# State is used to prevent CSRF, keep this for later. # State is used to prevent CSRF, keep this for later.
session[SESSION_STATE_KEY] = state session["oauth_state"] = state
session[SESSION_ACTION_KEY] = action.value session["oauth_action"] = extract_action().value
return redirect(authorization_url) return redirect(authorization_url)
@auth_bp.route("/proton/callback") @auth_bp.route("/proton/callback")
def proton_callback(): def proton_callback():
if SESSION_STATE_KEY not in session or SESSION_STATE_KEY not in session:
flash("Invalid state, please retry", "error")
return redirect(url_for("auth.login"))
if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None: if PROTON_CLIENT_ID is None or PROTON_CLIENT_SECRET is None:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
@ -122,7 +81,7 @@ def proton_callback():
proton = OAuth2Session( proton = OAuth2Session(
PROTON_CLIENT_ID, PROTON_CLIENT_ID,
state=session[SESSION_STATE_KEY], state=session["oauth_state"],
redirect_uri=_redirect_uri, redirect_uri=_redirect_uri,
) )
@ -139,21 +98,15 @@ def proton_callback():
if PROTON_EXTRA_HEADER_NAME and PROTON_EXTRA_HEADER_VALUE: if PROTON_EXTRA_HEADER_NAME and PROTON_EXTRA_HEADER_VALUE:
headers = {PROTON_EXTRA_HEADER_NAME: PROTON_EXTRA_HEADER_VALUE} headers = {PROTON_EXTRA_HEADER_NAME: PROTON_EXTRA_HEADER_VALUE}
try: token = proton.fetch_token(
token = proton.fetch_token( _token_url,
_token_url, client_secret=PROTON_CLIENT_SECRET,
client_secret=PROTON_CLIENT_SECRET, authorization_response=request.url,
authorization_response=request.url, verify=PROTON_VALIDATE_CERTS,
verify=PROTON_VALIDATE_CERTS, method="GET",
method="GET", include_client_id=True,
include_client_id=True, headers=headers,
headers=headers, )
)
except Exception as e:
LOG.warning(f"Error fetching Proton token: {e}")
flash("There was an error in the login process", "error")
return redirect(url_for("auth.login"))
credentials = convert_access_token(token["access_token"]) credentials = convert_access_token(token["access_token"])
action = get_action_from_state() action = get_action_from_state()
@ -163,7 +116,6 @@ def proton_callback():
handler = ProtonCallbackHandler(proton_client) handler = ProtonCallbackHandler(proton_client)
proton_partner = get_proton_partner() proton_partner = get_proton_partner()
next_url = session.get("oauth_next")
if action == Action.Login: if action == Action.Login:
res = handler.handle_login(proton_partner) res = handler.handle_login(proton_partner)
elif action == Action.Link: elif action == Action.Link:
@ -174,17 +126,11 @@ def proton_callback():
if res.flash_message is not None: if res.flash_message is not None:
flash(res.flash_message, res.flash_category) flash(res.flash_message, res.flash_category)
oauth_scheme = session.get("oauth_scheme")
if session.get("oauth_mode", "session") == "apikey":
apikey = get_api_key_for_user(res.user)
scheme = oauth_scheme or DEFAULT_SCHEME
return redirect(f"{scheme}:///login?apikey={apikey}")
if res.redirect_to_login: if res.redirect_to_login:
return redirect(url_for("auth.login")) return redirect(url_for("auth.login"))
if next_url and next_url[0] == "/" and oauth_scheme: if res.redirect:
next_url = f"{oauth_scheme}://{next_url}" return after_login(res.user, res.redirect)
redirect_url = next_url or res.redirect next_url = session.get("oauth_next")
return after_login(res.user, redirect_url, login_from_proton=True) return after_login(res.user, next_url)

View File

@ -42,7 +42,7 @@ def recovery_route():
if recovery_form.validate_on_submit(): if recovery_form.validate_on_submit():
code = recovery_form.code.data code = recovery_form.code.data
recovery_code = RecoveryCode.find_by_user_code(user, code) recovery_code = RecoveryCode.get_by(user_id=user.id, code=code)
if recovery_code: if recovery_code:
if recovery_code.used: if recovery_code.used:
@ -53,7 +53,7 @@ def recovery_route():
del session[MFA_USER_ID] del session[MFA_USER_ID]
login_user(user) login_user(user)
flash("Welcome back!", "success") flash(f"Welcome back!", "success")
recovery_code.used = True recovery_code.used = True
recovery_code.used_at = arrow.now() recovery_code.used_at = arrow.now()

View File

@ -6,7 +6,7 @@ from wtforms import StringField, validators
from app import email_utils, config from app import email_utils, config
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.config import CONNECT_WITH_PROTON, CONNECT_WITH_OIDC_ICON from app.config import CONNECT_WITH_PROTON
from app.auth.views.login_utils import get_referral from app.auth.views.login_utils import get_referral
from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY from app.config import URL, HCAPTCHA_SECRET, HCAPTCHA_SITEKEY
from app.db import Session from app.db import Session
@ -16,8 +16,8 @@ from app.email_utils import (
) )
from app.events.auth_event import RegisterEvent from app.events.auth_event import RegisterEvent
from app.log import LOG from app.log import LOG
from app.models import User, ActivationCode, DailyMetric from app.models import User, ActivationCode
from app.utils import random_string, encode_url, sanitize_email, canonicalize_email from app.utils import random_string, encode_url, sanitize_email
class RegisterForm(FlaskForm): class RegisterForm(FlaskForm):
@ -70,22 +70,19 @@ def register():
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
) )
email = canonicalize_email(form.email.data) email = sanitize_email(form.email.data)
if not email_can_be_used_as_mailbox(email): if not email_can_be_used_as_mailbox(email):
flash("You cannot use this email address as your personal inbox.", "error") flash("You cannot use this email address as your personal inbox.", "error")
RegisterEvent(RegisterEvent.ActionType.email_in_use).send() RegisterEvent(RegisterEvent.ActionType.email_in_use).send()
else: else:
sanitized_email = sanitize_email(form.email.data) if personal_email_already_used(email):
if personal_email_already_used(email) or personal_email_already_used(
sanitized_email
):
flash(f"Email {email} already used", "error") flash(f"Email {email} already used", "error")
RegisterEvent(RegisterEvent.ActionType.email_in_use).send() RegisterEvent(RegisterEvent.ActionType.email_in_use).send()
else: else:
LOG.d("create user %s", email) LOG.d("create user %s", email)
user = User.create( user = User.create(
email=email, email=email,
name=form.email.data, name="",
password=form.password.data, password=form.password.data,
referral=get_referral(), referral=get_referral(),
) )
@ -94,8 +91,6 @@ def register():
try: try:
send_activation_email(user, next_url) send_activation_email(user, next_url)
RegisterEvent(RegisterEvent.ActionType.success).send() RegisterEvent(RegisterEvent.ActionType.success).send()
DailyMetric.get_or_create_today_metric().nb_new_web_non_proton_user += 1
Session.commit()
except Exception: except Exception:
flash("Invalid email, are you sure the email is correct?", "error") flash("Invalid email, are you sure the email is correct?", "error")
RegisterEvent(RegisterEvent.ActionType.invalid_email).send() RegisterEvent(RegisterEvent.ActionType.invalid_email).send()
@ -109,14 +104,11 @@ def register():
next_url=next_url, next_url=next_url,
HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY, HCAPTCHA_SITEKEY=HCAPTCHA_SITEKEY,
connect_with_proton=CONNECT_WITH_PROTON, connect_with_proton=CONNECT_WITH_PROTON,
connect_with_oidc=config.OIDC_CLIENT_ID is not None,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )
def send_activation_email(user, next_url): def send_activation_email(user, next_url):
# the activation code is valid for 1h and delete all previous codes # the activation code is valid for 1h
Session.query(ActivationCode).filter(ActivationCode.user_id == user.id).delete()
activation = ActivationCode.create(user_id=user.id, code=random_string(30)) activation = ActivationCode.create(user_id=user.id, code=random_string(30))
Session.commit() Session.commit()
@ -126,4 +118,4 @@ def send_activation_email(user, next_url):
LOG.d("redirect user to %s after activation", next_url) LOG.d("redirect user to %s after activation", next_url)
activation_link = activation_link + "&next=" + encode_url(next_url) activation_link = activation_link + "&next=" + encode_url(next_url)
email_utils.send_activation_email(user, activation_link) email_utils.send_activation_email(user.email, activation_link)

View File

@ -4,10 +4,9 @@ from wtforms import StringField, validators
from app.auth.base import auth_bp from app.auth.base import auth_bp
from app.auth.views.register import send_activation_email from app.auth.views.register import send_activation_email
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import User from app.models import User
from app.utils import sanitize_email, canonicalize_email from app.utils import sanitize_email
class ResendActivationForm(FlaskForm): class ResendActivationForm(FlaskForm):
@ -15,14 +14,11 @@ class ResendActivationForm(FlaskForm):
@auth_bp.route("/resend_activation", methods=["GET", "POST"]) @auth_bp.route("/resend_activation", methods=["GET", "POST"])
@limiter.limit("10/hour")
def resend_activation(): def resend_activation():
form = ResendActivationForm(request.form) form = ResendActivationForm(request.form)
if form.validate_on_submit(): if form.validate_on_submit():
email = sanitize_email(form.email.data) user = User.filter_by(email=sanitize_email(form.email.data)).first()
canonical_email = canonicalize_email(email)
user = User.get_by(email=email) or User.get_by(email=canonical_email)
if not user: if not user:
flash("There is no such email", "warning") flash("There is no such email", "warning")

View File

@ -60,8 +60,8 @@ def reset_password():
# this can be served to activate user too # this can be served to activate user too
user.activated = True user.activated = True
# remove all reset password codes # remove the reset password code
ResetPasswordCode.filter_by(user_id=user.id).delete() ResetPasswordCode.delete(reset_password_code.id)
# change the alternative_id to log user out on other browsers # change the alternative_id to log user out on other browsers
user.alternative_id = str(uuid.uuid4()) user.alternative_id = str(uuid.uuid4())

View File

@ -3,11 +3,12 @@ import random
import socket import socket
import string import string
from ast import literal_eval from ast import literal_eval
from typing import Callable, List, Optional from typing import Callable, List
from urllib.parse import urlparse from urllib.parse import urlparse
from dotenv import load_dotenv from dotenv import load_dotenv
ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__))) ROOT_DIR = os.path.abspath(os.path.dirname(os.path.dirname(__file__)))
@ -35,33 +36,6 @@ def sl_getenv(env_var: str, default_factory: Callable = None):
return literal_eval(value) return literal_eval(value)
def get_env_dict(env_var: str) -> dict[str, str]:
"""
Get an env variable and convert it into a python dictionary with keys and values as strings.
Args:
env_var (str): env var, example: SL_DB
Syntax is: key1=value1;key2=value2
Components separated by ;
key and value separated by =
"""
value = os.getenv(env_var)
if not value:
return {}
components = value.split(";")
result = {}
for component in components:
if component == "":
continue
parts = component.split("=")
if len(parts) != 2:
raise Exception(f"Invalid config for env var {env_var}")
result[parts[0].strip()] = parts[1].strip()
return result
config_file = os.environ.get("CONFIG") config_file = os.environ.get("CONFIG")
if config_file: if config_file:
config_file = get_abs_path(config_file) config_file = get_abs_path(config_file)
@ -123,8 +97,6 @@ except Exception:
print("MAX_NB_EMAIL_FREE_PLAN is not set, use 5 as default value") print("MAX_NB_EMAIL_FREE_PLAN is not set, use 5 as default value")
MAX_NB_EMAIL_FREE_PLAN = 5 MAX_NB_EMAIL_FREE_PLAN = 5
MAX_NB_EMAIL_OLD_FREE_PLAN = int(os.environ.get("MAX_NB_EMAIL_OLD_FREE_PLAN", 15))
# maximum number of directory a premium user can create # maximum number of directory a premium user can create
MAX_NB_DIRECTORY = 50 MAX_NB_DIRECTORY = 50
MAX_NB_SUBDOMAIN = 5 MAX_NB_SUBDOMAIN = 5
@ -138,16 +110,13 @@ POSTFIX_SERVER = os.environ.get("POSTFIX_SERVER", "240.0.0.1")
DISABLE_REGISTRATION = "DISABLE_REGISTRATION" in os.environ DISABLE_REGISTRATION = "DISABLE_REGISTRATION" in os.environ
# allow using a different postfix port, useful when developing locally # allow using a different postfix port, useful when developing locally
POSTFIX_PORT = 25
if "POSTFIX_PORT" in os.environ:
POSTFIX_PORT = int(os.environ["POSTFIX_PORT"])
# Use port 587 instead of 25 when sending emails through Postfix # Use port 587 instead of 25 when sending emails through Postfix
# Useful when calling Postfix from an external network # Useful when calling Postfix from an external network
POSTFIX_SUBMISSION_TLS = "POSTFIX_SUBMISSION_TLS" in os.environ POSTFIX_SUBMISSION_TLS = "POSTFIX_SUBMISSION_TLS" in os.environ
if POSTFIX_SUBMISSION_TLS:
default_postfix_port = 587
else:
default_postfix_port = 25
POSTFIX_PORT = int(os.environ.get("POSTFIX_PORT", default_postfix_port))
POSTFIX_TIMEOUT = int(os.environ.get("POSTFIX_TIMEOUT", 3))
# ["domain1.com", "domain2.com"] # ["domain1.com", "domain2.com"]
OTHER_ALIAS_DOMAINS = sl_getenv("OTHER_ALIAS_DOMAINS", list) OTHER_ALIAS_DOMAINS = sl_getenv("OTHER_ALIAS_DOMAINS", list)
@ -190,7 +159,6 @@ if "DKIM_PRIVATE_KEY_PATH" in os.environ:
# Database # Database
DB_URI = os.environ["DB_URI"] DB_URI = os.environ["DB_URI"]
DB_CONN_NAME = os.environ.get("DB_CONN_NAME", "webapp")
# Flask secret # Flask secret
FLASK_SECRET = os.environ["FLASK_SECRET"] FLASK_SECRET = os.environ["FLASK_SECRET"]
@ -199,14 +167,12 @@ if not FLASK_SECRET:
SESSION_COOKIE_NAME = "slapp" SESSION_COOKIE_NAME = "slapp"
MAILBOX_SECRET = FLASK_SECRET + "mailbox" MAILBOX_SECRET = FLASK_SECRET + "mailbox"
CUSTOM_ALIAS_SECRET = FLASK_SECRET + "custom_alias" CUSTOM_ALIAS_SECRET = FLASK_SECRET + "custom_alias"
UNSUBSCRIBE_SECRET = FLASK_SECRET + "unsub"
# AWS # AWS
AWS_REGION = os.environ.get("AWS_REGION") or "eu-west-3" AWS_REGION = os.environ.get("AWS_REGION") or "eu-west-3"
BUCKET = os.environ.get("BUCKET") BUCKET = os.environ.get("BUCKET")
AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID") AWS_ACCESS_KEY_ID = os.environ.get("AWS_ACCESS_KEY_ID")
AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY") AWS_SECRET_ACCESS_KEY = os.environ.get("AWS_SECRET_ACCESS_KEY")
AWS_ENDPOINT_URL = os.environ.get("AWS_ENDPOINT_URL", None)
# Paddle # Paddle
try: try:
@ -261,7 +227,7 @@ else:
print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME) print("WARNING: Use a temp directory for GNUPGHOME", GNUPGHOME)
# Github, Google, Facebook, OIDC client id and secrets # Github, Google, Facebook client id and secrets
GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID") GITHUB_CLIENT_ID = os.environ.get("GITHUB_CLIENT_ID")
GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET") GITHUB_CLIENT_SECRET = os.environ.get("GITHUB_CLIENT_SECRET")
@ -271,13 +237,6 @@ GOOGLE_CLIENT_SECRET = os.environ.get("GOOGLE_CLIENT_SECRET")
FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID") FACEBOOK_CLIENT_ID = os.environ.get("FACEBOOK_CLIENT_ID")
FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET") FACEBOOK_CLIENT_SECRET = os.environ.get("FACEBOOK_CLIENT_SECRET")
CONNECT_WITH_OIDC_ICON = os.environ.get("CONNECT_WITH_OIDC_ICON")
OIDC_WELL_KNOWN_URL = os.environ.get("OIDC_WELL_KNOWN_URL")
OIDC_CLIENT_ID = os.environ.get("OIDC_CLIENT_ID")
OIDC_CLIENT_SECRET = os.environ.get("OIDC_CLIENT_SECRET")
OIDC_SCOPES = os.environ.get("OIDC_SCOPES")
OIDC_NAME_FIELD = os.environ.get("OIDC_NAME_FIELD", "name")
PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID") PROTON_CLIENT_ID = os.environ.get("PROTON_CLIENT_ID")
PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET") PROTON_CLIENT_SECRET = os.environ.get("PROTON_CLIENT_SECRET")
PROTON_BASE_URL = os.environ.get( PROTON_BASE_URL = os.environ.get(
@ -287,6 +246,7 @@ PROTON_VALIDATE_CERTS = "PROTON_VALIDATE_CERTS" in os.environ
CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ
PROTON_EXTRA_HEADER_NAME = os.environ.get("PROTON_EXTRA_HEADER_NAME") PROTON_EXTRA_HEADER_NAME = os.environ.get("PROTON_EXTRA_HEADER_NAME")
PROTON_EXTRA_HEADER_VALUE = os.environ.get("PROTON_EXTRA_HEADER_VALUE") PROTON_EXTRA_HEADER_VALUE = os.environ.get("PROTON_EXTRA_HEADER_VALUE")
CONNECT_WITH_PROTON_COOKIE_NAME = os.environ.get("CONNECT_WITH_PROTON_COOKIE_NAME")
# in seconds # in seconds
AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week
@ -308,7 +268,6 @@ JOB_DELETE_MAILBOX = "delete-mailbox"
JOB_DELETE_DOMAIN = "delete-domain" JOB_DELETE_DOMAIN = "delete-domain"
JOB_SEND_USER_REPORT = "send-user-report" JOB_SEND_USER_REPORT = "send-user-report"
JOB_SEND_PROTON_WELCOME_1 = "proton-welcome-1" JOB_SEND_PROTON_WELCOME_1 = "proton-welcome-1"
JOB_SEND_ALIAS_CREATION_EVENTS = "send-alias-creation-events"
# for pagination # for pagination
PAGE_LIMIT = 20 PAGE_LIMIT = 20
@ -393,7 +352,6 @@ ALERT_COMPLAINT_TRANSACTIONAL_PHASE = "alert_complaint_transactional_phase"
ALERT_QUARANTINE_DMARC = "alert_quarantine_dmarc" ALERT_QUARANTINE_DMARC = "alert_quarantine_dmarc"
ALERT_DUAL_SUBSCRIPTION_WITH_PARTNER = "alert_dual_sub_with_partner" ALERT_DUAL_SUBSCRIPTION_WITH_PARTNER = "alert_dual_sub_with_partner"
ALERT_WARN_MULTIPLE_SUBSCRIPTIONS = "alert_multiple_subscription"
# <<<<< END ALERT EMAIL >>>> # <<<<< END ALERT EMAIL >>>>
@ -456,11 +414,6 @@ try:
except Exception: except Exception:
HIBP_SCAN_INTERVAL_DAYS = 7 HIBP_SCAN_INTERVAL_DAYS = 7
HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or [] HIBP_API_KEYS = sl_getenv("HIBP_API_KEYS", list) or []
HIBP_MAX_ALIAS_CHECK = 10_000
HIBP_RPM = int(os.environ.get("HIBP_API_RPM", 100))
HIBP_SKIP_PARTNER_ALIAS = os.environ.get("HIBP_SKIP_PARTNER_ALIAS")
KEEP_OLD_DATA_DAYS = 30
POSTMASTER = os.environ.get("POSTMASTER") POSTMASTER = os.environ.get("POSTMASTER")
@ -529,131 +482,7 @@ def setup_nameservers():
NAMESERVERS = setup_nameservers() NAMESERVERS = setup_nameservers()
DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = os.environ.get( DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = False
"DISABLE_CREATE_CONTACTS_FOR_FREE_USERS", False
)
# Expect format hits,seconds:hits,seconds...
# Example 1,10:4,60 means 1 in the last 10 secs or 4 in the last 60 secs
def getRateLimitFromConfig(
env_var: string, default: string = ""
) -> list[tuple[int, int]]:
value = os.environ.get(env_var, default)
if not value:
return []
entries = [entry for entry in value.split(":")]
limits = []
for entry in entries:
fields = entry.split(",")
limit = (int(fields[0]), int(fields[1]))
limits.append(limit)
return limits
ALIAS_CREATE_RATE_LIMIT_FREE = getRateLimitFromConfig(
"ALIAS_CREATE_RATE_LIMIT_FREE", "10,900:50,3600"
)
ALIAS_CREATE_RATE_LIMIT_PAID = getRateLimitFromConfig(
"ALIAS_CREATE_RATE_LIMIT_PAID", "50,900:200,3600"
)
PARTNER_API_TOKEN_SECRET = os.environ.get("PARTNER_API_TOKEN_SECRET") or ( PARTNER_API_TOKEN_SECRET = os.environ.get("PARTNER_API_TOKEN_SECRET") or (
FLASK_SECRET + "partnerapitoken" FLASK_SECRET + "partnerapitoken"
) )
JOB_MAX_ATTEMPTS = 5
JOB_TAKEN_RETRY_WAIT_MINS = 30
# MEM_STORE
MEM_STORE_URI = os.environ.get("MEM_STORE_URI", None)
# Recovery codes hash salt
RECOVERY_CODE_HMAC_SECRET = os.environ.get("RECOVERY_CODE_HMAC_SECRET") or (
FLASK_SECRET + "generatearandomtoken"
)
if not RECOVERY_CODE_HMAC_SECRET or len(RECOVERY_CODE_HMAC_SECRET) < 16:
raise RuntimeError(
"Please define RECOVERY_CODE_HMAC_SECRET in your configuration with a random string at least 16 chars long"
)
# the minimum rspamd spam score above which emails that fail DMARC should be quarantined
if "MIN_RSPAMD_SCORE_FOR_FAILED_DMARC" in os.environ:
MIN_RSPAMD_SCORE_FOR_FAILED_DMARC = float(
os.environ["MIN_RSPAMD_SCORE_FOR_FAILED_DMARC"]
)
else:
MIN_RSPAMD_SCORE_FOR_FAILED_DMARC = None
# run over all reverse alias for an alias and replace them with sender address
ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT = (
"ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT" in os.environ
)
if ENABLE_ALL_REVERSE_ALIAS_REPLACEMENT:
# max number of reverse alias that can be replaced
MAX_NB_REVERSE_ALIAS_REPLACEMENT = int(
os.environ["MAX_NB_REVERSE_ALIAS_REPLACEMENT"]
)
# Only used for tests
SKIP_MX_LOOKUP_ON_CHECK = False
DISABLE_RATE_LIMIT = "DISABLE_RATE_LIMIT" in os.environ
SUBSCRIPTION_CHANGE_WEBHOOK = os.environ.get("SUBSCRIPTION_CHANGE_WEBHOOK", None)
MAX_API_KEYS = int(os.environ.get("MAX_API_KEYS", 30))
UPCLOUD_USERNAME = os.environ.get("UPCLOUD_USERNAME", None)
UPCLOUD_PASSWORD = os.environ.get("UPCLOUD_PASSWORD", None)
UPCLOUD_DB_ID = os.environ.get("UPCLOUD_DB_ID", None)
STORE_TRANSACTIONAL_EMAILS = "STORE_TRANSACTIONAL_EMAILS" in os.environ
EVENT_WEBHOOK = os.environ.get("EVENT_WEBHOOK", None)
# We want it disabled by default, so only skip if defined
EVENT_WEBHOOK_SKIP_VERIFY_SSL = "EVENT_WEBHOOK_SKIP_VERIFY_SSL" in os.environ
EVENT_WEBHOOK_DISABLE = "EVENT_WEBHOOK_DISABLE" in os.environ
def read_webhook_enabled_user_ids() -> Optional[List[int]]:
user_ids = os.environ.get("EVENT_WEBHOOK_ENABLED_USER_IDS", None)
if user_ids is None:
return None
ids = []
for user_id in user_ids.split(","):
try:
ids.append(int(user_id.strip()))
except ValueError:
pass
return ids
EVENT_WEBHOOK_ENABLED_USER_IDS: Optional[List[int]] = read_webhook_enabled_user_ids()
# Allow to define a different DB_URI for the event listener, in case we want to skip the connection pool
# It defaults to the regular DB_URI in case it's needed
EVENT_LISTENER_DB_URI = os.environ.get("EVENT_LISTENER_DB_URI", DB_URI)
def read_partner_dict(var: str) -> dict[int, str]:
partner_value = get_env_dict(var)
if len(partner_value) == 0:
return {}
res: dict[int, str] = {}
for partner_id in partner_value.keys():
try:
partner_id_int = int(partner_id.strip())
res[partner_id_int] = partner_value[partner_id]
except ValueError:
pass
return res
PARTNER_DOMAINS: dict[int, str] = read_partner_dict("PARTNER_DOMAINS")
PARTNER_DOMAIN_VALIDATION_PREFIXES: dict[int, str] = read_partner_dict(
"PARTNER_DOMAIN_VALIDATION_PREFIXES"
)

View File

@ -1,2 +0,0 @@
HEADER_ALLOW_API_COOKIES = "X-Sl-Allowcookies"
DMARC_RECORD = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"

View File

@ -1,113 +0,0 @@
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from sqlalchemy.exc import IntegrityError
from app.db import Session
from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email
from app.log import LOG
from app.models import Contact, Alias
from app.utils import sanitize_email
class ContactCreateError(Enum):
InvalidEmail = "Invalid email"
NotAllowed = "Your plan does not allow to create contacts"
@dataclass
class ContactCreateResult:
contact: Optional[Contact]
created: bool
error: Optional[ContactCreateError]
def __update_contact_if_needed(
contact: Contact, name: Optional[str], mail_from: Optional[str]
) -> ContactCreateResult:
if name and contact.name != name:
LOG.d(f"Setting {contact} name to {name}")
contact.name = name
Session.commit()
if mail_from and contact.mail_from is None:
LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from
Session.commit()
return ContactCreateResult(contact, created=False, error=None)
def create_contact(
email: str,
alias: Alias,
name: Optional[str] = None,
mail_from: Optional[str] = None,
allow_empty_email: bool = False,
automatic_created: bool = False,
from_partner: bool = False,
) -> ContactCreateResult:
# If user cannot create contacts, they still need to be created when receiving an email for an alias
if not automatic_created and not alias.user.can_create_contacts():
return ContactCreateResult(
None, created=False, error=ContactCreateError.NotAllowed
)
# Parse emails with form 'name <email>'
try:
email_name, email = parse_full_address(email)
except ValueError:
email = ""
email_name = ""
# If no name is explicitly given try to get it from the parsed email
if name is None:
name = email_name[: Contact.MAX_NAME_LENGTH]
else:
name = name[: Contact.MAX_NAME_LENGTH]
# If still no name is there, make sure the name is None instead of empty string
if not name:
name = None
if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00")
name = ""
# Sanitize email and if it's not valid only allow to create a contact if it's explicitly allowed. Otherwise fail
email = sanitize_email(email, not_lower=True)
if not is_valid_email(email):
LOG.w(f"invalid contact email {email}")
if not allow_empty_email:
return ContactCreateResult(
None, created=False, error=ContactCreateError.InvalidEmail
)
LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email
email = ""
# If contact exists, update name and mail_from if needed
contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None:
return __update_contact_if_needed(contact, name, mail_from)
# Create the contact
reply_email = generate_reply_email(email, alias)
try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=email,
name=name,
reply_email=reply_email,
mail_from=mail_from,
automatic_created=automatic_created,
flags=flags,
invalid_email=email == "",
commit=True,
)
LOG.d(
f"Created contact {contact} for alias {alias} with email {email} invalid_email={contact.invalid_email}"
)
except IntegrityError:
Session.rollback()
LOG.info(
f"Contact with email {email} for alias_id {alias.id} already existed, fetching from DB"
)
contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, created=True, error=None)

View File

@ -1,142 +0,0 @@
import arrow
import re
from dataclasses import dataclass
from enum import Enum
from typing import Optional
from app.config import JOB_DELETE_DOMAIN
from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import User, CustomDomain, SLDomain, Mailbox, Job
_ALLOWED_DOMAIN_REGEX = re.compile(r"^(?!-)[A-Za-z0-9-]{1,63}(?<!-)$")
@dataclass
class CreateCustomDomainResult:
message: str = ""
message_category: str = ""
success: bool = False
instance: Optional[CustomDomain] = None
redirect: Optional[str] = None
class CannotUseDomainReason(Enum):
InvalidDomain = 1
BuiltinDomain = 2
DomainAlreadyUsed = 3
DomainPartOfUserEmail = 4
DomainUserInMailbox = 5
def message(self, domain: str) -> str:
if self == CannotUseDomainReason.InvalidDomain:
return "This is not a valid domain"
elif self == CannotUseDomainReason.BuiltinDomain:
return "A custom domain cannot be a built-in domain."
elif self == CannotUseDomainReason.DomainAlreadyUsed:
return f"{domain} already used"
elif self == CannotUseDomainReason.DomainPartOfUserEmail:
return "You cannot add a domain that you are currently using for your personal email. Please change your personal email to your real email"
elif self == CannotUseDomainReason.DomainUserInMailbox:
return f"{domain} already used in a SimpleLogin mailbox"
else:
raise Exception("Invalid CannotUseDomainReason")
def is_valid_domain(domain: str) -> bool:
"""
Checks that a domain is valid according to RFC 1035
"""
if len(domain) > 255:
return False
if domain.endswith("."):
domain = domain[:-1] # Strip the trailing dot
labels = domain.split(".")
if not labels:
return False
for label in labels:
if not _ALLOWED_DOMAIN_REGEX.match(label):
return False
return True
def sanitize_domain(domain: str) -> str:
new_domain = domain.lower().strip()
if new_domain.startswith("http://"):
new_domain = new_domain[len("http://") :]
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
return new_domain
def can_domain_be_used(user: User, domain: str) -> Optional[CannotUseDomainReason]:
if not is_valid_domain(domain):
return CannotUseDomainReason.InvalidDomain
elif SLDomain.get_by(domain=domain):
return CannotUseDomainReason.BuiltinDomain
elif CustomDomain.get_by(domain=domain):
return CannotUseDomainReason.DomainAlreadyUsed
elif get_email_domain_part(user.email) == domain:
return CannotUseDomainReason.DomainPartOfUserEmail
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{domain}")
).first():
return CannotUseDomainReason.DomainUserInMailbox
else:
return None
def create_custom_domain(
user: User, domain: str, partner_id: Optional[int] = None
) -> CreateCustomDomainResult:
if not user.is_premium():
return CreateCustomDomainResult(
message="Only premium plan can add custom domain",
message_category="warning",
)
new_domain = sanitize_domain(domain)
domain_forbidden_cause = can_domain_be_used(user, new_domain)
if domain_forbidden_cause:
return CreateCustomDomainResult(
message=domain_forbidden_cause.message(new_domain), message_category="error"
)
new_custom_domain = CustomDomain.create(domain=new_domain, user_id=user.id)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in user.custom_domains:
if new_domain.endswith("." + root_cd.domain) and root_cd.ownership_verified:
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
# Add the partner_id in case it's passed
if partner_id is not None:
new_custom_domain.partner_id = partner_id
Session.commit()
return CreateCustomDomainResult(
success=True,
instance=new_custom_domain,
)
def delete_custom_domain(domain: CustomDomain):
# Schedule delete domain job
LOG.w("schedule delete domain job for %s", domain)
domain.pending_deletion = True
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": domain.id},
run_at=arrow.now(),
commit=True,
)

View File

@ -1,157 +0,0 @@
from dataclasses import dataclass
from typing import Optional
from app import config
from app.constants import DMARC_RECORD
from app.db import Session
from app.dns_utils import (
DNSClient,
is_mx_equivalent,
get_network_dns_client,
)
from app.models import CustomDomain
@dataclass
class DomainValidationResult:
success: bool
errors: [str]
class CustomDomainValidation:
def __init__(
self,
dkim_domain: str,
dns_client: DNSClient = get_network_dns_client(),
partner_domains: Optional[dict[int, str]] = None,
partner_domains_validation_prefixes: Optional[dict[int, str]] = None,
):
self.dkim_domain = dkim_domain
self._dns_client = dns_client
self._partner_domains = partner_domains or config.PARTNER_DOMAINS
self._partner_domain_validation_prefixes = (
partner_domains_validation_prefixes
or config.PARTNER_DOMAIN_VALIDATION_PREFIXES
)
def get_ownership_verification_record(self, domain: CustomDomain) -> str:
prefix = "sl"
if (
domain.partner_id is not None
and domain.partner_id in self._partner_domain_validation_prefixes
):
prefix = self._partner_domain_validation_prefixes[domain.partner_id]
return f"{prefix}-verification={domain.ownership_txt_token}"
def get_dkim_records(self, domain: CustomDomain) -> {str: str}:
"""
Get a list of dkim records to set up. Depending on the custom_domain, whether if it's from a partner or not,
it will return the default ones or the partner ones.
"""
# By default use the default domain
dkim_domain = self.dkim_domain
if domain.partner_id is not None:
# Domain is from a partner. Retrieve the partner config and use that domain if exists
dkim_domain = self._partner_domains.get(domain.partner_id, dkim_domain)
return {
f"{key}._domainkey": f"{key}._domainkey.{dkim_domain}"
for key in ("dkim", "dkim02", "dkim03")
}
def validate_dkim_records(self, custom_domain: CustomDomain) -> dict[str, str]:
"""
Check if dkim records are properly set for this custom domain.
Returns empty list if all records are ok. Other-wise return the records that aren't properly configured
"""
correct_records = {}
invalid_records = {}
expected_records = self.get_dkim_records(custom_domain)
for prefix, expected_record in expected_records.items():
custom_record = f"{prefix}.{custom_domain.domain}"
dkim_record = self._dns_client.get_cname_record(custom_record)
if dkim_record == expected_record:
correct_records[prefix] = custom_record
else:
invalid_records[custom_record] = dkim_record or "empty"
# HACK
# As initially we only had one dkim record, we want to allow users that had only the original dkim record and
# the domain validated to continue seeing it as validated (although showing them the missing records).
# However, if not even the original dkim record is right, even if the domain was dkim_verified in the past,
# we will remove the dkim_verified flag.
# This is done in order to give users with the old dkim config (only one) to update their CNAMEs
if custom_domain.dkim_verified:
# Check if at least the original dkim is there
if correct_records.get("dkim._domainkey") is not None:
# Original dkim record is there. Return the missing records (if any) and don't clear the flag
return invalid_records
# Original DKIM record is not there, which means the DKIM config is not finished. Proceed with the
# rest of the code path, returning the invalid records and clearing the flag
custom_domain.dkim_verified = len(invalid_records) == 0
Session.commit()
return invalid_records
def validate_domain_ownership(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
"""
Check if the custom_domain has added the ownership verification records
"""
txt_records = self._dns_client.get_txt_record(custom_domain.domain)
expected_verification_record = self.get_ownership_verification_record(
custom_domain
)
if expected_verification_record in txt_records:
custom_domain.ownership_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
return DomainValidationResult(success=False, errors=txt_records)
def validate_mx_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
mx_domains = self._dns_client.get_mx_domains(custom_domain.domain)
if not is_mx_equivalent(mx_domains, config.EMAIL_SERVERS_WITH_PRIORITY):
return DomainValidationResult(
success=False,
errors=[f"{priority} {domain}" for (priority, domain) in mx_domains],
)
else:
custom_domain.verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
def validate_spf_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
spf_domains = self._dns_client.get_spf_domain(custom_domain.domain)
if config.EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.spf_verified = False
Session.commit()
return DomainValidationResult(
success=False,
errors=self._dns_client.get_txt_record(custom_domain.domain),
)
def validate_dmarc_records(
self, custom_domain: CustomDomain
) -> DomainValidationResult:
txt_records = self._dns_client.get_txt_record("_dmarc." + custom_domain.domain)
if DMARC_RECORD in txt_records:
custom_domain.dmarc_verified = True
Session.commit()
return DomainValidationResult(success=True, errors=[])
else:
custom_domain.dmarc_verified = False
Session.commit()
return DomainValidationResult(success=False, errors=txt_records)

View File

@ -6,7 +6,6 @@ from .views import (
subdomain, subdomain,
billing, billing,
alias_log, alias_log,
alias_export,
unsubscribe, unsubscribe,
api_key, api_key,
custom_domain, custom_domain,
@ -24,6 +23,7 @@ from .views import (
mailbox_detail, mailbox_detail,
refused_email, refused_email,
referral, referral,
recovery_code,
contact_detail, contact_detail,
setup_done, setup_done,
batch_import, batch_import,
@ -32,42 +32,4 @@ from .views import (
delete_account, delete_account,
notification, notification,
support, support,
account_setting,
) )
__all__ = [
"index",
"pricing",
"setting",
"custom_alias",
"subdomain",
"billing",
"alias_log",
"alias_export",
"unsubscribe",
"api_key",
"custom_domain",
"alias_contact_manager",
"enter_sudo",
"mfa_setup",
"mfa_cancel",
"fido_setup",
"coupon",
"fido_manage",
"domain_detail",
"lifetime_licence",
"directory",
"mailbox",
"mailbox_detail",
"refused_email",
"referral",
"contact_detail",
"setup_done",
"batch_import",
"alias_transfer",
"app",
"delete_account",
"notification",
"support",
"account_setting",
]

View File

@ -1,242 +0,0 @@
import arrow
from flask import (
render_template,
request,
redirect,
url_for,
flash,
)
from flask_login import login_required, current_user
from app import email_utils
from app.config import (
URL,
FIRST_ALIAS_DOMAIN,
ALIAS_RANDOM_SUFFIX_LENGTH,
CONNECT_WITH_PROTON,
)
from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
from app.dashboard.views.mailbox_detail import ChangeEmailForm
from app.db import Session
from app.email_utils import (
email_can_be_used_as_mailbox,
personal_email_already_used,
)
from app.extensions import limiter
from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG
from app.models import (
BlockBehaviourEnum,
PlanEnum,
ResetPasswordCode,
EmailChange,
User,
Alias,
AliasGeneratorEnum,
SenderFormatEnum,
UnsubscribeBehaviourEnum,
)
from app.proton.utils import perform_proton_account_unlink
from app.utils import (
random_string,
CSRFValidationForm,
canonicalize_email,
)
@dashboard_bp.route("/account_setting", methods=["GET", "POST"])
@login_required
@sudo_required
@limiter.limit("5/minute", methods=["POST"])
def account_setting():
change_email_form = ChangeEmailForm()
csrf_form = CSRFValidationForm()
email_change = EmailChange.get_by(user_id=current_user.id)
if email_change:
pending_email = email_change.new_email
else:
pending_email = None
if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.setting"))
if request.form.get("form-name") == "update-email":
if change_email_form.validate():
# whether user can proceed with the email update
new_email_valid = True
new_email = canonicalize_email(change_email_form.email.data)
if new_email != current_user.email and not pending_email:
# check if this email is not already used
if personal_email_already_used(new_email) or Alias.get_by(
email=new_email
):
flash(f"Email {new_email} already used", "error")
new_email_valid = False
elif not email_can_be_used_as_mailbox(new_email):
flash(
"You cannot use this email address as your personal inbox.",
"error",
)
new_email_valid = False
# a pending email change with the same email exists from another user
elif EmailChange.get_by(new_email=new_email):
other_email_change: EmailChange = EmailChange.get_by(
new_email=new_email
)
LOG.w(
"Another user has a pending %s with the same email address. Current user:%s",
other_email_change,
current_user,
)
if other_email_change.is_expired():
LOG.d(
"delete the expired email change %s", other_email_change
)
EmailChange.delete(other_email_change.id)
Session.commit()
else:
flash(
"You cannot use this email address as your personal inbox.",
"error",
)
new_email_valid = False
if new_email_valid:
email_change = EmailChange.create(
user_id=current_user.id,
code=random_string(
60
), # todo: make sure the code is unique
new_email=new_email,
)
Session.commit()
send_change_email_confirmation(current_user, email_change)
flash(
"A confirmation email is on the way, please check your inbox",
"success",
)
return redirect(url_for("dashboard.account_setting"))
elif request.form.get("form-name") == "change-password":
flash(
"You are going to receive an email containing instructions to change your password",
"success",
)
send_reset_password_email(current_user)
return redirect(url_for("dashboard.account_setting"))
elif request.form.get("form-name") == "send-full-user-report":
if ExportUserDataJob(current_user).store_job_in_db():
flash(
"You will receive your SimpleLogin data via email shortly",
"success",
)
else:
flash("An export of your data is currently in progress", "error")
partner_sub = None
partner_name = None
return render_template(
"dashboard/account_setting.html",
csrf_form=csrf_form,
PlanEnum=PlanEnum,
SenderFormatEnum=SenderFormatEnum,
BlockBehaviourEnum=BlockBehaviourEnum,
change_email_form=change_email_form,
pending_email=pending_email,
AliasGeneratorEnum=AliasGeneratorEnum,
UnsubscribeBehaviourEnum=UnsubscribeBehaviourEnum,
partner_sub=partner_sub,
partner_name=partner_name,
FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN,
ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH,
connect_with_proton=CONNECT_WITH_PROTON,
)
def send_reset_password_email(user):
"""
generate a new ResetPasswordCode and send it over email to user
"""
# the activation code is valid for 1h
reset_password_code = ResetPasswordCode.create(
user_id=user.id, code=random_string(60)
)
Session.commit()
reset_password_link = f"{URL}/auth/reset_password?code={reset_password_code.code}"
email_utils.send_reset_password_email(user, reset_password_link)
def send_change_email_confirmation(user: User, email_change: EmailChange):
"""
send confirmation email to the new email address
"""
link = f"{URL}/auth/change_email?code={email_change.code}"
email_utils.send_change_email(user, email_change.new_email, link)
@dashboard_bp.route("/resend_email_change", methods=["GET", "POST"])
@limiter.limit("5/hour")
@login_required
@sudo_required
def resend_email_change():
form = CSRFValidationForm()
if not form.validate():
flash("Invalid request. Please try again", "warning")
return redirect(url_for("dashboard.setting"))
email_change = EmailChange.get_by(user_id=current_user.id)
if email_change:
# extend email change expiration
email_change.expired = arrow.now().shift(hours=12)
Session.commit()
send_change_email_confirmation(current_user, email_change)
flash("A confirmation email is on the way, please check your inbox", "success")
return redirect(url_for("dashboard.setting"))
else:
flash(
"You have no pending email change. Redirect back to Setting page", "warning"
)
return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/cancel_email_change", methods=["GET", "POST"])
@login_required
@sudo_required
def cancel_email_change():
form = CSRFValidationForm()
if not form.validate():
flash("Invalid request. Please try again", "warning")
return redirect(url_for("dashboard.setting"))
email_change = EmailChange.get_by(user_id=current_user.id)
if email_change:
EmailChange.delete(email_change.id)
Session.commit()
flash("Your email change is cancelled", "success")
return redirect(url_for("dashboard.setting"))
else:
flash(
"You have no pending email change. Redirect back to Setting page", "warning"
)
return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/unlink_proton_account", methods=["POST"])
@login_required
@sudo_required
def unlink_proton_account():
csrf_form = CSRFValidationForm()
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(url_for("dashboard.setting"))
perform_proton_account_unlink(current_user)
flash("Your Proton account has been unlinked", "success")
return redirect(url_for("dashboard.setting"))

View File

@ -9,11 +9,14 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter, contact_utils from app import config
from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_validation import is_valid_email from app.email_utils import (
is_valid_email,
generate_reply_email,
parse_full_address,
)
from app.errors import ( from app.errors import (
CannotCreateContactForReverseAlias, CannotCreateContactForReverseAlias,
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
@ -21,8 +24,8 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
) )
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog from app.models import Alias, Contact, EmailLog, User
from app.utils import CSRFValidationForm from app.utils import sanitize_email
def email_validator(): def email_validator():
@ -48,7 +51,15 @@ def email_validator():
return _check return _check
def create_contact(alias: Alias, contact_address: str) -> Contact: def user_can_create_contacts(user: User) -> bool:
if user.is_premium():
return True
if user.flags & User.FLAG_FREE_DISABLE_CREATE_ALIAS == 0:
return True
return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS
def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
""" """
Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS. Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS.
Can throw exceptions: Can throw exceptions:
@ -58,23 +69,37 @@ def create_contact(alias: Alias, contact_address: str) -> Contact:
""" """
if not contact_address: if not contact_address:
raise ErrAddressInvalid("Empty address") raise ErrAddressInvalid("Empty address")
output = contact_utils.create_contact(email=contact_address, alias=alias) try:
if output.error == ContactCreateError.InvalidEmail: contact_name, contact_email = parse_full_address(contact_address)
except ValueError:
raise ErrAddressInvalid(contact_address) raise ErrAddressInvalid(contact_address)
elif output.error == ContactCreateError.NotAllowed:
raise ErrContactErrorUpgradeNeeded()
elif output.error is not None:
raise ErrAddressInvalid("Invalid address")
elif not output.created:
raise ErrContactAlreadyExists(output.contact)
contact = output.contact contact_email = sanitize_email(contact_email)
if not is_valid_email(contact_email):
raise ErrAddressInvalid(contact_email)
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
raise ErrContactAlreadyExists(contact)
if not user_can_create_contacts(user):
raise ErrContactErrorUpgradeNeeded()
contact = Contact.create(
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_email,
name=contact_name,
reply_email=generate_reply_email(contact_email, user),
)
LOG.d( LOG.d(
"create reverse-alias for %s %s, reverse alias:%s", "create reverse-alias for %s %s, reverse alias:%s",
contact_address, contact_address,
alias, alias,
contact.reply_email, contact.reply_email,
) )
Session.commit()
return contact return contact
@ -204,17 +229,12 @@ def delete_contact(alias: Alias, contact_id: int):
flash(f"Reverse-alias for {delete_contact_email} has been deleted", "success") flash(f"Reverse-alias for {delete_contact_email} has been deleted", "success")
@dashboard_bp.route("/alias_contact_manager/<int:alias_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/alias_contact_manager/<alias_id>/", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(name="contact_creation")
def alias_contact_manager(alias_id): def alias_contact_manager(alias_id):
highlight_contact_id = None highlight_contact_id = None
if request.args.get("highlight_contact_id"): if request.args.get("highlight_contact_id"):
try: highlight_contact_id = int(request.args.get("highlight_contact_id"))
highlight_contact_id = int(request.args.get("highlight_contact_id"))
except ValueError:
flash("Invalid contact id", "error")
return redirect(url_for("dashboard.index"))
alias = Alias.get(alias_id) alias = Alias.get(alias_id)
@ -234,17 +254,13 @@ def alias_contact_manager(alias_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
new_contact_form = NewContactForm() new_contact_form = NewContactForm()
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if new_contact_form.validate(): if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip() contact_address = new_contact_form.email.data.strip()
try: try:
contact = create_contact(alias, contact_address) contact = create_contact(current_user, alias, contact_address)
except ( except (
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
ErrAddressInvalid, ErrAddressInvalid,
@ -302,6 +318,5 @@ def alias_contact_manager(alias_id):
last_page=last_page, last_page=last_page,
query=query, query=query,
nb_contact=nb_contact, nb_contact=nb_contact,
can_create_contacts=current_user.can_create_contacts(), can_create_contacts=user_can_create_contacts(current_user),
csrf_form=csrf_form,
) )

View File

@ -1,13 +0,0 @@
from app.dashboard.base import dashboard_bp
from flask_login import login_required, current_user
from app.alias_utils import alias_export_csv
from app.dashboard.views.enter_sudo import sudo_required
from app.extensions import limiter
@dashboard_bp.route("/alias_export", methods=["GET"])
@login_required
@sudo_required
@limiter.limit("2/minute")
def alias_export_route():
return alias_export_csv(current_user)

View File

@ -87,6 +87,6 @@ def get_alias_log(alias: Alias, page_id=0) -> [AliasLog]:
contact=contact, contact=contact,
) )
logs.append(al) logs.append(al)
logs = sorted(logs, key=lambda log: log.when, reverse=True) logs = sorted(logs, key=lambda l: l.when, reverse=True)
return logs return logs

View File

@ -7,17 +7,76 @@ from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import config from app import config
from app.alias_utils import transfer_alias
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.email_utils import send_email, render
from app.extensions import limiter from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
Contact,
AliasUsedOn,
AliasMailbox,
User,
ClientUser,
) )
from app.models import Mailbox from app.models import Mailbox
from app.utils import CSRFValidationForm
def transfer(alias, new_user, new_mailboxes: [Mailbox]):
# cannot transfer alias which is used for receiving newsletter
if User.get_by(newsletter_alias_id=alias.id):
raise Exception("Cannot transfer alias that's used to receive newsletter")
# update user_id
Session.query(Contact).filter(Contact.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(AliasUsedOn).filter(AliasUsedOn.alias_id == alias.id).update(
{"user_id": new_user.id}
)
Session.query(ClientUser).filter(ClientUser.alias_id == alias.id).update(
{"user_id": new_user.id}
)
# remove existing mailboxes from the alias
Session.query(AliasMailbox).filter(AliasMailbox.alias_id == alias.id).delete()
# set mailboxes
alias.mailbox_id = new_mailboxes.pop().id
for mb in new_mailboxes:
AliasMailbox.create(alias_id=alias.id, mailbox_id=mb.id)
# alias has never been transferred before
if not alias.original_owner_id:
alias.original_owner_id = alias.user_id
# inform previous owner
old_user = alias.user
send_email(
old_user.email,
f"Alias {alias.email} has been received",
render(
"transactional/alias-transferred.txt",
alias=alias,
),
render(
"transactional/alias-transferred.html",
alias=alias,
),
)
# now the alias belongs to the new user
alias.user_id = new_user.id
# set some fields back to default
alias.disable_pgp = False
alias.pinned = False
Session.commit()
def hmac_alias_transfer_token(transfer_token: str) -> str: def hmac_alias_transfer_token(transfer_token: str) -> str:
@ -46,12 +105,8 @@ def alias_transfer_send_route(alias_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
alias_transfer_url = None alias_transfer_url = None
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
# generate a new transfer_token # generate a new transfer_token
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}" transfer_token = f"{alias.id}.{secrets.token_urlsafe(32)}"
@ -78,7 +133,6 @@ def alias_transfer_send_route(alias_id):
alias_transfer_url=alias_transfer_url, alias_transfer_url=alias_transfer_url,
link_active=alias.transfer_token_expiration is not None link_active=alias.transfer_token_expiration is not None
and alias.transfer_token_expiration > arrow.utcnow(), and alias.transfer_token_expiration > arrow.utcnow(),
csrf_form=csrf_form,
) )
@ -154,13 +208,7 @@ def alias_transfer_receive_route():
mailboxes, mailboxes,
token, token,
) )
transfer_alias(alias, current_user, mailboxes) transfer(alias, current_user, mailboxes)
# reset transfer token
alias.transfer_token = None
alias.transfer_token_expiration = None
Session.commit()
flash(f"You are now owner of {alias.email}", "success") flash(f"You are now owner of {alias.email}", "success")
return redirect(url_for("dashboard.index", highlight_alias_id=alias.id)) return redirect(url_for("dashboard.index", highlight_alias_id=alias.id))

View File

@ -3,47 +3,19 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import config
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.models import ApiKey from app.models import ApiKey
from app.utils import CSRFValidationForm
class NewApiKeyForm(FlaskForm): class NewApiKeyForm(FlaskForm):
name = StringField("Name", validators=[validators.DataRequired()]) name = StringField("Name", validators=[validators.DataRequired()])
def clean_up_unused_or_old_api_keys(user_id: int):
total_keys = ApiKey.filter_by(user_id=user_id).count()
if total_keys <= config.MAX_API_KEYS:
return
# Remove oldest unused
for api_key in (
ApiKey.filter_by(user_id=user_id, last_used=None)
.order_by(ApiKey.created_at.asc())
.all()
):
Session.delete(api_key)
total_keys -= 1
if total_keys <= config.MAX_API_KEYS:
return
# Clean up oldest used
for api_key in (
ApiKey.filter_by(user_id=user_id).order_by(ApiKey.last_used.asc()).all()
):
Session.delete(api_key)
total_keys -= 1
if total_keys <= config.MAX_API_KEYS:
return
@dashboard_bp.route("/api_key", methods=["GET", "POST"]) @dashboard_bp.route("/api_key", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
@limiter.limit("10/hour")
def api_key(): def api_key():
api_keys = ( api_keys = (
ApiKey.filter(ApiKey.user_id == current_user.id) ApiKey.filter(ApiKey.user_id == current_user.id)
@ -51,13 +23,9 @@ def api_key():
.all() .all()
) )
csrf_form = CSRFValidationForm()
new_api_key_form = NewApiKeyForm() new_api_key_form = NewApiKeyForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
api_key_id = request.form.get("api-key-id") api_key_id = request.form.get("api-key-id")
@ -77,7 +45,6 @@ def api_key():
elif request.form.get("form-name") == "create": elif request.form.get("form-name") == "create":
if new_api_key_form.validate(): if new_api_key_form.validate():
clean_up_unused_or_old_api_keys(current_user.id)
new_api_key = ApiKey.create( new_api_key = ApiKey.create(
name=new_api_key_form.name.data, user_id=current_user.id name=new_api_key_form.name.data, user_id=current_user.id
) )
@ -95,8 +62,5 @@ def api_key():
return redirect(url_for("dashboard.api_key")) return redirect(url_for("dashboard.api_key"))
return render_template( return render_template(
"dashboard/api_key.html", "dashboard/api_key.html", api_keys=api_keys, new_api_key_form=new_api_key_form
api_keys=api_keys,
new_api_key_form=new_api_key_form,
csrf_form=csrf_form,
) )

View File

@ -1,9 +1,14 @@
from app.db import Session
"""
List of apps that user has used via the "Sign in with SimpleLogin"
"""
from flask import render_template, request, flash, redirect from flask import render_template, request, flash, redirect
from flask_login import login_required, current_user from flask_login import login_required, current_user
from sqlalchemy.orm import joinedload from sqlalchemy.orm import joinedload
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session
from app.models import ( from app.models import (
ClientUser, ClientUser,
) )
@ -12,10 +17,6 @@ from app.models import (
@dashboard_bp.route("/app", methods=["GET", "POST"]) @dashboard_bp.route("/app", methods=["GET", "POST"])
@login_required @login_required
def app_route(): def app_route():
"""
List of apps that user has used via the "Sign in with SimpleLogin"
"""
client_users = ( client_users = (
ClientUser.filter_by(user_id=current_user.id) ClientUser.filter_by(user_id=current_user.id)
.options(joinedload(ClientUser.client)) .options(joinedload(ClientUser.client))

View File

@ -5,18 +5,14 @@ from flask_login import login_required, current_user
from app import s3 from app import s3
from app.config import JOB_BATCH_IMPORT from app.config import JOB_BATCH_IMPORT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import File, BatchImport, Job from app.models import File, BatchImport, Job
from app.utils import random_string, CSRFValidationForm from app.utils import random_string
@dashboard_bp.route("/batch_import", methods=["GET", "POST"]) @dashboard_bp.route("/batch_import", methods=["GET", "POST"])
@login_required @login_required
@sudo_required
@limiter.limit("10/minute", methods=["POST"])
def batch_import_route(): def batch_import_route():
# only for users who have custom domains # only for users who have custom domains
if not current_user.verified_custom_domains(): if not current_user.verified_custom_domains():
@ -29,27 +25,9 @@ def batch_import_route():
) )
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
batch_imports = BatchImport.filter_by( batch_imports = BatchImport.filter_by(user_id=current_user.id).all()
user_id=current_user.id, processed=False
).all()
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if len(batch_imports) > 10:
flash(
"You have too many imports already. Please wait until some get cleaned up",
"error",
)
return render_template(
"dashboard/batch_import.html",
batch_imports=batch_imports,
csrf_form=csrf_form,
)
alias_file = request.files["alias-file"] alias_file = request.files["alias-file"]
file_path = random_string(20) + ".csv" file_path = random_string(20) + ".csv"
@ -77,6 +55,4 @@ def batch_import_route():
return redirect(url_for("dashboard.batch_import_route")) return redirect(url_for("dashboard.batch_import_route"))
return render_template( return render_template("dashboard/batch_import.html", batch_imports=batch_imports)
"dashboard/batch_import.html", batch_imports=batch_imports, csrf_form=csrf_form
)

View File

@ -1,7 +1,5 @@
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -9,14 +7,6 @@ from app.models import Contact
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
class PGPContactForm(FlaskForm):
action = StringField(
"action",
validators=[validators.DataRequired(), validators.AnyOf(("save", "remove"))],
)
pgp = StringField("pgp", validators=[validators.Optional()])
@dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/contact/<int:contact_id>/", methods=["GET", "POST"])
@login_required @login_required
def contact_detail_route(contact_id): def contact_detail_route(contact_id):
@ -26,41 +16,33 @@ def contact_detail_route(contact_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
alias = contact.alias alias = contact.alias
pgp_form = PGPContactForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "pgp": if request.form.get("form-name") == "pgp":
if not pgp_form.validate(): if request.form.get("action") == "save":
flash("Invalid request", "warning")
return redirect(request.url)
if pgp_form.action.data == "save":
if not current_user.is_premium(): if not current_user.is_premium():
flash("Only premium plan can add PGP Key", "warning") flash("Only premium plan can add PGP Key", "warning")
return redirect( return redirect(
url_for("dashboard.contact_detail_route", contact_id=contact_id) url_for("dashboard.contact_detail_route", contact_id=contact_id)
) )
if not pgp_form.pgp.data:
flash("Invalid pgp key") contact.pgp_public_key = request.form.get("pgp")
try:
contact.pgp_finger_print = load_public_key_and_check(
contact.pgp_public_key
)
except PGPException:
flash("Cannot add the public key, please verify it", "error")
else: else:
contact.pgp_public_key = pgp_form.pgp.data Session.commit()
try: flash(
contact.pgp_finger_print = load_public_key_and_check( f"PGP public key for {contact.email} is saved successfully",
contact.pgp_public_key "success",
) )
except PGPException: return redirect(
flash("Cannot add the public key, please verify it", "error") url_for("dashboard.contact_detail_route", contact_id=contact_id)
else: )
Session.commit() elif request.form.get("action") == "remove":
flash(
f"PGP public key for {contact.email} is saved successfully",
"success",
)
return redirect(
url_for(
"dashboard.contact_detail_route", contact_id=contact_id
)
)
elif pgp_form.action.data == "remove":
# Free user can decide to remove contact PGP key # Free user can decide to remove contact PGP key
contact.pgp_public_key = None contact.pgp_public_key = None
contact.pgp_finger_print = None contact.pgp_finger_print = None
@ -71,5 +53,5 @@ def contact_detail_route(contact_id):
) )
return render_template( return render_template(
"dashboard/contact_detail.html", contact=contact, alias=alias, pgp_form=pgp_form "dashboard/contact_detail.html", contact=contact, alias=alias
) )

View File

@ -4,7 +4,6 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import parallel_limiter
from app.config import PADDLE_VENDOR_ID, PADDLE_COUPON_ID from app.config import PADDLE_VENDOR_ID, PADDLE_COUPON_ID
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
@ -25,7 +24,6 @@ class CouponForm(FlaskForm):
@dashboard_bp.route("/coupon", methods=["GET", "POST"]) @dashboard_bp.route("/coupon", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock()
def coupon_route(): def coupon_route():
coupon_form = CouponForm() coupon_form = CouponForm()
@ -68,14 +66,9 @@ def coupon_route():
) )
return redirect(request.url) return redirect(request.url)
updated = ( coupon.used_by_user_id = current_user.id
Session.query(Coupon) coupon.used = True
.filter_by(code=code, used=False) Session.commit()
.update({"used_by_user_id": current_user.id, "used": True})
)
if updated != 1:
flash("Coupon is not valid", "error")
return redirect(request.url)
manual_sub: ManualSubscription = ManualSubscription.get_by( manual_sub: ManualSubscription = ManualSubscription.get_by(
user_id=current_user.id user_id=current_user.id
@ -100,7 +93,7 @@ def coupon_route():
commit=True, commit=True,
) )
flash( flash(
"Your account has been upgraded to Premium, thanks for your support!", f"Your account has been upgraded to Premium, thanks for your support!",
"success", "success",
) )

View File

@ -1,16 +1,16 @@
import json
from dataclasses import dataclass, asdict
from email_validator import validate_email, EmailNotValidError from email_validator import validate_email, EmailNotValidError
from flask import render_template, redirect, url_for, flash, request from flask import render_template, redirect, url_for, flash, request
from flask_login import login_required, current_user from flask_login import login_required, current_user
from itsdangerous import TimestampSigner, SignatureExpired
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app import parallel_limiter
from app.alias_suffix import (
get_alias_suffixes,
check_suffix_signature,
verify_prefix_suffix,
)
from app.alias_utils import check_alias_prefix from app.alias_utils import check_alias_prefix
from app.config import ( from app.config import (
DISABLE_ALIAS_SUFFIX,
CUSTOM_ALIAS_SECRET,
ALIAS_LIMIT, ALIAS_LIMIT,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -19,18 +19,180 @@ from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
CustomDomain,
DeletedAlias, DeletedAlias,
Mailbox, Mailbox,
User,
AliasMailbox, AliasMailbox,
DomainDeletedAlias, DomainDeletedAlias,
) )
from app.utils import CSRFValidationForm
signer = TimestampSigner(CUSTOM_ALIAS_SECRET)
@dataclass
class SuffixInfo:
"""
Alias suffix info
WARNING: should use AliasSuffix instead
"""
# whether this is a custom domain
is_custom: bool
suffix: str
signed_suffix: str
# whether this is a premium SL domain. Not apply to custom domain
is_premium: bool
def get_available_suffixes(user: User) -> [SuffixInfo]:
"""
WARNING: should use get_alias_suffixes() instead
"""
user_custom_domains = user.verified_custom_domains()
suffixes: [SuffixInfo] = []
# put custom domain first
# for each user domain, generate both the domain and a random suffix version
for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation:
suffix = "." + user.get_random_alias_suffix() + "@" + custom_domain.domain
suffix_info = SuffixInfo(True, suffix, signer.sign(suffix).decode(), False)
if user.default_alias_custom_domain_id == custom_domain.id:
suffixes.insert(0, suffix_info)
else:
suffixes.append(suffix_info)
suffix = "@" + custom_domain.domain
suffix_info = SuffixInfo(True, suffix, signer.sign(suffix).decode(), False)
# put the default domain to top
# only if random_prefix_generation isn't enabled
if (
user.default_alias_custom_domain_id == custom_domain.id
and not custom_domain.random_prefix_generation
):
suffixes.insert(0, suffix_info)
else:
suffixes.append(suffix_info)
# then SimpleLogin domain
for sl_domain in user.get_sl_domains():
suffix = (
("" if DISABLE_ALIAS_SUFFIX else "." + user.get_random_alias_suffix())
+ "@"
+ sl_domain.domain
)
suffix_info = SuffixInfo(
False, suffix, signer.sign(suffix).decode(), sl_domain.premium_only
)
# put the default domain to top
if user.default_alias_public_domain_id == sl_domain.id:
suffixes.insert(0, suffix_info)
else:
suffixes.append(suffix_info)
return suffixes
@dataclass
class AliasSuffix:
# whether this is a custom domain
is_custom: bool
suffix: str
# whether this is a premium SL domain. Not apply to custom domain
is_premium: bool
# can be either Custom or SL domain
domain: str
# if custom domain, whether the custom domain has MX verified, i.e. can receive emails
mx_verified: bool = True
def serialize(self):
return json.dumps(asdict(self))
@classmethod
def deserialize(cls, data: str) -> "AliasSuffix":
return AliasSuffix(**json.loads(data))
def get_alias_suffixes(user: User) -> [AliasSuffix]:
"""
Similar to as get_available_suffixes() but also return custom domain that doesn't have MX set up.
"""
user_custom_domains = CustomDomain.filter_by(
user_id=user.id, ownership_verified=True
).all()
alias_suffixes: [AliasSuffix] = []
# put custom domain first
# for each user domain, generate both the domain and a random suffix version
for custom_domain in user_custom_domains:
if custom_domain.random_prefix_generation:
suffix = "." + user.get_random_alias_suffix() + "@" + custom_domain.domain
alias_suffix = AliasSuffix(
is_custom=True,
suffix=suffix,
is_premium=False,
domain=custom_domain.domain,
mx_verified=custom_domain.verified,
)
if user.default_alias_custom_domain_id == custom_domain.id:
alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
suffix = "@" + custom_domain.domain
alias_suffix = AliasSuffix(
is_custom=True,
suffix=suffix,
is_premium=False,
domain=custom_domain.domain,
mx_verified=custom_domain.verified,
)
# put the default domain to top
# only if random_prefix_generation isn't enabled
if (
user.default_alias_custom_domain_id == custom_domain.id
and not custom_domain.random_prefix_generation
):
alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
# then SimpleLogin domain
for sl_domain in user.get_sl_domains():
suffix = (
("" if DISABLE_ALIAS_SUFFIX else "." + user.get_random_alias_suffix())
+ "@"
+ sl_domain.domain
)
alias_suffix = AliasSuffix(
is_custom=False,
suffix=suffix,
is_premium=sl_domain.premium_only,
domain=sl_domain.domain,
mx_verified=True,
)
# put the default domain to top
if user.default_alias_public_domain_id == sl_domain.id:
alias_suffixes.insert(0, alias_suffix)
else:
alias_suffixes.append(alias_suffix)
return alias_suffixes
@dashboard_bp.route("/custom_alias", methods=["GET", "POST"]) @dashboard_bp.route("/custom_alias", methods=["GET", "POST"])
@limiter.limit(ALIAS_LIMIT, methods=["POST"]) @limiter.limit(ALIAS_LIMIT, methods=["POST"])
@login_required @login_required
@parallel_limiter.lock(name="alias_creation")
def custom_alias(): def custom_alias():
# check if user has not exceeded the alias quota # check if user has not exceeded the alias quota
if not current_user.can_create_new_alias(): if not current_user.can_create_new_alias():
@ -49,13 +211,14 @@ def custom_alias():
at_least_a_premium_domain = True at_least_a_premium_domain = True
break break
csrf_form = CSRFValidationForm() alias_suffixes_with_signature = [
(alias_suffix, signer.sign(alias_suffix.serialize()).decode())
for alias_suffix in alias_suffixes
]
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
alias_prefix = request.form.get("prefix").strip().lower().replace(" ", "") alias_prefix = request.form.get("prefix").strip().lower().replace(" ", "")
signed_alias_suffix = request.form.get("signed-alias-suffix") signed_alias_suffix = request.form.get("signed-alias-suffix")
mailbox_ids = request.form.getlist("mailboxes") mailbox_ids = request.form.getlist("mailboxes")
@ -86,19 +249,25 @@ def custom_alias():
flash("At least one mailbox must be selected", "error") flash("At least one mailbox must be selected", "error")
return redirect(request.url) return redirect(request.url)
# hypothesis: user will click on the button in the 600 secs
try: try:
suffix = check_suffix_signature(signed_alias_suffix) signed_alias_suffix_decoded = signer.unsign(
if not suffix: signed_alias_suffix, max_age=600
LOG.w("Alias creation time expired for %s", current_user) ).decode()
flash("Alias creation time is expired, please retry", "warning") alias_suffix: AliasSuffix = AliasSuffix.deserialize(
return redirect(request.url) signed_alias_suffix_decoded
)
except SignatureExpired:
LOG.w("Alias creation time expired for %s", current_user)
flash("Alias creation time is expired, please retry", "warning")
return redirect(request.url)
except Exception: except Exception:
LOG.w("Alias suffix is tampered, user %s", current_user) LOG.w("Alias suffix is tampered, user %s", current_user)
flash("Unknown error, refresh the page", "error") flash("Unknown error, refresh the page", "error")
return redirect(request.url) return redirect(request.url)
if verify_prefix_suffix(current_user, alias_prefix, suffix): if verify_prefix_suffix(current_user, alias_prefix, alias_suffix.suffix):
full_alias = alias_prefix + suffix full_alias = alias_prefix + alias_suffix.suffix
if ".." in full_alias: if ".." in full_alias:
flash("Your alias can't contain 2 consecutive dots (..)", "error") flash("Your alias can't contain 2 consecutive dots (..)", "error")
@ -125,11 +294,18 @@ def custom_alias():
email=full_alias email=full_alias
) )
custom_domain = domain_deleted_alias.domain custom_domain = domain_deleted_alias.domain
flash( if domain_deleted_alias.user_id == current_user.id:
f"You have deleted this alias before. You can restore it on " flash(
f"{custom_domain.domain} 'Deleted Alias' page", f"You have deleted this alias before. You can restore it on "
"error", f"{custom_domain.domain} 'Deleted Alias' page",
) "error",
)
else:
# should never happen as user can only choose their domains
LOG.e(
"Deleted Alias %s does not belong to user %s",
domain_deleted_alias,
)
elif DeletedAlias.get_by(email=full_alias): elif DeletedAlias.get_by(email=full_alias):
flash(general_error_msg, "error") flash(general_error_msg, "error")
@ -166,8 +342,51 @@ def custom_alias():
return render_template( return render_template(
"dashboard/custom_alias.html", "dashboard/custom_alias.html",
user_custom_domains=user_custom_domains, user_custom_domains=user_custom_domains,
alias_suffixes=alias_suffixes, alias_suffixes_with_signature=alias_suffixes_with_signature,
at_least_a_premium_domain=at_least_a_premium_domain, at_least_a_premium_domain=at_least_a_premium_domain,
mailboxes=mailboxes, mailboxes=mailboxes,
csrf_form=csrf_form,
) )
def verify_prefix_suffix(user: User, alias_prefix, alias_suffix) -> bool:
"""verify if user could create an alias with the given prefix and suffix"""
if not alias_prefix or not alias_suffix: # should be caught on frontend
return False
user_custom_domains = [cd.domain for cd in user.verified_custom_domains()]
# make sure alias_suffix is either .random_word@simplelogin.co or @my-domain.com
alias_suffix = alias_suffix.strip()
# alias_domain_prefix is either a .random_word or ""
alias_domain_prefix, alias_domain = alias_suffix.split("@", 1)
# alias_domain must be either one of user custom domains or built-in domains
if alias_domain not in user.available_alias_domains():
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
# SimpleLogin domain case:
# 1) alias_suffix must start with "." and
# 2) alias_domain_prefix must come from the word list
if (
alias_domain in user.available_sl_domains()
and alias_domain not in user_custom_domains
# when DISABLE_ALIAS_SUFFIX is true, alias_domain_prefix is empty
and not DISABLE_ALIAS_SUFFIX
):
if not alias_domain_prefix.startswith("."):
LOG.e("User %s submits a wrong alias suffix %s", user, alias_suffix)
return False
else:
if alias_domain not in user_custom_domains:
if not DISABLE_ALIAS_SUFFIX:
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
if alias_domain not in user.available_sl_domains():
LOG.e("wrong alias suffix %s, user %s", alias_suffix, user)
return False
return True

View File

@ -3,11 +3,12 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators from wtforms import StringField, validators
from app import parallel_limiter
from app.config import EMAIL_SERVERS_WITH_PRIORITY from app.config import EMAIL_SERVERS_WITH_PRIORITY
from app.custom_domain_utils import create_custom_domain
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.models import CustomDomain from app.db import Session
from app.email_utils import get_email_domain_part
from app.log import LOG
from app.models import CustomDomain, Mailbox, DomainMailbox, SLDomain
class NewCustomDomainForm(FlaskForm): class NewCustomDomainForm(FlaskForm):
@ -18,13 +19,15 @@ class NewCustomDomainForm(FlaskForm):
@dashboard_bp.route("/custom_domain", methods=["GET", "POST"]) @dashboard_bp.route("/custom_domain", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def custom_domain(): def custom_domain():
custom_domains = CustomDomain.filter_by( custom_domains = CustomDomain.filter_by(
user_id=current_user.id, is_sl_subdomain=False user_id=current_user.id, is_sl_subdomain=False
).all() ).all()
mailboxes = current_user.mailboxes()
new_custom_domain_form = NewCustomDomainForm() new_custom_domain_form = NewCustomDomainForm()
errors = {}
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if not current_user.is_premium(): if not current_user.is_premium():
@ -32,25 +35,87 @@ def custom_domain():
return redirect(url_for("dashboard.custom_domain")) return redirect(url_for("dashboard.custom_domain"))
if new_custom_domain_form.validate(): if new_custom_domain_form.validate():
res = create_custom_domain( new_domain = new_custom_domain_form.domain.data.lower().strip()
user=current_user, domain=new_custom_domain_form.domain.data
) if new_domain.startswith("http://"):
if res.success: new_domain = new_domain[len("http://") :]
flash(f"New domain {res.instance.domain} is created", "success")
if new_domain.startswith("https://"):
new_domain = new_domain[len("https://") :]
if SLDomain.get_by(domain=new_domain):
flash("A custom domain cannot be a built-in domain.", "error")
elif CustomDomain.get_by(domain=new_domain):
flash(f"{new_domain} already used", "error")
elif get_email_domain_part(current_user.email) == new_domain:
flash(
"You cannot add a domain that you are currently using for your personal email. "
"Please change your personal email to your real email",
"error",
)
elif Mailbox.filter(
Mailbox.verified.is_(True), Mailbox.email.endswith(f"@{new_domain}")
).first():
flash(
f"{new_domain} already used in a SimpleLogin mailbox", "error"
)
else:
new_custom_domain = CustomDomain.create(
domain=new_domain, user_id=current_user.id
)
# new domain has ownership verified if its parent has the ownership verified
for root_cd in current_user.custom_domains:
if (
new_domain.endswith("." + root_cd.domain)
and root_cd.ownership_verified
):
LOG.i(
"%s ownership verified thanks to %s",
new_custom_domain,
root_cd,
)
new_custom_domain.ownership_verified = True
Session.commit()
mailbox_ids = request.form.getlist("mailbox_ids")
if mailbox_ids:
# check if mailbox is not tempered with
mailboxes = []
for mailbox_id in mailbox_ids:
mailbox = Mailbox.get(mailbox_id)
if (
not mailbox
or mailbox.user_id != current_user.id
or not mailbox.verified
):
flash("Something went wrong, please retry", "warning")
return redirect(url_for("dashboard.custom_domain"))
mailboxes.append(mailbox)
for mailbox in mailboxes:
DomainMailbox.create(
domain_id=new_custom_domain.id, mailbox_id=mailbox.id
)
Session.commit()
flash(
f"New domain {new_custom_domain.domain} is created", "success"
)
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
custom_domain_id=res.instance.id, custom_domain_id=new_custom_domain.id,
) )
) )
else:
flash(res.message, res.message_category)
if res.redirect:
return redirect(url_for(res.redirect))
return render_template( return render_template(
"dashboard/custom_domain.html", "dashboard/custom_domain.html",
custom_domains=custom_domains, custom_domains=custom_domains,
new_custom_domain_form=new_custom_domain_form, new_custom_domain_form=new_custom_domain_form,
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
errors=errors,
mailboxes=mailboxes,
) )

View File

@ -1,7 +1,6 @@
import arrow import arrow
from flask import flash, redirect, url_for, request, render_template from flask import flash, redirect, url_for, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from app.config import JOB_DELETE_ACCOUNT from app.config import JOB_DELETE_ACCOUNT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -10,21 +9,11 @@ from app.log import LOG
from app.models import Subscription, Job from app.models import Subscription, Job
class DeleteDirForm(FlaskForm):
pass
@dashboard_bp.route("/delete_account", methods=["GET", "POST"]) @dashboard_bp.route("/delete_account", methods=["GET", "POST"])
@login_required @login_required
@sudo_required @sudo_required
def delete_account(): def delete_account():
delete_form = DeleteDirForm()
if request.method == "POST" and request.form.get("form-name") == "delete-account": if request.method == "POST" and request.form.get("form-name") == "delete-account":
if not delete_form.validate():
flash("Invalid request", "warning")
return render_template(
"dashboard/delete_account.html", delete_form=delete_form
)
sub: Subscription = current_user.get_paddle_subscription() sub: Subscription = current_user.get_paddle_subscription()
# user who has canceled can also re-subscribe # user who has canceled can also re-subscribe
if sub and not sub.cancelled: if sub and not sub.cancelled:
@ -47,4 +36,6 @@ def delete_account():
) )
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
return render_template("dashboard/delete_account.html", delete_form=delete_form) return render_template(
"dashboard/delete_account.html",
)

View File

@ -1,15 +1,8 @@
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import ( from wtforms import StringField, validators
StringField,
validators,
SelectMultipleField,
BooleanField,
IntegerField,
)
from app import parallel_limiter
from app.config import ( from app.config import (
EMAIL_DOMAIN, EMAIL_DOMAIN,
ALIAS_DOMAINS, ALIAS_DOMAINS,
@ -28,25 +21,8 @@ class NewDirForm(FlaskForm):
) )
class ToggleDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
directory_enabled = BooleanField(validators=[])
class UpdateDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
mailbox_ids = SelectMultipleField(
validators=[validators.DataRequired()], validate_choice=False, choices=[]
)
class DeleteDirForm(FlaskForm):
directory_id = IntegerField(validators=[validators.DataRequired()])
@dashboard_bp.route("/directory", methods=["GET", "POST"]) @dashboard_bp.route("/directory", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def directory(): def directory():
dirs = ( dirs = (
Directory.filter_by(user_id=current_user.id) Directory.filter_by(user_id=current_user.id)
@ -57,68 +33,54 @@ def directory():
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
new_dir_form = NewDirForm() new_dir_form = NewDirForm()
toggle_dir_form = ToggleDirForm()
update_dir_form = UpdateDirForm()
update_dir_form.mailbox_ids.choices = [
(str(mailbox.id), str(mailbox.id)) for mailbox in mailboxes
]
delete_dir_form = DeleteDirForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
if not delete_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_obj = Directory.get(delete_dir_form.directory_id.data)
if not dir_obj: if not dir:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif dir_obj.user_id != current_user.id: elif dir.user_id != current_user.id:
flash("You cannot delete this directory", "warning") flash("You cannot delete this directory", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
name = dir_obj.name name = dir.name
Directory.delete(dir_obj.id) Directory.delete(dir_id)
Session.commit() Session.commit()
flash(f"Directory {name} has been deleted", "success") flash(f"Directory {name} has been deleted", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
if request.form.get("form-name") == "toggle-directory": if request.form.get("form-name") == "toggle-directory":
if not toggle_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_id = toggle_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir or dir.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
if toggle_dir_form.directory_enabled.data: if request.form.get("dir-status") == "on":
dir_obj.disabled = False dir.disabled = False
flash(f"On-the-fly is enabled for {dir_obj.name}", "success") flash(f"On-the-fly is enabled for {dir.name}", "success")
else: else:
dir_obj.disabled = True dir.disabled = True
flash(f"On-the-fly is disabled for {dir_obj.name}", "warning") flash(f"On-the-fly is disabled for {dir.name}", "warning")
Session.commit() Session.commit()
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif request.form.get("form-name") == "update": elif request.form.get("form-name") == "update":
if not update_dir_form.validate(): dir_id = request.form.get("dir-id")
flash("Invalid request", "warning") dir = Directory.get(dir_id)
return redirect(url_for("dashboard.directory"))
dir_id = update_dir_form.directory_id.data
dir_obj = Directory.get(dir_id)
if not dir_obj or dir_obj.user_id != current_user.id: if not dir or dir.user_id != current_user.id:
flash("Unknown error. Refresh the page", "warning") flash("Unknown error. Refresh the page", "warning")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
mailbox_ids = update_dir_form.mailbox_ids.data mailbox_ids = request.form.getlist("mailbox_ids")
# check if mailbox is not tempered with # check if mailbox is not tempered with
mailboxes = [] mailboxes = []
for mailbox_id in mailbox_ids: for mailbox_id in mailbox_ids:
@ -137,14 +99,14 @@ def directory():
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
# first remove all existing directory-mailboxes links # first remove all existing directory-mailboxes links
DirectoryMailbox.filter_by(directory_id=dir_obj.id).delete() DirectoryMailbox.filter_by(directory_id=dir.id).delete()
Session.flush() Session.flush()
for mailbox in mailboxes: for mailbox in mailboxes:
DirectoryMailbox.create(directory_id=dir_obj.id, mailbox_id=mailbox.id) DirectoryMailbox.create(directory_id=dir.id, mailbox_id=mailbox.id)
Session.commit() Session.commit()
flash(f"Directory {dir_obj.name} has been updated", "success") flash(f"Directory {dir.name} has been updated", "success")
return redirect(url_for("dashboard.directory")) return redirect(url_for("dashboard.directory"))
elif request.form.get("form-name") == "create": elif request.form.get("form-name") == "create":
@ -219,9 +181,6 @@ def directory():
return render_template( return render_template(
"dashboard/directory.html", "dashboard/directory.html",
dirs=dirs, dirs=dirs,
toggle_dir_form=toggle_dir_form,
update_dir_form=update_dir_form,
delete_dir_form=delete_dir_form,
new_dir_form=new_dir_form, new_dir_form=new_dir_form,
mailboxes=mailboxes, mailboxes=mailboxes,
EMAIL_DOMAIN=EMAIL_DOMAIN, EMAIL_DOMAIN=EMAIL_DOMAIN,

View File

@ -1,16 +1,22 @@
import re import re
import arrow
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import StringField, validators, IntegerField from wtforms import StringField, validators, IntegerField
from app.constants import DMARC_RECORD from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN, JOB_DELETE_DOMAIN
from app.config import EMAIL_SERVERS_WITH_PRIORITY, EMAIL_DOMAIN
from app.custom_domain_utils import delete_custom_domain
from app.custom_domain_validation import CustomDomainValidation
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.dns_utils import (
get_mx_domains,
get_spf_domain,
get_txt_record,
get_cname_record,
is_mx_equivalent,
)
from app.log import LOG
from app.models import ( from app.models import (
CustomDomain, CustomDomain,
Alias, Alias,
@ -19,9 +25,10 @@ from app.models import (
DomainMailbox, DomainMailbox,
AutoCreateRule, AutoCreateRule,
AutoCreateRuleMailbox, AutoCreateRuleMailbox,
Job,
) )
from app.regex_utils import regex_match from app.regex_utils import regex_match
from app.utils import random_string, CSRFValidationForm from app.utils import random_string
@dashboard_bp.route("/domains/<int:custom_domain_id>/dns", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/dns", methods=["GET", "POST"])
@ -39,25 +46,25 @@ def domain_detail_dns(custom_domain_id):
spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all" spf_record = f"v=spf1 include:{EMAIL_DOMAIN} ~all"
domain_validator = CustomDomainValidation(EMAIL_DOMAIN) # hardcode the DKIM selector here
csrf_form = CSRFValidationForm() dkim_cname = f"dkim._domainkey.{EMAIL_DOMAIN}"
dmarc_record = "v=DMARC1; p=quarantine; pct=100; adkim=s; aspf=s"
mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True mx_ok = spf_ok = dkim_ok = dmarc_ok = ownership_ok = True
mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = [] mx_errors = spf_errors = dkim_errors = dmarc_errors = ownership_errors = []
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "check-ownership": if request.form.get("form-name") == "check-ownership":
ownership_validation_result = domain_validator.validate_domain_ownership( txt_records = get_txt_record(custom_domain.domain)
custom_domain
) if custom_domain.get_ownership_dns_txt_value() in txt_records:
if ownership_validation_result.success:
flash( flash(
"Domain ownership is verified. Please proceed to the other records setup", "Domain ownership is verified. Please proceed to the other records setup",
"success", "success",
) )
custom_domain.ownership_verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", "dashboard.domain_detail_dns",
@ -68,28 +75,36 @@ def domain_detail_dns(custom_domain_id):
else: else:
flash("We can't find the needed TXT record", "error") flash("We can't find the needed TXT record", "error")
ownership_ok = False ownership_ok = False
ownership_errors = ownership_validation_result.errors ownership_errors = txt_records
elif request.form.get("form-name") == "check-mx": elif request.form.get("form-name") == "check-mx":
mx_validation_result = domain_validator.validate_mx_records(custom_domain) mx_domains = get_mx_domains(custom_domain.domain)
if mx_validation_result.success:
if not is_mx_equivalent(mx_domains, EMAIL_SERVERS_WITH_PRIORITY):
flash("The MX record is not correctly set", "warning")
mx_ok = False
# build mx_errors to show to user
mx_errors = [
f"{priority} {domain}" for (priority, domain) in mx_domains
]
else:
flash( flash(
"Your domain can start receiving emails. You can now use it to create alias", "Your domain can start receiving emails. You can now use it to create alias",
"success", "success",
) )
custom_domain.verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
) )
) )
else:
flash("The MX record is not correctly set", "warning")
mx_ok = False
mx_errors = mx_validation_result.errors
elif request.form.get("form-name") == "check-spf": elif request.form.get("form-name") == "check-spf":
spf_validation_result = domain_validator.validate_spf_records(custom_domain) spf_domains = get_spf_domain(custom_domain.domain)
if spf_validation_result.success: if EMAIL_DOMAIN in spf_domains:
custom_domain.spf_verified = True
Session.commit()
flash("SPF is setup correctly", "success") flash("SPF is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -97,31 +112,39 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.spf_verified = False
Session.commit()
flash( flash(
f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.", f"SPF: {EMAIL_DOMAIN} is not included in your SPF record.",
"warning", "warning",
) )
spf_ok = False spf_ok = False
spf_errors = spf_validation_result.errors spf_errors = get_txt_record(custom_domain.domain)
elif request.form.get("form-name") == "check-dkim": elif request.form.get("form-name") == "check-dkim":
dkim_errors = domain_validator.validate_dkim_records(custom_domain) dkim_record = get_cname_record("dkim._domainkey." + custom_domain.domain)
if len(dkim_errors) == 0: if dkim_record == dkim_cname:
flash("DKIM is setup correctly.", "success") flash("DKIM is setup correctly.", "success")
custom_domain.dkim_verified = True
Session.commit()
return redirect( return redirect(
url_for( url_for(
"dashboard.domain_detail_dns", custom_domain_id=custom_domain.id "dashboard.domain_detail_dns", custom_domain_id=custom_domain.id
) )
) )
else: else:
dkim_ok = False custom_domain.dkim_verified = False
Session.commit()
flash("DKIM: the CNAME record is not correctly set", "warning") flash("DKIM: the CNAME record is not correctly set", "warning")
dkim_ok = False
dkim_errors = [dkim_record or "[Empty]"]
elif request.form.get("form-name") == "check-dmarc": elif request.form.get("form-name") == "check-dmarc":
dmarc_validation_result = domain_validator.validate_dmarc_records( txt_records = get_txt_record("_dmarc." + custom_domain.domain)
custom_domain if dmarc_record in txt_records:
) custom_domain.dmarc_verified = True
if dmarc_validation_result.success: Session.commit()
flash("DMARC is setup correctly", "success") flash("DMARC is setup correctly", "success")
return redirect( return redirect(
url_for( url_for(
@ -129,21 +152,18 @@ def domain_detail_dns(custom_domain_id):
) )
) )
else: else:
custom_domain.dmarc_verified = False
Session.commit()
flash( flash(
"DMARC: The TXT record is not correctly set", "DMARC: The TXT record is not correctly set",
"warning", "warning",
) )
dmarc_ok = False dmarc_ok = False
dmarc_errors = dmarc_validation_result.errors dmarc_errors = txt_records
return render_template( return render_template(
"dashboard/domain_detail/dns.html", "dashboard/domain_detail/dns.html",
EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY, EMAIL_SERVERS_WITH_PRIORITY=EMAIL_SERVERS_WITH_PRIORITY,
ownership_record=domain_validator.get_ownership_verification_record(
custom_domain
),
dkim_records=domain_validator.get_dkim_records(custom_domain),
dmarc_record=DMARC_RECORD,
**locals(), **locals(),
) )
@ -151,7 +171,6 @@ def domain_detail_dns(custom_domain_id):
@dashboard_bp.route("/domains/<int:custom_domain_id>/info", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/info", methods=["GET", "POST"])
@login_required @login_required
def domain_detail(custom_domain_id): def domain_detail(custom_domain_id):
csrf_form = CSRFValidationForm()
custom_domain: CustomDomain = CustomDomain.get(custom_domain_id) custom_domain: CustomDomain = CustomDomain.get(custom_domain_id)
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
@ -160,9 +179,6 @@ def domain_detail(custom_domain_id):
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "switch-catch-all": if request.form.get("form-name") == "switch-catch-all":
custom_domain.catch_all = not custom_domain.catch_all custom_domain.catch_all = not custom_domain.catch_all
Session.commit() Session.commit()
@ -261,8 +277,16 @@ def domain_detail(custom_domain_id):
elif request.form.get("form-name") == "delete": elif request.form.get("form-name") == "delete":
name = custom_domain.domain name = custom_domain.domain
LOG.d("Schedule deleting %s", custom_domain)
delete_custom_domain(custom_domain) # Schedule delete domain job
LOG.w("schedule delete domain job for %s", custom_domain)
Job.create(
name=JOB_DELETE_DOMAIN,
payload={"custom_domain_id": custom_domain.id},
run_at=arrow.now(),
commit=True,
)
flash( flash(
f"{name} scheduled for deletion." f"{name} scheduled for deletion."
@ -283,16 +307,12 @@ def domain_detail(custom_domain_id):
@dashboard_bp.route("/domains/<int:custom_domain_id>/trash", methods=["GET", "POST"]) @dashboard_bp.route("/domains/<int:custom_domain_id>/trash", methods=["GET", "POST"])
@login_required @login_required
def domain_detail_trash(custom_domain_id): def domain_detail_trash(custom_domain_id):
csrf_form = CSRFValidationForm()
custom_domain = CustomDomain.get(custom_domain_id) custom_domain = CustomDomain.get(custom_domain_id)
if not custom_domain or custom_domain.user_id != current_user.id: if not custom_domain or custom_domain.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "empty-all": if request.form.get("form-name") == "empty-all":
DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete() DomainDeletedAlias.filter_by(domain_id=custom_domain.id).delete()
Session.commit() Session.commit()
@ -336,7 +356,6 @@ def domain_detail_trash(custom_domain_id):
"dashboard/domain_detail/trash.html", "dashboard/domain_detail/trash.html",
domain_deleted_aliases=domain_deleted_aliases, domain_deleted_aliases=domain_deleted_aliases,
custom_domain=custom_domain, custom_domain=custom_domain,
csrf_form=csrf_form,
) )

View File

@ -6,15 +6,11 @@ from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from wtforms import PasswordField, validators from wtforms import PasswordField, validators
from app.config import CONNECT_WITH_PROTON, OIDC_CLIENT_ID, CONNECT_WITH_OIDC_ICON
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import PartnerUser, SocialAuth
from app.proton.utils import get_proton_partner
from app.utils import sanitize_next_url from app.utils import sanitize_next_url
_SUDO_GAP = 120 _SUDO_GAP = 900
class LoginForm(FlaskForm): class LoginForm(FlaskForm):
@ -22,7 +18,6 @@ class LoginForm(FlaskForm):
@dashboard_bp.route("/enter_sudo", methods=["GET", "POST"]) @dashboard_bp.route("/enter_sudo", methods=["GET", "POST"])
@limiter.limit("3/minute")
@login_required @login_required
def enter_sudo(): def enter_sudo():
password_check_form = LoginForm() password_check_form = LoginForm()
@ -44,26 +39,8 @@ def enter_sudo():
else: else:
flash("Incorrect password", "warning") flash("Incorrect password", "warning")
proton_enabled = CONNECT_WITH_PROTON
if proton_enabled:
# Only for users that have the account linked
partner_user = PartnerUser.get_by(user_id=current_user.id)
if not partner_user or partner_user.partner_id != get_proton_partner().id:
proton_enabled = False
oidc_enabled = OIDC_CLIENT_ID is not None
if oidc_enabled:
oidc_enabled = (
SocialAuth.get_by(user_id=current_user.id, social="oidc") is not None
)
return render_template( return render_template(
"dashboard/enter_sudo.html", "dashboard/enter_sudo.html", password_check_form=password_check_form
password_check_form=password_check_form,
next=request.args.get("next"),
connect_with_proton=proton_enabled,
connect_with_oidc=oidc_enabled,
connect_with_oidc_icon=CONNECT_WITH_OIDC_ICON,
) )

View File

@ -78,10 +78,10 @@ def fido_setup():
) )
flash("Security key has been activated", "success") flash("Security key has been activated", "success")
recovery_codes = RecoveryCode.generate(current_user) if not RecoveryCode.filter_by(user_id=current_user.id).all():
return render_template( return redirect(url_for("dashboard.recovery_code_route"))
"dashboard/recovery_code.html", recovery_codes=recovery_codes else:
) return redirect(url_for("dashboard.fido_manage"))
# Prepare information for key registration process # Prepare information for key registration process
fido_uuid = ( fido_uuid = (

View File

@ -3,7 +3,7 @@ from dataclasses import dataclass
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import alias_utils, parallel_limiter from app import alias_utils
from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3 from app.api.serializer import get_alias_infos_with_pagination_v3, get_alias_info_v3
from app.config import ALIAS_LIMIT, PAGE_LIMIT from app.config import ALIAS_LIMIT, PAGE_LIMIT
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
@ -12,13 +12,11 @@ from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
Alias, Alias,
AliasDeleteReason,
AliasGeneratorEnum, AliasGeneratorEnum,
User, User,
EmailLog, EmailLog,
Contact, Contact,
) )
from app.utils import CSRFValidationForm
@dataclass @dataclass
@ -53,17 +51,12 @@ def get_stats(user: User) -> Stats:
@dashboard_bp.route("/", methods=["GET", "POST"]) @dashboard_bp.route("/", methods=["GET", "POST"])
@login_required
@limiter.limit( @limiter.limit(
ALIAS_LIMIT, ALIAS_LIMIT,
methods=["POST"], methods=["POST"],
exempt_when=lambda: request.form.get("form-name") != "create-random-email", exempt_when=lambda: request.form.get("form-name") != "create-random-email",
) )
@limiter.limit("10/minute", methods=["GET"], key_func=lambda: current_user.id) @login_required
@parallel_limiter.lock(
name="alias_creation",
only_when=lambda: request.form.get("form-name") == "create-random-email",
)
def index(): def index():
query = request.args.get("query") or "" query = request.args.get("query") or ""
sort = request.args.get("sort") or "" sort = request.args.get("sort") or ""
@ -82,12 +75,8 @@ def index():
"highlight_alias_id must be a number, received %s", "highlight_alias_id must be a number, received %s",
request.args.get("highlight_alias_id"), request.args.get("highlight_alias_id"),
) )
csrf_form = CSRFValidationForm()
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if request.form.get("form-name") == "create-custom-email": if request.form.get("form-name") == "create-custom-email":
if current_user.can_create_new_alias(): if current_user.can_create_new_alias():
return redirect(url_for("dashboard.custom_alias")) return redirect(url_for("dashboard.custom_alias"))
@ -142,25 +131,17 @@ def index():
) )
if request.form.get("form-name") == "delete-alias": if request.form.get("form-name") == "delete-alias":
LOG.i(f"User {current_user} requested deletion of alias {alias}") LOG.d("delete alias %s", alias)
email = alias.email email = alias.email
alias_utils.delete_alias( alias_utils.delete_alias(alias, current_user)
alias, current_user, AliasDeleteReason.ManualAction, commit=True
)
flash(f"Alias {email} has been deleted", "success") flash(f"Alias {email} has been deleted", "success")
elif request.form.get("form-name") == "disable-alias": elif request.form.get("form-name") == "disable-alias":
alias_utils.change_alias_status(alias, enabled=False) alias.enabled = False
Session.commit() Session.commit()
flash(f"Alias {alias.email} has been disabled", "success") flash(f"Alias {alias.email} has been disabled", "success")
return redirect( return redirect(
url_for( url_for("dashboard.index", query=query, sort=sort, filter=alias_filter)
"dashboard.index",
query=query,
sort=sort,
filter=alias_filter,
page=page,
)
) )
mailboxes = current_user.mailboxes() mailboxes = current_user.mailboxes()
@ -223,7 +204,6 @@ def index():
sort=sort, sort=sort,
filter=alias_filter, filter=alias_filter,
stats=stats, stats=stats,
csrf_form=csrf_form,
) )

View File

@ -1,21 +1,23 @@
import base64 import arrow
import binascii
import json
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from itsdangerous import TimestampSigner from itsdangerous import Signer
from wtforms import validators, IntegerField from wtforms import validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
from app import parallel_limiter, mailbox_utils, user_settings from app.config import MAILBOX_SECRET, URL, JOB_DELETE_MAILBOX
from app.config import MAILBOX_SECRET
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import (
email_can_be_used_as_mailbox,
mailbox_already_used,
render,
send_email,
is_valid_email,
)
from app.log import LOG from app.log import LOG
from app.models import Mailbox from app.models import Mailbox, Job
from app.utils import CSRFValidationForm
class NewMailboxForm(FlaskForm): class NewMailboxForm(FlaskForm):
@ -24,16 +26,8 @@ class NewMailboxForm(FlaskForm):
) )
class DeleteMailboxForm(FlaskForm):
mailbox_id = IntegerField(
validators=[validators.DataRequired()],
)
transfer_mailbox_id = IntegerField()
@dashboard_bp.route("/mailbox", methods=["GET", "POST"]) @dashboard_bp.route("/mailbox", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def mailbox_route(): def mailbox_route():
mailboxes = ( mailboxes = (
Mailbox.filter_by(user_id=current_user.id) Mailbox.filter_by(user_id=current_user.id)
@ -42,127 +36,169 @@ def mailbox_route():
) )
new_mailbox_form = NewMailboxForm() new_mailbox_form = NewMailboxForm()
csrf_form = CSRFValidationForm()
delete_mailbox_form = DeleteMailboxForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "delete": if request.form.get("form-name") == "delete":
if not delete_mailbox_form.validate(): mailbox_id = request.form.get("mailbox-id")
flash("Invalid request", "warning") mailbox = Mailbox.get(mailbox_id)
return redirect(request.url)
try: if not mailbox or mailbox.user_id != current_user.id:
mailbox = mailbox_utils.delete_mailbox( flash("Unknown error. Refresh the page", "warning")
current_user,
delete_mailbox_form.mailbox_id.data,
delete_mailbox_form.transfer_mailbox_id.data,
)
except mailbox_utils.MailboxError as e:
flash(e.msg, "warning")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
if mailbox.id == current_user.default_mailbox_id:
flash("You cannot delete default mailbox", "error")
return redirect(url_for("dashboard.mailbox_route"))
# Schedule delete account job
LOG.w("schedule delete mailbox job for %s", mailbox)
Job.create(
name=JOB_DELETE_MAILBOX,
payload={"mailbox_id": mailbox.id},
run_at=arrow.now(),
commit=True,
)
flash( flash(
f"Mailbox {mailbox.email} scheduled for deletion." f"Mailbox {mailbox.email} scheduled for deletion."
f"You will receive a confirmation email when the deletion is finished", f"You will receive a confirmation email when the deletion is finished",
"success", "success",
) )
return redirect(url_for("dashboard.mailbox_route"))
return redirect(url_for("dashboard.mailbox_route"))
if request.form.get("form-name") == "set-default": if request.form.get("form-name") == "set-default":
if not csrf_form.validate(): mailbox_id = request.form.get("mailbox-id")
flash("Invalid request", "warning") mailbox = Mailbox.get(mailbox_id)
return redirect(request.url)
try: if not mailbox or mailbox.user_id != current_user.id:
mailbox_id = request.form.get("mailbox_id") flash("Unknown error. Refresh the page", "warning")
mailbox = user_settings.set_default_mailbox(current_user, mailbox_id)
except user_settings.CannotSetMailbox as e:
flash(e.msg, "warning")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
if mailbox.id == current_user.default_mailbox_id:
flash("This mailbox is already default one", "error")
return redirect(url_for("dashboard.mailbox_route"))
if not mailbox.verified:
flash("Cannot set unverified mailbox as default", "error")
return redirect(url_for("dashboard.mailbox_route"))
current_user.default_mailbox_id = mailbox.id
Session.commit()
flash(f"Mailbox {mailbox.email} is set as Default Mailbox", "success") flash(f"Mailbox {mailbox.email} is set as Default Mailbox", "success")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
elif request.form.get("form-name") == "create": elif request.form.get("form-name") == "create":
if not new_mailbox_form.validate(): if not current_user.is_premium():
flash("Invalid request", "warning") flash("Only premium plan can add additional mailbox", "warning")
return redirect(request.url)
mailbox_email = new_mailbox_form.email.data.lower().strip().replace(" ", "")
try:
mailbox = mailbox_utils.create_mailbox(
current_user, mailbox_email
).mailbox
except mailbox_utils.MailboxError as e:
flash(e.msg, "warning")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
flash( if new_mailbox_form.validate():
f"You are going to receive an email to confirm {mailbox.email}.", mailbox_email = (
"success", new_mailbox_form.email.data.lower().strip().replace(" ", "")
)
return redirect(
url_for(
"dashboard.mailbox_detail_route",
mailbox_id=mailbox.id,
) )
)
if not is_valid_email(mailbox_email):
flash(f"{mailbox_email} invalid", "error")
elif mailbox_already_used(mailbox_email, current_user):
flash(f"{mailbox_email} already used", "error")
elif not email_can_be_used_as_mailbox(mailbox_email):
flash(f"You cannot use {mailbox_email}.", "error")
else:
new_mailbox = Mailbox.create(
email=mailbox_email, user_id=current_user.id
)
Session.commit()
send_verification_email(current_user, new_mailbox)
flash(
f"You are going to receive an email to confirm {mailbox_email}.",
"success",
)
return redirect(
url_for(
"dashboard.mailbox_detail_route", mailbox_id=new_mailbox.id
)
)
return render_template( return render_template(
"dashboard/mailbox.html", "dashboard/mailbox.html",
mailboxes=mailboxes, mailboxes=mailboxes,
new_mailbox_form=new_mailbox_form, new_mailbox_form=new_mailbox_form,
delete_mailbox_form=delete_mailbox_form, )
csrf_form=csrf_form,
def delete_mailbox(mailbox_id: int):
from server import create_light_app
with create_light_app().app_context():
mailbox = Mailbox.get(mailbox_id)
if not mailbox:
return
mailbox_email = mailbox.email
user = mailbox.user
Mailbox.delete(mailbox_id)
Session.commit()
LOG.d("Mailbox %s %s deleted", mailbox_id, mailbox_email)
send_email(
user.email,
f"Your mailbox {mailbox_email} has been deleted",
f"""Mailbox {mailbox_email} along with its aliases are deleted successfully.
Regards,
SimpleLogin team.
""",
)
def send_verification_email(user, mailbox):
s = Signer(MAILBOX_SECRET)
mailbox_id_signed = s.sign(str(mailbox.id)).decode()
verification_url = (
URL + "/dashboard/mailbox_verify" + f"?mailbox_id={mailbox_id_signed}"
)
send_email(
mailbox.email,
f"Please confirm your email {mailbox.email}",
render(
"transactional/verify-mailbox.txt",
user=user,
link=verification_url,
mailbox_email=mailbox.email,
),
render(
"transactional/verify-mailbox.html",
user=user,
link=verification_url,
mailbox_email=mailbox.email,
),
) )
@dashboard_bp.route("/mailbox_verify") @dashboard_bp.route("/mailbox_verify")
@login_required
def mailbox_verify(): def mailbox_verify():
s = Signer(MAILBOX_SECRET)
mailbox_id = request.args.get("mailbox_id") mailbox_id = request.args.get("mailbox_id")
code = request.args.get("code")
if not code:
# Old way
return verify_with_signed_secret(mailbox_id)
try:
mailbox = mailbox_utils.verify_mailbox_code(current_user, mailbox_id, code)
except mailbox_utils.MailboxError as e:
LOG.i(f"Cannot verify mailbox {mailbox_id} because of {e}")
flash(f"Cannot verify mailbox: {e.msg}", "error")
return redirect(url_for("dashboard.mailbox_route"))
LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)
def verify_with_signed_secret(request: str):
s = TimestampSigner(MAILBOX_SECRET)
mailbox_verify_request = request.args.get("mailbox_id")
try: try:
mailbox_raw_data = s.unsign(mailbox_verify_request, max_age=900) r_id = int(s.unsign(mailbox_id))
except Exception: except Exception:
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
try: else:
decoded_data = base64.urlsafe_b64decode(mailbox_raw_data) mailbox = Mailbox.get(r_id)
except binascii.Error: if not mailbox:
flash("Invalid link. Please delete and re-add your mailbox", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route")) return redirect(url_for("dashboard.mailbox_route"))
mailbox_data = json.loads(decoded_data)
if not isinstance(mailbox_data, list) or len(mailbox_data) != 2:
flash("Invalid link. Please delete and re-add your mailbox", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox_id = mailbox_data[0]
mailbox = Mailbox.get(mailbox_id)
if not mailbox:
flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox_email = mailbox_data[1]
if mailbox_email != mailbox.email:
flash("Invalid link", "error")
return redirect(url_for("dashboard.mailbox_route"))
mailbox.verified = True mailbox.verified = True
Session.commit() Session.commit()
LOG.d("Mailbox %s is verified", mailbox) LOG.d("Mailbox %s is verified", mailbox)
return render_template("dashboard/mailbox_validation.html", mailbox=mailbox) return render_template("dashboard/mailbox_validation.html", mailbox=mailbox)

View File

@ -1,26 +1,23 @@
from smtplib import SMTPRecipientsRefused from smtplib import SMTPRecipientsRefused
from email_validator import validate_email, EmailNotValidError
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from itsdangerous import TimestampSigner from itsdangerous import Signer
from wtforms import validators from wtforms import validators
from wtforms.fields.html5 import EmailField from wtforms.fields.html5 import EmailField
from app.config import ENFORCE_SPF, MAILBOX_SECRET from app.config import ENFORCE_SPF, MAILBOX_SECRET
from app.config import URL from app.config import URL
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.email_utils import email_can_be_used_as_mailbox from app.email_utils import email_can_be_used_as_mailbox
from app.email_utils import mailbox_already_used, render, send_email from app.email_utils import mailbox_already_used, render, send_email
from app.extensions import limiter
from app.log import LOG from app.log import LOG
from app.models import Alias, AuthorizedAddress from app.models import Alias, AuthorizedAddress
from app.models import Mailbox from app.models import Mailbox
from app.pgp_utils import PGPException, load_public_key_and_check from app.pgp_utils import PGPException, load_public_key_and_check
from app.utils import sanitize_email, CSRFValidationForm from app.utils import sanitize_email
class ChangeEmailForm(FlaskForm): class ChangeEmailForm(FlaskForm):
@ -31,16 +28,13 @@ class ChangeEmailForm(FlaskForm):
@dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"]) @dashboard_bp.route("/mailbox/<int:mailbox_id>/", methods=["GET", "POST"])
@login_required @login_required
@sudo_required
@limiter.limit("20/minute", methods=["POST"])
def mailbox_detail_route(mailbox_id): def mailbox_detail_route(mailbox_id):
mailbox: Mailbox = Mailbox.get(mailbox_id) mailbox = Mailbox.get(mailbox_id)
if not mailbox or mailbox.user_id != current_user.id: if not mailbox or mailbox.user_id != current_user.id:
flash("You cannot see this page", "warning") flash("You cannot see this page", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
change_email_form = ChangeEmailForm() change_email_form = ChangeEmailForm()
csrf_form = CSRFValidationForm()
if mailbox.new_email: if mailbox.new_email:
pending_email = mailbox.new_email pending_email = mailbox.new_email
@ -48,9 +42,6 @@ def mailbox_detail_route(mailbox_id):
pending_email = None pending_email = None
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
if ( if (
request.form.get("form-name") == "update-email" request.form.get("form-name") == "update-email"
and change_email_form.validate_on_submit() and change_email_form.validate_on_submit()
@ -103,23 +94,16 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("form-name") == "add-authorized-address": elif request.form.get("form-name") == "add-authorized-address":
address = sanitize_email(request.form.get("email")) address = sanitize_email(request.form.get("email"))
try: if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address):
validate_email( flash(f"{address} already added", "error")
address, check_deliverability=False, allow_smtputf8=False
).domain
except EmailNotValidError:
flash(f"invalid {address}", "error")
else: else:
if AuthorizedAddress.get_by(mailbox_id=mailbox.id, email=address): AuthorizedAddress.create(
flash(f"{address} already added", "error") user_id=current_user.id,
else: mailbox_id=mailbox.id,
AuthorizedAddress.create( email=address,
user_id=current_user.id, commit=True,
mailbox_id=mailbox.id, )
email=address, flash(f"{address} added as authorized address", "success")
commit=True,
)
flash(f"{address} added as authorized address", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
@ -148,15 +132,6 @@ def mailbox_detail_route(mailbox_id):
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
if mailbox.is_proton():
flash(
"Enabling PGP for a Proton Mail mailbox is redundant and does not add any security benefit",
"info",
)
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)
mailbox.pgp_public_key = request.form.get("pgp") mailbox.pgp_public_key = request.form.get("pgp")
try: try:
mailbox.pgp_finger_print = load_public_key_and_check( mailbox.pgp_finger_print = load_public_key_and_check(
@ -183,15 +158,8 @@ def mailbox_detail_route(mailbox_id):
elif request.form.get("form-name") == "toggle-pgp": elif request.form.get("form-name") == "toggle-pgp":
if request.form.get("pgp-enabled") == "on": if request.form.get("pgp-enabled") == "on":
if mailbox.is_proton(): mailbox.disable_pgp = False
mailbox.disable_pgp = True flash(f"PGP is enabled on {mailbox.email}", "success")
flash(
"Enabling PGP for a Proton Mail mailbox is redundant and does not add any security benefit",
"info",
)
else:
mailbox.disable_pgp = False
flash(f"PGP is enabled on {mailbox.email}", "info")
else: else:
mailbox.disable_pgp = True mailbox.disable_pgp = True
flash(f"PGP is disabled on {mailbox.email}", "info") flash(f"PGP is disabled on {mailbox.email}", "info")
@ -202,16 +170,25 @@ def mailbox_detail_route(mailbox_id):
) )
elif request.form.get("form-name") == "generic-subject": elif request.form.get("form-name") == "generic-subject":
if request.form.get("action") == "save": if request.form.get("action") == "save":
if not mailbox.pgp_enabled():
flash(
"Generic subject can only be used on PGP-enabled mailbox",
"error",
)
return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
)
mailbox.generic_subject = request.form.get("generic-subject") mailbox.generic_subject = request.form.get("generic-subject")
Session.commit() Session.commit()
flash("Generic subject is enabled", "success") flash("Generic subject for PGP-encrypted email is enabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
elif request.form.get("action") == "remove": elif request.form.get("action") == "remove":
mailbox.generic_subject = None mailbox.generic_subject = None
Session.commit() Session.commit()
flash("Generic subject is disabled", "success") flash("Generic subject for PGP-encrypted email is disabled", "success")
return redirect( return redirect(
url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id) url_for("dashboard.mailbox_detail_route", mailbox_id=mailbox_id)
) )
@ -221,7 +198,7 @@ def mailbox_detail_route(mailbox_id):
def verify_mailbox_change(user, mailbox, new_email): def verify_mailbox_change(user, mailbox, new_email):
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
mailbox_id_signed = s.sign(str(mailbox.id)).decode() mailbox_id_signed = s.sign(str(mailbox.id)).decode()
verification_url = ( verification_url = (
f"{URL}/dashboard/mailbox/confirm_change?mailbox_id={mailbox_id_signed}" f"{URL}/dashboard/mailbox/confirm_change?mailbox_id={mailbox_id_signed}"
@ -231,7 +208,7 @@ def verify_mailbox_change(user, mailbox, new_email):
new_email, new_email,
"Confirm mailbox change on SimpleLogin", "Confirm mailbox change on SimpleLogin",
render( render(
"transactional/verify-mailbox-change.txt.jinja2", "transactional/verify-mailbox-change.txt",
user=user, user=user,
link=verification_url, link=verification_url,
mailbox_email=mailbox.email, mailbox_email=mailbox.email,
@ -273,11 +250,11 @@ def cancel_mailbox_change_route(mailbox_id):
@dashboard_bp.route("/mailbox/confirm_change") @dashboard_bp.route("/mailbox/confirm_change")
def mailbox_confirm_change_route(): def mailbox_confirm_change_route():
s = TimestampSigner(MAILBOX_SECRET) s = Signer(MAILBOX_SECRET)
signed_mailbox_id = request.args.get("mailbox_id") signed_mailbox_id = request.args.get("mailbox_id")
try: try:
mailbox_id = int(s.unsign(signed_mailbox_id, max_age=900)) mailbox_id = int(s.unsign(signed_mailbox_id))
except Exception: except Exception:
flash("Invalid link", "error") flash("Invalid link", "error")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))

View File

@ -5,7 +5,6 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.models import RecoveryCode from app.models import RecoveryCode
from app.utils import CSRFValidationForm
@dashboard_bp.route("/mfa_cancel", methods=["GET", "POST"]) @dashboard_bp.route("/mfa_cancel", methods=["GET", "POST"])
@ -16,13 +15,8 @@ def mfa_cancel():
flash("you don't have MFA enabled", "warning") flash("you don't have MFA enabled", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
csrf_form = CSRFValidationForm()
# user cancels TOTP # user cancels TOTP
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate():
flash("Invalid request", "warning")
return redirect(request.url)
current_user.enable_otp = False current_user.enable_otp = False
current_user.otp_secret = None current_user.otp_secret = None
Session.commit() Session.commit()
@ -34,4 +28,4 @@ def mfa_cancel():
flash("TOTP is now disabled", "warning") flash("TOTP is now disabled", "warning")
return redirect(url_for("dashboard.index")) return redirect(url_for("dashboard.index"))
return render_template("dashboard/mfa_cancel.html", csrf_form=csrf_form) return render_template("dashboard/mfa_cancel.html")

View File

@ -8,7 +8,6 @@ from app.dashboard.base import dashboard_bp
from app.dashboard.views.enter_sudo import sudo_required from app.dashboard.views.enter_sudo import sudo_required
from app.db import Session from app.db import Session
from app.log import LOG from app.log import LOG
from app.models import RecoveryCode
class OtpTokenForm(FlaskForm): class OtpTokenForm(FlaskForm):
@ -40,10 +39,8 @@ def mfa_setup():
current_user.last_otp = token current_user.last_otp = token
Session.commit() Session.commit()
flash("MFA has been activated", "success") flash("MFA has been activated", "success")
recovery_codes = RecoveryCode.generate(current_user)
return render_template( return redirect(url_for("dashboard.recovery_code_route"))
"dashboard/recovery_code.html", recovery_codes=recovery_codes
)
else: else:
flash("Incorrect token", "warning") flash("Incorrect token", "warning")

View File

@ -80,9 +80,8 @@ def pricing():
@dashboard_bp.route("/subscription_success") @dashboard_bp.route("/subscription_success")
@login_required @login_required
def subscription_success(): def subscription_success():
return render_template( flash("Thanks so much for supporting SimpleLogin!", "success")
"dashboard/thank-you.html", return redirect(url_for("dashboard.index"))
)
@dashboard_bp.route("/coinbase_checkout") @dashboard_bp.route("/coinbase_checkout")

View File

@ -0,0 +1,30 @@
from flask import render_template, flash, redirect, url_for, request
from flask_login import login_required, current_user
from app.dashboard.base import dashboard_bp
from app.log import LOG
from app.models import RecoveryCode
@dashboard_bp.route("/recovery_code", methods=["GET", "POST"])
@login_required
def recovery_code_route():
if not current_user.two_factor_authentication_enabled():
flash("you need to enable either TOTP or WebAuthn", "warning")
return redirect(url_for("dashboard.index"))
recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "GET" and not recovery_codes:
# user arrives at this page for the first time
LOG.d("%s has no recovery keys, generate", current_user)
RecoveryCode.generate(current_user)
recovery_codes = RecoveryCode.filter_by(user_id=current_user.id).all()
if request.method == "POST":
RecoveryCode.generate(current_user)
flash("New recovery codes generated", "success")
return redirect(url_for("dashboard.recovery_code_route"))
return render_template(
"dashboard/recovery_code.html", recovery_codes=recovery_codes
)

View File

@ -1,5 +1,4 @@
from io import BytesIO from io import BytesIO
from typing import Optional, Tuple
import arrow import arrow
from flask import ( from flask import (
@ -12,40 +11,48 @@ from flask import (
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
from flask_wtf.file import FileField from flask_wtf.file import FileField
from newrelic import agent
from typing import Optional, Tuple
from wtforms import StringField, validators from wtforms import StringField, validators
from wtforms.fields.html5 import EmailField
from app import s3, user_settings from app import s3, email_utils
from app.config import ( from app.config import (
URL,
FIRST_ALIAS_DOMAIN, FIRST_ALIAS_DOMAIN,
ALIAS_RANDOM_SUFFIX_LENGTH, ALIAS_RANDOM_SUFFIX_LENGTH,
CONNECT_WITH_PROTON,
) )
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import (
email_can_be_used_as_mailbox,
personal_email_already_used,
)
from app.errors import ProtonPartnerNotSetUp from app.errors import ProtonPartnerNotSetUp
from app.extensions import limiter
from app.image_validation import detect_image_format, ImageFormat from app.image_validation import detect_image_format, ImageFormat
from app.jobs.export_user_data_job import ExportUserDataJob
from app.log import LOG from app.log import LOG
from app.models import ( from app.models import (
BlockBehaviourEnum, BlockBehaviourEnum,
PlanEnum, PlanEnum,
File, File,
ResetPasswordCode,
EmailChange, EmailChange,
User,
Alias,
CustomDomain,
AliasGeneratorEnum, AliasGeneratorEnum,
AliasSuffixEnum, AliasSuffixEnum,
ManualSubscription, ManualSubscription,
SenderFormatEnum, SenderFormatEnum,
SLDomain,
CoinbaseSubscription, CoinbaseSubscription,
AppleSubscription, AppleSubscription,
PartnerUser, PartnerUser,
PartnerSubscription, PartnerSubscription,
UnsubscribeBehaviourEnum,
)
from app.proton.utils import get_proton_partner
from app.utils import (
random_string,
CSRFValidationForm,
) )
from app.proton.utils import is_connect_with_proton_enabled, get_proton_partner
from app.utils import random_string, sanitize_email
class SettingForm(FlaskForm): class SettingForm(FlaskForm):
@ -53,6 +60,12 @@ class SettingForm(FlaskForm):
profile_picture = FileField("Profile Picture") profile_picture = FileField("Profile Picture")
class ChangeEmailForm(FlaskForm):
email = EmailField(
"email", validators=[validators.DataRequired(), validators.Email()]
)
class PromoCodeForm(FlaskForm): class PromoCodeForm(FlaskForm):
code = StringField("Name", validators=[validators.DataRequired()]) code = StringField("Name", validators=[validators.DataRequired()])
@ -86,11 +99,10 @@ def get_partner_subscription_and_name(
@dashboard_bp.route("/setting", methods=["GET", "POST"]) @dashboard_bp.route("/setting", methods=["GET", "POST"])
@login_required @login_required
@limiter.limit("5/minute", methods=["POST"])
def setting(): def setting():
form = SettingForm() form = SettingForm()
promo_form = PromoCodeForm() promo_form = PromoCodeForm()
csrf_form = CSRFValidationForm() change_email_form = ChangeEmailForm()
email_change = EmailChange.get_by(user_id=current_user.id) email_change = EmailChange.get_by(user_id=current_user.id)
if email_change: if email_change:
@ -99,10 +111,67 @@ def setting():
pending_email = None pending_email = None
if request.method == "POST": if request.method == "POST":
if not csrf_form.validate(): if request.form.get("form-name") == "update-email":
flash("Invalid request", "warning") if change_email_form.validate():
return redirect(url_for("dashboard.setting")) # whether user can proceed with the email update
new_email_valid = True
if (
sanitize_email(change_email_form.email.data) != current_user.email
and not pending_email
):
new_email = sanitize_email(change_email_form.email.data)
# check if this email is not already used
if personal_email_already_used(new_email) or Alias.get_by(
email=new_email
):
flash(f"Email {new_email} already used", "error")
new_email_valid = False
elif not email_can_be_used_as_mailbox(new_email):
flash(
"You cannot use this email address as your personal inbox.",
"error",
)
new_email_valid = False
# a pending email change with the same email exists from another user
elif EmailChange.get_by(new_email=new_email):
other_email_change: EmailChange = EmailChange.get_by(
new_email=new_email
)
LOG.w(
"Another user has a pending %s with the same email address. Current user:%s",
other_email_change,
current_user,
)
if other_email_change.is_expired():
LOG.d(
"delete the expired email change %s", other_email_change
)
EmailChange.delete(other_email_change.id)
Session.commit()
else:
flash(
"You cannot use this email address as your personal inbox.",
"error",
)
new_email_valid = False
if new_email_valid:
email_change = EmailChange.create(
user_id=current_user.id,
code=random_string(
60
), # todo: make sure the code is unique
new_email=new_email,
)
Session.commit()
send_change_email_confirmation(current_user, email_change)
flash(
"A confirmation email is on the way, please check your inbox",
"success",
)
return redirect(url_for("dashboard.setting"))
if request.form.get("form-name") == "update-profile": if request.form.get("form-name") == "update-profile":
if form.validate(): if form.validate():
profile_updated = False profile_updated = False
@ -121,16 +190,6 @@ def setting():
) )
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
if current_user.profile_picture_id is not None:
current_profile_file = File.get_by(
id=current_user.profile_picture_id
)
if (
current_profile_file is not None
and current_profile_file.user_id == current_user.id
):
s3.delete(current_profile_file.path)
file_path = random_string(30) file_path = random_string(30)
file = File.create(user_id=current_user.id, path=file_path) file = File.create(user_id=current_user.id, path=file_path)
@ -146,6 +205,15 @@ def setting():
if profile_updated: if profile_updated:
flash("Your profile has been updated", "success") flash("Your profile has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "change-password":
flash(
"You are going to receive an email containing instructions to change your password",
"success",
)
send_reset_password_email(current_user)
return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "notification-preference": elif request.form.get("form-name") == "notification-preference":
choose = request.form.get("notification") choose = request.form.get("notification")
if choose == "on": if choose == "on":
@ -155,6 +223,7 @@ def setting():
Session.commit() Session.commit()
flash("Your notification preference has been updated", "success") flash("Your notification preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "change-alias-generator": elif request.form.get("form-name") == "change-alias-generator":
scheme = int(request.form.get("alias-generator-scheme")) scheme = int(request.form.get("alias-generator-scheme"))
if AliasGeneratorEnum.has_value(scheme): if AliasGeneratorEnum.has_value(scheme):
@ -162,17 +231,46 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "change-random-alias-default-domain": elif request.form.get("form-name") == "change-random-alias-default-domain":
default_domain = request.form.get("random-alias-default-domain") default_domain = request.form.get("random-alias-default-domain")
try:
user_settings.set_default_alias_domain(current_user, default_domain) if default_domain:
except user_settings.CannotSetAlias as e: sl_domain: SLDomain = SLDomain.get_by(domain=default_domain)
flash(e.msg, "error") if sl_domain:
return redirect(url_for("dashboard.setting")) if sl_domain.premium_only and not current_user.is_premium():
flash("You cannot use this domain", "error")
return redirect(url_for("dashboard.setting"))
current_user.default_alias_public_domain_id = sl_domain.id
current_user.default_alias_custom_domain_id = None
else:
custom_domain = CustomDomain.get_by(domain=default_domain)
if custom_domain:
# sanity check
if (
custom_domain.user_id != current_user.id
or not custom_domain.verified
):
LOG.w(
"%s cannot use domain %s", current_user, custom_domain
)
flash(f"Domain {default_domain} can't be used", "error")
return redirect(request.url)
else:
current_user.default_alias_custom_domain_id = (
custom_domain.id
)
current_user.default_alias_public_domain_id = None
else:
current_user.default_alias_custom_domain_id = None
current_user.default_alias_public_domain_id = None
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "random-alias-suffix": elif request.form.get("form-name") == "random-alias-suffix":
scheme = int(request.form.get("random-alias-suffix-generator")) scheme = int(request.form.get("random-alias-suffix-generator"))
if AliasSuffixEnum.has_value(scheme): if AliasSuffixEnum.has_value(scheme):
@ -180,6 +278,7 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "change-sender-format": elif request.form.get("form-name") == "change-sender-format":
sender_format = int(request.form.get("sender-format")) sender_format = int(request.form.get("sender-format"))
if SenderFormatEnum.has_value(sender_format): if SenderFormatEnum.has_value(sender_format):
@ -189,6 +288,7 @@ def setting():
flash("Your sender format preference has been updated", "success") flash("Your sender format preference has been updated", "success")
Session.commit() Session.commit()
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "replace-ra": elif request.form.get("form-name") == "replace-ra":
choose = request.form.get("replace-ra") choose = request.form.get("replace-ra")
if choose == "on": if choose == "on":
@ -198,21 +298,7 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "enable_data_breach_check":
if not current_user.is_premium():
flash("Only premium plan can enable data breach monitoring", "warning")
return redirect(url_for("dashboard.setting"))
choose = request.form.get("enable_data_breach_check")
if choose == "on":
LOG.i("User {current_user} has enabled data breach monitoring")
current_user.enable_data_breach_check = True
flash("Data breach monitoring is enabled", "success")
else:
LOG.i("User {current_user} has disabled data breach monitoring")
current_user.enable_data_breach_check = False
flash("Data breach monitoring is disabled", "info")
Session.commit()
return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "sender-in-ra": elif request.form.get("form-name") == "sender-in-ra":
choose = request.form.get("enable") choose = request.form.get("enable")
if choose == "on": if choose == "on":
@ -222,6 +308,7 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "expand-alias-info": elif request.form.get("form-name") == "expand-alias-info":
choose = request.form.get("enable") choose = request.form.get("enable")
if choose == "on": if choose == "on":
@ -241,16 +328,11 @@ def setting():
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "one-click-unsubscribe": elif request.form.get("form-name") == "one-click-unsubscribe":
choose = request.form.get("unsubscribe-behaviour") choose = request.form.get("enable")
if choose == UnsubscribeBehaviourEnum.PreserveOriginal.name: if choose == "on":
current_user.unsub_behaviour = UnsubscribeBehaviourEnum.PreserveOriginal current_user.one_click_unsubscribe_block_sender = True
elif choose == UnsubscribeBehaviourEnum.DisableAlias.name:
current_user.unsub_behaviour = UnsubscribeBehaviourEnum.DisableAlias
elif choose == UnsubscribeBehaviourEnum.BlockContact.name:
current_user.unsub_behaviour = UnsubscribeBehaviourEnum.BlockContact
else: else:
flash("There was an error. Please try again", "warning") current_user.one_click_unsubscribe_block_sender = False
return redirect(url_for("dashboard.setting"))
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
@ -283,6 +365,14 @@ def setting():
Session.commit() Session.commit()
flash("Your preference has been updated", "success") flash("Your preference has been updated", "success")
return redirect(url_for("dashboard.setting")) return redirect(url_for("dashboard.setting"))
elif request.form.get("form-name") == "send-full-user-report":
if ExportUserDataJob(current_user).store_job_in_db():
flash(
"You will receive your SimpleLogin data via email shortly",
"success",
)
else:
flash("An export of your data is currently in progress", "error")
manual_sub = ManualSubscription.get_by(user_id=current_user.id) manual_sub = ManualSubscription.get_by(user_id=current_user.id)
apple_sub = AppleSubscription.get_by(user_id=current_user.id) apple_sub = AppleSubscription.get_by(user_id=current_user.id)
@ -299,15 +389,14 @@ def setting():
return render_template( return render_template(
"dashboard/setting.html", "dashboard/setting.html",
csrf_form=csrf_form,
form=form, form=form,
PlanEnum=PlanEnum, PlanEnum=PlanEnum,
SenderFormatEnum=SenderFormatEnum, SenderFormatEnum=SenderFormatEnum,
BlockBehaviourEnum=BlockBehaviourEnum, BlockBehaviourEnum=BlockBehaviourEnum,
promo_form=promo_form, promo_form=promo_form,
change_email_form=change_email_form,
pending_email=pending_email, pending_email=pending_email,
AliasGeneratorEnum=AliasGeneratorEnum, AliasGeneratorEnum=AliasGeneratorEnum,
UnsubscribeBehaviourEnum=UnsubscribeBehaviourEnum,
manual_sub=manual_sub, manual_sub=manual_sub,
partner_sub=partner_sub, partner_sub=partner_sub,
partner_name=partner_name, partner_name=partner_name,
@ -316,6 +405,81 @@ def setting():
coinbase_sub=coinbase_sub, coinbase_sub=coinbase_sub,
FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN, FIRST_ALIAS_DOMAIN=FIRST_ALIAS_DOMAIN,
ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH, ALIAS_RAND_SUFFIX_LENGTH=ALIAS_RANDOM_SUFFIX_LENGTH,
connect_with_proton=CONNECT_WITH_PROTON, connect_with_proton=is_connect_with_proton_enabled(),
proton_linked_account=proton_linked_account, proton_linked_account=proton_linked_account,
) )
def send_reset_password_email(user):
"""
generate a new ResetPasswordCode and send it over email to user
"""
# the activation code is valid for 1h
reset_password_code = ResetPasswordCode.create(
user_id=user.id, code=random_string(60)
)
Session.commit()
reset_password_link = f"{URL}/auth/reset_password?code={reset_password_code.code}"
email_utils.send_reset_password_email(user.email, reset_password_link)
def send_change_email_confirmation(user: User, email_change: EmailChange):
"""
send confirmation email to the new email address
"""
link = f"{URL}/auth/change_email?code={email_change.code}"
email_utils.send_change_email(email_change.new_email, user.email, link)
@dashboard_bp.route("/resend_email_change", methods=["GET", "POST"])
@login_required
def resend_email_change():
email_change = EmailChange.get_by(user_id=current_user.id)
if email_change:
# extend email change expiration
email_change.expired = arrow.now().shift(hours=12)
Session.commit()
send_change_email_confirmation(current_user, email_change)
flash("A confirmation email is on the way, please check your inbox", "success")
return redirect(url_for("dashboard.setting"))
else:
flash(
"You have no pending email change. Redirect back to Setting page", "warning"
)
return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/cancel_email_change", methods=["GET", "POST"])
@login_required
def cancel_email_change():
email_change = EmailChange.get_by(user_id=current_user.id)
if email_change:
EmailChange.delete(email_change.id)
Session.commit()
flash("Your email change is cancelled", "success")
return redirect(url_for("dashboard.setting"))
else:
flash(
"You have no pending email change. Redirect back to Setting page", "warning"
)
return redirect(url_for("dashboard.setting"))
@dashboard_bp.route("/unlink_proton_account", methods=["GET", "POST"])
@login_required
def unlink_proton_account():
proton_partner = get_proton_partner()
partner_user = PartnerUser.get_by(
user_id=current_user.id, partner_id=proton_partner.id
)
if partner_user is not None:
PartnerUser.delete(partner_user.id)
Session.commit()
flash("Your Proton account has been unlinked", "success")
agent.record_custom_event("AccountUnlinked", {"partner": proton_partner.name})
return redirect(url_for("dashboard.setting"))

View File

@ -2,10 +2,7 @@ import re
from flask import render_template, request, redirect, url_for, flash from flask import render_template, request, redirect, url_for, flash
from flask_login import login_required, current_user from flask_login import login_required, current_user
from flask_wtf import FlaskForm
from wtforms import StringField, validators
from app import parallel_limiter
from app.config import MAX_NB_SUBDOMAIN from app.config import MAX_NB_SUBDOMAIN
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.errors import SubdomainInTrashError from app.errors import SubdomainInTrashError
@ -16,18 +13,8 @@ from app.models import CustomDomain, Mailbox, SLDomain
_SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}" _SUBDOMAIN_PATTERN = r"[0-9a-z-]{1,}"
class NewSubdomainForm(FlaskForm):
domain = StringField(
"domain", validators=[validators.DataRequired(), validators.Length(max=64)]
)
subdomain = StringField(
"subdomain", validators=[validators.DataRequired(), validators.Length(max=64)]
)
@dashboard_bp.route("/subdomain", methods=["GET", "POST"]) @dashboard_bp.route("/subdomain", methods=["GET", "POST"])
@login_required @login_required
@parallel_limiter.lock(only_when=lambda: request.method == "POST")
def subdomain_route(): def subdomain_route():
if not current_user.subdomain_is_available(): if not current_user.subdomain_is_available():
flash("Unknown error, redirect to the home page", "error") flash("Unknown error, redirect to the home page", "error")
@ -39,13 +26,9 @@ def subdomain_route():
).all() ).all()
errors = {} errors = {}
new_subdomain_form = NewSubdomainForm()
if request.method == "POST": if request.method == "POST":
if request.form.get("form-name") == "create": if request.form.get("form-name") == "create":
if not new_subdomain_form.validate():
flash("Invalid new subdomain", "warning")
return redirect(url_for("dashboard.subdomain_route"))
if not current_user.is_premium(): if not current_user.is_premium():
flash("Only premium plan can add subdomain", "warning") flash("Only premium plan can add subdomain", "warning")
return redirect(request.url) return redirect(request.url)
@ -56,8 +39,8 @@ def subdomain_route():
) )
return redirect(request.url) return redirect(request.url)
subdomain = new_subdomain_form.subdomain.data.lower().strip() subdomain = request.form.get("subdomain").lower().strip()
domain = new_subdomain_form.domain.data.lower().strip() domain = request.form.get("domain").lower().strip()
if len(subdomain) < 3: if len(subdomain) < 3:
flash("Subdomain must have at least 3 characters", "error") flash("Subdomain must have at least 3 characters", "error")
@ -125,5 +108,4 @@ def subdomain_route():
sl_domains=sl_domains, sl_domains=sl_domains,
errors=errors, errors=errors,
subdomains=subdomains, subdomains=subdomains,
new_subdomain_form=new_subdomain_form,
) )

View File

@ -8,14 +8,11 @@ from app.db import Session
from flask import redirect, url_for, flash, request, render_template from flask import redirect, url_for, flash, request, render_template
from flask_login import login_required, current_user from flask_login import login_required, current_user
from app import alias_utils
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.handler.unsubscribe_encoder import UnsubscribeAction
from app.handler.unsubscribe_handler import UnsubscribeHandler
from app.models import Alias, Contact from app.models import Alias, Contact
@dashboard_bp.route("/unsubscribe/<int:alias_id>", methods=["GET", "POST"]) @dashboard_bp.route("/unsubscribe/<alias_id>", methods=["GET", "POST"])
@login_required @login_required
def unsubscribe(alias_id): def unsubscribe(alias_id):
alias = Alias.get(alias_id) alias = Alias.get(alias_id)
@ -32,7 +29,7 @@ def unsubscribe(alias_id):
# automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058 # automatic unsubscribe, according to https://tools.ietf.org/html/rfc8058
if request.method == "POST": if request.method == "POST":
alias_utils.change_alias_status(alias, False) alias.enabled = False
flash(f"Alias {alias.email} has been blocked", "success") flash(f"Alias {alias.email} has been blocked", "success")
Session.commit() Session.commit()
@ -41,7 +38,7 @@ def unsubscribe(alias_id):
return render_template("dashboard/unsubscribe.html", alias=alias.email) return render_template("dashboard/unsubscribe.html", alias=alias.email)
@dashboard_bp.route("/block_contact/<int:contact_id>", methods=["GET", "POST"]) @dashboard_bp.route("/block_contact/<contact_id>", methods=["GET", "POST"])
@login_required @login_required
def block_contact(contact_id): def block_contact(contact_id):
contact = Contact.get(contact_id) contact = Contact.get(contact_id)
@ -71,43 +68,3 @@ def block_contact(contact_id):
) )
else: # ask user confirmation else: # ask user confirmation
return render_template("dashboard/block_contact.html", contact=contact) return render_template("dashboard/block_contact.html", contact=contact)
@dashboard_bp.route("/unsubscribe/encoded/<encoded_request>", methods=["GET"])
@login_required
def encoded_unsubscribe(encoded_request: str):
unsub_data = UnsubscribeHandler().handle_unsubscribe_from_request(
current_user, encoded_request
)
if not unsub_data:
flash("Invalid unsubscribe request", "error")
return redirect(url_for("dashboard.index"))
if unsub_data.action == UnsubscribeAction.DisableAlias:
alias = Alias.get(unsub_data.data)
flash(f"Alias {alias.email} has been blocked", "success")
return redirect(url_for("dashboard.index", highlight_alias_id=alias.id))
if unsub_data.action == UnsubscribeAction.DisableContact:
contact = Contact.get(unsub_data.data)
flash(f"Emails sent from {contact.website_email} are now blocked", "success")
return redirect(
url_for(
"dashboard.alias_contact_manager",
alias_id=contact.alias_id,
highlight_contact_id=contact.id,
)
)
if unsub_data.action == UnsubscribeAction.UnsubscribeNewsletter:
flash("You've unsubscribed from the newsletter", "success")
return redirect(
url_for(
"dashboard.index",
)
)
if unsub_data.action == UnsubscribeAction.OriginalUnsubscribeMailto:
flash("The original unsubscribe request has been forwarded", "success")
return redirect(
url_for(
"dashboard.index",
)
)
return redirect(url_for("dashboard.index"))

View File

@ -3,12 +3,9 @@ from sqlalchemy import create_engine
from sqlalchemy.orm import scoped_session from sqlalchemy.orm import scoped_session
from sqlalchemy.orm import sessionmaker from sqlalchemy.orm import sessionmaker
from app import config from app.config import DB_URI
engine = create_engine(DB_URI)
engine = create_engine(
config.DB_URI, connect_args={"application_name": config.DB_CONN_NAME}
)
connection = engine.connect() connection = engine.connect()
Session = scoped_session(sessionmaker(bind=connection)) Session = scoped_session(sessionmaker(bind=connection))

View File

@ -1,3 +1 @@
from .views import index, new_client, client_detail from .views import index, new_client, client_detail
__all__ = ["index", "new_client", "client_detail"]

View File

@ -1,5 +1,4 @@
from io import BytesIO from io import BytesIO
from urllib.parse import urlparse
from flask import request, render_template, redirect, url_for, flash from flask import request, render_template, redirect, url_for, flash
from flask_login import current_user, login_required from flask_login import current_user, login_required
@ -12,7 +11,6 @@ from app.config import ADMIN_EMAIL
from app.db import Session from app.db import Session
from app.developer.base import developer_bp from app.developer.base import developer_bp
from app.email_utils import send_email from app.email_utils import send_email
from app.image_validation import detect_image_format, ImageFormat
from app.log import LOG from app.log import LOG
from app.models import Client, RedirectUri, File, Referral from app.models import Client, RedirectUri, File, Referral
from app.utils import random_string from app.utils import random_string
@ -48,25 +46,16 @@ def client_detail(client_id):
approval_form.description.data = client.description approval_form.description.data = client.description
if action == "edit" and form.validate_on_submit(): if action == "edit" and form.validate_on_submit():
parsed_url = urlparse(form.url.data)
if parsed_url.scheme != "https":
flash("Only https urls are allowed", "error")
return redirect(url_for("developer.index"))
client.name = form.name.data client.name = form.name.data
client.home_url = form.url.data client.home_url = form.url.data
if form.icon.data: if form.icon.data:
icon_data = form.icon.data.read(10240) # todo: remove current icon if any
if detect_image_format(icon_data) == ImageFormat.Unknown: # todo: handle remove icon
flash("Unknown file format", "warning")
return redirect(url_for("developer.index"))
if client.icon:
s3.delete(client.icon_id)
File.delete(client.icon)
file_path = random_string(30) file_path = random_string(30)
file = File.create(path=file_path, user_id=client.user_id) file = File.create(path=file_path, user_id=client.user_id)
s3.upload_from_bytesio(file_path, BytesIO(icon_data)) s3.upload_from_bytesio(file_path, BytesIO(form.icon.data.read()))
Session.flush() Session.flush()
LOG.d("upload file %s to s3", file) LOG.d("upload file %s to s3", file)
@ -98,7 +87,7 @@ def client_detail(client_id):
) )
flash( flash(
"Thanks for submitting, we are informed and will come back to you asap!", f"Thanks for submitting, we are informed and will come back to you asap!",
"success", "success",
) )

View File

@ -1,5 +1,3 @@
from urllib.parse import urlparse
from flask import render_template, redirect, url_for, flash from flask import render_template, redirect, url_for, flash
from flask_login import current_user, login_required from flask_login import current_user, login_required
from flask_wtf import FlaskForm from flask_wtf import FlaskForm
@ -22,10 +20,6 @@ def new_client():
if form.validate_on_submit(): if form.validate_on_submit():
client = Client.create_new(form.name.data, current_user.id) client = Client.create_new(form.name.data, current_user.id)
parsed_url = urlparse(form.url.data)
if parsed_url.scheme != "https":
flash("Only https urls are allowed", "error")
return redirect(url_for("developer.new_client"))
client.home_url = form.url.data client.home_url = form.url.data
Session.commit() Session.commit()

View File

@ -1,3 +1 @@
from .views import index from .views import index
__all__ = ["index"]

View File

@ -1,13 +1,100 @@
from abc import ABC, abstractmethod from app import config
from typing import List, Tuple, Optional from typing import Optional, List, Tuple
import dns.resolver import dns.resolver
from app.config import NAMESERVERS
def _get_dns_resolver():
my_resolver = dns.resolver.Resolver()
my_resolver.nameservers = config.NAMESERVERS
return my_resolver
def get_ns(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "NS", search=True)
except Exception:
return []
return [a.to_text() for a in answers]
def get_cname_record(hostname) -> Optional[str]:
"""Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end"""
try:
answers = _get_dns_resolver().resolve(hostname, "CNAME", search=True)
except Exception:
return None
for a in answers:
ret = a.to_text()
return ret[:-1]
return None
def get_mx_domains(hostname) -> [(int, str)]:
"""return list of (priority, domain name).
domain name ends with a "." at the end.
"""
try:
answers = _get_dns_resolver().resolve(hostname, "MX", search=True)
except Exception:
return []
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return ret
_include_spf = "include:" _include_spf = "include:"
def get_spf_domain(hostname) -> [str]:
"""return all domains listed in *include:*"""
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(part[part.find(_include_spf) + len(_include_spf) :])
return ret
def get_txt_record(hostname) -> [str]:
try:
answers = _get_dns_resolver().resolve(hostname, "TXT", search=True)
except Exception:
return []
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
record = record.decode() # record is bytes
ret.append(record)
return ret
def is_mx_equivalent( def is_mx_equivalent(
mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]] mx_domains: List[Tuple[int, str]], ref_mx_domains: List[Tuple[int, str]]
) -> bool: ) -> bool:
@ -18,127 +105,16 @@ def is_mx_equivalent(
The priority order is taken into account but not the priority number. The priority order is taken into account but not the priority number.
For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)] For example, [(1, domain1), (2, domain2)] is equivalent to [(10, domain1), (20, domain2)]
""" """
mx_domains = sorted(mx_domains, key=lambda x: x[0]) mx_domains = sorted(mx_domains, key=lambda priority_domain: priority_domain[0])
ref_mx_domains = sorted(ref_mx_domains, key=lambda x: x[0]) ref_mx_domains = sorted(
ref_mx_domains, key=lambda priority_domain: priority_domain[0]
)
if len(mx_domains) < len(ref_mx_domains): if len(mx_domains) < len(ref_mx_domains):
return False return False
for i in range(len(ref_mx_domains)): for i in range(0, len(ref_mx_domains)):
if mx_domains[i][1] != ref_mx_domains[i][1]: if mx_domains[i][1] != ref_mx_domains[i][1]:
return False return False
return True return True
class DNSClient(ABC):
@abstractmethod
def get_cname_record(self, hostname: str) -> Optional[str]:
pass
@abstractmethod
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
pass
def get_spf_domain(self, hostname: str) -> List[str]:
"""
return all domains listed in *include:*
"""
try:
records = self.get_txt_record(hostname)
ret = []
for record in records:
if record.startswith("v=spf1"):
parts = record.split(" ")
for part in parts:
if part.startswith(_include_spf):
ret.append(
part[part.find(_include_spf) + len(_include_spf) :]
)
return ret
except Exception:
return []
@abstractmethod
def get_txt_record(self, hostname: str) -> List[str]:
pass
class NetworkDNSClient(DNSClient):
def __init__(self, nameservers: List[str]):
self._resolver = dns.resolver.Resolver()
self._resolver.nameservers = nameservers
def get_cname_record(self, hostname: str) -> Optional[str]:
"""
Return the CNAME record if exists for a domain, WITHOUT the trailing period at the end
"""
try:
answers = self._resolver.resolve(hostname, "CNAME", search=True)
for a in answers:
ret = a.to_text()
return ret[:-1]
except Exception:
return None
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
"""
return list of (priority, domain name) sorted by priority (lowest priority first)
domain name ends with a "." at the end.
"""
try:
answers = self._resolver.resolve(hostname, "MX", search=True)
ret = []
for a in answers:
record = a.to_text() # for ex '20 alt2.aspmx.l.google.com.'
parts = record.split(" ")
ret.append((int(parts[0]), parts[1]))
return sorted(ret, key=lambda x: x[0])
except Exception:
return []
def get_txt_record(self, hostname: str) -> List[str]:
try:
answers = self._resolver.resolve(hostname, "TXT", search=True)
ret = []
for a in answers: # type: dns.rdtypes.ANY.TXT.TXT
for record in a.strings:
ret.append(record.decode())
return ret
except Exception:
return []
class InMemoryDNSClient(DNSClient):
def __init__(self):
self.cname_records: dict[str, Optional[str]] = {}
self.mx_records: dict[str, List[Tuple[int, str]]] = {}
self.spf_records: dict[str, List[str]] = {}
self.txt_records: dict[str, List[str]] = {}
def set_cname_record(self, hostname: str, cname: str):
self.cname_records[hostname] = cname
def set_mx_records(self, hostname: str, mx_list: List[Tuple[int, str]]):
self.mx_records[hostname] = mx_list
def set_txt_record(self, hostname: str, txt_list: List[str]):
self.txt_records[hostname] = txt_list
def get_cname_record(self, hostname: str) -> Optional[str]:
return self.cname_records.get(hostname)
def get_mx_domains(self, hostname: str) -> List[Tuple[int, str]]:
mx_list = self.mx_records.get(hostname, [])
return sorted(mx_list, key=lambda x: x[0])
def get_txt_record(self, hostname: str) -> List[str]:
return self.txt_records.get(hostname, [])
def get_network_dns_client() -> NetworkDNSClient:
return NetworkDNSClient(NAMESERVERS)
def get_mx_domains(hostname: str) -> [(int, str)]:
return get_network_dns_client().get_mx_domains(hostname)

View File

@ -20,8 +20,6 @@ X_SPAM_STATUS = "X-Spam-Status"
LIST_UNSUBSCRIBE = "List-Unsubscribe" LIST_UNSUBSCRIBE = "List-Unsubscribe"
LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post" LIST_UNSUBSCRIBE_POST = "List-Unsubscribe-Post"
RETURN_PATH = "Return-Path" RETURN_PATH = "Return-Path"
AUTHENTICATION_RESULTS = "Authentication-Results"
SL_QUEUE_ID = "X-SL-Queue-Id"
# headers used to DKIM sign in order of preference # headers used to DKIM sign in order of preference
DKIM_HEADERS = [ DKIM_HEADERS = [
@ -34,7 +32,6 @@ DKIM_HEADERS = [
SL_DIRECTION = "X-SimpleLogin-Type" SL_DIRECTION = "X-SimpleLogin-Type"
SL_EMAIL_LOG_ID = "X-SimpleLogin-EmailLog-ID" SL_EMAIL_LOG_ID = "X-SimpleLogin-EmailLog-ID"
SL_ENVELOPE_FROM = "X-SimpleLogin-Envelope-From" SL_ENVELOPE_FROM = "X-SimpleLogin-Envelope-From"
SL_ORIGINAL_FROM = "X-SimpleLogin-Original-From"
SL_ENVELOPE_TO = "X-SimpleLogin-Envelope-To" SL_ENVELOPE_TO = "X-SimpleLogin-Envelope-To"
SL_CLIENT_IP = "X-SimpleLogin-Client-IP" SL_CLIENT_IP = "X-SimpleLogin-Client-IP"

View File

@ -31,7 +31,11 @@ E402 = "421 SL E402 Encryption failed - Retry later"
# E403 = "421 SL E403 Retry later" # E403 = "421 SL E403 Retry later"
E404 = "421 SL E404 Unexpected error - Retry later" E404 = "421 SL E404 Unexpected error - Retry later"
E405 = "421 SL E405 Mailbox domain problem - Retry later" E405 = "421 SL E405 Mailbox domain problem - Retry later"
E406 = "421 SL E406 Retry later"
E407 = "421 SL E407 Retry later" E407 = "421 SL E407 Retry later"
E408 = "421 SL E408 Retry later"
E409 = "421 SL E409 Retry later"
E410 = "421 SL E410 Retry later"
# endregion # endregion
# region 5** errors # region 5** errors
@ -60,5 +64,4 @@ E522 = (
) )
E523 = "550 SL E523 Unknown error" E523 = "550 SL E523 Unknown error"
E524 = "550 SL E524 Wrong use of reverse-alias" E524 = "550 SL E524 Wrong use of reverse-alias"
E525 = "550 SL E525 Alias loop"
# endregion # endregion

View File

@ -14,7 +14,7 @@ from email.header import decode_header, Header
from email.message import Message, EmailMessage from email.message import Message, EmailMessage
from email.mime.multipart import MIMEMultipart from email.mime.multipart import MIMEMultipart
from email.mime.text import MIMEText from email.mime.text import MIMEText
from email.utils import make_msgid, formatdate, formataddr from email.utils import make_msgid, formatdate
from smtplib import SMTP, SMTPException from smtplib import SMTP, SMTPException
from typing import Tuple, List, Optional, Union from typing import Tuple, List, Optional, Union
@ -33,9 +33,31 @@ from flanker.addresslib import address
from flanker.addresslib.address import EmailAddress from flanker.addresslib.address import EmailAddress
from jinja2 import Environment, FileSystemLoader from jinja2 import Environment, FileSystemLoader
from sqlalchemy import func from sqlalchemy import func
from flask_login import current_user
from app import config from app.config import (
ROOT_DIR,
POSTFIX_SERVER,
DKIM_SELECTOR,
DKIM_PRIVATE_KEY,
ALIAS_DOMAINS,
POSTFIX_SUBMISSION_TLS,
MAX_NB_EMAIL_FREE_PLAN,
MAX_ALERT_24H,
POSTFIX_PORT,
URL,
LANDING_PAGE_URL,
EMAIL_DOMAIN,
ALERT_DIRECTORY_DISABLED_ALIAS_CREATION,
ALERT_SPF,
ALERT_INVALID_TOTP_LOGIN,
TEMP_DIR,
ALIAS_AUTOMATIC_DISABLE,
RSPAMD_SIGN_DKIM,
NOREPLY,
VERP_PREFIX,
VERP_MESSAGE_LIFETIME,
VERP_EMAIL_SECRET,
)
from app.db import Session from app.db import Session
from app.dns_utils import get_mx_domains from app.dns_utils import get_mx_domains
from app.email import headers from app.email import headers
@ -55,7 +77,6 @@ from app.models import (
IgnoreBounceSender, IgnoreBounceSender,
InvalidMailboxDomain, InvalidMailboxDomain,
VerpType, VerpType,
available_sl_email,
) )
from app.utils import ( from app.utils import (
random_string, random_string,
@ -69,42 +90,32 @@ VERP_TIME_START = 1640995200
VERP_HMAC_ALGO = "sha3-224" VERP_HMAC_ALGO = "sha3-224"
def render(template_name: str, user: Optional[User], **kwargs) -> str: def render(template_name, **kwargs) -> str:
templates_dir = os.path.join(config.ROOT_DIR, "templates", "emails") templates_dir = os.path.join(ROOT_DIR, "templates", "emails")
env = Environment(loader=FileSystemLoader(templates_dir)) env = Environment(loader=FileSystemLoader(templates_dir))
template = env.get_template(template_name) template = env.get_template(template_name)
if user is None:
if current_user and current_user.is_authenticated:
user = current_user
use_partner_template = False
if user:
use_partner_template = user.has_used_alias_from_partner()
kwargs["user"] = user
return template.render( return template.render(
MAX_NB_EMAIL_FREE_PLAN=config.MAX_NB_EMAIL_FREE_PLAN, MAX_NB_EMAIL_FREE_PLAN=MAX_NB_EMAIL_FREE_PLAN,
URL=config.URL, URL=URL,
LANDING_PAGE_URL=config.LANDING_PAGE_URL, LANDING_PAGE_URL=LANDING_PAGE_URL,
YEAR=arrow.now().year, YEAR=arrow.now().year,
USE_PARTNER_TEMPLATE=use_partner_template,
**kwargs, **kwargs,
) )
def send_welcome_email(user): def send_welcome_email(user):
comm_email, unsubscribe_link, via_email = user.get_communication_email() to_email, unsubscribe_link, via_email = user.get_communication_email()
if not comm_email: if not to_email:
return return
# whether this email is sent to an alias # whether this email is sent to an alias
alias = comm_email if comm_email != user.email else None alias = to_email if to_email != user.email else None
send_email( send_email(
comm_email, to_email,
"Welcome to SimpleLogin", f"Welcome to SimpleLogin",
render("com/welcome.txt", user=user, alias=alias), render("com/welcome.txt", user=user, alias=alias),
render("com/welcome.html", user=user, alias=alias), render("com/welcome.html", user=user, alias=alias),
unsubscribe_link, unsubscribe_link,
@ -115,66 +126,60 @@ def send_welcome_email(user):
def send_trial_end_soon_email(user): def send_trial_end_soon_email(user):
send_email( send_email(
user.email, user.email,
"Your trial will end soon", f"Your trial will end soon",
render("transactional/trial-end.txt.jinja2", user=user), render("transactional/trial-end.txt.jinja2", user=user),
render("transactional/trial-end.html", user=user), render("transactional/trial-end.html", user=user),
ignore_smtp_error=True, ignore_smtp_error=True,
) )
def send_activation_email(user: User, activation_link): def send_activation_email(email, activation_link):
send_email( send_email(
user.email, email,
"Just one more step to join SimpleLogin", f"Just one more step to join SimpleLogin",
render( render(
"transactional/activation.txt", "transactional/activation.txt",
user=user,
activation_link=activation_link, activation_link=activation_link,
email=user.email, email=email,
), ),
render( render(
"transactional/activation.html", "transactional/activation.html",
user=user,
activation_link=activation_link, activation_link=activation_link,
email=user.email, email=email,
), ),
) )
def send_reset_password_email(user: User, reset_password_link): def send_reset_password_email(email, reset_password_link):
send_email( send_email(
user.email, email,
"Reset your password on SimpleLogin", "Reset your password on SimpleLogin",
render( render(
"transactional/reset-password.txt", "transactional/reset-password.txt",
user=user,
reset_password_link=reset_password_link, reset_password_link=reset_password_link,
), ),
render( render(
"transactional/reset-password.html", "transactional/reset-password.html",
user=user,
reset_password_link=reset_password_link, reset_password_link=reset_password_link,
), ),
) )
def send_change_email(user: User, new_email, link): def send_change_email(new_email, current_email, link):
send_email( send_email(
new_email, new_email,
"Confirm email update on SimpleLogin", "Confirm email update on SimpleLogin",
render( render(
"transactional/change-email.txt", "transactional/change-email.txt",
user=user,
link=link, link=link,
new_email=new_email, new_email=new_email,
current_email=user.email, current_email=current_email,
), ),
render( render(
"transactional/change-email.html", "transactional/change-email.html",
user=user,
link=link, link=link,
new_email=new_email, new_email=new_email,
current_email=user.email, current_email=current_email,
), ),
) )
@ -182,37 +187,33 @@ def send_change_email(user: User, new_email, link):
def send_invalid_totp_login_email(user, totp_type): def send_invalid_totp_login_email(user, totp_type):
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_INVALID_TOTP_LOGIN, ALERT_INVALID_TOTP_LOGIN,
user.email, user.email,
"Unsuccessful attempt to login to your SimpleLogin account", "Unsuccessful attempt to login to your SimpleLogin account",
render( render(
"transactional/invalid-totp-login.txt", "transactional/invalid-totp-login.txt",
user=user,
type=totp_type, type=totp_type,
), ),
render( render(
"transactional/invalid-totp-login.html", "transactional/invalid-totp-login.html",
user=user,
type=totp_type, type=totp_type,
), ),
1, 1,
) )
def send_test_email_alias(user: User, email: str): def send_test_email_alias(email, name):
send_email( send_email(
email, email,
f"This email is sent to {email}", f"This email is sent to {email}",
render( render(
"transactional/test-email.txt", "transactional/test-email.txt",
user=user, name=name,
name=user.name,
alias=email, alias=email,
), ),
render( render(
"transactional/test-email.html", "transactional/test-email.html",
user=user, name=name,
name=user.name,
alias=email, alias=email,
), ),
) )
@ -227,13 +228,11 @@ def send_cannot_create_directory_alias(user, alias_address, directory_name):
f"Alias {alias_address} cannot be created", f"Alias {alias_address} cannot be created",
render( render(
"transactional/cannot-create-alias-directory.txt", "transactional/cannot-create-alias-directory.txt",
user=user,
alias=alias_address, alias=alias_address,
directory=directory_name, directory=directory_name,
), ),
render( render(
"transactional/cannot-create-alias-directory.html", "transactional/cannot-create-alias-directory.html",
user=user,
alias=alias_address, alias=alias_address,
directory=directory_name, directory=directory_name,
), ),
@ -246,18 +245,16 @@ def send_cannot_create_directory_alias_disabled(user, alias_address, directory_n
""" """
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_DIRECTORY_DISABLED_ALIAS_CREATION, ALERT_DIRECTORY_DISABLED_ALIAS_CREATION,
user.email, user.email,
f"Alias {alias_address} cannot be created", f"Alias {alias_address} cannot be created",
render( render(
"transactional/cannot-create-alias-directory-disabled.txt", "transactional/cannot-create-alias-directory-disabled.txt",
user=user,
alias=alias_address, alias=alias_address,
directory=directory_name, directory=directory_name,
), ),
render( render(
"transactional/cannot-create-alias-directory-disabled.html", "transactional/cannot-create-alias-directory-disabled.html",
user=user,
alias=alias_address, alias=alias_address,
directory=directory_name, directory=directory_name,
), ),
@ -273,13 +270,11 @@ def send_cannot_create_domain_alias(user, alias, domain):
f"Alias {alias} cannot be created", f"Alias {alias} cannot be created",
render( render(
"transactional/cannot-create-alias-domain.txt", "transactional/cannot-create-alias-domain.txt",
user=user,
alias=alias, alias=alias,
domain=domain, domain=domain,
), ),
render( render(
"transactional/cannot-create-alias-domain.html", "transactional/cannot-create-alias-domain.html",
user=user,
alias=alias, alias=alias,
domain=domain, domain=domain,
), ),
@ -302,9 +297,8 @@ def send_email(
LOG.d("send email to %s, subject '%s'", to_email, subject) LOG.d("send email to %s, subject '%s'", to_email, subject)
from_name = from_name or config.NOREPLY from_name = from_name or NOREPLY
from_addr = from_addr or config.NOREPLY from_addr = from_addr or NOREPLY
from_domain = get_email_domain_part(from_addr)
if html: if html:
msg = MIMEMultipart("alternative") msg = MIMEMultipart("alternative")
@ -319,14 +313,13 @@ def send_email(
msg[headers.FROM] = f'"{from_name}" <{from_addr}>' msg[headers.FROM] = f'"{from_name}" <{from_addr}>'
msg[headers.TO] = to_email msg[headers.TO] = to_email
msg_id_header = make_msgid(domain=config.EMAIL_DOMAIN) msg_id_header = make_msgid(domain=EMAIL_DOMAIN)
msg[headers.MESSAGE_ID] = msg_id_header msg[headers.MESSAGE_ID] = msg_id_header
date_header = formatdate() date_header = formatdate()
msg[headers.DATE] = date_header msg[headers.DATE] = date_header
if headers.MIME_VERSION not in msg: msg[headers.MIME_VERSION] = "1.0"
msg[headers.MIME_VERSION] = "1.0"
if unsubscribe_link: if unsubscribe_link:
add_or_replace_header(msg, headers.LIST_UNSUBSCRIBE, f"<{unsubscribe_link}>") add_or_replace_header(msg, headers.LIST_UNSUBSCRIBE, f"<{unsubscribe_link}>")
@ -343,7 +336,7 @@ def send_email(
# use a different envelope sender for each transactional email (aka VERP) # use a different envelope sender for each transactional email (aka VERP)
sl_sendmail( sl_sendmail(
generate_verp_email(VerpType.transactional, transaction.id, from_domain), generate_verp_email(VerpType.transactional, transaction.id),
to_email, to_email,
msg, msg,
retries=retries, retries=retries,
@ -358,7 +351,7 @@ def send_email_with_rate_control(
subject, subject,
plaintext, plaintext,
html=None, html=None,
max_nb_alert=config.MAX_ALERT_24H, max_nb_alert=MAX_ALERT_24H,
nb_day=1, nb_day=1,
ignore_smtp_error=False, ignore_smtp_error=False,
retries=0, retries=0,
@ -455,7 +448,7 @@ def get_email_domain_part(address):
def add_dkim_signature(msg: Message, email_domain: str): def add_dkim_signature(msg: Message, email_domain: str):
if config.RSPAMD_SIGN_DKIM: if RSPAMD_SIGN_DKIM:
LOG.d("DKIM signature will be added by rspamd") LOG.d("DKIM signature will be added by rspamd")
msg[headers.SL_WANT_SIGNING] = "yes" msg[headers.SL_WANT_SIGNING] = "yes"
return return
@ -470,9 +463,9 @@ def add_dkim_signature(msg: Message, email_domain: str):
continue continue
# To investigate why some emails can't be DKIM signed. todo: remove # To investigate why some emails can't be DKIM signed. todo: remove
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(msg.as_bytes()) f.write(msg.as_bytes())
LOG.w("email saved to %s", file_name) LOG.w("email saved to %s", file_name)
@ -487,12 +480,12 @@ def add_dkim_signature_with_header(
# Specify headers in "byte" form # Specify headers in "byte" form
# Generate message signature # Generate message signature
if config.DKIM_PRIVATE_KEY: if DKIM_PRIVATE_KEY:
sig = dkim.sign( sig = dkim.sign(
message_to_bytes(msg), message_to_bytes(msg),
config.DKIM_SELECTOR, DKIM_SELECTOR,
email_domain.encode(), email_domain.encode(),
config.DKIM_PRIVATE_KEY.encode(), DKIM_PRIVATE_KEY.encode(),
include_headers=dkim_headers, include_headers=dkim_headers,
) )
sig = sig.decode() sig = sig.decode()
@ -521,10 +514,9 @@ def delete_header(msg: Message, header: str):
def sanitize_header(msg: Message, header: str): def sanitize_header(msg: Message, header: str):
"""remove trailing space and remove linebreak from a header""" """remove trailing space and remove linebreak from a header"""
header_lowercase = header.lower()
for i in reversed(range(len(msg._headers))): for i in reversed(range(len(msg._headers))):
header_name = msg._headers[i][0].lower() header_name = msg._headers[i][0].lower()
if header_name == header_lowercase: if header_name == header.lower():
# msg._headers[i] is a tuple like ('From', 'hey@google.com') # msg._headers[i] is a tuple like ('From', 'hey@google.com')
if msg._headers[i][1]: if msg._headers[i][1]:
msg._headers[i] = ( msg._headers[i] = (
@ -545,12 +537,10 @@ def delete_all_headers_except(msg: Message, headers: [str]):
def can_create_directory_for_address(email_address: str) -> bool: def can_create_directory_for_address(email_address: str) -> bool:
"""return True if an email ends with one of the alias domains provided by SimpleLogin""" """return True if an email ends with one of the alias domains provided by SimpleLogin"""
# not allow creating directory with premium domain # not allow creating directory with premium domain
for domain in config.ALIAS_DOMAINS: for domain in ALIAS_DOMAINS:
if email_address.endswith("@" + domain): if email_address.endswith("@" + domain):
return True return True
LOG.i(
f"Cannot create address in directory for {email_address} since it does not belong to a valid directory domain"
)
return False return False
@ -604,7 +594,7 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
mx_domains = get_mx_domain_list(domain) mx_domains = get_mx_domain_list(domain)
# if no MX record, email is not valid # if no MX record, email is not valid
if not config.SKIP_MX_LOOKUP_ON_CHECK and not mx_domains: if not mx_domains:
LOG.d("No MX record for domain %s", domain) LOG.d("No MX record for domain %s", domain)
return False return False
@ -613,26 +603,6 @@ def email_can_be_used_as_mailbox(email_address: str) -> bool:
LOG.d("MX Domain %s %s is invalid mailbox domain", mx_domain, domain) LOG.d("MX Domain %s %s is invalid mailbox domain", mx_domain, domain)
return False return False
existing_user = User.get_by(email=email_address)
if existing_user and existing_user.disabled:
LOG.d(
f"User {existing_user} is disabled. {email_address} cannot be used for other mailbox"
)
return False
for existing_user in (
User.query()
.join(Mailbox, User.id == Mailbox.user_id)
.filter(Mailbox.email == email_address)
.group_by(User.id)
.all()
):
if existing_user.disabled:
LOG.d(
f"User {existing_user} is disabled and has a mailbox with {email_address}. Id cannot be used for other mailbox"
)
return False
return True return True
@ -818,7 +788,7 @@ def get_header_unicode(header: Union[str, Header]) -> str:
ret = "" ret = ""
for to_decoded_str, charset in decode_header(header): for to_decoded_str, charset in decode_header(header):
if charset is None: if charset is None:
if isinstance(to_decoded_str, bytes): if type(to_decoded_str) is bytes:
decoded_str = to_decoded_str.decode() decoded_str = to_decoded_str.decode()
else: else:
decoded_str = to_decoded_str decoded_str = to_decoded_str
@ -855,13 +825,13 @@ def to_bytes(msg: Message):
for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]: for generator_policy in [None, policy.SMTP, policy.SMTPUTF8]:
try: try:
return msg.as_bytes(policy=generator_policy) return msg.as_bytes(policy=generator_policy)
except Exception: except:
LOG.w("as_bytes() fails with %s policy", policy, exc_info=True) LOG.w("as_bytes() fails with %s policy", policy, exc_info=True)
msg_string = msg.as_string() msg_string = msg.as_string()
try: try:
return msg_string.encode() return msg_string.encode()
except Exception: except:
LOG.w("as_string().encode() fails", exc_info=True) LOG.w("as_string().encode() fails", exc_info=True)
return msg_string.encode(errors="replace") return msg_string.encode(errors="replace")
@ -878,6 +848,19 @@ def should_add_dkim_signature(domain: str) -> bool:
return False return False
def is_valid_email(email_address: str) -> bool:
"""
Used to check whether an email address is valid
NOT run MX check.
NOT allow unicode.
"""
try:
validate_email(email_address, check_deliverability=False, allow_smtputf8=False)
return True
except EmailNotValidError:
return False
class EmailEncoding(enum.Enum): class EmailEncoding(enum.Enum):
BASE64 = "base64" BASE64 = "base64"
QUOTED = "quoted-printable" QUOTED = "quoted-printable"
@ -891,24 +874,8 @@ def get_encoding(msg: Message) -> EmailEncoding:
- base64 - base64
- 7bit: default if unknown or empty - 7bit: default if unknown or empty
""" """
cte = ( cte = str(msg.get(headers.CONTENT_TRANSFER_ENCODING, "")).lower().strip()
str(msg.get(headers.CONTENT_TRANSFER_ENCODING, "")) if cte in ("", "7bit", "7-bit", "8bit", "binary", "8bit;", "utf-8"):
.lower()
.strip()
.strip('"')
.strip("'")
)
if cte in (
"",
"7bit",
"7-bit",
"7bits",
"8bit",
"8bits",
"binary",
"8bit;",
"utf-8",
):
return EmailEncoding.NO return EmailEncoding.NO
if cte == "base64": if cte == "base64":
@ -948,35 +915,22 @@ def decode_text(text: str, encoding: EmailEncoding = EmailEncoding.NO) -> str:
return text return text
def add_header( def add_header(msg: Message, text_header, html_header) -> Message:
msg: Message, text_header, html_header=None, subject_prefix=None
) -> Message:
if not html_header:
html_header = text_header.replace("\n", "<br>")
if subject_prefix is not None:
subject = msg[headers.SUBJECT]
if not subject:
msg.add_header(headers.SUBJECT, subject_prefix)
else:
subject = f"{subject_prefix} {subject}"
msg.replace_header(headers.SUBJECT, subject)
content_type = msg.get_content_type().lower() content_type = msg.get_content_type().lower()
if content_type == "text/plain": if content_type == "text/plain":
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
clone_msg = copy(msg) clone_msg = copy(msg)
new_payload = f"""{text_header} new_payload = f"""{text_header}
------------------------------ ---
{decode_text(payload, encoding)}""" {decode_text(payload, encoding)}"""
clone_msg.set_payload(encode_text(new_payload, encoding)) clone_msg.set_payload(encode_text(new_payload, encoding))
return clone_msg return clone_msg
elif content_type == "text/html": elif content_type == "text/html":
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
new_payload = f"""<table width="100%" style="width: 100%; -premailer-width: 100%; -premailer-cellpadding: 0; new_payload = f"""<table width="100%" style="width: 100%; -premailer-width: 100%; -premailer-cellpadding: 0;
-premailer-cellspacing: 0; margin: 0; padding: 0;"> -premailer-cellspacing: 0; margin: 0; padding: 0;">
<tr> <tr>
@ -998,8 +952,6 @@ def add_header(
for part in msg.get_payload(): for part in msg.get_payload():
if isinstance(part, Message): if isinstance(part, Message):
new_parts.append(add_header(part, text_header, html_header)) new_parts.append(add_header(part, text_header, html_header))
elif isinstance(part, str):
new_parts.append(MIMEText(part))
else: else:
new_parts.append(part) new_parts.append(part)
clone_msg = copy(msg) clone_msg = copy(msg)
@ -1008,14 +960,7 @@ def add_header(
elif content_type in ("multipart/mixed", "multipart/signed"): elif content_type in ("multipart/mixed", "multipart/signed"):
new_parts = [] new_parts = []
payload = msg.get_payload() parts = list(msg.get_payload())
if isinstance(payload, str):
# The message is badly formatted inject as new
new_parts = [MIMEText(text_header, "plain"), MIMEText(payload, "plain")]
clone_msg = copy(msg)
clone_msg.set_payload(new_parts)
return clone_msg
parts = list(payload)
LOG.d("only add header for the first part for %s", content_type) LOG.d("only add header for the first part for %s", content_type)
for ix, part in enumerate(parts): for ix, part in enumerate(parts):
if ix == 0: if ix == 0:
@ -1031,11 +976,7 @@ def add_header(
return msg return msg
def replace(msg: Union[Message, str], old, new) -> Union[Message, str]: def replace(msg: Message, old, new) -> Message:
if isinstance(msg, str):
msg = msg.replace(old, new)
return msg
content_type = msg.get_content_type() content_type = msg.get_content_type()
if ( if (
@ -1055,7 +996,7 @@ def replace(msg: Union[Message, str], old, new) -> Union[Message, str]:
if content_type in ("text/plain", "text/html"): if content_type in ("text/plain", "text/html"):
encoding = get_encoding(msg) encoding = get_encoding(msg)
payload = msg.get_payload() payload = msg.get_payload()
if isinstance(payload, str): if type(payload) is str:
if encoding == EmailEncoding.QUOTED: if encoding == EmailEncoding.QUOTED:
LOG.d("handle quoted-printable replace %s -> %s", old, new) LOG.d("handle quoted-printable replace %s -> %s", old, new)
# first decode the payload # first decode the payload
@ -1100,7 +1041,7 @@ def replace(msg: Union[Message, str], old, new) -> Union[Message, str]:
return msg return msg
def generate_reply_email(contact_email: str, alias: Alias) -> str: def generate_reply_email(contact_email: str, user: User) -> str:
""" """
generate a reply_email (aka reverse-alias), make sure it isn't used by any contact generate a reply_email (aka reverse-alias), make sure it isn't used by any contact
""" """
@ -1111,7 +1052,6 @@ def generate_reply_email(contact_email: str, alias: Alias) -> str:
include_sender_in_reverse_alias = False include_sender_in_reverse_alias = False
user = alias.user
# user has set this option explicitly # user has set this option explicitly
if user.include_sender_in_reverse_alias is not None: if user.include_sender_in_reverse_alias is not None:
include_sender_in_reverse_alias = user.include_sender_in_reverse_alias include_sender_in_reverse_alias = user.include_sender_in_reverse_alias
@ -1126,28 +1066,22 @@ def generate_reply_email(contact_email: str, alias: Alias) -> str:
contact_email = contact_email.replace(".", "_") contact_email = contact_email.replace(".", "_")
contact_email = convert_to_alphanumeric(contact_email) contact_email = convert_to_alphanumeric(contact_email)
reply_domain = config.EMAIL_DOMAIN
alias_domain = get_email_domain_part(alias.email)
sl_domain = SLDomain.get_by(domain=alias_domain)
if sl_domain and sl_domain.use_as_reverse_alias:
reply_domain = alias_domain
# not use while to avoid infinite loop # not use while to avoid infinite loop
for _ in range(1000): for _ in range(1000):
if include_sender_in_reverse_alias and contact_email: if include_sender_in_reverse_alias and contact_email:
random_length = random.randint(5, 10) random_length = random.randint(5, 10)
reply_email = ( reply_email = (
# do not use the ra+ anymore # do not use the ra+ anymore
# f"ra+{contact_email}+{random_string(random_length)}@{config.EMAIL_DOMAIN}" # f"ra+{contact_email}+{random_string(random_length)}@{EMAIL_DOMAIN}"
f"{contact_email}_{random_string(random_length)}@{reply_domain}" f"{contact_email}_{random_string(random_length)}@{EMAIL_DOMAIN}"
) )
else: else:
random_length = random.randint(20, 50) random_length = random.randint(20, 50)
# do not use the ra+ anymore # do not use the ra+ anymore
# reply_email = f"ra+{random_string(random_length)}@{config.EMAIL_DOMAIN}" # reply_email = f"ra+{random_string(random_length)}@{EMAIL_DOMAIN}"
reply_email = f"{random_string(random_length)}@{reply_domain}" reply_email = f"{random_string(random_length)}@{EMAIL_DOMAIN}"
if available_sl_email(reply_email): if not Contact.get_by(reply_email=reply_email):
return reply_email return reply_email
raise Exception("Cannot generate reply email") raise Exception("Cannot generate reply email")
@ -1158,11 +1092,31 @@ def is_reverse_alias(address: str) -> bool:
if Contact.get_by(reply_email=address): if Contact.get_by(reply_email=address):
return True return True
return address.endswith(f"@{config.EMAIL_DOMAIN}") and ( return address.endswith(f"@{EMAIL_DOMAIN}") and (
address.startswith("reply+") or address.startswith("ra+") address.startswith("reply+") or address.startswith("ra+")
) )
# allow also + and @ that are present in a reply address
_ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.+@"
def normalize_reply_email(reply_email: str) -> str:
"""Handle the case where reply email contains *strange* char that was wrongly generated in the past"""
if not reply_email.isascii():
reply_email = convert_to_id(reply_email)
ret = []
# drop all control characters like shift, separator, etc
for c in reply_email:
if c not in _ALLOWED_CHARS:
ret.append("_")
else:
ret.append(c)
return "".join(ret)
def should_disable(alias: Alias) -> (bool, str): def should_disable(alias: Alias) -> (bool, str):
""" """
Return whether an alias should be disabled and if yes, the reason why Return whether an alias should be disabled and if yes, the reason why
@ -1172,7 +1126,7 @@ def should_disable(alias: Alias) -> (bool, str):
LOG.w("%s cannot be disabled", alias) LOG.w("%s cannot be disabled", alias)
return False, "" return False, ""
if not config.ALIAS_AUTOMATIC_DISABLE: if not ALIAS_AUTOMATIC_DISABLE:
return False, "" return False, ""
yesterday = arrow.now().shift(days=-1) yesterday = arrow.now().shift(days=-1)
@ -1287,24 +1241,22 @@ def spf_pass(
subject = get_header_unicode(msg[headers.SUBJECT]) subject = get_header_unicode(msg[headers.SUBJECT])
send_email_with_rate_control( send_email_with_rate_control(
user, user,
config.ALERT_SPF, ALERT_SPF,
mailbox.email, mailbox.email,
f"SimpleLogin Alert: attempt to send emails from your alias {alias.email} from unknown IP Address", f"SimpleLogin Alert: attempt to send emails from your alias {alias.email} from unknown IP Address",
render( render(
"transactional/spf-fail.txt", "transactional/spf-fail.txt",
user=user,
alias=alias.email, alias=alias.email,
ip=ip, ip=ip,
mailbox_url=config.URL + f"/dashboard/mailbox/{mailbox.id}#spf", mailbox_url=URL + f"/dashboard/mailbox/{mailbox.id}#spf",
to_email=contact_email, to_email=contact_email,
subject=subject, subject=subject,
time=arrow.now(), time=arrow.now(),
), ),
render( render(
"transactional/spf-fail.html", "transactional/spf-fail.html",
user=user,
ip=ip, ip=ip,
mailbox_url=config.URL + f"/dashboard/mailbox/{mailbox.id}#spf", mailbox_url=URL + f"/dashboard/mailbox/{mailbox.id}#spf",
to_email=contact_email, to_email=contact_email,
subject=subject, subject=subject,
time=arrow.now(), time=arrow.now(),
@ -1327,11 +1279,11 @@ def spf_pass(
@cached(cache=TTLCache(maxsize=2, ttl=20)) @cached(cache=TTLCache(maxsize=2, ttl=20))
def get_smtp_server(): def get_smtp_server():
LOG.d("get a smtp server") LOG.d("get a smtp server")
if config.POSTFIX_SUBMISSION_TLS: if POSTFIX_SUBMISSION_TLS:
smtp = SMTP(config.POSTFIX_SERVER, 587) smtp = SMTP(POSTFIX_SERVER, 587)
smtp.starttls() smtp.starttls()
else: else:
smtp = SMTP(config.POSTFIX_SERVER, config.POSTFIX_PORT) smtp = SMTP(POSTFIX_SERVER, POSTFIX_PORT)
return smtp return smtp
@ -1403,12 +1355,12 @@ def save_email_for_debugging(msg: Message, file_name_prefix=None) -> str:
"""Save email for debugging to temporary location """Save email for debugging to temporary location
Return the file path Return the file path
""" """
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
if file_name_prefix: if file_name_prefix:
file_name = "{}-{}".format(file_name_prefix, file_name) file_name = "{}-{}".format(file_name_prefix, file_name)
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(msg.as_bytes()) f.write(msg.as_bytes())
LOG.d("email saved to %s", file_name) LOG.d("email saved to %s", file_name)
@ -1421,12 +1373,12 @@ def save_envelope_for_debugging(envelope: Envelope, file_name_prefix=None) -> st
"""Save envelope for debugging to temporary location """Save envelope for debugging to temporary location
Return the file path Return the file path
""" """
if config.TEMP_DIR: if TEMP_DIR:
file_name = str(uuid.uuid4()) + ".eml" file_name = str(uuid.uuid4()) + ".eml"
if file_name_prefix: if file_name_prefix:
file_name = "{}-{}".format(file_name_prefix, file_name) file_name = "{}-{}".format(file_name_prefix, file_name)
with open(os.path.join(config.TEMP_DIR, file_name), "wb") as f: with open(os.path.join(TEMP_DIR, file_name), "wb") as f:
f.write(envelope.original_content) f.write(envelope.original_content)
LOG.d("envelope saved to %s", file_name) LOG.d("envelope saved to %s", file_name)
@ -1445,22 +1397,19 @@ def generate_verp_email(
# Time is in minutes granularity and start counting on 2022-01-01 to reduce bytes to represent time # Time is in minutes granularity and start counting on 2022-01-01 to reduce bytes to represent time
data = [ data = [
verp_type.value, verp_type.value,
object_id or 0, object_id,
int((time.time() - VERP_TIME_START) / 60), int((time.time() - VERP_TIME_START) / 60),
] ]
json_payload = json.dumps(data).encode("utf-8") json_payload = json.dumps(data).encode("utf-8")
# Signing without itsdangereous because it uses base64 that includes +/= symbols and lower and upper case letters. # Signing without itsdangereous because it uses base64 that includes +/= symbols and lower and upper case letters.
# We need to encode in base32 # We need to encode in base32
payload_hmac = hmac.new( payload_hmac = hmac.new(
config.VERP_EMAIL_SECRET.encode("utf-8"), json_payload, VERP_HMAC_ALGO VERP_EMAIL_SECRET.encode("utf-8"), json_payload, VERP_HMAC_ALGO
).digest()[:8] ).digest()[:8]
encoded_payload = base64.b32encode(json_payload).rstrip(b"=").decode("utf-8") encoded_payload = base64.b32encode(json_payload).rstrip(b"=").decode("utf-8")
encoded_signature = base64.b32encode(payload_hmac).rstrip(b"=").decode("utf-8") encoded_signature = base64.b32encode(payload_hmac).rstrip(b"=").decode("utf-8")
return "{}.{}.{}@{}".format( return "{}.{}.{}@{}".format(
config.VERP_PREFIX, VERP_PREFIX, encoded_payload, encoded_signature, sender_domain or EMAIL_DOMAIN
encoded_payload,
encoded_signature,
sender_domain or config.EMAIL_DOMAIN,
).lower() ).lower()
@ -1473,7 +1422,7 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
return None return None
username = email[:idx] username = email[:idx]
fields = username.split(".") fields = username.split(".")
if len(fields) != 3 or fields[0] != config.VERP_PREFIX: if len(fields) != 3 or fields[0] != VERP_PREFIX:
return None return None
try: try:
padding = (8 - (len(fields[1]) % 8)) % 8 padding = (8 - (len(fields[1]) % 8)) % 8
@ -1485,7 +1434,7 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
except binascii.Error: except binascii.Error:
return None return None
expected_signature = hmac.new( expected_signature = hmac.new(
config.VERP_EMAIL_SECRET.encode("utf-8"), payload, VERP_HMAC_ALGO VERP_EMAIL_SECRET.encode("utf-8"), payload, VERP_HMAC_ALGO
).digest()[:8] ).digest()[:8]
if expected_signature != signature: if expected_signature != signature:
return None return None
@ -1493,13 +1442,6 @@ def get_verp_info_from_email(email: str) -> Optional[Tuple[VerpType, int]]:
# verp type, object_id, time # verp type, object_id, time
if len(data) != 3: if len(data) != 3:
return None return None
if data[2] > (time.time() + config.VERP_MESSAGE_LIFETIME - VERP_TIME_START) / 60: if data[2] > (time.time() + VERP_MESSAGE_LIFETIME - VERP_TIME_START) / 60:
return None return None
return VerpType(data[0]), data[1] return VerpType(data[0]), data[1]
def sl_formataddr(name_address_tuple: Tuple[str, str]):
"""Same as formataddr but use utf-8 encoding by default and always return str (and never Header)"""
name, addr = name_address_tuple
# formataddr can return Header, make sure to convert to str
return str(formataddr((name, Header(addr, "utf-8"))))

View File

@ -1,38 +0,0 @@
from email_validator import (
validate_email,
EmailNotValidError,
)
from app.utils import convert_to_id
# allow also + and @ that are present in a reply address
_ALLOWED_CHARS = "abcdefghijklmnopqrstuvwxyzABCDEFGHIJKLMNOPQRSTUVWXYZ0123456789_-.+@"
def is_valid_email(email_address: str) -> bool:
"""
Used to check whether an email address is valid
NOT run MX check.
NOT allow unicode.
"""
try:
validate_email(email_address, check_deliverability=False, allow_smtputf8=False)
return True
except EmailNotValidError:
return False
def normalize_reply_email(reply_email: str) -> str:
"""Handle the case where reply email contains *strange* char that was wrongly generated in the past"""
if not reply_email.isascii():
reply_email = convert_to_id(reply_email)
ret = []
# drop all control characters like shift, separator, etc
for c in reply_email:
if c not in _ALLOWED_CHARS:
ret.append("_")
else:
ret.append(c)
return "".join(ret)

View File

@ -71,7 +71,7 @@ class ErrContactErrorUpgradeNeeded(SLException):
"""raised when user cannot create a contact because the plan doesn't allow it""" """raised when user cannot create a contact because the plan doesn't allow it"""
def error_for_user(self) -> str: def error_for_user(self) -> str:
return "Please upgrade to premium to create reverse-alias" return f"Please upgrade to premium to create reverse-alias"
class ErrAddressInvalid(SLException): class ErrAddressInvalid(SLException):
@ -84,14 +84,6 @@ class ErrAddressInvalid(SLException):
return f"{self.address} is not a valid email address" return f"{self.address} is not a valid email address"
class InvalidContactEmailError(SLException):
def __init__(self, website_email: str): # noqa: F821
self.website_email = website_email
def error_for_user(self) -> str:
return f"Cannot create contact with invalid email {self.website_email}"
class ErrContactAlreadyExists(SLException): class ErrContactAlreadyExists(SLException):
"""raised when a contact already exists""" """raised when a contact already exists"""
@ -116,15 +108,3 @@ class AccountAlreadyLinkedToAnotherPartnerException(LinkException):
class AccountAlreadyLinkedToAnotherUserException(LinkException): class AccountAlreadyLinkedToAnotherUserException(LinkException):
def __init__(self): def __init__(self):
super().__init__("This account is linked to another user") super().__init__("This account is linked to another user")
class AccountIsUsingAliasAsEmail(LinkException):
def __init__(self):
super().__init__("Your account has an alias as it's email address")
class ProtonAccountNotVerified(LinkException):
def __init__(self):
super().__init__(
"The Proton account you are trying to use has not been verified"
)

View File

View File

@ -9,7 +9,6 @@ class LoginEvent:
failed = 1 failed = 1
disabled_login = 2 disabled_login = 2
not_activated = 3 not_activated = 3
scheduled_to_be_deleted = 4
class Source(EnumE): class Source(EnumE):
web = 0 web = 0

View File

@ -1,95 +0,0 @@
from abc import ABC, abstractmethod
import newrelic.agent
from app import config
from app.db import Session
from app.errors import ProtonPartnerNotSetUp
from app.events.generated import event_pb2
from app.log import LOG
from app.models import User, PartnerUser, SyncEvent
from app.proton.utils import get_proton_partner
from typing import Optional
NOTIFICATION_CHANNEL = "simplelogin_sync_events"
class Dispatcher(ABC):
@abstractmethod
def send(self, event: bytes):
pass
class PostgresDispatcher(Dispatcher):
def send(self, event: bytes):
instance = SyncEvent.create(content=event, flush=True)
Session.execute(f"NOTIFY {NOTIFICATION_CHANNEL}, '{instance.id}';")
@staticmethod
def get():
return PostgresDispatcher()
class GlobalDispatcher:
__dispatcher: Optional[Dispatcher] = None
@staticmethod
def get_dispatcher() -> Dispatcher:
if not GlobalDispatcher.__dispatcher:
GlobalDispatcher.__dispatcher = PostgresDispatcher.get()
return GlobalDispatcher.__dispatcher
@staticmethod
def set_dispatcher(dispatcher: Optional[Dispatcher]):
GlobalDispatcher.__dispatcher = dispatcher
class EventDispatcher:
@staticmethod
def send_event(
user: User,
content: event_pb2.EventContent,
dispatcher: Optional[Dispatcher] = None,
skip_if_webhook_missing: bool = True,
):
if dispatcher is None:
dispatcher = GlobalDispatcher.get_dispatcher()
if config.EVENT_WEBHOOK_DISABLE:
LOG.i("Not sending events because webhook is disabled")
return
if not config.EVENT_WEBHOOK and skip_if_webhook_missing:
LOG.i(
"Not sending events because webhook is not configured and allowed to be empty"
)
return
partner_user = EventDispatcher.__partner_user(user.id)
if not partner_user:
LOG.i(f"Not sending events because there's no partner user for user {user}")
return
event = event_pb2.Event(
user_id=user.id,
external_user_id=partner_user.external_user_id,
partner_id=partner_user.partner_id,
content=content,
)
serialized = event.SerializeToString()
dispatcher.send(serialized)
event_type = content.WhichOneof("content")
newrelic.agent.record_custom_event("EventStoredToDb", {"type": event_type})
LOG.i("Sent event to the dispatcher")
@staticmethod
def __partner_user(user_id: int) -> Optional[PartnerUser]:
# Check if the current user has a partner_id
try:
proton_partner_id = get_proton_partner().id
except ProtonPartnerNotSetUp:
return None
# It has. Retrieve the information for the PartnerUser
return PartnerUser.get_by(user_id=user_id, partner_id=proton_partner_id)

View File

@ -1,50 +0,0 @@
# -*- coding: utf-8 -*-
# Generated by the protocol buffer compiler. DO NOT EDIT!
# NO CHECKED-IN PROTOBUF GENCODE
# source: event.proto
# Protobuf Python Version: 5.27.0
"""Generated protocol buffer code."""
from google.protobuf import descriptor as _descriptor
from google.protobuf import descriptor_pool as _descriptor_pool
from google.protobuf import runtime_version as _runtime_version
from google.protobuf import symbol_database as _symbol_database
from google.protobuf.internal import builder as _builder
_runtime_version.ValidateProtobufRuntimeVersion(
_runtime_version.Domain.PUBLIC,
5,
27,
0,
'',
'event.proto'
)
# @@protoc_insertion_point(imports)
_sym_db = _symbol_database.Default()
DESCRIPTOR = _descriptor_pool.Default().AddSerializedFile(b'\n\x0b\x65vent.proto\x12\x12simplelogin_events\"(\n\x0fUserPlanChanged\x12\x15\n\rplan_end_time\x18\x01 \x01(\r\"\r\n\x0bUserDeleted\"\\\n\x0c\x41liasCreated\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0c\n\x04note\x18\x03 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x04 \x01(\x08\x12\x12\n\ncreated_at\x18\x05 \x01(\r\"T\n\x12\x41liasStatusChanged\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\x12\x0f\n\x07\x65nabled\x18\x03 \x01(\x08\x12\x12\n\ncreated_at\x18\x04 \x01(\r\")\n\x0c\x41liasDeleted\x12\n\n\x02id\x18\x01 \x01(\r\x12\r\n\x05\x65mail\x18\x02 \x01(\t\"D\n\x10\x41liasCreatedList\x12\x30\n\x06\x65vents\x18\x01 \x03(\x0b\x32 .simplelogin_events.AliasCreated\"\x93\x03\n\x0c\x45ventContent\x12?\n\x10user_plan_change\x18\x01 \x01(\x0b\x32#.simplelogin_events.UserPlanChangedH\x00\x12\x37\n\x0cuser_deleted\x18\x02 \x01(\x0b\x32\x1f.simplelogin_events.UserDeletedH\x00\x12\x39\n\ralias_created\x18\x03 \x01(\x0b\x32 .simplelogin_events.AliasCreatedH\x00\x12\x45\n\x13\x61lias_status_change\x18\x04 \x01(\x0b\x32&.simplelogin_events.AliasStatusChangedH\x00\x12\x39\n\ralias_deleted\x18\x05 \x01(\x0b\x32 .simplelogin_events.AliasDeletedH\x00\x12\x41\n\x11\x61lias_create_list\x18\x06 \x01(\x0b\x32$.simplelogin_events.AliasCreatedListH\x00\x42\t\n\x07\x63ontent\"y\n\x05\x45vent\x12\x0f\n\x07user_id\x18\x01 \x01(\r\x12\x18\n\x10\x65xternal_user_id\x18\x02 \x01(\t\x12\x12\n\npartner_id\x18\x03 \x01(\r\x12\x31\n\x07\x63ontent\x18\x04 \x01(\x0b\x32 .simplelogin_events.EventContentb\x06proto3')
_globals = globals()
_builder.BuildMessageAndEnumDescriptors(DESCRIPTOR, _globals)
_builder.BuildTopDescriptorsAndMessages(DESCRIPTOR, 'event_pb2', _globals)
if not _descriptor._USE_C_DESCRIPTORS:
DESCRIPTOR._loaded_options = None
_globals['_USERPLANCHANGED']._serialized_start=35
_globals['_USERPLANCHANGED']._serialized_end=75
_globals['_USERDELETED']._serialized_start=77
_globals['_USERDELETED']._serialized_end=90
_globals['_ALIASCREATED']._serialized_start=92
_globals['_ALIASCREATED']._serialized_end=184
_globals['_ALIASSTATUSCHANGED']._serialized_start=186
_globals['_ALIASSTATUSCHANGED']._serialized_end=270
_globals['_ALIASDELETED']._serialized_start=272
_globals['_ALIASDELETED']._serialized_end=313
_globals['_ALIASCREATEDLIST']._serialized_start=315
_globals['_ALIASCREATEDLIST']._serialized_end=383
_globals['_EVENTCONTENT']._serialized_start=386
_globals['_EVENTCONTENT']._serialized_end=789
_globals['_EVENT']._serialized_start=791
_globals['_EVENT']._serialized_end=912
# @@protoc_insertion_point(module_scope)

View File

@ -1,84 +0,0 @@
from google.protobuf.internal import containers as _containers
from google.protobuf import descriptor as _descriptor
from google.protobuf import message as _message
from typing import ClassVar as _ClassVar, Iterable as _Iterable, Mapping as _Mapping, Optional as _Optional, Union as _Union
DESCRIPTOR: _descriptor.FileDescriptor
class UserPlanChanged(_message.Message):
__slots__ = ("plan_end_time",)
PLAN_END_TIME_FIELD_NUMBER: _ClassVar[int]
plan_end_time: int
def __init__(self, plan_end_time: _Optional[int] = ...) -> None: ...
class UserDeleted(_message.Message):
__slots__ = ()
def __init__(self) -> None: ...
class AliasCreated(_message.Message):
__slots__ = ("id", "email", "note", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
NOTE_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
note: str
enabled: bool
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., note: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasStatusChanged(_message.Message):
__slots__ = ("id", "email", "enabled", "created_at")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
ENABLED_FIELD_NUMBER: _ClassVar[int]
CREATED_AT_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
enabled: bool
created_at: int
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ..., enabled: bool = ..., created_at: _Optional[int] = ...) -> None: ...
class AliasDeleted(_message.Message):
__slots__ = ("id", "email")
ID_FIELD_NUMBER: _ClassVar[int]
EMAIL_FIELD_NUMBER: _ClassVar[int]
id: int
email: str
def __init__(self, id: _Optional[int] = ..., email: _Optional[str] = ...) -> None: ...
class AliasCreatedList(_message.Message):
__slots__ = ("events",)
EVENTS_FIELD_NUMBER: _ClassVar[int]
events: _containers.RepeatedCompositeFieldContainer[AliasCreated]
def __init__(self, events: _Optional[_Iterable[_Union[AliasCreated, _Mapping]]] = ...) -> None: ...
class EventContent(_message.Message):
__slots__ = ("user_plan_change", "user_deleted", "alias_created", "alias_status_change", "alias_deleted", "alias_create_list")
USER_PLAN_CHANGE_FIELD_NUMBER: _ClassVar[int]
USER_DELETED_FIELD_NUMBER: _ClassVar[int]
ALIAS_CREATED_FIELD_NUMBER: _ClassVar[int]
ALIAS_STATUS_CHANGE_FIELD_NUMBER: _ClassVar[int]
ALIAS_DELETED_FIELD_NUMBER: _ClassVar[int]
ALIAS_CREATE_LIST_FIELD_NUMBER: _ClassVar[int]
user_plan_change: UserPlanChanged
user_deleted: UserDeleted
alias_created: AliasCreated
alias_status_change: AliasStatusChanged
alias_deleted: AliasDeleted
alias_create_list: AliasCreatedList
def __init__(self, user_plan_change: _Optional[_Union[UserPlanChanged, _Mapping]] = ..., user_deleted: _Optional[_Union[UserDeleted, _Mapping]] = ..., alias_created: _Optional[_Union[AliasCreated, _Mapping]] = ..., alias_status_change: _Optional[_Union[AliasStatusChanged, _Mapping]] = ..., alias_deleted: _Optional[_Union[AliasDeleted, _Mapping]] = ..., alias_create_list: _Optional[_Union[AliasCreatedList, _Mapping]] = ...) -> None: ...
class Event(_message.Message):
__slots__ = ("user_id", "external_user_id", "partner_id", "content")
USER_ID_FIELD_NUMBER: _ClassVar[int]
EXTERNAL_USER_ID_FIELD_NUMBER: _ClassVar[int]
PARTNER_ID_FIELD_NUMBER: _ClassVar[int]
CONTENT_FIELD_NUMBER: _ClassVar[int]
user_id: int
external_user_id: str
partner_id: int
content: EventContent
def __init__(self, user_id: _Optional[int] = ..., external_user_id: _Optional[str] = ..., partner_id: _Optional[int] = ..., content: _Optional[_Union[EventContent, _Mapping]] = ...) -> None: ...

View File

@ -1,31 +1,12 @@
from flask_limiter import Limiter from flask_limiter import Limiter
from flask_limiter.util import get_remote_address from flask_limiter.util import get_remote_address
from flask_login import current_user, LoginManager from flask_login import LoginManager
from app import config
login_manager = LoginManager() login_manager = LoginManager()
login_manager.session_protection = "strong" login_manager.session_protection = "strong"
# We want to rate limit based on:
# - If the user is not logged in: request source IP
# - If the user is logged in: user_id
def __key_func():
if current_user.is_authenticated:
return f"userid:{current_user.id}"
else:
ip_addr = get_remote_address()
return f"ip:{ip_addr}"
# Setup rate limit facility # Setup rate limit facility
limiter = Limiter(key_func=__key_func) limiter = Limiter(key_func=get_remote_address)
@limiter.request_filter
def disable_rate_limit():
return config.DISABLE_RATE_LIMIT
# @limiter.request_filter # @limiter.request_filter

View File

@ -5,7 +5,7 @@ from typing import Optional, Tuple
from aiosmtpd.handlers import Message from aiosmtpd.handlers import Message
from aiosmtpd.smtp import Envelope from aiosmtpd.smtp import Envelope
from app import s3, config from app import s3
from app.config import ( from app.config import (
DMARC_CHECK_ENABLED, DMARC_CHECK_ENABLED,
ALERT_QUARANTINE_DMARC, ALERT_QUARANTINE_DMARC,
@ -30,44 +30,10 @@ def apply_dmarc_policy_for_forward_phase(
) -> Tuple[Message, Optional[str]]: ) -> Tuple[Message, Optional[str]]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.forward) spam_result = SpamdResult.extract_from_headers(msg, Phase.forward)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return msg, None return msg, None
LOG.i(f"Spam check result in {spam_result}")
from_header = get_header_unicode(msg[headers.FROM]) from_header = get_header_unicode(msg[headers.FROM])
warning_plain_text = """This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on https://simplelogin.io/docs/getting-started/anti-phishing/
"""
warning_html = """
<p style="color:red">
This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on <a href="https://simplelogin.io/docs/getting-started/anti-phishing/">anti-phishing measure</a>
</p>
"""
# do not quarantine an email if fails DMARC but has a small rspamd score
if (
config.MIN_RSPAMD_SCORE_FOR_FAILED_DMARC is not None
and spam_result.rspamd_score < config.MIN_RSPAMD_SCORE_FOR_FAILED_DMARC
and spam_result.dmarc
in (
DmarcCheckResult.quarantine,
DmarcCheckResult.reject,
)
):
LOG.w(
f"email fails DMARC but has a small rspamd score, from contact {contact.email} to alias {alias.email}."
f"mail_from:{envelope.mail_from}, from_header: {from_header}"
)
changed_msg = add_header(
msg,
warning_plain_text,
warning_html,
subject_prefix="[Possible phishing attempt]",
)
return changed_msg, None
if spam_result.dmarc == DmarcCheckResult.soft_fail: if spam_result.dmarc == DmarcCheckResult.soft_fail:
LOG.w( LOG.w(
f"dmarc forward: soft_fail from contact {contact.email} to alias {alias.email}." f"dmarc forward: soft_fail from contact {contact.email} to alias {alias.email}."
@ -75,9 +41,15 @@ More info on https://simplelogin.io/docs/getting-started/anti-phishing/
) )
changed_msg = add_header( changed_msg = add_header(
msg, msg,
warning_plain_text, f"""This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
warning_html, More info on https://simplelogin.io/docs/getting-started/anti-phishing/
subject_prefix="[Possible phishing attempt]", """,
f"""
<p style="color:red">
This email failed anti-phishing checks when it was received by SimpleLogin, be careful with its content.
More info on <a href="https://simplelogin.io/docs/getting-started/anti-phishing/">anti-phishing measure</a>
</p>
""",
) )
return changed_msg, None return changed_msg, None
@ -106,14 +78,12 @@ More info on https://simplelogin.io/docs/getting-started/anti-phishing/
f"An email sent to {alias.email} has been quarantined", f"An email sent to {alias.email} has been quarantined",
render( render(
"transactional/message-quarantine-dmarc.txt.jinja2", "transactional/message-quarantine-dmarc.txt.jinja2",
user=user,
from_header=from_header, from_header=from_header,
alias=alias, alias=alias,
refused_email_url=email_log.get_dashboard_url(), refused_email_url=email_log.get_dashboard_url(),
), ),
render( render(
"transactional/message-quarantine-dmarc.html", "transactional/message-quarantine-dmarc.html",
user=user,
from_header=from_header, from_header=from_header,
alias=alias, alias=alias,
refused_email_url=email_log.get_dashboard_url(), refused_email_url=email_log.get_dashboard_url(),
@ -137,7 +107,7 @@ def quarantine_dmarc_failed_forward_email(alias, contact, envelope, msg) -> Emai
refused_email = RefusedEmail.create( refused_email = RefusedEmail.create(
full_report_path=s3_report_path, user_id=alias.user_id, flush=True full_report_path=s3_report_path, user_id=alias.user_id, flush=True
) )
email_log = EmailLog.create( return EmailLog.create(
user_id=alias.user_id, user_id=alias.user_id,
mailbox_id=alias.mailbox_id, mailbox_id=alias.mailbox_id,
contact_id=contact.id, contact_id=contact.id,
@ -148,7 +118,6 @@ def quarantine_dmarc_failed_forward_email(alias, contact, envelope, msg) -> Emai
blocked=True, blocked=True,
commit=True, commit=True,
) )
return email_log
def apply_dmarc_policy_for_reply_phase( def apply_dmarc_policy_for_reply_phase(
@ -156,17 +125,14 @@ def apply_dmarc_policy_for_reply_phase(
) -> Optional[str]: ) -> Optional[str]:
spam_result = SpamdResult.extract_from_headers(msg, Phase.reply) spam_result = SpamdResult.extract_from_headers(msg, Phase.reply)
if not DMARC_CHECK_ENABLED or not spam_result: if not DMARC_CHECK_ENABLED or not spam_result:
LOG.i("DMARC check disabled")
return None return None
LOG.i(f"Spam check result is {spam_result}")
if spam_result.dmarc not in ( if spam_result.dmarc not in (
DmarcCheckResult.quarantine, DmarcCheckResult.quarantine,
DmarcCheckResult.reject, DmarcCheckResult.reject,
DmarcCheckResult.soft_fail, DmarcCheckResult.soft_fail,
): ):
return None return None
LOG.w( LOG.w(
f"dmarc reply: Put email from {alias_from.email} to {contact_recipient} into quarantine. {spam_result.event_data()}, " f"dmarc reply: Put email from {alias_from.email} to {contact_recipient} into quarantine. {spam_result.event_data()}, "
f"mail_from:{envelope.mail_from}, from_header: {msg[headers.FROM]}" f"mail_from:{envelope.mail_from}, from_header: {msg[headers.FROM]}"
@ -178,14 +144,12 @@ def apply_dmarc_policy_for_reply_phase(
f"Attempt to send an email to your contact {contact_recipient.email} from {envelope.mail_from}", f"Attempt to send an email to your contact {contact_recipient.email} from {envelope.mail_from}",
render( render(
"transactional/spoof-reply.txt.jinja2", "transactional/spoof-reply.txt.jinja2",
user=alias_from.user,
contact=contact_recipient, contact=contact_recipient,
alias=alias_from, alias=alias_from,
sender=envelope.mail_from, sender=envelope.mail_from,
), ),
render( render(
"transactional/spoof-reply.html", "transactional/spoof-reply.html",
user=alias_from.user,
contact=contact_recipient, contact=contact_recipient,
alias=alias_from, alias=alias_from,
sender=envelope.mail_from, sender=envelope.mail_from,

View File

@ -3,7 +3,7 @@ from abc import ABC, abstractmethod
from dataclasses import dataclass from dataclasses import dataclass
from io import BytesIO from io import BytesIO
from mailbox import Message from mailbox import Message
from typing import Optional, Union from typing import Optional
from app import s3 from app import s3
from app.config import ( from app.config import (
@ -189,7 +189,7 @@ def handle_yahoo_complaint(message: Message) -> bool:
return handle_complaint(message, ProviderComplaintYahoo()) return handle_complaint(message, ProviderComplaintYahoo())
def find_alias_with_address(address: str) -> Optional[Union[Alias, DomainDeletedAlias]]: def find_alias_with_address(address: str) -> Optional[Alias]:
return Alias.get_by(email=address) or DomainDeletedAlias.get_by(email=address) return Alias.get_by(email=address) or DomainDeletedAlias.get_by(email=address)
@ -221,7 +221,7 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
return True return True
if is_deleted_alias(msg_info.sender_address): if is_deleted_alias(msg_info.sender_address):
LOG.i("Complaint is for deleted alias. Do nothing") LOG.i(f"Complaint is for deleted alias. Do nothing")
return True return True
contact = Contact.get_by(reply_email=msg_info.sender_address) contact = Contact.get_by(reply_email=msg_info.sender_address)
@ -231,7 +231,7 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
alias = find_alias_with_address(msg_info.rcpt_address) alias = find_alias_with_address(msg_info.rcpt_address)
if is_deleted_alias(msg_info.rcpt_address): if is_deleted_alias(msg_info.rcpt_address):
LOG.i("Complaint is for deleted alias. Do nothing") LOG.i(f"Complaint is for deleted alias. Do nothing")
return True return True
if not alias: if not alias:
@ -245,22 +245,16 @@ def handle_complaint(message: Message, origin: ProviderComplaintOrigin) -> bool:
def report_complaint_to_user_in_reply_phase( def report_complaint_to_user_in_reply_phase(
alias: Union[Alias, DomainDeletedAlias], alias: Alias,
to_address: str, to_address: str,
origin: ProviderComplaintOrigin, origin: ProviderComplaintOrigin,
msg_info: OriginalMessageInformation, msg_info: OriginalMessageInformation,
): ):
capitalized_name = origin.name().capitalize() capitalized_name = origin.name().capitalize()
mailbox_email = msg_info.mailbox_address
if not mailbox_email:
if type(alias) is Alias:
mailbox_email = alias.mailbox.email
else:
mailbox_email = alias.domain.mailboxes[0].email
send_email_with_rate_control( send_email_with_rate_control(
alias.user, alias.user,
f"{ALERT_COMPLAINT_REPLY_PHASE}_{origin.name()}", f"{ALERT_COMPLAINT_REPLY_PHASE}_{origin.name()}",
mailbox_email, msg_info.mailbox_address or alias.mailbox.email,
f"Abuse report from {capitalized_name}", f"Abuse report from {capitalized_name}",
render( render(
"transactional/provider-complaint-reply-phase.txt.jinja2", "transactional/provider-complaint-reply-phase.txt.jinja2",
@ -299,19 +293,11 @@ def report_complaint_to_user_in_transactional_phase(
def report_complaint_to_user_in_forward_phase( def report_complaint_to_user_in_forward_phase(
alias: Union[Alias, DomainDeletedAlias], alias: Alias, origin: ProviderComplaintOrigin, msg_info: OriginalMessageInformation
origin: ProviderComplaintOrigin,
msg_info: OriginalMessageInformation,
): ):
capitalized_name = origin.name().capitalize() capitalized_name = origin.name().capitalize()
user = alias.user user = alias.user
mailbox_email = msg_info.mailbox_address or alias.mailbox.email
mailbox_email = msg_info.mailbox_address
if not mailbox_email:
if type(alias) is Alias:
mailbox_email = alias.mailbox.email
else:
mailbox_email = alias.domain.mailboxes[0].email
send_email_with_rate_control( send_email_with_rate_control(
user, user,
f"{ALERT_COMPLAINT_FORWARD_PHASE}_{origin.name()}", f"{ALERT_COMPLAINT_FORWARD_PHASE}_{origin.name()}",
@ -319,13 +305,11 @@ def report_complaint_to_user_in_forward_phase(
f"Abuse report from {capitalized_name}", f"Abuse report from {capitalized_name}",
render( render(
"transactional/provider-complaint-forward-phase.txt.jinja2", "transactional/provider-complaint-forward-phase.txt.jinja2",
user=user,
email=mailbox_email, email=mailbox_email,
provider=capitalized_name, provider=capitalized_name,
), ),
render( render(
"transactional/provider-complaint-forward-phase.html", "transactional/provider-complaint-forward-phase.html",
user=user,
email=mailbox_email, email=mailbox_email,
provider=capitalized_name, provider=capitalized_name,
), ),

View File

@ -4,7 +4,6 @@ from typing import Dict, Optional
import newrelic.agent import newrelic.agent
from app.email import headers from app.email import headers
from app.log import LOG
from app.models import EnumE, Phase from app.models import EnumE, Phase
from email.message import Message from email.message import Message
@ -56,7 +55,6 @@ class SpamdResult:
self.phase: Phase = phase self.phase: Phase = phase
self.dmarc: DmarcCheckResult = DmarcCheckResult.not_available self.dmarc: DmarcCheckResult = DmarcCheckResult.not_available
self.spf: SPFCheckResult = SPFCheckResult.not_available self.spf: SPFCheckResult = SPFCheckResult.not_available
self.rspamd_score = -1
def set_dmarc_result(self, dmarc_result: DmarcCheckResult): def set_dmarc_result(self, dmarc_result: DmarcCheckResult):
self.dmarc = dmarc_result self.dmarc = dmarc_result
@ -87,7 +85,6 @@ class SpamdResult:
spam_entries = [ spam_entries = [
entry.strip() for entry in str(spam_result_header[-1]).split("\n") entry.strip() for entry in str(spam_result_header[-1]).split("\n")
] ]
for entry_pos in range(len(spam_entries)): for entry_pos in range(len(spam_entries)):
sep = spam_entries[entry_pos].find("(") sep = spam_entries[entry_pos].find("(")
if sep > -1: if sep > -1:
@ -104,17 +101,6 @@ class SpamdResult:
spamd_result.set_spf_result(spf_result) spamd_result.set_spf_result(spf_result)
break break
# parse the rspamd score
try:
score_line = spam_entries[0] # e.g. "default: False [2.30 / 13.00];"
spamd_result.rspamd_score = float(
score_line[(score_line.find("[") + 1) : score_line.find("]")]
.split("/")[0]
.strip()
)
except (IndexError, ValueError):
LOG.e("cannot parse rspamd score")
cls._store_in_message(spamd_result, msg) cls._store_in_message(spamd_result, msg)
return spamd_result return spamd_result

View File

@ -1,36 +1,20 @@
import base64
import enum import enum
import hashlib
import json
from dataclasses import dataclass from dataclasses import dataclass
from typing import Optional, Union from typing import Optional
import itsdangerous
from app import config from app import config
from app.log import LOG
UNSUB_PREFIX = "un"
class UnsubscribeAction(enum.Enum): class UnsubscribeAction(enum.Enum):
UnsubscribeNewsletter = 1 UnsubscribeNewsletter = 1
DisableAlias = 2 DisableAlias = 2
DisableContact = 3 DisableContact = 3
OriginalUnsubscribeMailto = 4
@dataclass
class UnsubscribeOriginalData:
alias_id: int
recipient: str
subject: str
@dataclass @dataclass
class UnsubscribeData: class UnsubscribeData:
action: UnsubscribeAction action: UnsubscribeAction
data: Union[UnsubscribeOriginalData, int] data: int
@dataclass @dataclass
@ -41,110 +25,52 @@ class UnsubscribeLink:
class UnsubscribeEncoder: class UnsubscribeEncoder:
@staticmethod @staticmethod
def encode( def encode(action: UnsubscribeAction, data: int) -> UnsubscribeLink:
action: UnsubscribeAction, if config.UNSUBSCRIBER:
data: Union[int, UnsubscribeOriginalData],
force_web: bool = False,
) -> UnsubscribeLink:
if config.UNSUBSCRIBER and not force_web:
return UnsubscribeLink(UnsubscribeEncoder.encode_mailto(action, data), True) return UnsubscribeLink(UnsubscribeEncoder.encode_mailto(action, data), True)
return UnsubscribeLink(UnsubscribeEncoder.encode_url(action, data), False) return UnsubscribeLink(UnsubscribeEncoder.encode_url(action, data), False)
@classmethod @staticmethod
def encode_subject( def encode_subject(action: UnsubscribeAction, data: int) -> str:
cls, action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData] if action == UnsubscribeAction.DisableAlias:
) -> str: return f"{data}="
if action != UnsubscribeAction.OriginalUnsubscribeMailto and not isinstance( if action == UnsubscribeAction.DisableContact:
data, int return f"{data}_"
): if action == UnsubscribeAction.UnsubscribeNewsletter:
raise ValueError(f"Data has to be an int for an action of type {action}") return f"{data}*"
if action == UnsubscribeAction.OriginalUnsubscribeMailto:
if type(data) is not UnsubscribeOriginalData:
raise ValueError(
f"Data has to be an UnsubscribeOriginalData for an action of type {action}"
)
# Initial 0 is the version number. If we need to add support for extra use-cases we can bump up this number
data = (0, data.alias_id, data.recipient, data.subject)
payload = (action.value, data)
serialized_data = (
base64.urlsafe_b64encode(json.dumps(payload).encode("utf-8"))
.rstrip(b"=")
.decode("utf-8")
)
signed_data = cls._get_signer().sign(serialized_data).decode("utf-8")
encoded_request = f"{UNSUB_PREFIX}.{signed_data}"
if len(encoded_request) > 512:
LOG.w("Encoded request is longer than 512 chars")
return encoded_request
@staticmethod @staticmethod
def encode_mailto( def encode_mailto(action: UnsubscribeAction, data: int) -> str:
action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData]
) -> str:
subject = UnsubscribeEncoder.encode_subject(action, data) subject = UnsubscribeEncoder.encode_subject(action, data)
return f"mailto:{config.UNSUBSCRIBER}?subject={subject}" return f"mailto:{config.UNSUBSCRIBER}?subject={subject}"
@staticmethod @staticmethod
def encode_url( def encode_url(action: UnsubscribeAction, data: int) -> str:
action: UnsubscribeAction, data: Union[int, UnsubscribeOriginalData]
) -> str:
if action == UnsubscribeAction.DisableAlias: if action == UnsubscribeAction.DisableAlias:
return f"{config.URL}/dashboard/unsubscribe/{data}" return f"{config.URL}/dashboard/unsubscribe/{data}"
if action == UnsubscribeAction.DisableContact: if action == UnsubscribeAction.DisableContact:
return f"{config.URL}/dashboard/block_contact/{data}" return f"{config.URL}/dashboard/block_contact/{data}"
if action in ( if action == UnsubscribeAction.UnsubscribeNewsletter:
UnsubscribeAction.UnsubscribeNewsletter, raise Exception("Cannot encode url to disable newsletter")
UnsubscribeAction.OriginalUnsubscribeMailto,
):
encoded = UnsubscribeEncoder.encode_subject(action, data)
return f"{config.URL}/dashboard/unsubscribe/encoded?data={encoded}"
@staticmethod @staticmethod
def _get_signer() -> itsdangerous.Signer: def decode_subject(data: str) -> Optional[UnsubscribeData]:
return itsdangerous.Signer(
config.UNSUBSCRIBE_SECRET, digest_method=hashlib.sha3_224
)
@classmethod
def decode_subject(cls, data: str) -> Optional[UnsubscribeData]:
if data.find(UNSUB_PREFIX) == -1:
try:
# subject has the format {alias.id}=
if data.endswith("="):
alias_id = int(data[:-1])
return UnsubscribeData(UnsubscribeAction.DisableAlias, alias_id)
# {contact.id}_
elif data.endswith("_"):
contact_id = int(data[:-1])
return UnsubscribeData(UnsubscribeAction.DisableContact, contact_id)
# {user.id}*
elif data.endswith("*"):
user_id = int(data[:-1])
return UnsubscribeData(
UnsubscribeAction.UnsubscribeNewsletter, user_id
)
else:
# some email providers might strip off the = suffix
alias_id = int(data)
return UnsubscribeData(UnsubscribeAction.DisableAlias, alias_id)
except ValueError:
return None
signer = cls._get_signer()
try: try:
verified_data = signer.unsign(data[len(UNSUB_PREFIX) + 1 :]) # subject has the format {alias.id}=
except itsdangerous.BadSignature: if data.endswith("="):
return None alias_id = int(data[:-1])
try: return UnsubscribeData(UnsubscribeAction.DisableAlias, alias_id)
padded_data = verified_data + (b"=" * (-len(verified_data) % 4)) # {contact.id}_
payload = json.loads(base64.urlsafe_b64decode(padded_data)) elif data.endswith("_"):
contact_id = int(data[:-1])
return UnsubscribeData(UnsubscribeAction.DisableContact, contact_id)
# {user.id}*
elif data.endswith("*"):
user_id = int(data[:-1])
return UnsubscribeData(UnsubscribeAction.UnsubscribeNewsletter, user_id)
else:
# some email providers might strip off the = suffix
alias_id = int(data)
return UnsubscribeData(UnsubscribeAction.DisableAlias, alias_id)
except ValueError: except ValueError:
return None return None
action = UnsubscribeAction(payload[0])
action_data = payload[1]
if action == UnsubscribeAction.OriginalUnsubscribeMailto:
# Skip version number in action_data[0] for now it's always 0
action_data = UnsubscribeOriginalData(
action_data[1], action_data[2], action_data[3]
)
return UnsubscribeData(action, action_data)

View File

@ -1,92 +1,30 @@
import urllib
from email.header import Header
from email.message import Message from email.message import Message
from app.email import headers from app.email import headers
from app import config from app.email_utils import add_or_replace_header
from app.email_utils import add_or_replace_header, delete_header
from app.handler.unsubscribe_encoder import ( from app.handler.unsubscribe_encoder import (
UnsubscribeEncoder, UnsubscribeEncoder,
UnsubscribeAction, UnsubscribeAction,
UnsubscribeData,
UnsubscribeOriginalData,
) )
from app.log import LOG from app.models import Alias, Contact
from app.models import Alias, Contact, UnsubscribeBehaviourEnum
class UnsubscribeGenerator: class UnsubscribeGenerator:
def _generate_header_with_original_behaviour( def add_header_to_message(
self, alias: Alias, message: Message self, alias: Alias, contact: Contact, message: Message
) -> Message: ) -> Message:
""" """
Generate a header that will encode the original unsub request. To do so Add List-Unsubscribe header
1. Look if there's an original List_Unsubscribe headers, otherwise do nothing
2. Header has the form <method1>, <method2>, .. where each method is either
- mailto:s@b.c?subject=something
- http(s)://somewhere.com
3. Check if there are http unsub requests in the header. If there are, reserve them and remove all mailto
methods to avoid leaking the real mailbox. We forward the message with only http(s) methods.
4. If there aren't neither https nor mailto methods, strip the header from the message and that's it.
It could happen if the header is malformed.
5. Encode in our unsub request the first original mail and subject to unsub, and use that as our unsub header.
""" """
unsubscribe_data = message[headers.LIST_UNSUBSCRIBE] user = alias.user
if not unsubscribe_data: if user.one_click_unsubscribe_block_sender:
LOG.info("Email has no unsubscribe header") unsub_link = UnsubscribeEncoder.encode(
return message UnsubscribeAction.DisableContact, contact.id
if isinstance(unsubscribe_data, Header):
unsubscribe_data = str(unsubscribe_data.encode())
raw_methods = [method.strip() for method in unsubscribe_data.split(",")]
mailto_unsubs = None
other_unsubs = []
for raw_method in raw_methods:
start = raw_method.find("<")
end = raw_method.rfind(">")
if start == -1 or end == -1 or start >= end:
continue
method = raw_method[start + 1 : end]
url_data = urllib.parse.urlparse(method)
if url_data.scheme == "mailto":
if url_data.path == config.UNSUBSCRIBER:
LOG.debug(
f"Skipping replacing unsubscribe since the original email already points to {config.UNSUBSCRIBER}"
)
return message
query_data = urllib.parse.parse_qs(url_data.query)
mailto_unsubs = (url_data.path, query_data.get("subject", [""])[0])
LOG.debug(f"Unsub is mailto to {mailto_unsubs}")
else:
LOG.debug(f"Unsub has {url_data.scheme} scheme")
other_unsubs.append(method)
# If there are non mailto unsubscribe methods, use those in the header
if other_unsubs:
add_or_replace_header(
message,
headers.LIST_UNSUBSCRIBE,
", ".join([f"<{method}>" for method in other_unsubs]),
) )
add_or_replace_header( else:
message, headers.LIST_UNSUBSCRIBE_POST, "List-Unsubscribe=One-Click" unsub_link = UnsubscribeEncoder.encode(
UnsubscribeAction.DisableAlias, alias.id
) )
LOG.debug(f"Adding click unsub methods to header {other_unsubs}")
return message
elif not mailto_unsubs:
LOG.debug("No unsubs. Deleting all unsub headers")
delete_header(message, headers.LIST_UNSUBSCRIBE)
delete_header(message, headers.LIST_UNSUBSCRIBE_POST)
return message
unsub_data = UnsubscribeData(
UnsubscribeAction.OriginalUnsubscribeMailto,
UnsubscribeOriginalData(alias.id, mailto_unsubs[0], mailto_unsubs[1]),
)
LOG.debug(f"Adding unsub data {unsub_data}")
return self._add_unsubscribe_header(message, unsub_data)
def _add_unsubscribe_header(
self, message: Message, unsub: UnsubscribeData
) -> Message:
unsub_link = UnsubscribeEncoder.encode(unsub.action, unsub.data)
add_or_replace_header(message, headers.LIST_UNSUBSCRIBE, f"<{unsub_link.link}>") add_or_replace_header(message, headers.LIST_UNSUBSCRIBE, f"<{unsub_link.link}>")
if not unsub_link.via_email: if not unsub_link.via_email:
@ -94,19 +32,3 @@ class UnsubscribeGenerator:
message, headers.LIST_UNSUBSCRIBE_POST, "List-Unsubscribe=One-Click" message, headers.LIST_UNSUBSCRIBE_POST, "List-Unsubscribe=One-Click"
) )
return message return message
def add_header_to_message(
self, alias: Alias, contact: Contact, message: Message
) -> Message:
"""
Add List-Unsubscribe header based on the user preference.
"""
unsub_behaviour = alias.user.unsub_behaviour
if unsub_behaviour == UnsubscribeBehaviourEnum.PreserveOriginal:
return self._generate_header_with_original_behaviour(alias, message)
elif unsub_behaviour == UnsubscribeBehaviourEnum.DisableAlias:
unsub = UnsubscribeData(UnsubscribeAction.DisableAlias, alias.id)
return self._add_unsubscribe_header(message, unsub)
else:
unsub = UnsubscribeData(UnsubscribeAction.DisableContact, contact.id)
return self._add_unsubscribe_header(message, unsub)

Some files were not shown because too many files have changed in this diff Show More