From b2d4ac8e657574e932002536bec3c2bc712c00ce Mon Sep 17 00:00:00 2001 From: Son NK Date: Sun, 11 Aug 2019 11:56:10 +0200 Subject: [PATCH] add get_response_types_from_str, response_types_to_str --- app/oauth_models.py | 15 ++++++++++++++- tests/test_oauth_models.py | 27 ++++++++++++++++++++++++++- 2 files changed, 40 insertions(+), 2 deletions(-) diff --git a/app/oauth_models.py b/app/oauth_models.py index 4dde6baf..86aba04c 100644 --- a/app/oauth_models.py +++ b/app/oauth_models.py @@ -26,7 +26,20 @@ def get_scopes(request: flask.Request) -> Set[Scope]: def get_response_types(request: flask.Request) -> Set[ResponseType]: response_type_strs = _split_arg(request.args.getlist("response_type")) - return set([ResponseType(r) for r in response_type_strs]) + return set([ResponseType(r) for r in response_type_strs if r]) + + +def get_response_types_from_str(response_type_str) -> Set[ResponseType]: + response_type_strs = _split_arg(response_type_str) + + return set([ResponseType(r) for r in response_type_strs if r]) + + +def response_types_to_str(response_types: [ResponseType]) -> str: + """return a string representing a list of response type, for ex + *code*, *id_token,token*,... + """ + return ",".join([r.value for r in response_types]) def _split_arg(arg_input: Union[str, list]) -> Set[str]: diff --git a/tests/test_oauth_models.py b/tests/test_oauth_models.py index 619a5f2d..425db221 100644 --- a/tests/test_oauth_models.py +++ b/tests/test_oauth_models.py @@ -1,7 +1,14 @@ import flask import pytest -from app.oauth_models import get_scopes, Scope, get_response_types, ResponseType +from app.oauth_models import ( + get_scopes, + Scope, + get_response_types, + ResponseType, + response_types_to_str, + get_response_types_from_str, +) def test_get_scopes(flask_app): @@ -52,3 +59,21 @@ def test_get_response_types(flask_app): with flask_app.test_request_context("/?response_type=abcd"): with pytest.raises(ValueError): get_response_types(flask.request) + + +def test_response_types_to_str(): + assert response_types_to_str([]) == "" + assert response_types_to_str([ResponseType.CODE]) == "code" + assert ( + response_types_to_str([ResponseType.CODE, ResponseType.ID_TOKEN]) + == "code,id_token" + ) + + +def test_get_response_types_from_str(): + assert get_response_types_from_str("") == set() + assert get_response_types_from_str("token") == {ResponseType.TOKEN} + assert get_response_types_from_str("token id_token") == { + ResponseType.TOKEN, + ResponseType.ID_TOKEN, + }