#![warn(clippy::pedantic)] use std::{ io::{stdin, Read}, ops::ControlFlow, }; use nom::{ bits::complete as bits, multi::{many0, many_m_n}, error::ParseError, Offset, }; use nom::{combinator::map, sequence::pair}; use parsers::fold_till; mod parsers; type Input<'a> = (&'a [u8], usize); type IResult<'a, T, E> = nom::IResult, T, E>; #[derive(Clone, Debug, PartialEq, Eq)] struct Packet { version: u8, typ: PacketType, } impl Packet { pub fn parse<'a, E: ParseError>>(i: Input<'a>) -> IResult<'a, Packet, E> { map( pair(bits::take(3_usize), PacketType::parse), |(version, typ)| Packet { version, typ }, )(i) } pub fn version_sum(&self) -> usize { usize::from(self.version) + match self.typ { PacketType::Literal(_) => 0, PacketType::Operation { ref sub_packets, .. } => sub_packets.iter().map(Packet::version_sum).sum(), } } } #[derive(Clone, Debug, PartialEq, Eq)] enum PacketType { Literal(usize), Operation { operator: Operator, sub_packets: Vec, }, } impl PacketType { pub fn parse<'a, E: ParseError>>(i: Input<'a>) -> IResult<'a, PacketType, E> { let (i, operator) = map(bits::take(3_usize), |type_id: u8| { Operator::try_from(type_id) })(i)?; match operator { Ok(operator) => map(Self::parse_sub_packets, |sub_packets| { PacketType::Operation { operator, sub_packets, } })(i), Err(_) => map(Self::parse_literal_value, PacketType::Literal)(i), } } fn parse_sub_packets<'a, E: ParseError>>(i: Input<'a>) -> IResult<'a, Vec, E> { enum LengthType { Bits(usize), Packets(usize), } impl LengthType { pub fn parse<'a, E: ParseError>>(i: Input<'a>) -> IResult<'a, Self, E> { let (i, length_type_id) = bits::take(1_usize)(i)?; match length_type_id { 0 => map(bits::take(15_usize), LengthType::Bits)(i), 1 => map(bits::take(11_usize), LengthType::Packets)(i), _ => unreachable!(), } } } let (i, length_type) = LengthType::parse(i)?; match length_type { LengthType::Packets(n) => many_m_n(n, n, Packet::parse)(i), LengthType::Bits(n) => { // manual implementation of something like the following: // map_parser(recognize(bits::take(n)), many0(Packet::parse))(i) let new_byte_offset = (n + i.1) / 8; let new_bit_offset = (n + i.1) % 8; let subpackets_input = (&i.0[..=new_byte_offset], i.1); let (subpackets_end, subpackets) = many0(Packet::parse)(subpackets_input)?; let new_input = (&i.0[new_byte_offset..], new_bit_offset); assert_eq!(i.0.offset(subpackets_end.0), i.0.offset(new_input.0)); assert_eq!(subpackets_end.1, new_input.1); Ok(( new_input, subpackets, )) } } } fn parse_literal_value<'a, E: ParseError>>(i: Input<'a>) -> IResult<'a, usize, E> { fold_till( pair(bits::take(1_usize), bits::take(4_usize)), || 0, |acc, (marker, bits): (u8, usize)| { (if marker == 1 { ControlFlow::Continue } else { ControlFlow::Break })((acc << 4) | bits) }, )(i) } pub fn evaluate(&self) -> usize { match self { Self::Literal(n) => *n, Self::Operation { operator, sub_packets, } => operator.evaluate(sub_packets.iter().map(|p| p.typ.evaluate())), } } } #[derive(Clone, Copy, Debug, PartialEq, Eq)] enum Operator { Sum, Product, Minimum, Maximum, GreaterThan, LessThan, EqualTo, } impl Operator { pub fn evaluate(self, mut operands: impl Iterator) -> usize { match self { Self::Sum => operands.sum(), Self::Product => operands.product(), Self::Minimum => operands.min().unwrap(), Self::Maximum => operands.max().unwrap(), Self::GreaterThan => usize::from(operands.next().unwrap() > operands.next().unwrap()), Self::LessThan => usize::from(operands.next().unwrap() < operands.next().unwrap()), Self::EqualTo => usize::from(operands.next().unwrap() == operands.next().unwrap()), } } } impl TryFrom for Operator { type Error = (); fn try_from(value: u8) -> Result { match value { 0 => Ok(Self::Sum), 1 => Ok(Self::Product), 2 => Ok(Self::Minimum), 3 => Ok(Self::Maximum), 5 => Ok(Self::GreaterThan), 6 => Ok(Self::LessThan), 7 => Ok(Self::EqualTo), _ => Err(()), } } } fn main() { let bytes: Vec = stdin() .lock() .bytes() .filter_map(|c| char::from(c.unwrap()).to_digit(16)) .map(|c| { #[allow(clippy::cast_possible_truncation)] // a hex digit always fits in a u8 let c = c as u8; c }) .scan(None, |prev, n| { Some(match *prev { Some(i) => { *prev = None; Some(i | n) } None => prev.replace(n << 4), }) }) .flatten() .collect(); let packet = Packet::parse::>((&bytes, 0)).unwrap().1; println!("{}", packet.version_sum()); println!("{}", packet.typ.evaluate()); }