diff --git a/app/auth/views/proton.py b/app/auth/views/proton.py index 5150f945..53ab5dbd 100644 --- a/app/auth/views/proton.py +++ b/app/auth/views/proton.py @@ -10,6 +10,8 @@ from app.config import ( PROTON_BASE_URL, PROTON_CLIENT_ID, PROTON_CLIENT_SECRET, + PROTON_EXTRA_HEADER_NAME, + PROTON_EXTRA_HEADER_VALUE, PROTON_VALIDATE_CERTS, URL, ) @@ -89,6 +91,11 @@ def proton_callback(): return response proton.register_compliance_hook("access_token_response", check_status_code) + + headers = None + if PROTON_EXTRA_HEADER_NAME and PROTON_EXTRA_HEADER_VALUE: + headers = {PROTON_EXTRA_HEADER_NAME: PROTON_EXTRA_HEADER_VALUE} + token = proton.fetch_token( _token_url, client_secret=PROTON_CLIENT_SECRET, @@ -96,6 +103,7 @@ def proton_callback(): verify=PROTON_VALIDATE_CERTS, method="GET", include_client_id=True, + headers=headers, ) credentials = convert_access_token(token["access_token"]) action = get_action_from_state() diff --git a/app/config.py b/app/config.py index a228ece9..a7ef9e37 100644 --- a/app/config.py +++ b/app/config.py @@ -244,6 +244,8 @@ PROTON_BASE_URL = os.environ.get( ) PROTON_VALIDATE_CERTS = "PROTON_VALIDATE_CERTS" in os.environ CONNECT_WITH_PROTON = "CONNECT_WITH_PROTON" in os.environ +PROTON_EXTRA_HEADER_NAME = os.environ.get("PROTON_EXTRA_HEADER_NAME") +PROTON_EXTRA_HEADER_VALUE = os.environ.get("PROTON_EXTRA_HEADER_VALUE") # in seconds AVATAR_URL_EXPIRATION = 3600 * 24 * 7 # 1h*24h/d*7d=1week diff --git a/app/proton/proton_client.py b/app/proton/proton_client.py index 82282d0d..9f4beac5 100644 --- a/app/proton/proton_client.py +++ b/app/proton/proton_client.py @@ -6,6 +6,7 @@ from requests import Response, Session from typing import Optional from app.account_linking import SLPlan, SLPlanType +from app.config import PROTON_EXTRA_HEADER_NAME, PROTON_EXTRA_HEADER_VALUE from app.log import LOG _APP_VERSION = "OauthClient_1.0.0" @@ -82,6 +83,10 @@ class HttpProtonClient(ProtonClient): "accept": "application/vnd.protonmail.v1+json", "user-agent": "ProtonOauthClient", } + + if PROTON_EXTRA_HEADER_NAME and PROTON_EXTRA_HEADER_VALUE: + headers[PROTON_EXTRA_HEADER_NAME] = PROTON_EXTRA_HEADER_VALUE + if original_ip is not None: headers["x-forwarded-for"] = original_ip client.headers.update(headers)