Have the dead letter also take events to avoid race conditions (#2267)

* Have the dead letter also take events to avoid race conditions

* Ensure we take the event

* Tests for event taken

* Rename
This commit is contained in:
Adrià Casajús 2024-10-17 11:16:33 +02:00 committed by GitHub
parent 2cd6ee777f
commit 3c13f1ce20
No known key found for this signature in database
GPG key ID: B5690EEEBB952194
5 changed files with 172 additions and 23 deletions

View file

@ -3771,15 +3771,14 @@ class SyncEvent(Base, ModelMixin):
sa.Index("ix_sync_event_taken_time", "taken_time"),
)
def mark_as_taken(self) -> bool:
sql = """
UPDATE sync_event
SET taken_time = :taken_time
WHERE id = :sync_event_id
AND taken_time IS NULL
"""
def mark_as_taken(self, allow_taken_older_than: Optional[Arrow] = None) -> bool:
taken_condition = ["taken_time IS NULL"]
args = {"taken_time": arrow.now().datetime, "sync_event_id": self.id}
if allow_taken_older_than:
taken_condition.append("taken_time < :taken_older_than")
args["taken_older_than"] = allow_taken_older_than.datetime
sql_taken_condition = "({})".format(" OR ".join(taken_condition))
sql = f"UPDATE sync_event SET taken_time = :taken_time WHERE id = :sync_event_id AND {sql_taken_condition}"
res = Session.execute(sql, args)
Session.commit()

View file

@ -85,24 +85,28 @@ class DeadLetterEventSource(EventSource):
def __init__(self, max_retries: int):
self.__max_retries = max_retries
def execute_loop(
self, on_event: Callable[[SyncEvent], NoReturn]
) -> list[SyncEvent]:
threshold = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
if event.mark_as_taken(allow_taken_older_than=threshold):
on_event(event)
return events
@newrelic.agent.background_task()
def run(self, on_event: Callable[[SyncEvent], NoReturn]):
while True:
try:
threshold = arrow.utcnow().shift(
minutes=-_DEAD_LETTER_THRESHOLD_MINUTES
)
events = SyncEvent.get_dead_letter(
older_than=threshold, max_retries=self.__max_retries
)
if events:
LOG.info(f"Got {len(events)} dead letter events")
if events:
newrelic.agent.record_custom_metric(
"Custom/dead_letter_events_to_process", len(events)
)
for event in events:
on_event(event)
events = self.execute_loop(on_event)
Session.close() # Ensure that we have a new connection and we don't have a dangling tx with a lock
if not events:
LOG.debug("No dead letter events")

View file

@ -0,0 +1,63 @@
#!/usr/bin/env python3
import argparse
import time
from sqlalchemy import func
from app.events.event_dispatcher import EventDispatcher
from app.events.generated.event_pb2 import UserPlanChanged, EventContent
from app.models import PartnerUser
from app.db import Session
parser = argparse.ArgumentParser(
prog="Backfill alias", description="Update alias notes and backfill flag"
)
parser.add_argument(
"-s", "--start_pu_id", default=0, type=int, help="Initial partner_user_id"
)
parser.add_argument(
"-e", "--end_pu_id", default=0, type=int, help="Last partner_user_id"
)
args = parser.parse_args()
pu_id_start = args.start_pu_id
max_pu_id = args.end_pu_id
if max_pu_id == 0:
max_pu_id = Session.query(func.max(PartnerUser.id)).scalar()
print(f"Checking partner user {pu_id_start} to {max_pu_id}")
step = 100
updated = 0
start_time = time.time()
with_premium = 0
for batch_start in range(pu_id_start, max_pu_id, step):
partner_users = (
Session.query(PartnerUser).filter(
PartnerUser.id >= batch_start, PartnerUser.id < batch_start + step
)
).all()
for partner_user in partner_users:
subscription_end = partner_user.user.get_active_subscription_end(
include_partner_subscription=False
)
end_timestamp = None
if subscription_end:
with_premium += 1
end_timestamp = subscription_end.timestamp
event = UserPlanChanged(plan_end_time=end_timestamp)
EventDispatcher.send_event(
partner_user.user, EventContent(user_plan_change=event)
)
Session.flush()
updated += 1
Session.commit()
elapsed = time.time() - start_time
last_batch_id = batch_start + step
time_per_alias = elapsed / (last_batch_id)
remaining = max_pu_id - last_batch_id
time_remaining = remaining / time_per_alias
hours_remaining = time_remaining / 60.0
print(
f"\PartnerUser {batch_start}/{max_pu_id} {updated} {hours_remaining:.2f} mins remaining"
)
print(f"With SL premium {with_premium}")

View file

@ -0,0 +1,54 @@
import arrow
from app.db import Session
from app.models import SyncEvent
from events.event_source import DeadLetterEventSource, _DEAD_LETTER_THRESHOLD_MINUTES
class EventCounter:
def __init__(self):
self.processed_events = 0
def on_event(self, event: SyncEvent):
self.processed_events += 1
def setup_function(func):
Session.query(SyncEvent).delete()
def test_dead_letter_does_not_take_untaken_events():
source = DeadLetterEventSource(1)
counter = EventCounter()
threshold_time = arrow.utcnow().shift(minutes=-(_DEAD_LETTER_THRESHOLD_MINUTES) + 1)
SyncEvent.create(
content="test".encode("utf-8"), created_at=threshold_time, flush=True
)
SyncEvent.create(
content="test".encode("utf-8"), taken_time=threshold_time, flush=True
)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 0
assert counter.processed_events == 0
def test_dead_letter_takes_untaken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_create = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), created_at=old_create, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_create
assert counter.processed_events == 1
def test_dead_letter_takes_taken_events_created_older_than_threshold():
source = DeadLetterEventSource(1)
counter = EventCounter()
old_taken = arrow.utcnow().shift(minutes=-_DEAD_LETTER_THRESHOLD_MINUTES - 1)
SyncEvent.create(content="test".encode("utf-8"), taken_time=old_taken, flush=True)
events_processed = source.execute_loop(on_event=counter.on_event)
assert len(events_processed) == 1
assert events_processed[0].taken_time > old_taken
assert counter.processed_events == 1

View file

@ -1,7 +1,9 @@
import arrow
from app import config, alias_utils
from app.db import Session
from app.events.event_dispatcher import GlobalDispatcher
from app.models import Alias
from app.models import Alias, SyncEvent
from tests.utils import random_token
from .event_test_utils import (
OnMemoryDispatcher,
@ -26,6 +28,33 @@ def setup_function(func):
on_memory_dispatcher.clear()
def test_event_taken_updates():
event = SyncEvent.create(content="test".encode("utf-8"), flush=True)
assert event.taken_time is None
assert event.mark_as_taken()
assert event.taken_time is not None
def test_event_mark_as_taken_does_nothing_for_taken_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
assert not event.mark_as_taken()
def test_event_mark_as_taken_does_nothing_for_not_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=-1)
assert not event.mark_as_taken(allow_taken_older_than=older_than)
def test_event_mark_as_taken_works_for_before_events():
now = arrow.utcnow()
event = SyncEvent.create(content="test".encode("utf-8"), taken_time=now, flush=True)
older_than = now.shift(minutes=+1)
assert event.mark_as_taken(allow_taken_older_than=older_than)
def test_fire_event_on_alias_creation():
(user, pu) = _create_linked_user()
alias = Alias.create_new_random(user)