mirror of
https://github.com/simple-login/app.git
synced 2024-11-16 17:08:30 +01:00
Have the dead letter also take events to avoid race conditions (#2267)
* Have the dead letter also take events to avoid race conditions * Ensure we take the event * Tests for event taken * Rename
This commit is contained in:
parent
2cd6ee777f
commit
3c13f1ce20
5 changed files with 172 additions and 23 deletions
|
@ -3771,15 +3771,14 @@ class SyncEvent(Base, ModelMixin):
|
||||||
sa.Index("ix_sync_event_taken_time", "taken_time"),
|
sa.Index("ix_sync_event_taken_time", "taken_time"),
|
||||||
)
|
)
|
||||||
|
|
||||||
def mark_as_taken(self) -> bool:
|
def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool:
|
||||||
sql = """
|
taken_condition = ["taken_time IS NULL"]
|
||||||
UPDATE sync_event
|
|
||||||
SET taken_time = :taken_time
|
|
||||||
WHERE id = :sync_event_id
|
|
||||||
AND taken_time IS NULL
|
|
||||||
"""
|
|
||||||
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
|
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
|
||||||
|
if allow_taken_older_than:
|
||||||
|
taken_condition.append("taken_time < :taken_older_than")
|
||||||
|
args["taken_older_than"] = allow_taken_older_than.datetime
|
||||||
|
sql_taken_condition = "({})".format(" OR ".join(taken_condition))
|
||||||
|
sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}"
|
||||||
res = Session.execute(sql, args)
|
res = Session.execute(sql, args)
|
||||||
Session.commit()
|
Session.commit()
|
||||||
|
|
||||||
|
|
|
@ -85,24 +85,28 @@ class DeadLetterEventSource(EventSource):
|
||||||
def __init__(self, max_retries: int):
|
def __init__(self, max_retries: int):
|
||||||
self.__max_retries = max_retries
|
self.__max_retries = max_retries
|
||||||
|
|
||||||
|
def execute_loop(
|
||||||
|
self, on_event: Callable[[SyncEvent], NoReturn]
|
||||||
|
) -> list[SyncEvent]:
|
||||||
|
threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES)
|
||||||
|
events = SyncEvent.get_dead_letter(
|
||||||
|
older_than=threshold, max_retries=self.__max_retries
|
||||||
|
)
|
||||||
|
if events:
|
||||||
|
LOG.info(f"Got {len(events)} dead letter events")
|
||||||
|
newrelic.agent.record_custom_metric(
|
||||||
|
"Custom/dead_letter_events_to_process", len(events)
|
||||||
|
)
|
||||||
|
for event in events:
|
||||||
|
if event.mark_as_taken(allow_taken_older_than=threshold):
|
||||||
|
on_event(event)
|
||||||
|
return events
|
||||||
|
|
||||||
@newrelic.agent.background_task()
|
@newrelic.agent.background_task()
|
||||||
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
|
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
|
||||||
while True:
|
while True:
|
||||||
try:
|
try:
|
||||||
threshold = arrow.utcnow().shift(
|
events = self.execute_loop(on_event)
|
||||||
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
|
|
||||||
)
|
|
||||||
events = SyncEvent.get_dead_letter(
|
|
||||||
older_than=threshold, max_retries=self.__max_retries
|
|
||||||
)
|
|
||||||
if events:
|
|
||||||
LOG.info(f"Got {len(events)} dead letter events")
|
|
||||||
if events:
|
|
||||||
newrelic.agent.record_custom_metric(
|
|
||||||
"Custom/dead_letter_events_to_process", len(events)
|
|
||||||
)
|
|
||||||
for event in events:
|
|
||||||
on_event(event)
|
|
||||||
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
|
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
|
||||||
if not events:
|
if not events:
|
||||||
LOG.debug("No dead letter events")
|
LOG.debug("No dead letter events")
|
||||||
|
|
63
oneshot/send_plan_change_events.py
Normal file
63
oneshot/send_plan_change_events.py
Normal file
|
@ -0,0 +1,63 @@
|
||||||
|
#!/usr/bin/env python3
|
||||||
|
import argparse
|
||||||
|
import time
|
||||||
|
|
||||||
|
from sqlalchemy import func
|
||||||
|
|
||||||
|
from app.events.event_dispatcher import EventDispatcher
|
||||||
|
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
|
||||||
|
from app.models import PartnerUser
|
||||||
|
from app.db import Session
|
||||||
|
|
||||||
|
parser = argparse.ArgumentParser(
|
||||||
|
prog="Backfill alias", description="Update alias notes and backfill flag"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
|
||||||
|
)
|
||||||
|
parser.add_argument(
|
||||||
|
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
|
||||||
|
)
|
||||||
|
|
||||||
|
args = parser.parse_args()
|
||||||
|
pu_id_start = args.start_pu_id
|
||||||
|
max_pu_id = args.end_pu_id
|
||||||
|
if max_pu_id == 0:
|
||||||
|
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
|
||||||
|
|
||||||
|
print(f"Checking partner user {pu_id_start} to {max_pu_id}")
|
||||||
|
step = 100
|
||||||
|
updated = 0
|
||||||
|
start_time = time.time()
|
||||||
|
with_premium = 0
|
||||||
|
for batch_start in range(pu_id_start, max_pu_id, step):
|
||||||
|
partner_users = (
|
||||||
|
Session.query(PartnerUser).filter(
|
||||||
|
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step
|
||||||
|
)
|
||||||
|
).all()
|
||||||
|
for partner_user in partner_users:
|
||||||
|
subscription_end = partner_user.user.get_active_subscription_end(
|
||||||
|
include_partner_subscription=False
|
||||||
|
)
|
||||||
|
end_timestamp = None
|
||||||
|
if subscription_end:
|
||||||
|
with_premium += 1
|
||||||
|
end_timestamp = subscription_end.timestamp
|
||||||
|
event = UserPlanChanged(plan_end_time=end_timestamp)
|
||||||
|
EventDispatcher.send_event(
|
||||||
|
partner_user.user, EventContent(user_plan_change=event)
|
||||||
|
)
|
||||||
|
Session.flush()
|
||||||
|
updated += 1
|
||||||
|
Session.commit()
|
||||||
|
elapsed = time.time() - start_time
|
||||||
|
last_batch_id = batch_start + step
|
||||||
|
time_per_alias = elapsed / (last_batch_id)
|
||||||
|
remaining = max_pu_id - last_batch_id
|
||||||
|
time_remaining = remaining / time_per_alias
|
||||||
|
hours_remaining = time_remaining / 60.0
|
||||||
|
print(
|
||||||
|
f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining"
|
||||||
|
)
|
||||||
|
print(f"With SL premium {with_premium}")
|
54
tests/events/test_dead_letter_event_source.py
Normal file
54
tests/events/test_dead_letter_event_source.py
Normal file
|
@ -0,0 +1,54 @@
|
||||||
|
import arrow
|
||||||
|
|
||||||
|
from app.db import Session
|
||||||
|
from app.models import SyncEvent
|
||||||
|
from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES
|
||||||
|
|
||||||
|
|
||||||
|
class EventCounter:
|
||||||
|
def __init__(self):
|
||||||
|
self.processed_events = 0
|
||||||
|
|
||||||
|
def on_event(self, event: SyncEvent):
|
||||||
|
self.processed_events += 1
|
||||||
|
|
||||||
|
|
||||||
|
def setup_function(func):
|
||||||
|
Session.query(SyncEvent).delete()
|
||||||
|
|
||||||
|
|
||||||
|
def test_dead_letter_does_not_take_untaken_events():
|
||||||
|
source = DeadLetterEventSource(1)
|
||||||
|
counter = EventCounter()
|
||||||
|
threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1)
|
||||||
|
SyncEvent.create(
|
||||||
|
content="test".encode("utf-8"), created_at=threshold_time, flush=True
|
||||||
|
)
|
||||||
|
SyncEvent.create(
|
||||||
|
content="test".encode("utf-8"), taken_time=threshold_time, flush=True
|
||||||
|
)
|
||||||
|
events_processed = source.execute_loop(on_event=counter.on_event)
|
||||||
|
assert len(events_processed) == 0
|
||||||
|
assert counter.processed_events == 0
|
||||||
|
|
||||||
|
|
||||||
|
def test_dead_letter_takes_untaken_events_created_older_than_threshold():
|
||||||
|
source = DeadLetterEventSource(1)
|
||||||
|
counter = EventCounter()
|
||||||
|
old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
|
||||||
|
SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True)
|
||||||
|
events_processed = source.execute_loop(on_event=counter.on_event)
|
||||||
|
assert len(events_processed) == 1
|
||||||
|
assert events_processed[0].taken_time > old_create
|
||||||
|
assert counter.processed_events == 1
|
||||||
|
|
||||||
|
|
||||||
|
def test_dead_letter_takes_taken_events_created_older_than_threshold():
|
||||||
|
source = DeadLetterEventSource(1)
|
||||||
|
counter = EventCounter()
|
||||||
|
old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
|
||||||
|
SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True)
|
||||||
|
events_processed = source.execute_loop(on_event=counter.on_event)
|
||||||
|
assert len(events_processed) == 1
|
||||||
|
assert events_processed[0].taken_time > old_taken
|
||||||
|
assert counter.processed_events == 1
|
|
@ -1,7 +1,9 @@
|
||||||
|
import arrow
|
||||||
|
|
||||||
from app import config, alias_utils
|
from app import config, alias_utils
|
||||||
from app.db import Session
|
from app.db import Session
|
||||||
from app.events.event_dispatcher import GlobalDispatcher
|
from app.events.event_dispatcher import GlobalDispatcher
|
||||||
from app.models import Alias
|
from app.models import Alias, SyncEvent
|
||||||
from tests.utils import random_token
|
from tests.utils import random_token
|
||||||
from .event_test_utils import (
|
from .event_test_utils import (
|
||||||
OnMemoryDispatcher,
|
OnMemoryDispatcher,
|
||||||
|
@ -26,6 +28,33 @@ def setup_function(func):
|
||||||
on_memory_dispatcher.clear()
|
on_memory_dispatcher.clear()
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_taken_updates():
|
||||||
|
event = SyncEvent.create(content="test".encode("utf-8"), flush=True)
|
||||||
|
assert event.taken_time is None
|
||||||
|
assert event.mark_as_taken()
|
||||||
|
assert event.taken_time is not None
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_mark_as_taken_does_nothing_for_taken_events():
|
||||||
|
now = arrow.utcnow()
|
||||||
|
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
|
||||||
|
assert not event.mark_as_taken()
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_mark_as_taken_does_nothing_for_not_before_events():
|
||||||
|
now = arrow.utcnow()
|
||||||
|
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
|
||||||
|
older_than = now.shift(minutes=-1)
|
||||||
|
assert not event.mark_as_taken(allow_taken_older_than=older_than)
|
||||||
|
|
||||||
|
|
||||||
|
def test_event_mark_as_taken_works_for_before_events():
|
||||||
|
now = arrow.utcnow()
|
||||||
|
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
|
||||||
|
older_than = now.shift(minutes=+1)
|
||||||
|
assert event.mark_as_taken(allow_taken_older_than=older_than)
|
||||||
|
|
||||||
|
|
||||||
def test_fire_event_on_alias_creation():
|
def test_fire_event_on_alias_creation():
|
||||||
(user, pu) = _create_linked_user()
|
(user, pu) = _create_linked_user()
|
||||||
alias = Alias.create_new_random(user)
|
alias = Alias.create_new_random(user)
|
||||||
|
|
Loading…
Reference in a new issue