{-# LANGUAGE NamedFieldPuns #-}

module Day14 where

import AoC

import Data.Bits
import Data.List
import Data.Map (Map)
import qualified Data.Map as M
import Data.Maybe
import Data.Tuple
import Data.Word
import Text.ParserCombinators.ReadP
import Text.Read.Lex

data Mask = Mask { set :: Word64
                 , clear :: Word64
                 , dontcare :: Word64
                 } deriving (Eq, Show)

data Store = Store { address :: Word64
                   , value :: Word64
                   } deriving (Show)

parseMaskDecl :: ReadP Mask
parseMaskDecl = do
    string "mask = "
    mask <- many1 (choice $ char <$> "01X")
    let extractMask c = foldl setBit zeroBits . findIndices (==c) . reverse $ mask
    return $ Mask { set = extractMask '1'
                  , clear = extractMask '0'
                  , dontcare = extractMask 'X'
                  }

parseStoreInstruction :: ReadP Store
parseStoreInstruction = do
    string "mem"
    address <- between (char '[') (char ']') readDecP
    string " = "
    value <- readDecP
    return $ Store {address, value}

parseGroup :: ReadP (Mask, [Store])
parseGroup = do
    mask <- parseMaskDecl
    char '\n'
    stores <- parseStoreInstruction `endBy` (char '\n')
    return (mask, stores)

runStores :: [Store] -> Word64
runStores = sum . M.fromList . map (\(Store a v) -> (a, v))

applyMask :: Mask -> Word64 -> Word64
applyMask Mask{set, clear} n = (n .|. set) .&. complement clear

maskValues :: Mask -> [Store] -> [Store]
maskValues mask = map maskValue
    where maskValue store@Store{value} = store {value = applyMask mask value}

maskAddresses :: Mask -> [Store] -> [Store]
maskAddresses mask = concatMap maskAddress
    where maskAddress store@Store{address} = map (\m -> store {address = applyMask m address}) . evalDontcare $ mask { clear = 0 }

evalDontcare :: Mask -> [Mask]
evalDontcare = go 0
    where go i mask@Mask{set, clear, dontcare}
            | dontcare == 0 = [mask]
            | dontcare `testBit` i = go (i+1) withSet ++ go (i+1) withClear
            | otherwise = go (i+1) mask
            where new = mask { dontcare = dontcare `clearBit` i }
                  withSet = new { set = set `setBit` i }
                  withClear = new { clear = clear `setBit` i }

part1 :: [(Mask, [Store])] -> Word64
part1 = runStores . concatMap (uncurry maskValues)

part2 :: [(Mask, [Store])] -> Word64
part2 = runStores . concatMap (uncurry maskAddresses)

main = runAoC readGroups part1 part2
    where readGroups = fromJust . oneCompleteResult (many parseGroup)