diff --git a/cron.py b/cron.py index fc8caade..b84aba79 100644 --- a/cron.py +++ b/cron.py @@ -5,7 +5,7 @@ from typing import List, Tuple import arrow import requests -from sqlalchemy import func, desc, or_, and_, nullsfirst +from sqlalchemy import func, desc, or_, and_ from sqlalchemy.ext.compiler import compiles from sqlalchemy.orm import joinedload from sqlalchemy.orm.exc import ObjectDeletedError @@ -1033,6 +1033,60 @@ async def _hibp_check(api_key, queue): await asyncio.sleep(rate_sleep) +def get_alias_to_check_hibp( + oldest_hibp_allowed: arrow.Arrow, + user_ids_to_skip: list[int], + min_alias_id: int, + max_alias_id: int, +): + now = arrow.now() + alias_query = ( + Session.query(Alias) + .join(User, User.id == Alias.user_id) + .join(Subscription, User.id == Subscription.user_id, isouter=True) + .join(ManualSubscription, User.id == ManualSubscription.user_id, isouter=True) + .join(AppleSubscription, User.id == AppleSubscription.user_id, isouter=True) + .join( + CoinbaseSubscription, + User.id == CoinbaseSubscription.user_id, + isouter=True, + ) + .join(PartnerUser, User.id == PartnerUser.user_id, isouter=True) + .join( + PartnerSubscription, + PartnerSubscription.partner_user_id == PartnerUser.id, + isouter=True, + ) + .filter( + or_( + Alias.hibp_last_check.is_(None), + Alias.hibp_last_check < oldest_hibp_allowed, + ), + Alias.user_id.notin_(user_ids_to_skip), + Alias.enabled, + Alias.id >= min_alias_id, + Alias.id < max_alias_id, + User.disabled == False, # noqa: E712 + or_( + User.lifetime, + ManualSubscription.end_at > now, + Subscription.next_bill_date > now.date(), + AppleSubscription.expires_date > now, + CoinbaseSubscription.end_at > now, + PartnerSubscription.end_at > now, + ), + ) + ) + if config.HIBP_SKIP_PARTNER_ALIAS: + alias_query = alias_query.filter( + Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0 + ) + for alias in ( + alias_query.order_by(Alias.id.asc()).enable_eagerloads(False).yield_per(500) + ): + yield alias + + async def check_hibp(): """ Check all aliases on the HIBP (Have I Been Pwned) API @@ -1063,28 +1117,20 @@ async def check_hibp(): queue = asyncio.Queue() min_alias_id = 0 max_alias_id = Session.query(func.max(Alias.id)).scalar() - step = 500 - max_date = arrow.now().shift(days=-config.HIBP_SCAN_INTERVAL_DAYS) + step = 10000 + now = arrow.now() + oldest_hibp_allowed = now.shift(days=-config.HIBP_SCAN_INTERVAL_DAYS) alias_checked = 0 for alias_batch_id in range(min_alias_id, max_alias_id, step): - alias_query = Alias.filter( - or_(Alias.hibp_last_check.is_(None), Alias.hibp_last_check < max_date), - Alias.user_id.notin_(user_ids), - Alias.enabled, - Alias.id >= alias_batch_id, - Alias.id < alias_batch_id + step, - ) - if config.HIBP_SKIP_PARTNER_ALIAS: - alias_query = alias_query( - Alias.flags.op("&")(Alias.FLAG_PARTNER_CREATED) == 0 - ) - for alias in alias_query.order_by( - nullsfirst(Alias.hibp_last_check.asc()), Alias.id.asc() - ).enable_eagerloads(False): + for alias in get_alias_to_check_hibp( + oldest_hibp_allowed, user_ids, alias_batch_id, alias_batch_id + step + ): await queue.put(alias.id) alias_checked += queue.qsize() - LOG.d("Need to check about %s aliases in this loop", queue.qsize()) + LOG.d( + f"Need to check about {queue.qsize()} aliases in this loop {alias_batch_id}/{max_alias_id}" + ) # Start one checking process per API key # Each checking process will take one alias from the queue, get the info diff --git a/tests/cron/__init__.py b/tests/cron/__init__.py new file mode 100644 index 00000000..e69de29b diff --git a/tests/test_cron.py b/tests/cron/test_cron.py similarity index 100% rename from tests/test_cron.py rename to tests/cron/test_cron.py diff --git a/tests/cron/test_get_alias_for_hibp.py b/tests/cron/test_get_alias_for_hibp.py new file mode 100644 index 00000000..ca7d82ae --- /dev/null +++ b/tests/cron/test_get_alias_for_hibp.py @@ -0,0 +1,116 @@ +import arrow +import pytest + +import cron +from app.db import Session +from app.models import ( + Alias, + AppleSubscription, + PlanEnum, + CoinbaseSubscription, + ManualSubscription, + Subscription, + PartnerUser, + PartnerSubscription, + User, +) +from app.proton.utils import get_proton_partner +from tests.utils import create_new_user, random_token + + +def test_get_alias_for_free_user_has_no_alias(): + user = create_new_user() + alias_id = Alias.create_new_random(user).id + Session.commit() + aliases = list( + cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1) + ) + assert len(aliases) == 0 + + +def test_get_alias_for_lifetime(): + user = create_new_user() + user.lifetime = True + alias_id = Alias.create_new_random(user).id + Session.commit() + aliases = list( + cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1) + ) + assert alias_id == aliases[0].id + + +def create_partner_sub(user: User): + pu = PartnerUser.create( + partner_id=get_proton_partner().id, + partner_email=user.email, + external_user_id=random_token(10), + user_id=user.id, + flush=True, + ) + PartnerSubscription.create( + partner_user_id=pu.id, end_at=arrow.utcnow().shift(days=15) + ) + + +sub_generator_list = [ + lambda u: AppleSubscription.create( + user_id=u.id, + expires_date=arrow.now().shift(days=15), + original_transaction_id=random_token(10), + receipt_data=random_token(10), + plan=PlanEnum.monthly, + ), + lambda u: CoinbaseSubscription.create( + user_id=u.id, + end_at=arrow.now().shift(days=15), + ), + lambda u: ManualSubscription.create( + user_id=u.id, + end_at=arrow.now().shift(days=15), + ), + lambda u: Subscription.create( + user_id=u.id, + cancel_url="", + update_url="", + subscription_id=random_token(10), + event_time=arrow.now(), + next_bill_date=arrow.now().shift(days=15).date(), + plan=PlanEnum.monthly, + ), + create_partner_sub, +] + + +@pytest.mark.parametrize("sub_generator", sub_generator_list) +def test_get_alias_for_sub(sub_generator): + user = create_new_user() + sub_generator(user) + alias_id = Alias.create_new_random(user).id + Session.commit() + aliases = list( + cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1) + ) + assert alias_id == aliases[0].id + + +def test_disabled_user_is_not_checked(): + user = create_new_user() + user.lifetime = True + user.disabled = True + alias_id = Alias.create_new_random(user).id + Session.commit() + aliases = list( + cron.get_alias_to_check_hibp(arrow.now(), [], alias_id, alias_id + 1) + ) + assert len(aliases) == 0 + + +def test_skipped_user_is_not_checked(): + user = create_new_user() + user.lifetime = True + alias_id = Alias.create_new_random(user).id + Session.commit() + aliases = list( + cron.get_alias_to_check_hibp(arrow.now(), [user.id], alias_id, alias_id + 1) + ) + assert len(aliases) == 0