diff --git a/app/auth/views/mfa.py b/app/auth/views/mfa.py index 6bd0ab12..0a6bfc72 100644 --- a/app/auth/views/mfa.py +++ b/app/auth/views/mfa.py @@ -15,7 +15,6 @@ from wtforms import BooleanField, StringField, validators from app.auth.base import auth_bp from app.config import MFA_USER_ID, URL from app.extensions import db -from app.log import LOG from app.models import User, MfaBrowser @@ -47,21 +46,21 @@ def mfa(): if request.cookies.get("mfa"): browser = MfaBrowser.get_by(token=request.cookies.get("mfa")) - if browser and not browser.is_expired(): + if browser and not browser.is_expired() and browser.user_id == user.id: login_user(user) flash(f"Welcome back {user.name}!", "success") # Redirect user to correct page return redirect(next_url or url_for("dashboard.index")) - MfaBrowser.delete(browser.token) - if otp_token_form.validate_on_submit(): totp = pyotp.TOTP(user.otp_secret) - token = otp_token_form.token.data + token = otp_token_form.token.data.replace(" ", "") - if totp.verify(token): + if totp.verify(token) and user.last_otp != token: del session[MFA_USER_ID] + user.last_otp = token + db.session.commit() login_user(user) flash(f"Welcome back {user.name}!", "success") diff --git a/app/dashboard/views/mfa_setup.py b/app/dashboard/views/mfa_setup.py index 2b7ac48c..afb639e4 100644 --- a/app/dashboard/views/mfa_setup.py +++ b/app/dashboard/views/mfa_setup.py @@ -32,10 +32,11 @@ def mfa_setup(): totp = pyotp.TOTP(current_user.otp_secret) if otp_token_form.validate_on_submit(): - token = otp_token_form.token.data + token = otp_token_form.token.data.replace(" ", "") - if totp.verify(token): + if totp.verify(token) and current_user.last_otp != token: current_user.enable_otp = True + current_user.last_otp = token db.session.commit() flash("MFA has been activated", "success") diff --git a/app/models.py b/app/models.py index bf6df8f4..9e3ef46f 100644 --- a/app/models.py +++ b/app/models.py @@ -76,9 +76,9 @@ class ModelMixin(object): def save(self): db.session.add(self) - @classmethod - def delete(cls, obj_id): - cls.query.filter(cls.id == obj_id).delete() + @classmethod + def delete(cls, obj_id): + cls.query.filter(cls.id == obj_id).delete() def __repr__(self): values = ", ".join( @@ -161,6 +161,7 @@ class User(db.Model, ModelMixin, UserMixin): enable_otp = db.Column( db.Boolean, nullable=False, default=False, server_default="0" ) + last_otp = db.Column(db.String(12), nullable=True, default=False) # Fields for WebAuthn fido_uuid = db.Column(db.String(), nullable=True, unique=True) @@ -510,17 +511,22 @@ def generate_oauth_client_id(client_name) -> str: class MfaBrowser(db.Model, ModelMixin): user_id = db.Column(db.ForeignKey(User.id, ondelete="cascade"), nullable=False) - token = db.Column(db.String(64), default=False, nullable=False) + token = db.Column(db.String(64), default=False, unique=True, nullable=False) expires = db.Column(ArrowType, default=False, nullable=False) user = db.relationship(User) @classmethod def create_new(cls, user, token_length=64) -> "MfaBrowser": + found = False + while not found: + token = random_string(token_length) + + if not cls.get_by(token=token): + found = True + return MfaBrowser.create( - user_id=user.id, - token=random_string(token_length), - expires=arrow.now().shift(days=30), + user_id=user.id, token=token, expires=arrow.now().shift(days=30), ) @classmethod diff --git a/migrations/versions/2020_052216_ea50319ea811_.py b/migrations/versions/2020_052216_ea50319ea811_.py new file mode 100644 index 00000000..faad2228 --- /dev/null +++ b/migrations/versions/2020_052216_ea50319ea811_.py @@ -0,0 +1,29 @@ +"""empty message + +Revision ID: ea50319ea811 +Revises: 95599239860a +Create Date: 2020-05-22 16:49:25.613344 + +""" +import sqlalchemy_utils +from alembic import op +import sqlalchemy as sa + + +# revision identifiers, used by Alembic. +revision = 'ea50319ea811' +down_revision = '95599239860a' +branch_labels = None +depends_on = None + + +def upgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.add_column('users', sa.Column('last_otp', sa.String(length=12), nullable=True)) + # ### end Alembic commands ### + + +def downgrade(): + # ### commands auto generated by Alembic - please adjust! ### + op.drop_column('users', 'last_otp') + # ### end Alembic commands ###