Only check HIBP alias of paid users (#2065)

This commit is contained in:
Adrià Casajús 2024-03-15 10:13:06 +01:00 committed by GitHub
parent 30ddd4c807
commit aa2c676b5e
No known key found for this signature in database
GPG Key ID: B5690EEEBB952194
4 changed files with 180 additions and 18 deletions

82
cron.py
View File

@ -5,7 +5,7 @@ from typing import List, Tuple
import arrow
import requests
from sqlalchemy import func, desc, or_, and_, nullsfirst
from sqlalchemy import func, desc, or_, and_
from sqlalchemy.ext.compiler import compiles
from sqlalchemy.orm import joinedload
from sqlalchemy.orm.exc import ObjectDeletedError
@ -1033,6 +1033,60 @@ async def _hibp_check(api_key, queue):
await asyncio.sleep(rate_sleep)
def get_alias_to_check_hibp(
oldest_hibp_allowed: arrow.Arrow,
user_ids_to_skip: list[int],
min_alias_id: int,
max_alias_id: int,
):
now = arrow.now()
alias_query = (
Session.query(Alias)
.join(User, User.id == Alias.user_id)
.join(Subscription, User.id == Subscription.user_id, isouter=True)
.join(ManualSubscription, User.id == ManualSubscription.user_id, isouter=True)
.join(AppleSubscription, User.id == AppleSubscription.user_id, isouter=True)
.join(
CoinbaseSubscription,
User.id == CoinbaseSubscription.user_id,
isouter=True,
)
.join(PartnerUser, User.id == PartnerUser.user_id, isouter=True)
.join(
PartnerSubscription,
PartnerSubscription.partner_user_id == PartnerUser.id,
isouter=True,
)
.filter(
or_(
Alias.hibp_last_check.is_(None),
Alias.hibp_last_check < oldest_hibp_allowed,
),
Alias.user_id.notin_(user_ids_to_skip),
Alias.enabled,
Alias.id >= min_alias_id,
Alias.id < max_alias_id,
User.disabled == False, # noqa: E712
or_(
User.lifetime,
ManualSubscription.end_at > now,
Subscription.next_bill_date > now.date(),
AppleSubscription.expires_date > now,
CoinbaseSubscription.end_at > now,
PartnerSubscription.end_at > now,
),
)
)
if config.HIBP_SKIP_PARTNER_ALIAS:
alias_query = alias_query.filter(
Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0
)
for alias in (
alias_query.order_by(Alias.id.asc()).enable_eagerloads(False).yield_per(500)
):
yield alias
async def check_hibp():
"""
Check all aliases on the HIBP (Have I Been Pwned) API
@ -1063,28 +1117,20 @@ async def check_hibp():
queue = asyncio.Queue()
min_alias_id = 0
max_alias_id = Session.query(func.max(Alias.id)).scalar()
step = 500
max_date = arrow.now().shift(days=-config.HIBP_SCAN_INTERVAL_DAYS)
step = 10000
now = arrow.now()
oldest_hibp_allowed = now.shift(days=-config.HIBP_SCAN_INTERVAL_DAYS)
alias_checked = 0
for alias_batch_id in range(min_alias_id, max_alias_id, step):
alias_query = Alias.filter(
or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date),
Alias.user_id.notin_(user_ids),
Alias.enabled,
Alias.id >= alias_batch_id,
Alias.id < alias_batch_id + step,
)
if config.HIBP_SKIP_PARTNER_ALIAS:
alias_query = alias_query(
Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0
)
for alias in alias_query.order_by(
nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc()
).enable_eagerloads(False):
for alias in get_alias_to_check_hibp(
oldest_hibp_allowed, user_ids, alias_batch_id, alias_batch_id + step
):
await queue.put(alias.id)
alias_checked += queue.qsize()
LOG.d("Need to check about %s aliases in this loop", queue.qsize())
LOG.d(
f"Need to check about {queue.qsize()} aliases in this loop {alias_batch_id}/{max_alias_id}"
)
# Start one checking process per API key
# Each checking process will take one alias from the queue, get the info

0
tests/cron/__init__.py Normal file
View File

View File

@ -0,0 +1,116 @@
import arrow
import pytest
import cron
from app.db import Session
from app.models import (
Alias,
AppleSubscription,
PlanEnum,
CoinbaseSubscription,
ManualSubscription,
Subscription,
PartnerUser,
PartnerSubscription,
User,
)
from app.proton.utils import get_proton_partner
from tests.utils import create_new_user, random_token
def test_get_alias_for_free_user_has_no_alias():
user = create_new_user()
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_get_alias_for_lifetime():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def create_partner_sub(user: User):
pu = PartnerUser.create(
partner_id=get_proton_partner().id,
partner_email=user.email,
external_user_id=random_token(10),
user_id=user.id,
flush=True,
)
PartnerSubscription.create(
partner_user_id=pu.id, end_at=arrow.utcnow().shift(days=15)
)
sub_generator_list = [
lambda u: AppleSubscription.create(
user_id=u.id,
expires_date=arrow.now().shift(days=15),
original_transaction_id=random_token(10),
receipt_data=random_token(10),
plan=PlanEnum.monthly,
),
lambda u: CoinbaseSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: ManualSubscription.create(
user_id=u.id,
end_at=arrow.now().shift(days=15),
),
lambda u: Subscription.create(
user_id=u.id,
cancel_url="",
update_url="",
subscription_id=random_token(10),
event_time=arrow.now(),
next_bill_date=arrow.now().shift(days=15).date(),
plan=PlanEnum.monthly,
),
create_partner_sub,
]
@pytest.mark.parametrize("sub_generator", sub_generator_list)
def test_get_alias_for_sub(sub_generator):
user = create_new_user()
sub_generator(user)
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert alias_id == aliases[0].id
def test_disabled_user_is_not_checked():
user = create_new_user()
user.lifetime = True
user.disabled = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1)
)
assert len(aliases) == 0
def test_skipped_user_is_not_checked():
user = create_new_user()
user.lifetime = True
alias_id = Alias.create_new_random(user).id
Session.commit()
aliases = list(
cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1)
)
assert len(aliases) == 0