mirror of
https://github.com/simple-login/app.git
synced 2024-11-14 08:01:13 +01:00
73 lines
2.2 KiB
Python
73 lines
2.2 KiB
Python
import uuid
|
|
from datetime import timedelta
|
|
from functools import wraps
|
|
from typing import Callable, Any, Optional
|
|
|
|
from flask import request
|
|
from flask_login import current_user
|
|
from limits.storage import RedisStorage
|
|
from werkzeug import exceptions
|
|
|
|
lock_redis: Optional[RedisStorage] = None
|
|
|
|
|
|
def set_redis_concurrent_lock(redis: RedisStorage):
|
|
global lock_redis
|
|
lock_redis = redis
|
|
|
|
|
|
class _InnerLock:
|
|
def __init__(
|
|
self,
|
|
lock_suffix: Optional[str] = None,
|
|
max_wait_secs: int = 5,
|
|
only_when: Optional[Callable[..., bool]] = None,
|
|
):
|
|
self.lock_suffix = lock_suffix
|
|
self.max_wait_secs = max_wait_secs
|
|
self.only_when = only_when
|
|
|
|
def acquire_lock(self, lock_name: str, lock_value: str):
|
|
if not lock_redis.storage.set(
|
|
lock_name, lock_value, ex=timedelta(seconds=self.max_wait_secs), nx=True
|
|
):
|
|
raise exceptions.TooManyRequests()
|
|
|
|
def release_lock(self, lock_name: str, lock_value: str):
|
|
current_lock_value = lock_redis.storage.get(lock_name)
|
|
if current_lock_value == lock_value.encode("utf-8"):
|
|
lock_redis.storage.delete(lock_name)
|
|
|
|
def __call__(self, f: Callable[..., Any]):
|
|
if self.lock_suffix is None:
|
|
lock_suffix = f.__name__
|
|
else:
|
|
lock_suffix = self.lock_suffix
|
|
|
|
@wraps(f)
|
|
def decorated(*args, **kwargs):
|
|
if self.only_when and not self.only_when():
|
|
return f(*args, **kwargs)
|
|
if not lock_redis:
|
|
return f(*args, **kwargs)
|
|
|
|
lock_value = str(uuid.uuid4())[:10]
|
|
if "id" in dir(current_user):
|
|
lock_name = f"cl:{current_user.id}:{lock_suffix}"
|
|
else:
|
|
lock_name = f"cl:{request.remote_addr}:{lock_suffix}"
|
|
self.acquire_lock(lock_name, lock_value)
|
|
try:
|
|
return f(*args, **kwargs)
|
|
finally:
|
|
self.release_lock(lock_name, lock_value)
|
|
|
|
return decorated
|
|
|
|
|
|
def lock(
|
|
name: Optional[str] = None,
|
|
max_wait_secs: int = 5,
|
|
only_when: Optional[Callable[..., bool]] = None,
|
|
):
|
|
return _InnerLock(name, max_wait_secs, only_when)
|