{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

{- |
Module: Control.RTree
-}
module Control.RTree where

import Control.Applicative
import Control.Monad.Reader
import Control.Monad.Trans.Maybe
import Data.Functor.Identity
import qualified Data.Map.Strict as Map
import Data.Maybe
import GHC.IORef

import Control.Monad.Reduce
import Control.Monad.State.Strict

data RTree l i
  = Split (RTree l i) !(RTree l i)
  | SplitOn !l (RTree l i) !(RTree l i)
  | Done i
  deriving (Functor)

extract :: RTree l i -> i
extract = \case
  Split _ rhs -> extract rhs
  SplitOn _ _ rhs -> extract rhs
  Done v -> v

instance Applicative (RTree l) where
  pure = Done
  (<*>) = ap

instance Monad (RTree l) where
  ma >>= f = case ma of
    Done i -> f i
    Split lhs rhs ->
      Split (lhs >>= f) (rhs >>= f)
    SplitOn l lhs rhs ->
      SplitOn l (lhs >>= f) (rhs >>= f)

instance MonadReduce l (RTree l) where
  splitWith = \case
    Just n -> SplitOn n
    Nothing -> Split

reduce
  :: forall m l i
   . (Alternative m)
  => (i -> m ())
  -> RTree l i
  -> m i
reduce p = checkgo
 where
  go = \case
    (Done i) -> pure i
    (Split lhs rhs) -> (checkgo lhs <|> go rhs)
    (SplitOn _ lhs rhs) -> (checkgo lhs <|> go rhs)
  checkgo rt = p (extract rt) *> go rt
{-# SPECIALIZE reduce :: (i -> MaybeT IO ()) -> RTree l i -> MaybeT IO i #-}

type Valuation l = Map.Map l Bool

extractL :: (Ord l) => Valuation l -> RTree l i -> i
extractL v = \case
  Split _ rhs -> extractL v rhs
  SplitOn l lhs rhs -> case Map.lookup l v of
    Just False -> extractL v lhs
    _ -> extractL v rhs
  Done i -> i

reduceL
  :: forall m l i
   . (Alternative m, Ord l)
  => (Valuation l -> i -> m ())
  -> Valuation l
  -> RTree l i
  -> m i
reduceL p = checkgo
 where
  checkgo v r = p v (extractL v r) *> go v r
  go v = \case
    Done i -> pure i
    SplitOn l lhs rhs -> case Map.lookup l v of
      Just True -> checkgo v rhs
      Just False -> checkgo v lhs
      Nothing -> checkgo (Map.insert l False v) lhs <|> go (Map.insert l True v) rhs
    Split lhs rhs -> (checkgo v lhs <|> go v rhs)
{-# INLINE reduceL #-}

data ReState l = ReState ![Bool] !(Valuation l)

newtype IRTree l i = IRTree {runIRTree :: ReState l -> (i, ReState l)}
  deriving (Functor, Applicative, Monad) via (State (ReState l))

instance (Ord l) => MonadReduce l (IRTree l) where
  checkWith = \case
    Nothing ->
      IRTree \case
        ReState (a : as) v -> (a, ReState as v)
        ReState [] v -> (False, ReState [] v)
    Just l -> IRTree \case
      ReState as v@(Map.lookup l -> Just x) -> (not x, ReState as v)
      ReState (a : as) v -> (a, ReState as (Map.insert l (not a) v))
      ReState [] v -> (False, ReState [] (Map.insert l True v))

reduceI
  :: forall m l i
   . (Monad m, Ord l)
  => (Valuation l -> i -> m Bool)
  -> Valuation l
  -> IRTree l i
  -> m i
reduceI p v (IRTree m) = go []
 where
  go pth =
    case m (ReState pth v) of
      (i, ReState [] v') -> do
        t <- p v' i
        go (pth <> [t])
      (i, _) -> pure i
{-# INLINE reduceI #-}