Move more contact creation logic to a single function (#2234)

* Move more contact creation logic to a single function

* Reordered parameters

* Fix invalid arguments
This commit is contained in:
Adrià Casajús 2024-09-27 16:04:32 +02:00 committed by GitHub
parent 4762dffd96
commit b59ca3e47c
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
8 changed files with 148 additions and 60 deletions

View file

@ -424,7 +424,7 @@ def create_contact_route(alias_id):
contact_address = data.get("contact") contact_address = data.get("contact")
try: try:
contact = create_contact(g.user, alias, contact_address) contact = create_contact(alias, contact_address)
except ErrContactErrorUpgradeNeeded as err: except ErrContactErrorUpgradeNeeded as err:
return jsonify(error=err.error_for_user()), 403 return jsonify(error=err.error_for_user()), 403
except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err: except (ErrAddressInvalid, CannotCreateContactForReverseAlias) as err:

View file

@ -5,7 +5,7 @@ from typing import Optional
from sqlalchemy.exc import IntegrityError from sqlalchemy.exc import IntegrityError
from app.db import Session from app.db import Session
from app.email_utils import generate_reply_email from app.email_utils import generate_reply_email, parse_full_address
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.log import LOG from app.log import LOG
from app.models import Contact, Alias from app.models import Contact, Alias
@ -14,11 +14,13 @@ from app.utils import sanitize_email
class ContactCreateError(Enum): class ContactCreateError(Enum):
InvalidEmail = "Invalid email" InvalidEmail = "Invalid email"
NotAllowed = "Your plan does not allow to create contacts"
@dataclass @dataclass
class ContactCreateResult: class ContactCreateResult:
contact: Optional[Contact] contact: Optional[Contact]
created: bool
error: Optional[ContactCreateError] error: Optional[ContactCreateError]
@ -33,34 +35,56 @@ def __update_contact_if_needed(
LOG.d(f"Setting {contact} mail_from to {mail_from}") LOG.d(f"Setting {contact} mail_from to {mail_from}")
contact.mail_from = mail_from contact.mail_from = mail_from
Session.commit() Session.commit()
return ContactCreateResult(contact, None) return ContactCreateResult(contact, created=False, error=None)
def create_contact( def create_contact(
email: str, email: str,
name: Optional[str],
alias: Alias, alias: Alias,
name: Optional[str] = None,
mail_from: Optional[str] = None, mail_from: Optional[str] = None,
allow_empty_email: bool = False, allow_empty_email: bool = False,
automatic_created: bool = False, automatic_created: bool = False,
from_partner: bool = False, from_partner: bool = False,
) -> ContactCreateResult: ) -> ContactCreateResult:
if name is not None: # If user cannot create contacts, they still need to be created when receiving an email for an alias
if not automatic_created and not alias.user.can_create_contacts():
return ContactCreateResult(
None, created=False, error=ContactCreateError.NotAllowed
)
# Parse emails with form 'name <email>'
try:
email_name, email = parse_full_address(email)
except ValueError:
email = ""
email_name = ""
# If no name is explicitly given try to get it from the parsed email
if name is None:
name = email_name[: Contact.MAX_NAME_LENGTH]
else:
name = name[: Contact.MAX_NAME_LENGTH] name = name[: Contact.MAX_NAME_LENGTH]
# If still no name is there, make sure the name is None instead of empty string
if not name:
name = None
if name is not None and "\x00" in name: if name is not None and "\x00" in name:
LOG.w("Cannot use contact name because has \\x00") LOG.w("Cannot use contact name because has \\x00")
name = "" name = ""
# Sanitize email and if it's not valid only allow to create a contact if it's explicitly allowed. Otherwise fail
email = sanitize_email(email, not_lower=True)
if not is_valid_email(email): if not is_valid_email(email):
LOG.w(f"invalid contact email {email}") LOG.w(f"invalid contact email {email}")
if not allow_empty_email: if not allow_empty_email:
return ContactCreateResult(None, ContactCreateError.InvalidEmail) return ContactCreateResult(
None, created=False, error=ContactCreateError.InvalidEmail
)
LOG.d("Create a contact with invalid email for %s", alias) LOG.d("Create a contact with invalid email for %s", alias)
# either reuse a contact with empty email or create a new contact with empty email # either reuse a contact with empty email or create a new contact with empty email
email = "" email = ""
email = sanitize_email(email, not_lower=True) # If contact exists, update name and mail_from if needed
contact = Contact.get_by(alias_id=alias.id, website_email=email) contact = Contact.get_by(alias_id=alias.id, website_email=email)
if contact is not None: if contact is not None:
return __update_contact_if_needed(contact, name, mail_from) return __update_contact_if_needed(contact, name, mail_from)
# Create the contact
reply_email = generate_reply_email(email, alias) reply_email = generate_reply_email(email, alias)
try: try:
flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0 flags = Contact.FLAG_PARTNER_CREATED if from_partner else 0
@ -86,4 +110,4 @@ def create_contact(
) )
contact = Contact.get_by(alias_id=alias.id, website_email=email) contact = Contact.get_by(alias_id=alias.id, website_email=email)
return __update_contact_if_needed(contact, name, mail_from) return __update_contact_if_needed(contact, name, mail_from)
return ContactCreateResult(contact, None) return ContactCreateResult(contact, created=True, error=None)

View file

@ -9,13 +9,10 @@ from sqlalchemy import and_, func, case
from wtforms import StringField, validators, ValidationError from wtforms import StringField, validators, ValidationError
# Need to import directly from config to allow modification from the tests # Need to import directly from config to allow modification from the tests
from app import config, parallel_limiter from app import config, parallel_limiter, contact_utils
from app.contact_utils import ContactCreateError
from app.dashboard.base import dashboard_bp from app.dashboard.base import dashboard_bp
from app.db import Session from app.db import Session
from app.email_utils import (
generate_reply_email,
parse_full_address,
)
from app.email_validation import is_valid_email from app.email_validation import is_valid_email
from app.errors import ( from app.errors import (
CannotCreateContactForReverseAlias, CannotCreateContactForReverseAlias,
@ -24,8 +21,8 @@ from app.errors import (
ErrContactAlreadyExists, ErrContactAlreadyExists,
) )
from app.log import LOG from app.log import LOG
from app.models import Alias, Contact, EmailLog, User from app.models import Alias, Contact, EmailLog
from app.utils import sanitize_email, CSRFValidationForm from app.utils import CSRFValidationForm
def email_validator(): def email_validator():
@ -51,7 +48,7 @@ def email_validator():
return _check return _check
def create_contact(user: User, alias: Alias, contact_address: str) -> Contact: def create_contact(alias: Alias, contact_address: str) -> Contact:
""" """
Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS. Create a contact for a user. Can be restricted for new free users by enabling DISABLE_CREATE_CONTACTS_FOR_FREE_USERS.
Can throw exceptions: Can throw exceptions:
@ -61,37 +58,23 @@ def create_contact(user: User, alias: Alias, contact_address: str) -> Contact:
""" """
if not contact_address: if not contact_address:
raise ErrAddressInvalid("Empty address") raise ErrAddressInvalid("Empty address")
try: output = contact_utils.create_contact(email=contact_address, alias=alias)
contact_name, contact_email = parse_full_address(contact_address) if output.error == ContactCreateError.InvalidEmail:
except ValueError:
raise ErrAddressInvalid(contact_address) raise ErrAddressInvalid(contact_address)
elif output.error == ContactCreateError.NotAllowed:
contact_email = sanitize_email(contact_email)
if not is_valid_email(contact_email):
raise ErrAddressInvalid(contact_email)
contact = Contact.get_by(alias_id=alias.id, website_email=contact_email)
if contact:
raise ErrContactAlreadyExists(contact)
if not user.can_create_contacts():
raise ErrContactErrorUpgradeNeeded() raise ErrContactErrorUpgradeNeeded()
elif output.error is not None:
raise ErrAddressInvalid("Invalid address")
elif not output.created:
raise ErrContactAlreadyExists(output.contact)
contact = Contact.create( contact = output.contact
user_id=alias.user_id,
alias_id=alias.id,
website_email=contact_email,
name=contact_name,
reply_email=generate_reply_email(contact_email, alias),
)
LOG.d( LOG.d(
"create reverse-alias for %s %s, reverse alias:%s", "create reverse-alias for %s %s, reverse alias:%s",
contact_address, contact_address,
alias, alias,
contact.reply_email, contact.reply_email,
) )
Session.commit()
return contact return contact
@ -261,7 +244,7 @@ def alias_contact_manager(alias_id):
if new_contact_form.validate(): if new_contact_form.validate():
contact_address = new_contact_form.email.data.strip() contact_address = new_contact_form.email.data.strip()
try: try:
contact = create_contact(current_user, alias, contact_address) contact = create_contact(alias, contact_address)
except ( except (
ErrContactErrorUpgradeNeeded, ErrContactErrorUpgradeNeeded,
ErrAddressInvalid, ErrAddressInvalid,

View file

@ -336,7 +336,7 @@ class Fido(Base, ModelMixin):
class User(Base, ModelMixin, UserMixin, PasswordOracle): class User(Base, ModelMixin, UserMixin, PasswordOracle):
__tablename__ = "users" __tablename__ = "users"
FLAG_FREE_DISABLE_CREATE_ALIAS = 1 << 0 FLAG_DISABLE_CREATE_CONTACTS = 1 << 0
FLAG_CREATED_FROM_PARTNER = 1 << 1 FLAG_CREATED_FROM_PARTNER = 1 << 1
FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2 FLAG_FREE_OLD_ALIAS_LIMIT = 1 << 2
FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3 FLAG_CREATED_ALIAS_FROM_PARTNER = 1 << 3
@ -543,7 +543,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
# bitwise flags. Allow for future expansion # bitwise flags. Allow for future expansion
flags = sa.Column( flags = sa.Column(
sa.BigInteger, sa.BigInteger,
default=FLAG_FREE_DISABLE_CREATE_ALIAS, default=FLAG_DISABLE_CREATE_CONTACTS,
server_default="0", server_default="0",
nullable=False, nullable=False,
) )
@ -1168,7 +1168,7 @@ class User(Base, ModelMixin, UserMixin, PasswordOracle):
def can_create_contacts(self) -> bool: def can_create_contacts(self) -> bool:
if self.is_premium(): if self.is_premium():
return True return True
if self.flags & User.FLAG_FREE_DISABLE_CREATE_ALIAS == 0: if self.flags & User.FLAG_DISABLE_CREATE_CONTACTS == 0:
return True return True
return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS return not config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS

View file

@ -197,8 +197,8 @@ def get_or_create_contact(from_header: str, mail_from: str, alias: Alias) -> Con
contact_email = mail_from contact_email = mail_from
contact_result = contact_utils.create_contact( contact_result = contact_utils.create_contact(
email=contact_email, email=contact_email,
name=contact_name,
alias=alias, alias=alias,
name=contact_name,
mail_from=mail_from, mail_from=mail_from,
allow_empty_email=True, allow_empty_email=True,
automatic_created=True, automatic_created=True,
@ -229,7 +229,7 @@ def get_or_create_reply_to_contact(
) )
return None return None
return contact_utils.create_contact(contact_address, contact_name, alias).contact return contact_utils.create_contact(contact_address, alias, contact_name).contact
def replace_header_when_forward(msg: Message, alias: Alias, header: str): def replace_header_when_forward(msg: Message, alias: Alias, header: str):

View file

@ -536,7 +536,7 @@ def test_create_contact_route_free_users(flask_client):
assert r.status_code == 201 assert r.status_code == 201
# End trial and disallow for new free users. Config should allow it # End trial and disallow for new free users. Config should allow it
user.flags = User.FLAG_FREE_DISABLE_CREATE_ALIAS user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.commit() Session.commit()
r = flask_client.post( r = flask_client.post(
url_for("api.create_contact_route", alias_id=alias.id), url_for("api.create_contact_route", alias_id=alias.id),

View file

@ -4,7 +4,7 @@ from app.models import (
Alias, Alias,
Contact, Contact,
) )
from tests.utils import login from tests.utils import login, random_email
def test_add_contact_success(flask_client): def test_add_contact_success(flask_client):
@ -13,26 +13,28 @@ def test_add_contact_success(flask_client):
assert Contact.filter_by(user_id=user.id).count() == 0 assert Contact.filter_by(user_id=user.id).count() == 0
email = random_email()
# <<< Create a new contact >>> # <<< Create a new contact >>>
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "abcd@gmail.com", "email": email,
}, },
follow_redirects=True, follow_redirects=True,
) )
# a new contact is added # a new contact is added
assert Contact.filter_by(user_id=user.id).count() == 1 assert Contact.filter_by(user_id=user.id).count() == 1
contact = Contact.filter_by(user_id=user.id).first() contact = Contact.filter_by(user_id=user.id).first()
assert contact.website_email == "abcd@gmail.com" assert contact.website_email == email
# <<< Create a new contact using a full email format >>> # <<< Create a new contact using a full email format >>>
email = random_email()
flask_client.post( flask_client.post(
url_for("dashboard.alias_contact_manager", alias_id=alias.id), url_for("dashboard.alias_contact_manager", alias_id=alias.id),
data={ data={
"form-name": "create", "form-name": "create",
"email": "First Last <another@gmail.com>", "email": f"First Last <{email}>",
}, },
follow_redirects=True, follow_redirects=True,
) )
@ -41,7 +43,7 @@ def test_add_contact_success(flask_client):
contact = ( contact = (
Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first() Contact.filter_by(user_id=user.id).filter(Contact.id != contact.id).first()
) )
assert contact.website_email == "another@gmail.com" assert contact.website_email == email
assert contact.name == "First Last" assert contact.name == "First Last"
# <<< Create a new contact with invalid email address >>> # <<< Create a new contact with invalid email address >>>

View file

@ -1,15 +1,26 @@
from typing import Optional from typing import Optional
import pytest import pytest
from app import config
from app.contact_utils import create_contact, ContactCreateError from app.contact_utils import create_contact, ContactCreateError
from app.db import Session from app.db import Session
from app.models import ( from app.models import (
Alias, Alias,
Contact, Contact,
User,
) )
from tests.utils import create_new_user, random_email, random_token from tests.utils import create_new_user, random_email, random_token
def setup_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = True
def teardown_module(module):
config.DISABLE_CREATE_CONTACTS_FOR_FREE_USERS = False
def create_provider(): def create_provider():
# name auto_created from_partner # name auto_created from_partner
yield ["name", "a@b.c", True, True] yield ["name", "a@b.c", True, True]
@ -34,8 +45,8 @@ def test_create_contact(
email = random_email() email = random_email()
contact_result = create_contact( contact_result = create_contact(
email, email,
name,
alias, alias,
name=name,
mail_from=mail_from, mail_from=mail_from,
automatic_created=automatic_created, automatic_created=automatic_created,
from_partner=from_partner, from_partner=from_partner,
@ -57,7 +68,7 @@ def test_create_contact_email_email_not_allowed():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("", "", alias) contact_result = create_contact("", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
@ -66,21 +77,84 @@ def test_create_contact_email_email_allowed():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("", "", alias, allow_empty_email=True) contact_result = create_contact("", alias, allow_empty_email=True)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.website_email == "" assert contact_result.contact.website_email == ""
assert contact_result.contact.invalid_email assert contact_result.contact.invalid_email
def test_create_contact_name_overrides_email_name():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"superseeded <{email}>", alias, name=name)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_name_taken_from_email():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
name = random_token()
contact_result = create_contact(f"{name} <{email}>", alias)
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name == name
def test_create_contact_empty_name_is_none():
user = create_new_user()
alias = Alias.create_new_random(user)
Session.commit()
email = random_email()
contact_result = create_contact(email, alias, name="")
assert contact_result.error is None
assert contact_result.contact is not None
assert contact_result.contact.website_email == email
assert contact_result.contact.name is None
def test_create_contact_free_user():
user = create_new_user()
user.trial_end = None
user.flags = 0
alias = Alias.create_new_random(user)
Session.flush()
# Free users without the FREE_DISABLE_CREATE_CONTACTS
result = create_contact(random_email(), alias)
assert result.error is None
assert result.created
assert result.contact is not None
assert not result.contact.automatic_created
# Free users with the flag should be able to still create automatic emails
user.flags = User.FLAG_DISABLE_CREATE_CONTACTS
Session.flush()
result = create_contact(random_email(), alias, automatic_created=True)
assert result.error is None
assert result.created
assert result.contact is not None
assert result.contact.automatic_created
# Free users with the flag cannot create non-automatic emails
result = create_contact(random_email(), alias)
assert result.error == ContactCreateError.NotAllowed
def test_do_not_allow_invalid_email(): def test_do_not_allow_invalid_email():
user = create_new_user() user = create_new_user()
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
contact_result = create_contact("potato", "", alias) contact_result = create_contact("potato", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
contact_result = create_contact("asdf\x00@gmail.com", "", alias) contact_result = create_contact("asdf\x00@gmail.com", alias)
assert contact_result.contact is None assert contact_result.contact is None
assert contact_result.error == ContactCreateError.InvalidEmail assert contact_result.error == ContactCreateError.InvalidEmail
@ -90,13 +164,15 @@ def test_update_name_for_existing():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
email = random_email() email = random_email()
contact_result = create_contact(email, "", alias) contact_result = create_contact(email, alias)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.name == "" assert contact_result.contact.name is None
name = random_token() name = random_token()
contact_result = create_contact(email, name, alias) contact_result = create_contact(email, alias, name=name)
assert contact_result.error is None assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.name == name assert contact_result.contact.name == name
@ -106,12 +182,15 @@ def test_update_mail_from_for_existing():
alias = Alias.create_new_random(user) alias = Alias.create_new_random(user)
Session.commit() Session.commit()
email = random_email() email = random_email()
contact_result = create_contact(email, "", alias) contact_result = create_contact(email, alias)
assert contact_result.error is None assert contact_result.error is None
assert contact_result.created
assert contact_result.contact is not None
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.mail_from is None assert contact_result.contact.mail_from is None
mail_from = random_email() mail_from = random_email()
contact_result = create_contact(email, "", alias, mail_from=mail_from) contact_result = create_contact(email, alias, mail_from=mail_from)
assert contact_result.error is None assert contact_result.error is None
assert not contact_result.created
assert contact_result.contact is not None assert contact_result.contact is not None
assert contact_result.contact.mail_from == mail_from assert contact_result.contact.mail_from == mail_from