app-MAIL-temp/app/parallel_limiter.py
Adrià Casajús a5056b3fcc
Fix: Use source ip if user is not authenticated (#1379)
* Fix: Use source ip if user is not authenticated

* Fix lint

Co-authored-by: Adrià Casajús <adria.casajus@proton.ch>
2022-10-27 13:37:45 +02:00

74 lines
2.2 KiB
Python

import uuid
from datetime import timedelta
from functools import wraps
from typing import Callable, Any, Optional
from flask import request
from flask_login import current_user
from limits.storage import RedisStorage
from werkzeug import exceptions
lock_redis: Optional[RedisStorage] = None
def set_redis_concurrent_lock(redis: RedisStorage):
global lock_redis
lock_redis = redis
class _InnerLock:
def __init__(
self,
lock_suffix: Optional[str] = None,
max_wait_secs: int = 5,
only_when: Optional[Callable[..., bool]] = None,
):
self.lock_suffix = lock_suffix
self.max_wait_secs = max_wait_secs
self.only_when = only_when
def acquire_lock(self, lock_name: str, lock_value: str):
if not lock_redis.storage.set(
lock_name, lock_value, ex=timedelta(seconds=self.max_wait_secs), nx=True
):
raise exceptions.TooManyRequests()
def release_lock(self, lock_name: str, lock_value: str):
current_lock_value = lock_redis.storage.get(lock_name)
if current_lock_value == lock_value.encode("utf-8"):
lock_redis.storage.delete(lock_name)
def __call__(self, f: Callable[..., Any]):
if self.lock_suffix is None:
lock_suffix = f.__name__
else:
lock_suffix = self.lock_suffix
@wraps(f)
def decorated(*args, **kwargs):
if self.only_when and not self.only_when():
return f(*args, **kwargs)
if not lock_redis:
return f(*args, **kwargs)
lock_value = str(uuid.uuid4())[:10]
if "id" in dir(current_user):
lock_name = f"cl:{current_user.id}:{lock_suffix}"
else:
lock_name = f"cl:{request.remote_addr}:{lock_suffix}"
self.acquire_lock(lock_name, lock_value)
try:
return f(*args, **kwargs)
finally:
self.release_lock(lock_name, lock_value)
return decorated
def lock(
name: Optional[str] = None,
max_wait_secs: int = 5,
only_when: Optional[Callable[..., bool]] = None,
):
return _InnerLock(name, max_wait_secs, only_when)