diff --git a/2022/day4/rust/src/main.rs b/2022/day4/rust/src/main.rs index e004d1f..165f04d 100644 --- a/2022/day4/rust/src/main.rs +++ b/2022/day4/rust/src/main.rs @@ -10,7 +10,7 @@ fn main() { let pairs: Vec<_> = data .lines() - .map(|line| -> (SectionRange, SectionRange) { + .map(|line| -> (SectionRange, SectionRange) { let (left, right) = line.split_once(',').unwrap(); let left = left.parse().unwrap(); let right = right.parse().unwrap(); diff --git a/Cargo.lock b/Cargo.lock index 3819068..a9a0cc7 100644 --- a/Cargo.lock +++ b/Cargo.lock @@ -14,6 +14,9 @@ dependencies = [ [[package]] name = "aoc" version = "0.1.0" +dependencies = [ + "num-traits", +] [[package]] name = "arrayvec" diff --git a/common/rust/Cargo.toml b/common/rust/Cargo.toml index 706d2e1..00a20b8 100644 --- a/common/rust/Cargo.toml +++ b/common/rust/Cargo.toml @@ -6,3 +6,4 @@ edition = "2021" # See more keys and their definitions at https://doc.rust-lang.org/cargo/reference/manifest.html [dependencies] +num-traits = "0.2.15" diff --git a/common/rust/src/section_range.rs b/common/rust/src/section_range.rs index 83f3993..57293d9 100644 --- a/common/rust/src/section_range.rs +++ b/common/rust/src/section_range.rs @@ -1,11 +1,13 @@ use std::{ error::Error, fmt, - num::NonZeroU64, - ops::{BitAnd, BitOr, RangeInclusive, Sub}, + num::NonZeroUsize, + ops::{Add, BitAnd, BitOr, RangeInclusive, Sub}, str::FromStr, }; +use num_traits::{one, One}; + use crate::UpToTwo; /// Error returned when an attempt is made to construct an empty `SectionRange`. @@ -34,12 +36,12 @@ impl Error for InvalidSectionString {} /// A range of sections. Always contains at least one element. #[derive(Debug, Clone, Copy, PartialEq, Eq)] -pub struct SectionRange { - start: u64, - end: u64, +pub struct SectionRange { + start: T, + end: T, } -impl SectionRange { +impl SectionRange { /// Constructs a new section range from a start section and an end section, both inclusive. /// /// # Errors @@ -56,7 +58,7 @@ impl SectionRange { /// let range = SectionRange::try_new(3, 2); /// assert_eq!(range, Err(EmptyRange)); /// ``` - pub fn try_new(start: u64, end: u64) -> Result { + pub fn try_new(start: T, end: T) -> Result { if start <= end { Ok(Self { start, end }) } else { @@ -64,25 +66,16 @@ impl SectionRange { } } - /// Returns the number of sections contained by the range. Since the range always contains at - /// least one element, the length is never zero. - #[allow(clippy::missing_panics_doc)] - #[must_use] - pub fn len(&self) -> NonZeroU64 { - debug_assert!(self.start <= self.end); - NonZeroU64::new(self.end - self.start + 1).unwrap() - } - /// Returns true if and only if the range contains the given section. #[must_use] - pub fn contains(&self, section: u64) -> bool { + pub fn contains(&self, section: T) -> bool { debug_assert!(self.start <= self.end); section >= self.start && section <= self.end } /// Returns true if and only if the range contains the entirety of `other`. #[must_use] - pub fn encompasses(&self, other: &SectionRange) -> bool { + pub fn encompasses(&self, other: &Self) -> bool { let Some(intersection) = *self & *other else { return false; }; @@ -90,12 +83,23 @@ impl SectionRange { } } -impl FromStr for SectionRange { +impl, L: TryInto> SectionRange { + /// Returns the number of sections contained by the range. Since the range always contains at + /// least one element, the length is never zero. + #[allow(clippy::missing_panics_doc)] + #[must_use] + pub fn len(&self) -> NonZeroUsize { + debug_assert!(self.start <= self.end); + NonZeroUsize::new((self.end - self.start).try_into().ok().unwrap() + 1).unwrap() + } +} + +impl FromStr for SectionRange { type Err = InvalidSectionString; fn from_str(s: &str) -> Result { // poor man's try block - fn inner(s: &str) -> Option { + fn inner(s: &str) -> Option> { let (start, end) = s.split_once('-')?; let start = start.parse().ok()?; @@ -108,21 +112,21 @@ impl FromStr for SectionRange { } } -impl From for RangeInclusive { - fn from(r: SectionRange) -> Self { +impl From> for RangeInclusive { + fn from(r: SectionRange) -> Self { r.start..=r.end } } -impl TryFrom> for SectionRange { +impl TryFrom> for SectionRange { type Error = EmptyRange; - fn try_from(range: RangeInclusive) -> Result { + fn try_from(range: RangeInclusive) -> Result { Self::try_new(*range.start(), *range.end()) } } -impl BitAnd for SectionRange { +impl BitAnd for SectionRange { type Output = Option; fn bitand(self, other: Self) -> Self::Output { @@ -133,15 +137,15 @@ impl BitAnd for SectionRange { } } -impl BitOr for SectionRange { +impl BitOr for SectionRange { type Output = UpToTwo; fn bitor(self, other: Self) -> Self::Output { - let first_start = u64::min(self.start, other.start); - let first_end = u64::min(self.end, other.end); + let first_start = T::min(self.start, other.start); + let first_end = T::min(self.end, other.end); - let second_start = u64::max(self.start, other.start); - let second_end = u64::max(self.end, other.end); + let second_start = T::max(self.start, other.start); + let second_end = T::max(self.end, other.end); if first_end < second_start { let first = Self { @@ -162,7 +166,7 @@ impl BitOr for SectionRange { } } -impl Sub for SectionRange { +impl + Add + One> Sub for SectionRange { type Output = UpToTwo; fn sub(self, other: Self) -> Self::Output { @@ -173,10 +177,10 @@ impl Sub for SectionRange { // Closures to prevent integer overflow let first = || Self { start: self.start, - end: u64::min(self.end, other.start - 1), + end: T::min(self.end, other.start - one()), }; let second = || Self { - start: u64::max(self.start, other.end + 1), + start: T::max(self.start, other.end + one()), end: self.end, }; @@ -213,7 +217,11 @@ mod tests { #[test] fn section_range_intersection() { - fn check_intersection(a: SectionRange, b: SectionRange, expected: Option) { + fn check_intersection( + a: SectionRange, + b: SectionRange, + expected: Option>, + ) { let x = a & b; let y = b & a; @@ -247,11 +255,11 @@ mod tests { #[test] fn section_range_operations() { - fn elements(r: SectionRange) -> Vec { + fn elements(r: SectionRange) -> Vec { RangeInclusive::<_>::from(r).collect() } - fn elements_multi(r: UpToTwo) -> Vec { + fn elements_multi(r: UpToTwo>) -> Vec { let is_empty = r.is_empty(); let sections: Vec<_> = r.into_iter().flat_map(elements).collect(); @@ -288,7 +296,7 @@ mod tests { // len() assert_eq!( left.len().get(), - *range_left.end() - *range_left.start() + 1 + usize::try_from(*range_left.end() - *range_left.start() + 1).unwrap() ); // contains()