{-# LANGUAGE FlexibleContexts #-}

import AoC

import Control.Applicative
import Control.Monad
import Control.Monad.ST
import Control.Monad.Trans.Class
import Control.Monad.Trans.Except
import Control.Monad.Trans.Maybe
import Data.Array.ST
import Data.Maybe
import Text.ParserCombinators.ReadP
import Text.Read.Lex

data Instruction = Nop Int | Acc Int | Jmp Int deriving (Show)
data MemoryCell = MemoryCell { instruction :: Instruction
                             , executed :: Bool
                             }
                             deriving (Show)
type Memory s = STArray s Int MemoryCell
data Machine s = Machine { pc :: Int
                         , acc :: Int
                         , memory :: Memory s
                         }

data RunError = OutOfRange | WouldLoop | Breakpoint deriving (Show)

type MachineRunner s = Machine s -> ExceptT RunError (ST s) (Machine s)

parseInstruction :: ReadP Instruction
parseInstruction = choice [ Nop <$> parseIns "nop"
                          , Acc <$> parseIns "acc"
                          , Jmp <$> parseIns "jmp"
                          ]
    where parseIns name = string name >> char ' ' >> parseNumber
          parseNumber = do
              sign <- option '+' (choice $ map char $ "+-")
              num <- readDecP
              return $ case sign of
                         '+' -> num
                         '-' -> -num

newMemory :: [Instruction] -> ST s (Memory s)
newMemory insns = newListArray (0, length insns-1) $ [MemoryCell { instruction = ins, executed = False } | ins <- insns]

newMachine :: [Instruction] -> ST s (Machine s)
newMachine insns = do
    mem <- newMemory insns
    return $ Machine { pc = 0, acc = 0, memory = mem }

stepMachine :: Machine s -> ST s (Machine s)
stepMachine m = do
    cell <- readArray (memory m) (pc m)
    writeArray (memory m) (pc m) $ cell { executed = True }
    (pcF, accF) <- case instruction cell of
                     Nop _ -> return ((+1), id)
                     Acc i -> return ((+1), (+i))
                     Jmp i -> return ((+i), id)
    
    let newPc = pcF $ pc m
        newAcc = accF $ acc m

    return $ m { pc = newPc, acc = newAcc }

runUnless :: (Machine s -> MaybeT (ST s) RunError) -> MachineRunner s -> MachineRunner s
runUnless p stepper m = do
    res <- lift $ runMaybeT $ p m
    case res of
      Nothing -> stepper m
      Just e -> throwE e

breakOnLoop :: MachineRunner s -> MachineRunner s
breakOnLoop = runUnless alreadyExecuted
    where alreadyExecuted m =
            do cell <- lift $ readArray (memory m) (pc m)
               unless (executed cell) empty
               return WouldLoop

breakOnAddr :: Int -> MachineRunner s -> MachineRunner s
breakOnAddr addr = runUnless $ fetchFrom addr
    where fetchFrom addr m =
            do unless (pc m == addr) empty
               return Breakpoint

runUntilError :: MachineRunner s -> Machine s -> ST s (RunError, Int, Int)
runUntilError step m = do
    result <- runExceptT $ step m
    case result of
      Right m' -> runUntilError step m'
      Left e -> return (e, (pc m), (acc m))

part1 :: [Instruction] -> Int
part1 is = fromJust $ runST $ runMaybeT $ do
    (WouldLoop, _, acc) <- lift $ newMachine is >>= runUntilError stepper
    return acc
    where stepper = breakOnLoop (lift . stepMachine)

tryTerminate :: Machine s -> MaybeT (ST s) Int
tryTerminate m = do
    breakpoint <- fmap (\(_, max) -> max + 1) $ lift $ getBounds (memory m)
    (Breakpoint, _, acc) <- lift $ runUntilError (stepper breakpoint) m
    return acc
    where stepper b = breakOnAddr b . breakOnLoop $ (lift . stepMachine)

jmpNopFlips :: [Instruction] -> [[Instruction]]
jmpNopFlips [] = []
jmpNopFlips (x:xs) = case x of
                       Jmp i -> (Nop i:xs) : rest
                       Nop i -> (Jmp i:xs) : rest
                       _ -> rest
    where rest = (x:) <$> jmpNopFlips xs

part2 :: [Instruction] -> Int
part2 instructions = fromJust $ msum $ runST $ mapM testInstructions possibleInstructions
    where testInstructions is = runMaybeT $ lift (newMachine is) >>= tryTerminate
          possibleInstructions = instructions : jmpNopFlips instructions

main = runAoC (map (fromJust . oneCompleteResult parseInstruction) . lines) part1 part2