Store session in redis if redis is enabled (#1288)
* Store sesions in redis to prevent saving old cookies * Format * Rename sid to session_id * Logout session completely Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
This commit is contained in:
parent
2760b149ff
commit
b5aff490ef
|
@ -3,7 +3,6 @@ from io import BytesIO
|
||||||
from typing import Optional
|
from typing import Optional
|
||||||
|
|
||||||
from flask import jsonify, g, request, make_response
|
from flask import jsonify, g, request, make_response
|
||||||
from flask_login import logout_user
|
|
||||||
|
|
||||||
from app import s3, config
|
from app import s3, config
|
||||||
from app.api.base import api_bp, require_api_auth
|
from app.api.base import api_bp, require_api_auth
|
||||||
|
@ -11,6 +10,7 @@ from app.config import SESSION_COOKIE_NAME
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.models import ApiKey, File, PartnerUser, User
|
from app.models import ApiKey, File, PartnerUser, User
|
||||||
from app.proton.utils import get_proton_partner
|
from app.proton.utils import get_proton_partner
|
||||||
|
from app.session import logout_session
|
||||||
from app.utils import random_string
|
from app.utils import random_string
|
||||||
|
|
||||||
|
|
||||||
|
@ -131,7 +131,7 @@ def logout():
|
||||||
Output:
|
Output:
|
||||||
- 200
|
- 200
|
||||||
"""
|
"""
|
||||||
logout_user()
|
logout_session()
|
||||||
response = make_response(jsonify(msg="User is logged out"), 200)
|
response = make_response(jsonify(msg="User is logged out"), 200)
|
||||||
response.delete_cookie(SESSION_COOKIE_NAME)
|
response.delete_cookie(SESSION_COOKIE_NAME)
|
||||||
|
|
||||||
|
|
|
@ -1,13 +1,13 @@
|
||||||
from flask import redirect, url_for, flash, make_response
|
from flask import redirect, url_for, flash, make_response
|
||||||
from flask_login import logout_user
|
|
||||||
|
|
||||||
from app.auth.base import auth_bp
|
from app.auth.base import auth_bp
|
||||||
from app.config import SESSION_COOKIE_NAME
|
from app.config import SESSION_COOKIE_NAME
|
||||||
|
from app.session import logout_session
|
||||||
|
|
||||||
|
|
||||||
@auth_bp.route("/logout")
|
@auth_bp.route("/logout")
|
||||||
def logout():
|
def logout():
|
||||||
logout_user()
|
logout_session()
|
||||||
flash("You are logged out", "success")
|
flash("You are logged out", "success")
|
||||||
response = make_response(redirect(url_for("auth.login")))
|
response = make_response(redirect(url_for("auth.login")))
|
||||||
response.delete_cookie(SESSION_COOKIE_NAME)
|
response.delete_cookie(SESSION_COOKIE_NAME)
|
||||||
|
|
|
@ -0,0 +1,117 @@
|
||||||
|
import uuid
|
||||||
|
from typing import Optional
|
||||||
|
|
||||||
|
import flask
|
||||||
|
import redis
|
||||||
|
from flask import current_app, session
|
||||||
|
from flask_login import logout_user
|
||||||
|
|
||||||
|
try:
|
||||||
|
import cPickle as pickle
|
||||||
|
except ImportError:
|
||||||
|
import pickle
|
||||||
|
|
||||||
|
import itsdangerous
|
||||||
|
from flask.sessions import SessionMixin, SessionInterface
|
||||||
|
from werkzeug.datastructures import CallbackDict
|
||||||
|
|
||||||
|
SESSION_PREFIX = "session"
|
||||||
|
|
||||||
|
|
||||||
|
class ServerSession(CallbackDict, SessionMixin):
|
||||||
|
def __init__(self, initial=None, session_id=None):
|
||||||
|
def on_update(self):
|
||||||
|
self.modified = True
|
||||||
|
|
||||||
|
super(ServerSession, self).__init__(initial, on_update)
|
||||||
|
self.session_id = session_id
|
||||||
|
self.modified = False
|
||||||
|
|
||||||
|
|
||||||
|
class RedisSessionStore(SessionInterface):
|
||||||
|
def __init__(self, redis, app):
|
||||||
|
self._redis = redis
|
||||||
|
self._app = app
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_signer(cls, app) -> itsdangerous.Signer:
|
||||||
|
return itsdangerous.Signer(
|
||||||
|
app.secret_key, salt="session", key_derivation="hmac"
|
||||||
|
)
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def _get_key(cls, session_Id: str) -> str:
|
||||||
|
return f"{SESSION_PREFIX}:{session_Id}"
|
||||||
|
|
||||||
|
@classmethod
|
||||||
|
def extract_and_validate_session_id(
|
||||||
|
cls, app: flask.Flask, request: flask.Request
|
||||||
|
) -> Optional[str]:
|
||||||
|
unverified_session_Id = request.cookies.get(app.session_cookie_name)
|
||||||
|
if not unverified_session_Id:
|
||||||
|
return None
|
||||||
|
signer = cls._get_signer(app)
|
||||||
|
try:
|
||||||
|
sid_as_bytes = signer.unsign(unverified_session_Id)
|
||||||
|
return sid_as_bytes.decode()
|
||||||
|
except itsdangerous.BadSignature:
|
||||||
|
return None
|
||||||
|
|
||||||
|
def purge_session(self, session: ServerSession):
|
||||||
|
try:
|
||||||
|
self._redis.delete(self._get_key(session.session_id))
|
||||||
|
session.session_id = str(uuid.uuid4())
|
||||||
|
except AttributeError:
|
||||||
|
pass
|
||||||
|
|
||||||
|
def open_session(self, app: flask.Flask, request: flask.Request):
|
||||||
|
session_id = self.extract_and_validate_session_id(app, request)
|
||||||
|
if not session_id:
|
||||||
|
return ServerSession(session_id=str(uuid.uuid4()))
|
||||||
|
|
||||||
|
val = self._redis.get(self._get_key(session_id))
|
||||||
|
if val is not None:
|
||||||
|
try:
|
||||||
|
data = pickle.loads(val)
|
||||||
|
return ServerSession(data, session_id=session_id)
|
||||||
|
except:
|
||||||
|
pass
|
||||||
|
return ServerSession(session_id=str(uuid.uuid4()))
|
||||||
|
|
||||||
|
def save_session(
|
||||||
|
self, app: flask.Flask, session: ServerSession, response: flask.Response
|
||||||
|
):
|
||||||
|
domain = self.get_cookie_domain(app)
|
||||||
|
path = self.get_cookie_path(app)
|
||||||
|
httponly = self.get_cookie_httponly(app)
|
||||||
|
secure = self.get_cookie_secure(app)
|
||||||
|
expires = self.get_expiration_time(app, session)
|
||||||
|
val = pickle.dumps(dict(session))
|
||||||
|
self._redis.setex(
|
||||||
|
name=self._get_key(session.session_id),
|
||||||
|
value=val,
|
||||||
|
time=int(app.permanent_session_lifetime.total_seconds()),
|
||||||
|
)
|
||||||
|
signed_session_id = self._get_signer(app).sign(
|
||||||
|
itsdangerous.want_bytes(session.session_id)
|
||||||
|
)
|
||||||
|
response.set_cookie(
|
||||||
|
app.session_cookie_name,
|
||||||
|
signed_session_id,
|
||||||
|
expires=expires,
|
||||||
|
httponly=httponly,
|
||||||
|
domain=domain,
|
||||||
|
path=path,
|
||||||
|
secure=secure,
|
||||||
|
)
|
||||||
|
|
||||||
|
|
||||||
|
def set_redis_session(app: flask.Flask, redis_url: str):
|
||||||
|
app.session_interface = RedisSessionStore(redis.from_url(redis_url), app)
|
||||||
|
|
||||||
|
|
||||||
|
def logout_session():
|
||||||
|
logout_user()
|
||||||
|
purge_fn = getattr(current_app.session_interface, "purge_session", None)
|
||||||
|
if callable(purge_fn):
|
||||||
|
purge_fn(session)
|
|
@ -108,6 +108,7 @@ from app.newsletter_utils import send_newsletter_to_user
|
||||||
from app.oauth.base import oauth_bp
|
from app.oauth.base import oauth_bp
|
||||||
from app.onboarding.base import onboarding_bp
|
from app.onboarding.base import onboarding_bp
|
||||||
from app.phone.base import phone_bp
|
from app.phone.base import phone_bp
|
||||||
|
from app.session import set_redis_session
|
||||||
from app.utils import random_string
|
from app.utils import random_string
|
||||||
|
|
||||||
if SENTRY_DSN:
|
if SENTRY_DSN:
|
||||||
|
@ -163,6 +164,7 @@ def create_app() -> Flask:
|
||||||
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
app.config["SESSION_COOKIE_SAMESITE"] = "Lax"
|
||||||
if MEM_STORE_URI:
|
if MEM_STORE_URI:
|
||||||
app.config[flask_limiter.extension.C.STORAGE_URL] = MEM_STORE_URI
|
app.config[flask_limiter.extension.C.STORAGE_URL] = MEM_STORE_URI
|
||||||
|
set_redis_session(app, MEM_STORE_URI)
|
||||||
|
|
||||||
limiter.init_app(app)
|
limiter.init_app(app)
|
||||||
|
|
||||||
|
|
Loading…
Reference in New Issue