advent-of-code/common/rust/src/section_range.rs

345 lines
10 KiB
Rust

use std::{
error::Error,
fmt,
num::NonZeroUsize,
ops::{Add, BitAnd, BitOr, RangeInclusive, Sub},
str::FromStr,
};
use num_traits::{one, One};
use crate::UpToTwo;
/// Error returned when an attempt is made to construct an empty `SectionRange`.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct EmptyRange;
impl fmt::Display for EmptyRange {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Range is empty")
}
}
impl Error for EmptyRange {}
/// Error returned when a malformed string is attempted to be parsed as a `SectionRange`.
#[derive(Debug, Clone, PartialEq, Eq)]
pub struct InvalidSectionString(String);
impl fmt::Display for InvalidSectionString {
fn fmt(&self, f: &mut fmt::Formatter<'_>) -> fmt::Result {
write!(f, "Invalid section range: {}", self.0)
}
}
impl Error for InvalidSectionString {}
/// A range of sections. Always contains at least one element.
#[derive(Debug, Clone, Copy, PartialEq, Eq)]
pub struct SectionRange<T> {
start: T,
end: T,
}
impl<T: Ord + Copy> SectionRange<T> {
/// Constructs a new section range from a start section and an end section, both inclusive.
///
/// # Errors
///
/// Returns an error if the range would be empty, i.e. `start > end`.
///
/// # Examples
///
/// ```rust
/// # use aoc::{EmptyRange, SectionRange};
/// let range = SectionRange::try_new(3, 10);
/// assert!(range.is_ok());
///
/// let range = SectionRange::try_new(3, 2);
/// assert_eq!(range, Err(EmptyRange));
/// ```
pub fn try_new(start: T, end: T) -> Result<Self, EmptyRange> {
if start <= end {
Ok(Self { start, end })
} else {
Err(EmptyRange)
}
}
/// Returns true if and only if the range contains the given section.
#[must_use]
pub fn contains(&self, section: T) -> bool {
debug_assert!(self.start <= self.end);
section >= self.start && section <= self.end
}
/// Returns true if and only if the range contains the entirety of `other`.
#[must_use]
pub fn encompasses(&self, other: &Self) -> bool {
let Some(intersection) = *self & *other else {
return false;
};
intersection == *other
}
}
impl<T: Ord + Copy + Sub<Output = L>, L: TryInto<usize>> SectionRange<T> {
/// Returns the number of sections contained by the range. Since the range always contains at
/// least one element, the length is never zero.
#[allow(clippy::missing_panics_doc)]
#[must_use]
pub fn len(&self) -> NonZeroUsize {
debug_assert!(self.start <= self.end);
NonZeroUsize::new((self.end - self.start).try_into().ok().unwrap() + 1).unwrap()
}
}
impl<T: Ord + Copy + FromStr> FromStr for SectionRange<T> {
type Err = InvalidSectionString;
fn from_str(s: &str) -> Result<Self, Self::Err> {
// poor man's try block
fn inner<T: Ord + Copy + FromStr>(s: &str) -> Option<SectionRange<T>> {
let (start, end) = s.split_once('-')?;
let start = start.parse().ok()?;
let end = end.parse().ok()?;
SectionRange::try_from(start..=end).ok()
}
inner(s).ok_or_else(|| InvalidSectionString(s.to_owned()))
}
}
impl<T> From<SectionRange<T>> for RangeInclusive<T> {
fn from(r: SectionRange<T>) -> Self {
r.start..=r.end
}
}
impl<T: Ord + Copy> TryFrom<RangeInclusive<T>> for SectionRange<T> {
type Error = EmptyRange;
fn try_from(range: RangeInclusive<T>) -> Result<Self, Self::Error> {
Self::try_new(*range.start(), *range.end())
}
}
impl<T: Ord + Copy> BitAnd for SectionRange<T> {
type Output = Option<Self>;
fn bitand(self, other: Self) -> Self::Output {
let start = self.start.max(other.start);
let end = self.end.min(other.end);
Self::try_from(start..=end).ok()
}
}
impl<T: Ord + Copy> BitOr for SectionRange<T> {
type Output = UpToTwo<Self>;
fn bitor(self, other: Self) -> Self::Output {
let first_start = T::min(self.start, other.start);
let first_end = T::min(self.end, other.end);
let second_start = T::max(self.start, other.start);
let second_end = T::max(self.end, other.end);
if first_end < second_start {
let first = Self {
start: first_start,
end: first_end,
};
let second = Self {
start: second_start,
end: second_end,
};
UpToTwo::Two(first, second)
} else {
UpToTwo::One(Self {
start: first_start,
end: second_end,
})
}
}
}
impl<T: Ord + Copy + Sub<Output = T> + Add<Output = T> + One> Sub for SectionRange<T> {
type Output = UpToTwo<Self>;
fn sub(self, other: Self) -> Self::Output {
if other.encompasses(&self) {
return UpToTwo::Zero;
}
// Closures to prevent integer overflow
let first = || Self {
start: self.start,
end: T::min(self.end, other.start - one()),
};
let second = || Self {
start: T::max(self.start, other.end + one()),
end: self.end,
};
// The other range does not encompass this one entirely - find out if the remaining bits
// are at the start, at the end or both
if other.end >= self.end {
UpToTwo::One(first())
} else if other.start <= self.start {
UpToTwo::One(second())
} else {
UpToTwo::Two(first(), second())
}
}
}
#[cfg(test)]
mod tests {
use super::*;
#[test]
fn section_range_construction() {
let range = SectionRange::try_from(3..=10).unwrap();
assert_eq!(range, SectionRange { start: 3, end: 10 });
assert_eq!(range.len().get(), 8);
let range = SectionRange::try_from(3..=3).unwrap();
assert_eq!(range, SectionRange { start: 3, end: 3 });
assert_eq!(range.len().get(), 1);
#[allow(clippy::reversed_empty_ranges)]
let range = SectionRange::try_from(3..=2);
assert_eq!(range, Err(EmptyRange));
}
#[test]
fn section_range_intersection() {
fn check_intersection(
a: SectionRange<u64>,
b: SectionRange<u64>,
expected: Option<SectionRange<u64>>,
) {
let x = a & b;
let y = b & a;
assert_eq!(x, y, "intersection not commutative");
assert_eq!(x, expected);
}
let range1 = SectionRange::try_from(3..=5).unwrap();
let range2 = SectionRange::try_from(4..=10).unwrap();
let range3 = SectionRange::try_from(6..=9).unwrap();
check_intersection(range1, range2, Some(SectionRange { start: 4, end: 5 }));
check_intersection(range1, range3, None);
check_intersection(range2, range3, Some(range3));
// self-intersection is always encompassing
assert!(range1.encompasses(&range1));
assert!(range2.encompasses(&range2));
assert!(range3.encompasses(&range3));
// only 2 includes all of 3
assert!(!range1.encompasses(&range2));
assert!(!range1.encompasses(&range3));
assert!(!range2.encompasses(&range1));
assert!(range2.encompasses(&range3));
assert!(!range3.encompasses(&range1));
assert!(!range3.encompasses(&range2));
}
#[test]
fn section_range_operations() {
fn elements(r: SectionRange<u64>) -> Vec<u64> {
RangeInclusive::<_>::from(r).collect()
}
fn elements_multi(r: UpToTwo<SectionRange<u64>>) -> Vec<u64> {
let is_empty = r.is_empty();
let sections: Vec<_> = r.into_iter().flat_map(elements).collect();
if !is_empty {
let mut it = sections.iter().copied();
let mut prev_section = it.next().unwrap();
for section in it {
assert!(section > prev_section);
prev_section = section;
}
}
sections
}
let ranges = (0..=5).flat_map(|start| (start..=5).map(move |end| start..=end));
for range_left in ranges.clone() {
for range_right in ranges.clone() {
// SETUP
println!("testing {:?} against {:?}", range_left, range_right);
let left = SectionRange::try_from(range_left.clone()).unwrap();
let right = SectionRange::try_from(range_right.clone()).unwrap();
let e_left: Vec<_> = elements(left);
let e_right: Vec<_> = elements(right);
// TESTS
assert_eq!(RangeInclusive::<_>::from(left), range_left);
assert_eq!(RangeInclusive::<_>::from(right), range_right);
// len()
assert_eq!(
left.len().get(),
usize::try_from(*range_left.end() - *range_left.start() + 1).unwrap()
);
// contains()
for i in 0..=10 {
assert_eq!(
left.contains(i),
i >= *range_left.start() && i <= *range_left.end()
);
}
// encompasses()
assert!(left.encompasses(&left));
assert_eq!(
left.encompasses(&right),
e_right.iter().all(|e| e_left.contains(e))
);
// intersection
let e_intersection: Vec<_> = e_right
.iter()
.copied()
.filter(|e| e_left.contains(e))
.collect();
match left & right {
None => assert_eq!(e_intersection, []),
Some(intersection) => assert_eq!(elements(intersection), e_intersection),
}
// union
let mut e_union: Vec<_> = e_left.iter().chain(e_right.iter()).copied().collect();
e_union.sort_unstable();
e_union.dedup();
assert_eq!(elements_multi(left | right), e_union);
// difference
let e_difference: Vec<_> = e_left
.iter()
.copied()
.filter(|e| !e_right.contains(e))
.collect();
assert_eq!(elements_multi(left - right), e_difference);
}
}
}
}