mirror of
https://github.com/simple-login/app.git
synced 2024-11-14 08:01:13 +01:00
Merge pull request #1019 from simple-login/feature/proton-callback-receive-partner_id-as-param
Receive partner as param in ProtonCallbackHandler
This commit is contained in:
commit
687b51be0f
4 changed files with 80 additions and 63 deletions
|
@ -14,7 +14,11 @@ from app.config import (
|
|||
URL,
|
||||
)
|
||||
from app.proton.proton_client import HttpProtonClient, convert_access_token
|
||||
from app.proton.proton_callback_handler import ProtonCallbackHandler, Action
|
||||
from app.proton.proton_callback_handler import (
|
||||
ProtonCallbackHandler,
|
||||
Action,
|
||||
get_proton_partner,
|
||||
)
|
||||
from app.utils import sanitize_next_url
|
||||
|
||||
_authorization_base_url = PROTON_BASE_URL + "/oauth/authorize"
|
||||
|
@ -100,11 +104,12 @@ def proton_callback():
|
|||
PROTON_BASE_URL, credentials, get_remote_address(), verify=PROTON_VALIDATE_CERTS
|
||||
)
|
||||
handler = ProtonCallbackHandler(proton_client)
|
||||
proton_partner = get_proton_partner()
|
||||
|
||||
if action == Action.Login:
|
||||
res = handler.handle_login()
|
||||
res = handler.handle_login(proton_partner)
|
||||
elif action == Action.Link:
|
||||
res = handler.handle_link(current_user)
|
||||
res = handler.handle_link(current_user, proton_partner)
|
||||
else:
|
||||
raise Exception(f"Unknown Action: {action.name}")
|
||||
|
||||
|
|
|
@ -48,7 +48,7 @@ from app.models import (
|
|||
AppleSubscription,
|
||||
PartnerUser,
|
||||
)
|
||||
from app.proton.proton_callback_handler import get_proton_partner_id
|
||||
from app.proton.proton_callback_handler import get_proton_partner
|
||||
from app.utils import random_string, sanitize_email
|
||||
|
||||
|
||||
|
@ -70,7 +70,7 @@ class PromoCodeForm(FlaskForm):
|
|||
def get_proton_linked_account() -> Optional[str]:
|
||||
# Check if the current user has a partner_id
|
||||
try:
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
proton_partner_id = get_proton_partner().id
|
||||
except ProtonPartnerNotSetUp:
|
||||
return None
|
||||
|
||||
|
@ -444,7 +444,7 @@ def unlink_proton_account():
|
|||
current_user.partner_id = None
|
||||
current_user.partner_user_id = None
|
||||
partner_user = PartnerUser.get_by(
|
||||
user_id=current_user.id, partner_id=get_proton_partner_id()
|
||||
user_id=current_user.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
if partner_user is not None:
|
||||
PartnerUser.delete(partner_user.id)
|
||||
|
|
|
@ -11,18 +11,18 @@ from app.proton.proton_client import ProtonClient, ProtonUser
|
|||
from app.utils import random_string
|
||||
|
||||
PROTON_PARTNER_NAME = "Proton"
|
||||
_PROTON_PARTNER_ID: Optional[int] = None
|
||||
_PROTON_PARTNER: Optional[Partner] = None
|
||||
|
||||
|
||||
def get_proton_partner_id() -> int:
|
||||
global _PROTON_PARTNER_ID
|
||||
if _PROTON_PARTNER_ID is None:
|
||||
def get_proton_partner() -> Partner:
|
||||
global _PROTON_PARTNER
|
||||
if _PROTON_PARTNER is None:
|
||||
partner = Partner.get_by(name=PROTON_PARTNER_NAME)
|
||||
if partner is None:
|
||||
raise ProtonPartnerNotSetUp
|
||||
_PROTON_PARTNER_ID = partner.id
|
||||
|
||||
return _PROTON_PARTNER_ID
|
||||
Session.expunge(partner)
|
||||
_PROTON_PARTNER = partner
|
||||
return _PROTON_PARTNER
|
||||
|
||||
|
||||
class Action(enum.Enum):
|
||||
|
@ -39,23 +39,27 @@ class ProtonCallbackResult:
|
|||
user: Optional[User]
|
||||
|
||||
|
||||
def ensure_partner_user_exists(proton_user: ProtonUser, sl_user: User):
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=proton_partner_id):
|
||||
def ensure_partner_user_exists(
|
||||
proton_user: ProtonUser, sl_user: User, partner: Partner
|
||||
):
|
||||
if not PartnerUser.get_by(user_id=sl_user.id, partner_id=partner.id):
|
||||
PartnerUser.create(
|
||||
user_id=sl_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_id=partner.id,
|
||||
partner_email=proton_user.email,
|
||||
)
|
||||
Session.commit()
|
||||
|
||||
|
||||
class ClientMergeStrategy(ABC):
|
||||
def __init__(self, proton_user: ProtonUser, sl_user: Optional[User]):
|
||||
def __init__(
|
||||
self, proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
|
||||
):
|
||||
if self.__class__ == ClientMergeStrategy:
|
||||
raise RuntimeError("Cannot directly instantiate a ClientMergeStrategy")
|
||||
self.proton_user = proton_user
|
||||
self.sl_user = sl_user
|
||||
self.partner = partner
|
||||
|
||||
@abstractmethod
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
|
@ -65,17 +69,16 @@ class ClientMergeStrategy(ABC):
|
|||
class UnexistantSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
# Will create a new SL User with a random password
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
new_user = User.create(
|
||||
email=self.proton_user.email,
|
||||
name=self.proton_user.name,
|
||||
partner_user_id=self.proton_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_id=self.partner.id,
|
||||
password=random_string(20),
|
||||
)
|
||||
PartnerUser.create(
|
||||
user_id=new_user.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_id=self.partner.id,
|
||||
partner_email=self.proton_user.email,
|
||||
)
|
||||
# TODO: Adjust plans
|
||||
|
@ -92,7 +95,7 @@ class UnexistantSlClientStrategy(ClientMergeStrategy):
|
|||
|
||||
class ExistingSlClientStrategy(ClientMergeStrategy):
|
||||
def process(self) -> ProtonCallbackResult:
|
||||
ensure_partner_user_exists(self.proton_user, self.sl_user)
|
||||
ensure_partner_user_exists(self.proton_user, self.sl_user, self.partner)
|
||||
# TODO: Adjust plans
|
||||
|
||||
return ProtonCallbackResult(
|
||||
|
@ -127,52 +130,54 @@ class AlreadyLinkedUserStrategy(ClientMergeStrategy):
|
|||
|
||||
|
||||
def get_login_strategy(
|
||||
proton_user: ProtonUser, sl_user: Optional[User]
|
||||
proton_user: ProtonUser, sl_user: Optional[User], partner: Partner
|
||||
) -> ClientMergeStrategy:
|
||||
if sl_user is None:
|
||||
# We couldn't find any SimpleLogin user with the requested e-mail
|
||||
return UnexistantSlClientStrategy(proton_user, sl_user)
|
||||
return UnexistantSlClientStrategy(proton_user, sl_user, partner)
|
||||
# There is a SimpleLogin user with the proton_user's e-mail
|
||||
# Try to find if it has been registered via a partner
|
||||
if sl_user.partner_id is None:
|
||||
# It has not been registered via a Partner
|
||||
return ExistingSlClientStrategy(proton_user, sl_user)
|
||||
return ExistingSlClientStrategy(proton_user, sl_user, partner)
|
||||
# It has been registered via a partner
|
||||
# Check if the partner_user_id matches
|
||||
if sl_user.partner_user_id != proton_user.id:
|
||||
# It doesn't match. That means that the SimpleLogin user has a different Proton account linked
|
||||
return ExistingSlUserLinkedWithDifferentProtonAccountStrategy(
|
||||
proton_user, sl_user
|
||||
proton_user, sl_user, partner
|
||||
)
|
||||
# This case means that the sl_user is already linked, so nothing to do
|
||||
return AlreadyLinkedUserStrategy(proton_user, sl_user)
|
||||
return AlreadyLinkedUserStrategy(proton_user, sl_user, partner)
|
||||
|
||||
|
||||
def process_login_case(proton_user: ProtonUser) -> ProtonCallbackResult:
|
||||
def process_login_case(
|
||||
proton_user: ProtonUser, partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user registered with that proton user id
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
sl_user_with_external_id = User.get_by(
|
||||
partner_id=proton_partner_id, partner_user_id=proton_user.id
|
||||
partner_id=partner.id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_with_external_id is None:
|
||||
# We didn't find any SimpleLogin user registered with that proton user id
|
||||
# Try to find it using the proton's e-mail address
|
||||
sl_user = User.get_by(email=proton_user.email)
|
||||
return get_login_strategy(proton_user, sl_user).process()
|
||||
return get_login_strategy(proton_user, sl_user, partner).process()
|
||||
else:
|
||||
# We found the SL user registered with that proton user id
|
||||
# We're done
|
||||
return AlreadyLinkedUserStrategy(
|
||||
proton_user, sl_user_with_external_id
|
||||
proton_user, sl_user_with_external_id, partner
|
||||
).process()
|
||||
|
||||
|
||||
def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResult:
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
def link_user(
|
||||
proton_user: ProtonUser, current_user: User, partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
current_user.partner_user_id = proton_user.id
|
||||
current_user.partner_id = proton_partner_id
|
||||
current_user.partner_id = partner.id
|
||||
|
||||
ensure_partner_user_exists(proton_user, current_user)
|
||||
ensure_partner_user_exists(proton_user, current_user, partner)
|
||||
|
||||
Session.commit()
|
||||
return ProtonCallbackResult(
|
||||
|
@ -185,16 +190,17 @@ def link_user(proton_user: ProtonUser, current_user: User) -> ProtonCallbackResu
|
|||
|
||||
|
||||
def process_link_case(
|
||||
proton_user: ProtonUser, current_user: User
|
||||
proton_user: ProtonUser,
|
||||
current_user: User,
|
||||
partner: Partner,
|
||||
) -> ProtonCallbackResult:
|
||||
# Try to find a SimpleLogin user linked with this Proton account
|
||||
proton_partner_id = get_proton_partner_id()
|
||||
sl_user_linked_to_proton_account = User.get_by(
|
||||
partner_id=proton_partner_id, partner_user_id=proton_user.id
|
||||
partner_id=partner.id, partner_user_id=proton_user.id
|
||||
)
|
||||
if sl_user_linked_to_proton_account is None:
|
||||
# There is no SL user linked with the proton email. Proceed with linking
|
||||
return link_user(proton_user, current_user)
|
||||
return link_user(proton_user, current_user, partner)
|
||||
else:
|
||||
# There is a SL user registered with the proton email. Check if is the current one
|
||||
if sl_user_linked_to_proton_account.id == current_user.id:
|
||||
|
@ -212,25 +218,27 @@ def process_link_case(
|
|||
sl_user_linked_to_proton_account.partner_user_id = None
|
||||
other_partner_user = PartnerUser.get_by(
|
||||
user_id=sl_user_linked_to_proton_account.id,
|
||||
partner_id=proton_partner_id,
|
||||
partner_id=partner.id,
|
||||
)
|
||||
if other_partner_user is not None:
|
||||
PartnerUser.delete(other_partner_user.id)
|
||||
|
||||
return link_user(proton_user, current_user)
|
||||
return link_user(proton_user, current_user, partner)
|
||||
|
||||
|
||||
class ProtonCallbackHandler:
|
||||
def __init__(self, proton_client: ProtonClient):
|
||||
self.proton_client = proton_client
|
||||
|
||||
def handle_login(self) -> ProtonCallbackResult:
|
||||
return process_login_case(self.__get_proton_user())
|
||||
def handle_login(self, partner: Partner) -> ProtonCallbackResult:
|
||||
return process_login_case(self.__get_proton_user(), partner)
|
||||
|
||||
def handle_link(self, current_user: Optional[User]) -> ProtonCallbackResult:
|
||||
def handle_link(
|
||||
self, current_user: Optional[User], partner: Partner
|
||||
) -> ProtonCallbackResult:
|
||||
if current_user is None:
|
||||
raise Exception("Cannot link account with current_user being None")
|
||||
return process_link_case(self.__get_proton_user(), current_user)
|
||||
return process_link_case(self.__get_proton_user(), current_user, partner)
|
||||
|
||||
def __get_proton_user(self) -> ProtonUser:
|
||||
user = self.proton_client.get_user()
|
||||
|
|
|
@ -4,7 +4,7 @@ from app.db import Session
|
|||
from app.proton.proton_client import ProtonClient, UserInformation, ProtonPlan
|
||||
from app.proton.proton_callback_handler import (
|
||||
ProtonCallbackHandler,
|
||||
get_proton_partner_id,
|
||||
get_proton_partner,
|
||||
get_login_strategy,
|
||||
process_link_case,
|
||||
ProtonUser,
|
||||
|
@ -65,11 +65,11 @@ def create_user(email: str = None) -> User:
|
|||
def create_user_for_partner(partner_user_id: str, email: str = None) -> User:
|
||||
email = email if email is not None else random_email()
|
||||
user = User.create(email=email)
|
||||
user.partner_id = get_proton_partner_id()
|
||||
user.partner_id = get_proton_partner().id
|
||||
user.partner_user_id = partner_user_id
|
||||
|
||||
PartnerUser.create(
|
||||
user_id=user.id, partner_id=get_proton_partner_id(), partner_email=email
|
||||
user_id=user.id, partner_id=get_proton_partner().id, partner_email=email
|
||||
)
|
||||
Session.commit()
|
||||
return user
|
||||
|
@ -84,7 +84,7 @@ def test_proton_callback_handler_unexistant_sl_user():
|
|||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
res = handler.handle_login()
|
||||
res = handler.handle_login(get_proton_partner())
|
||||
|
||||
assert res.user is not None
|
||||
assert res.user.email == email
|
||||
|
@ -102,12 +102,12 @@ def test_proton_callback_handler_existant_sl_user():
|
|||
user=user, plan=ProtonPlan.Professional, organization={}
|
||||
)
|
||||
handler = ProtonCallbackHandler(mock_client)
|
||||
res = handler.handle_login()
|
||||
res = handler.handle_login(get_proton_partner())
|
||||
|
||||
assert res.user is not None
|
||||
assert res.user.id == sl_user.id
|
||||
|
||||
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner_id())
|
||||
sa = PartnerUser.get_by(user_id=sl_user.id, partner_id=get_proton_partner().id)
|
||||
assert sa is not None
|
||||
assert sa.partner_email == user.email
|
||||
|
||||
|
@ -116,6 +116,7 @@ def test_get_strategy_unexistant_sl_user():
|
|||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(),
|
||||
sl_user=None,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, UnexistantSlClientStrategy)
|
||||
|
||||
|
@ -126,6 +127,7 @@ def test_get_strategy_existing_sl_user():
|
|||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(email=email),
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlClientStrategy)
|
||||
|
||||
|
@ -137,6 +139,7 @@ def test_get_strategy_already_linked_user():
|
|||
strategy = get_login_strategy(
|
||||
proton_user=random_proton_user(user_id=proton_user_id, email=email),
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, AlreadyLinkedUserStrategy)
|
||||
|
||||
|
@ -159,6 +162,7 @@ def test_get_strategy_existing_sl_user_linked_with_different_proton_account():
|
|||
strategy = get_login_strategy(
|
||||
proton_user=proton_user_1,
|
||||
sl_user=sl_user,
|
||||
partner=get_proton_partner(),
|
||||
)
|
||||
assert isinstance(strategy, ExistingSlUserLinkedWithDifferentProtonAccountStrategy)
|
||||
|
||||
|
@ -179,14 +183,14 @@ def test_link_account_with_proton_account_same_address(flask_client):
|
|||
proton_user = random_proton_user(user_id=proton_user_id, email=email)
|
||||
sl_user = create_user(email)
|
||||
|
||||
res = process_link_case(proton_user, sl_user)
|
||||
res = process_link_case(proton_user, sl_user, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner_id()
|
||||
assert updated_user.partner_id == get_proton_partner().id
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
|
@ -199,14 +203,14 @@ def test_link_account_with_proton_account_different_address(flask_client):
|
|||
proton_user = random_proton_user(user_id=proton_user_id, email=random_email())
|
||||
sl_user = create_user()
|
||||
|
||||
res = process_link_case(proton_user, sl_user)
|
||||
res = process_link_case(proton_user, sl_user, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user = User.get(sl_user.id)
|
||||
assert updated_user.partner_id == get_proton_partner_id()
|
||||
assert updated_user.partner_id == get_proton_partner().id
|
||||
assert updated_user.partner_user_id == proton_user_id
|
||||
|
||||
|
||||
|
@ -226,14 +230,14 @@ def test_link_account_with_proton_account_same_address_but_linked_to_other_user(
|
|||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1)
|
||||
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner_id()
|
||||
assert updated_user_1.partner_id == get_proton_partner().id
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
|
||||
updated_user_2 = User.get(sl_user_2.id)
|
||||
|
@ -256,17 +260,17 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
|
|||
proton_user_id, email=random_email()
|
||||
) # User already linked with the proton account
|
||||
|
||||
res = process_link_case(proton_user, sl_user_1)
|
||||
res = process_link_case(proton_user, sl_user_1, get_proton_partner())
|
||||
assert res.redirect_to_login is False
|
||||
assert res.redirect is not None
|
||||
assert res.flash_category == "success"
|
||||
assert res.flash_message is not None
|
||||
|
||||
updated_user_1 = User.get(sl_user_1.id)
|
||||
assert updated_user_1.partner_id == get_proton_partner_id()
|
||||
assert updated_user_1.partner_id == get_proton_partner().id
|
||||
assert updated_user_1.partner_user_id == proton_user_id
|
||||
partner_user_1 = PartnerUser.get_by(
|
||||
user_id=sl_user_1.id, partner_id=get_proton_partner_id()
|
||||
user_id=sl_user_1.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_1 is not None
|
||||
assert partner_user_1.partner_email == proton_user.email
|
||||
|
@ -275,11 +279,11 @@ def test_link_account_with_proton_account_different_address_and_linked_to_other_
|
|||
assert updated_user_2.partner_id is None
|
||||
assert updated_user_2.partner_user_id is None
|
||||
partner_user_2 = PartnerUser.get_by(
|
||||
user_id=sl_user_2.id, partner_id=get_proton_partner_id()
|
||||
user_id=sl_user_2.id, partner_id=get_proton_partner().id
|
||||
)
|
||||
assert partner_user_2 is None
|
||||
|
||||
|
||||
def test_cannot_create_instance_of_base_strategy():
|
||||
with pytest.raises(Exception):
|
||||
ClientMergeStrategy(random_proton_user(), None)
|
||||
ClientMergeStrategy(random_proton_user(), None, get_proton_partner())
|
||||
|
|
Loading…
Reference in a new issue