From c16fd25b2ec8948d9e269143ad99923b6d87038e Mon Sep 17 00:00:00 2001 From: =?UTF-8?q?Adri=C3=A0=20Casaj=C3=BAs?= Date: Mon, 11 Apr 2022 15:52:31 +0200 Subject: [PATCH] Added fix for parts that are not messages --- app/email_utils.py | 5 ++- pytest.ini | 2 +- tests/example_emls/multipart_alternative.eml | 25 +++++++++++++ tests/test_email_handler.py | 37 ++++++++++---------- tests/test_email_utils.py | 12 +++++++ 5 files changed, 60 insertions(+), 21 deletions(-) create mode 100644 tests/example_emls/multipart_alternative.eml diff --git a/app/email_utils.py b/app/email_utils.py index a0ce7a0a..26194158 100644 --- a/app/email_utils.py +++ b/app/email_utils.py @@ -970,7 +970,10 @@ def add_header(msg: Message, text_header, html_header) -> Message: elif content_type in ("multipart/alternative", "multipart/related"): new_parts = [] for part in msg.get_payload(): - new_parts.append(add_header(part, text_header, html_header)) + if isinstance(part, Message): + new_parts.append(add_header(part, text_header, html_header)) + else: + new_parts.append(part) clone_msg = copy(msg) clone_msg.set_payload(new_parts) return clone_msg diff --git a/pytest.ini b/pytest.ini index 3d362baf..c0f5472c 100644 --- a/pytest.ini +++ b/pytest.ini @@ -1,5 +1,5 @@ [pytest] -addopts = +xaddopts = --cov --cov-config coverage.ini --cov-report=html:htmlcov diff --git a/tests/example_emls/multipart_alternative.eml b/tests/example_emls/multipart_alternative.eml new file mode 100644 index 00000000..27fa7d67 --- /dev/null +++ b/tests/example_emls/multipart_alternative.eml @@ -0,0 +1,25 @@ +Content-Type: multipart/alternative; boundary="===============5006593052976639648==" +MIME-Version: 1.0 +Subject: My subject +From: foo@example.org +To: bar@example.net + +--===============5006593052976639648== +Content-Type: text/plain; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + +This is HTML +--===============5006593052976639648== +Content-Type: text/html; charset="us-ascii" +MIME-Version: 1.0 +Content-Transfer-Encoding: 7bit + + + + This is HTML + + + +--===============5006593052976639648==-- + diff --git a/tests/test_email_handler.py b/tests/test_email_handler.py index 1f7cdd59..77517c5a 100644 --- a/tests/test_email_handler.py +++ b/tests/test_email_handler.py @@ -98,25 +98,24 @@ def test_dmarc_quarantine(flask_client): assert f"{alias.email} has a new mail in quarantine" == notifications[0].title -# todo: re-enable test when softfail is quarantined -# def test_gmail_dmarc_softfail(flask_client): -# user = create_random_user() -# alias = Alias.create_new_random(user) -# msg = load_eml_file("dmarc_gmail_softfail.eml", {"alias_email": alias.email}) -# envelope = Envelope() -# envelope.mail_from = msg["from"] -# envelope.rcpt_tos = [msg["to"]] -# result = email_handler.handle(envelope, msg) -# assert result == status.E215 -# email_logs = ( -# EmailLog.filter_by(user_id=user.id, alias_id=alias.id) -# .order_by(EmailLog.id.desc()) -# .all() -# ) -# assert len(email_logs) == 1 -# email_log = email_logs[0] -# assert email_log.blocked -# assert email_log.refused_email_id +def test_gmail_dmarc_softfail(flask_client): + user = create_random_user() + alias = Alias.create_new_random(user) + msg = load_eml_file("dmarc_gmail_softfail.eml", {"alias_email": alias.email}) + envelope = Envelope() + envelope.mail_from = msg["from"] + envelope.rcpt_tos = [msg["to"]] + result = email_handler.handle(envelope, msg) + assert result == status.E215 + email_logs = ( + EmailLog.filter_by(user_id=user.id, alias_id=alias.id) + .order_by(EmailLog.id.desc()) + .all() + ) + assert len(email_logs) == 1 + email_log = email_logs[0] + assert email_log.blocked + assert email_log.refused_email_id def test_prevent_5xx_from_spf(flask_client): diff --git a/tests/test_email_utils.py b/tests/test_email_utils.py index e5d16744..30c95ae6 100644 --- a/tests/test_email_utils.py +++ b/tests/test_email_utils.py @@ -823,3 +823,15 @@ def test_dmarc_result_na(): def test_dmarc_result_bad_policy(): msg = load_eml_file("dmarc_bad_policy.eml") assert DmarcCheckResult.bad_policy == get_spamd_result(msg).dmarc + + +def test_add_header_multipart_with_invalid_part(): + msg = load_eml_file("multipart_alternative.eml") + parts = msg.get_payload() + ["invalid"] + msg.set_payload(parts) + msg = add_header(msg, "INJECT", "INJECT") + for i, part in enumerate(msg.get_payload()): + if i < 2: + assert part.get_payload().index("INJECT") > -1 + else: + assert part == "invalid"