{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE UndecidableInstances #-}
{-# LANGUAGE ViewPatterns #-}

{- |
Module: Control.RTree
-}
module Control.RTree (
  -- # RTree
  RTree (..),
  extract,
  reduce,
  -- # IRTree
  IRTree,
  iextract,
  ireduce,
  ireduceExp,
  IRTreeT (..),
  iextractT,
  ireduceT,
  ireduceExpT,
  ReState (..),
  -- # Valuation
  Valuation,
) where

import Control.Applicative
import Control.Monad.Reader
import Control.Monad.State.Strict
import Data.Functor.Identity

import Control.Monad.Reduce
import qualified Data.Valuation as Val

type Valuation = Val.Valuation
type Truth = Val.Truth

data RTree l i
  = SplitWith (Maybe (Truth l)) (RTree l i) !(RTree l i)
  | Done i
  deriving (Functor)

extract :: (Ord l) => Valuation l -> RTree l i -> i
extract v = \case
  SplitWith ml lhs rhs -> case ml >>= Val.condition v of
    Just v' -> extract v' rhs
    _ -> extract v lhs
  Done i -> i

instance Applicative (RTree l) where
  pure = Done
  (<*>) = ap

instance Monad (RTree l) where
  ma >>= f = case ma of
    Done i -> f i
    SplitWith ml lhs rhs -> SplitWith ml (lhs >>= f) (rhs >>= f)

instance MonadReduce l (RTree l) where
  splitWith = SplitWith

reduce
  :: forall m l i
   . (Alternative m, Ord l)
  => (i -> m ())
  -> Valuation l
  -> RTree l i
  -> m i
reduce p = checkgo
 where
  checkgo v r = p (extract v r) *> go v r
  go v = \case
    Done i -> pure i
    SplitWith (Just l) lhs rhs -> case Val.truthValue v (Val.label l) of
      Just t
        | t == Val.truth l -> checkgo v rhs
        | otherwise -> checkgo v lhs
      Nothing -> checkgo (Val.withTruth v $ Val.not l) lhs <|> go (Val.withTruth v l) rhs
    SplitWith Nothing lhs rhs -> (checkgo v lhs <|> go v rhs)
{-# INLINE reduce #-}

data ReState l = ReState
  { choices :: [Bool]
  , valuation :: !(Valuation l)
  }

type IRTree l = IRTreeT l Identity

newtype IRTreeT l m i = IRTreeT {runIRTreeT :: StateT (ReState l) m i}
  deriving (Functor, Applicative, Monad) via (StateT (ReState l) m)
  deriving (MonadTrans) via (StateT (ReState l))

instance (Monad m, Ord l) => MonadReduce l (IRTreeT l m) where
  checkWith =
    IRTreeT . StateT . \case
      Nothing -> \case
        ReState (uncons -> (a, as)) v ->
          pure (a, ReState as v)
      Just l -> \case
        ReState as v@((`Val.truthValue` Val.label l) -> Just x) ->
          pure (x, ReState as v)
        ReState (uncons -> (a, as)) v ->
          pure (a, ReState as (Val.withTruth v (if a then l else Val.not l)))
   where
    uncons [] = (True, [])
    uncons (a : as) = (a, as)
  {-# INLINE checkWith #-}

iextract :: (Ord l) => Valuation l -> IRTree l a -> a
iextract v t = runIdentity $ iextractT v t
{-# INLINE iextract #-}

iextractT :: (Ord l, Monad m) => Valuation l -> IRTreeT l m i -> m i
iextractT v (IRTreeT m) = evalStateT m (ReState [] v)
{-# INLINE iextractT #-}

ireduce
  :: forall m l i
   . (Monad m, Ord l)
  => (Valuation l -> i -> m Bool)
  -> Valuation l
  -> IRTree l i
  -> m i
ireduce = ireduceT (pure . runIdentity)
{-# INLINE ireduce #-}

-- | Interpreted reduction with an m base monad
ireduceT
  :: forall t m l i
   . (Monad m, Monad t, Ord l)
  => (forall a. m a -> t a)
  -- ^ a lift of monad m into t (normally @id@ or @lift@)
  -> (Valuation l -> i -> t Bool)
  -> Valuation l
  -> IRTreeT l m i
  -> t i
ireduceT lift_ p v (IRTreeT m) = go []
 where
  go pth =
    lift_ (runStateT m (ReState pth v)) >>= \case
      (i, ReState [] v') -> do
        t <- p v' i
        -- if the predicate is true, we can reduce to the false branch.
        go (pth <> [not t])
      (i, _) -> pure i
{-# INLINE ireduceT #-}

ireduceExp
  :: forall m l i
   . (Monad m, Ord l)
  => (Valuation l -> i -> m Bool)
  -> Valuation l
  -> IRTree l i
  -> m i
ireduceExp = ireduceExpT (pure . runIdentity)
{-# INLINE ireduceExp #-}

-- | Interpreted reduction with an m base monad, and running in expoential mode.
ireduceExpT
  :: forall t m l i
   . (Monad m, Monad t, Ord l)
  => (forall a. m a -> t a)
  -- ^ a lift of monad m into t (normally @id@ or @lift@)
  -> (Valuation l -> i -> t Bool)
  -> Valuation l
  -> IRTreeT l m i
  -> t i
ireduceExpT lift_ p v (IRTreeT (StateT m)) = go 0 []
 where
  -- here n is the number of explorative elements
  go n pth =
    lift_ (m $ ReState pth v) >>= \case
      (i, ReState [] v') -> do
        p v' i >>= \case
          True -> do
            let n' = next n
            go n' (pth <> replicate n' False)
          False -> do
            case n of
              0 -> go 0 (pth <> [True])
              n' -> go n' $ take (length pth - prev n') pth
      (i, _) -> pure i

  next 0 = 1
  next n = n * 2

  prev 1 = 0
  prev n = n `quot` 2