Rate limiting depending on user authenticated status (#1221)
* Rate limiting depending on user authenticated status * Update app/extensions.py Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com> * Add rate_limiting tests Co-authored-by: Adrià Casajús <acasajus@users.noreply.github.com>
This commit is contained in:
parent
a88a8ff2be
commit
a9549c11d7
|
@ -1,13 +1,24 @@
|
||||||
from flask_limiter import Limiter
|
from flask_limiter import Limiter
|
||||||
from flask_limiter.util import get_remote_address
|
from flask_limiter.util import get_remote_address
|
||||||
from flask_login import LoginManager
|
from flask_login import current_user, LoginManager
|
||||||
|
|
||||||
login_manager = LoginManager()
|
login_manager = LoginManager()
|
||||||
login_manager.session_protection = "strong"
|
login_manager.session_protection = "strong"
|
||||||
|
|
||||||
# Setup rate limit facility
|
|
||||||
limiter = Limiter(key_func=get_remote_address)
|
|
||||||
|
|
||||||
|
# We want to rate limit based on:
|
||||||
|
# - If the user is not logged in: request source IP
|
||||||
|
# - If the user is logged in: user_id
|
||||||
|
def __key_func():
|
||||||
|
if current_user.is_authenticated:
|
||||||
|
return f"userid:{current_user.id}"
|
||||||
|
else:
|
||||||
|
ip_addr = get_remote_address()
|
||||||
|
return f"ip:{ip_addr}"
|
||||||
|
|
||||||
|
|
||||||
|
# Setup rate limit facility
|
||||||
|
limiter = Limiter(key_func=__key_func)
|
||||||
|
|
||||||
# @limiter.request_filter
|
# @limiter.request_filter
|
||||||
# def ip_whitelist():
|
# def ip_whitelist():
|
||||||
|
|
|
@ -0,0 +1,104 @@
|
||||||
|
from flask import g
|
||||||
|
from http import HTTPStatus
|
||||||
|
from random import Random
|
||||||
|
|
||||||
|
from app.extensions import limiter
|
||||||
|
from tests.utils import login
|
||||||
|
from tests.conftest import app as test_app
|
||||||
|
|
||||||
|
# IMPORTANT NOTICE
|
||||||
|
# ----------------
|
||||||
|
# This test file has a special behaviour. After each request, a call to fix_rate_limit_after_request must
|
||||||
|
# be performed, in order for the rate_limiting process to work appropriately in test time.
|
||||||
|
# If you want to see why, feel free to refer to the source of the "hack":
|
||||||
|
# https://github.com/alisaifee/flask-limiter/issues/147#issuecomment-642683820
|
||||||
|
|
||||||
|
_ENDPOINT = "/tests/internal/rate_limited"
|
||||||
|
_MAX_PER_MINUTE = 3
|
||||||
|
|
||||||
|
|
||||||
|
@test_app.route(
|
||||||
|
_ENDPOINT,
|
||||||
|
methods=["GET"],
|
||||||
|
)
|
||||||
|
@limiter.limit(f"{_MAX_PER_MINUTE}/minute")
|
||||||
|
def rate_limited_endpoint_1():
|
||||||
|
return "Working", HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
def random_ip() -> str:
|
||||||
|
rand = Random()
|
||||||
|
octets = [str(rand.randint(0, 255)) for _ in range(4)]
|
||||||
|
return ".".join(octets)
|
||||||
|
|
||||||
|
|
||||||
|
def fix_rate_limit_after_request():
|
||||||
|
g._rate_limiting_complete = False
|
||||||
|
|
||||||
|
|
||||||
|
def request_headers(source_ip: str) -> dict:
|
||||||
|
return {"X-Forwarded-For": source_ip}
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_limits_by_source_ip(flask_client):
|
||||||
|
source_ip = random_ip()
|
||||||
|
|
||||||
|
for _ in range(_MAX_PER_MINUTE):
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(source_ip))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(source_ip))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
# Check that changing the "X-Forwarded-For" allows the request to succeed
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(random_ip()))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_limits_by_user_id(flask_client):
|
||||||
|
# Login with a user
|
||||||
|
login(flask_client)
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
|
||||||
|
# Run the N requests with a different source IP but with the same user
|
||||||
|
for _ in range(_MAX_PER_MINUTE):
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(random_ip()))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(random_ip()))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
|
||||||
|
def test_rate_limit_limits_by_user_id_ignoring_ip(flask_client):
|
||||||
|
source_ip = random_ip()
|
||||||
|
|
||||||
|
# Login with a user
|
||||||
|
login(flask_client)
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
|
||||||
|
# Run the N requests with a different source IP but with the same user
|
||||||
|
for _ in range(_MAX_PER_MINUTE):
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(source_ip))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.OK
|
||||||
|
|
||||||
|
res = flask_client.get(_ENDPOINT)
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.TOO_MANY_REQUESTS
|
||||||
|
|
||||||
|
# Log out
|
||||||
|
flask_client.cookie_jar.clear()
|
||||||
|
|
||||||
|
# Log in with another user
|
||||||
|
login(flask_client)
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
|
||||||
|
# Run the request again, reusing the same IP as before
|
||||||
|
res = flask_client.get(_ENDPOINT, headers=request_headers(source_ip))
|
||||||
|
fix_rate_limit_after_request()
|
||||||
|
assert res.status_code == HTTPStatus.OK
|
Loading…
Reference in New Issue