create get_spam_score() as a sync function, use a simpler version for running MailHandler. Remove async/await

This commit is contained in:
Son NK 2020-09-30 11:05:21 +02:00
parent 91e3cc5dcb
commit abc42df0fb
1 changed files with 75 additions and 40 deletions

View File

@ -49,6 +49,7 @@ import aiosmtpd
import aiospamc
import arrow
import spf
from aiosmtpd.controller import Controller
from aiosmtpd.smtp import Envelope
from sqlalchemy.exc import IntegrityError
@ -109,6 +110,7 @@ from app.models import (
Mailbox,
)
from app.pgp_utils import PGPException
from app.spamassassin_utils import SpamAssassin
from app.utils import random_string
from init_app import load_pgp_public_keys
from server import create_app, create_light_app
@ -436,9 +438,7 @@ def handle_email_sent_to_ourself(alias, mailbox, msg: Message, user):
)
async def handle_forward(
envelope, msg: Message, rcpt_to: str
) -> List[Tuple[bool, str]]:
def handle_forward(envelope, msg: Message, rcpt_to: str) -> List[Tuple[bool, str]]:
"""return an array of SMTP status (is_success, smtp_status)
is_success indicates whether an email has been delivered and
smtp_status is the SMTP Status ("250 Message accepted", "550 Non-existent email address", etc)
@ -490,7 +490,7 @@ async def handle_forward(
return [(False, "550 SL E18 unverified mailbox")]
else:
ret.append(
await forward_email_to_mailbox(
forward_email_to_mailbox(
alias, msg, email_log, contact, envelope, mailbox, user
)
)
@ -502,7 +502,7 @@ async def handle_forward(
ret.append((False, "550 SL E19 unverified mailbox"))
else:
ret.append(
await forward_email_to_mailbox(
forward_email_to_mailbox(
alias,
copy(msg),
email_log,
@ -516,7 +516,7 @@ async def handle_forward(
return ret
async def forward_email_to_mailbox(
def forward_email_to_mailbox(
alias,
msg: Message,
email_log: EmailLog,
@ -566,7 +566,7 @@ async def forward_email_to_mailbox(
if SPAMASSASSIN_HOST:
start = time.time()
spam_score = await get_spam_score(msg)
spam_score = get_spam_score(msg)
LOG.d(
"%s -> %s - spam score %s in %s seconds",
contact,
@ -684,7 +684,7 @@ async def forward_email_to_mailbox(
return True, "250 Message accepted for delivery"
async def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
"""
return whether an email has been delivered and
the smtp status ("250 Message accepted", "550 Non-existent email address", etc)
@ -762,7 +762,7 @@ async def handle_reply(envelope, msg: Message, rcpt_to: str) -> (bool, str):
# do not use user.max_spam_score here
if SPAMASSASSIN_HOST:
start = time.time()
spam_score = await get_spam_score(msg)
spam_score = get_spam_score(msg)
LOG.d(
"%s -> %s - spam score %s in %s seconds",
alias,
@ -1418,7 +1418,7 @@ def handle_sender_email(envelope: Envelope):
return "250 email to sender accepted"
async def handle(envelope: Envelope) -> str:
def handle(envelope: Envelope) -> str:
"""Return SMTP status"""
# sanitize mail_from, rcpt_tos
@ -1455,7 +1455,7 @@ async def handle(envelope: Envelope) -> str:
# recipient starts with "reply+" or "ra+" (ra=reverse-alias) prefix
if rcpt_to.startswith("reply+") or rcpt_to.startswith("ra+"):
LOG.debug(">>> Reply phase %s(%s) -> %s", mail_from, msg["From"], rcpt_to)
is_delivered, smtp_status = await handle_reply(envelope, msg, rcpt_to)
is_delivered, smtp_status = handle_reply(envelope, msg, rcpt_to)
res.append((is_delivered, smtp_status))
else: # Forward case
LOG.debug(
@ -1464,9 +1464,7 @@ async def handle(envelope: Envelope) -> str:
msg["From"],
rcpt_to,
)
for is_delivered, smtp_status in await handle_forward(
envelope, msg, rcpt_to
):
for is_delivered, smtp_status in handle_forward(envelope, msg, rcpt_to):
res.append((is_delivered, smtp_status))
for (is_success, smtp_status) in res:
@ -1478,7 +1476,7 @@ async def handle(envelope: Envelope) -> str:
return res[0][1]
async def get_spam_score(message: Message) -> float:
async def get_spam_score_async(message: Message) -> float:
LOG.debug("get spam score for %s", message[_MESSAGE_ID])
sa_input = to_bytes(message)
@ -1502,6 +1500,24 @@ async def get_spam_score(message: Message) -> float:
return -999
def get_spam_score(message: Message) -> float:
LOG.debug("get spam score for %s", message[_MESSAGE_ID])
sa_input = to_bytes(message)
# Spamassassin requires to have an ending linebreak
if not sa_input.endswith(b"\n"):
LOG.d("add linebreak to spamassassin input")
sa_input += b"\n"
try:
sa = SpamAssassin(sa_input, host=SPAMASSASSIN_HOST)
return sa.get_score()
except Exception:
LOG.exception("SpamAssassin exception")
# return a negative score so the message is always considered as ham
return -999
def sl_sendmail(from_addr, to_addr, msg: Message, mail_options, rcpt_options):
"""replace smtp.sendmail"""
if POSTFIX_SUBMISSION_TLS:
@ -1522,12 +1538,9 @@ def sl_sendmail(from_addr, to_addr, msg: Message, mail_options, rcpt_options):
class MailHandler:
def __init__(self, lock):
self.lock = lock
async def handle_DATA(self, server, session, envelope: Envelope):
try:
ret = await self._handle(envelope)
ret = self._handle(envelope)
return ret
except Exception:
LOG.exception(
@ -1537,31 +1550,42 @@ class MailHandler:
)
return "421 SL Retry later"
async def _handle(self, envelope: Envelope):
async with self.lock:
start = time.time()
LOG.info(
"===>> New message, mail from %s, rctp tos %s ",
envelope.mail_from,
envelope.rcpt_tos,
)
def _handle(self, envelope: Envelope):
start = time.time()
LOG.info(
"===>> New message, mail from %s, rctp tos %s ",
envelope.mail_from,
envelope.rcpt_tos,
)
app = new_app()
with app.app_context():
ret = await handle(envelope)
LOG.info("takes %s seconds <<===", time.time() - start)
return ret
app = new_app()
with app.app_context():
ret = handle(envelope)
LOG.info("takes %s seconds <<===", time.time() - start)
return ret
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--port", help="SMTP port to listen for", type=int, default=20381
)
args = parser.parse_args()
def main(port: int):
"""Use aiosmtpd Controller"""
controller = Controller(MailHandler(), hostname="0.0.0.0", port=port)
LOG.info("Listen for port %s", args.port)
controller.start()
LOG.d("Start mail controller %s %s", controller.hostname, controller.port)
if LOAD_PGP_EMAIL_HANDLER:
LOG.warning("LOAD PGP keys")
app = create_app()
with app.app_context():
load_pgp_public_keys()
while True:
time.sleep(2)
def asyncio_main(port: int):
"""
Main entrypoint using asyncio directly without passing by aiosmtpd Controller
"""
if LOAD_PGP_EMAIL_HANDLER:
LOG.warning("LOAD PGP keys")
app = create_app()
@ -1577,7 +1601,7 @@ if __name__ == "__main__":
return aiosmtpd.smtp.SMTP(handler, enable_SMTPUTF8=True)
server = loop.run_until_complete(
loop.create_server(factory, host="0.0.0.0", port=args.port)
loop.create_server(factory, host="0.0.0.0", port=port)
)
try:
@ -1590,3 +1614,14 @@ if __name__ == "__main__":
server.close()
loop.run_until_complete(server.wait_closed())
loop.close()
if __name__ == "__main__":
parser = argparse.ArgumentParser()
parser.add_argument(
"-p", "--port", help="SMTP port to listen for", type=int, default=20381
)
args = parser.parse_args()
LOG.info("Listen for port %s", args.port)
main(port=args.port)