diff --git a/src/utils.rs b/src/utils.rs index d88a42f..078a6f0 100644 --- a/src/utils.rs +++ b/src/utils.rs @@ -93,33 +93,60 @@ pub fn detect_media_type_by_file_name(filename: &str) -> String { } pub fn domain_is_within_domain(domain: &str, domain_to_match_against: &str) -> bool { - let domain_partials: Vec<&str> = domain.split(".").collect(); + if domain_to_match_against.len() == 0 { + return false; + } + + if domain_to_match_against == "." { + return true; + } + + let domain_partials: Vec<&str> = domain.trim_end_matches(".").rsplit(".").collect(); let domain_to_match_against_partials: Vec<&str> = domain_to_match_against - .trim_start_matches(".") - .split(".") + .trim_end_matches(".") + .rsplit(".") .collect(); + let domain_to_match_against_starts_with_a_dot = domain_to_match_against.starts_with("."); - let mut i: usize = domain_partials.len(); - let mut j: usize = domain_to_match_against_partials.len(); + let mut i: usize = 0; + let l: usize = std::cmp::max( + domain_partials.len(), + domain_to_match_against_partials.len(), + ); + let mut ok: bool = true; - if i >= j { - while j > 0 { - if !domain_partials - .get(i - 1) - .unwrap() - .eq_ignore_ascii_case(&domain_to_match_against_partials.get(j - 1).unwrap()) - { - break; - } - - i -= 1; - j -= 1; + while i < l { + // Exit and return false if went out of bounds of domain to match against, and it didn't start with a dot + if domain_to_match_against_partials.len() < i + 1 + && !domain_to_match_against_starts_with_a_dot + { + ok = false; + break; } - j == 0 - } else { - false + let domain_partial = if domain_partials.len() < i + 1 { + "" + } else { + domain_partials.get(i).unwrap() + }; + let domain_to_match_against_partial = if domain_to_match_against_partials.len() < i + 1 { + "" + } else { + domain_to_match_against_partials.get(i).unwrap() + }; + + let parts_match = domain_to_match_against_starts_with_a_dot + || domain_to_match_against_partial.eq_ignore_ascii_case(domain_partial); + + if !parts_match { + ok = false; + break; + } + + i += 1; } + + ok } pub fn indent(level: u32) -> String { diff --git a/tests/utils/domain_is_within_domain.rs b/tests/utils/domain_is_within_domain.rs index 7eaeae3..fcd0840 100644 --- a/tests/utils/domain_is_within_domain.rs +++ b/tests/utils/domain_is_within_domain.rs @@ -25,14 +25,6 @@ mod passing { )); } - #[test] - fn dotted_domain_is_within_domain() { - assert!(utils::domain_is_within_domain( - ".ycombinator.com", - "ycombinator.com" - )); - } - #[test] fn sub_domain_is_within_dotted_domain() { assert!(utils::domain_is_within_domain( @@ -67,18 +59,12 @@ mod passing { #[test] fn domain_with_trailing_dot_is_within_single_dot() { - assert!(utils::domain_is_within_domain( - "ycombinator.com.", - "." - )); + assert!(utils::domain_is_within_domain("ycombinator.com.", ".")); } #[test] fn domain_matches_single_dot() { - assert!(utils::domain_is_within_domain( - "ycombinator.com", - "." - )); + assert!(utils::domain_is_within_domain("ycombinator.com", ".")); } #[test] @@ -140,6 +126,14 @@ mod failing { )); } + #[test] + fn dotted_domain_is_not_within_domain() { + assert!(!utils::domain_is_within_domain( + ".ycombinator.com", + "ycombinator.com" + )); + } + #[test] fn no_domain_can_be_within_empty_domain() { assert!(!utils::domain_is_within_domain("ycombinator.com", ""));