advent-of-code/2021/day16/day16_rs/src/main.rs

200 lines
5.7 KiB
Rust
Raw Normal View History

2021-12-16 18:48:50 +01:00
#![warn(clippy::pedantic)]
use std::{
io::{stdin, Read},
ops::ControlFlow,
};
use nom::{
bits::complete as bits,
combinator::flat_map,
multi::{many0, many_m_n},
};
use nom::{combinator::map, sequence::pair};
use parsers::fold_till;
mod parsers;
type Input<'a> = (&'a [u8], usize);
type IResult<'a, T> = nom::IResult<Input<'a>, T>;
#[derive(Clone, Debug, PartialEq, Eq)]
struct Packet {
version: u8,
typ: PacketType,
}
impl Packet {
pub fn parse(i: Input) -> IResult<Packet> {
map(
pair(bits::take(3_usize), PacketType::parse),
|(version, typ)| Packet { version, typ },
)(i)
}
pub fn version_sum(&self) -> usize {
usize::from(self.version)
+ match self.typ {
PacketType::Literal(_) => 0,
PacketType::Operation {
ref sub_packets, ..
} => sub_packets.iter().map(Packet::version_sum).sum(),
}
}
}
#[derive(Clone, Debug, PartialEq, Eq)]
enum PacketType {
Literal(usize),
Operation {
operator: Operator,
sub_packets: Vec<Packet>,
},
}
impl PacketType {
pub fn parse(input: Input) -> IResult<PacketType> {
flat_map(bits::take(3_usize), |type_id: u8| {
move |i| match Operator::try_from(type_id) {
Ok(operator) => map(Self::operation_inner, |sub_packets| PacketType::Operation {
operator,
sub_packets,
})(i),
Err(_) => map(Self::literal_inner, PacketType::Literal)(i),
}
})(input)
}
fn operation_inner(i: Input) -> IResult<Vec<Packet>> {
enum LengthType {
Bits(u16),
Packets(u16),
}
flat_map(
flat_map(bits::take(1_usize), |length_type_id: u8| {
move |i| match length_type_id {
0 => map(bits::take(15_usize), LengthType::Bits)(i),
1 => map(bits::take(11_usize), LengthType::Packets)(i),
_ => unreachable!(),
}
}),
|length_type| {
move |i| match length_type {
LengthType::Packets(n) => {
many_m_n(usize::from(n), usize::from(n), Packet::parse)(i)
}
LengthType::Bits(n) => {
// map_parser(recognize(bits::take(n)), many1(Packet::parse))(i)
let extra_bytes_required = (usize::from(n) + i.1) / 8;
let extra_bits_required = (usize::from(n) + i.1) % 8;
let subpackets_slice = (&i.0[..=extra_bytes_required], i.1);
let (subpackets_end, subpackets) = many0(Packet::parse)(subpackets_slice)?;
if subpackets_end.0.len() > 1 {
todo!()
}
Ok((
(&i.0[extra_bytes_required..], extra_bits_required),
subpackets,
))
}
}
},
)(i)
}
fn literal_inner(i: Input) -> IResult<usize> {
fold_till(
pair(bits::take(1_usize), bits::take(4_usize)),
|| 0,
|acc, (marker, bits): (u8, usize)| {
(if marker == 1 {
ControlFlow::Continue
} else {
ControlFlow::Break
})((acc << 4) | bits)
},
)(i)
}
pub fn evaluate(&self) -> usize {
match self {
Self::Literal(n) => *n,
Self::Operation {
ref operator,
ref sub_packets,
} => operator.evaluate(sub_packets.iter().map(|p| p.typ.evaluate())),
}
}
}
#[derive(Clone, Copy, Debug, PartialEq, Eq)]
enum Operator {
Sum,
Product,
Minimum,
Maximum,
GreaterThan,
LessThan,
EqualTo,
}
impl Operator {
pub fn evaluate(self, mut operands: impl Iterator<Item = usize>) -> usize {
match self {
Self::Sum => operands.sum(),
Self::Product => operands.product(),
Self::Minimum => operands.min().unwrap(),
Self::Maximum => operands.max().unwrap(),
Self::GreaterThan => usize::from(operands.next().unwrap() > operands.next().unwrap()),
Self::LessThan => usize::from(operands.next().unwrap() < operands.next().unwrap()),
Self::EqualTo => usize::from(operands.next().unwrap() == operands.next().unwrap()),
}
}
}
impl TryFrom<u8> for Operator {
type Error = ();
fn try_from(value: u8) -> Result<Self, Self::Error> {
match value {
0 => Ok(Self::Sum),
1 => Ok(Self::Product),
2 => Ok(Self::Minimum),
3 => Ok(Self::Maximum),
5 => Ok(Self::GreaterThan),
6 => Ok(Self::LessThan),
7 => Ok(Self::EqualTo),
_ => Err(()),
}
}
}
fn main() {
let bytes: Vec<u8> = stdin()
.lock()
.bytes()
.filter_map(|c| char::from(c.unwrap()).to_digit(16))
.map(|c| {
#[allow(clippy::cast_possible_truncation)] // a hex digit always fits in a u8
let c = c as u8;
c
})
.scan(None, |prev, n| {
Some(match *prev {
Some(i) => {
*prev = None;
Some(i | n)
}
None => prev.replace(n << 4),
})
})
.flatten()
.collect();
let packet = Packet::parse((&bytes, 0)).unwrap().1;
println!("{}", packet.version_sum());
println!("{}", packet.typ.evaluate());
}