{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}

-- |
--
-- Module: Control.RTree
module Control.RTree where

import Control.Monad (MonadPlus (..), ap, liftM2)
import Data.Functor

-- | The reduction tree, parameterized by a genrative functor 'f'.
data ReduceT f i
  = Done i
  | f (ReduceT f i) :<| ReduceT f i

instance (Functor f) => Functor (ReduceT f) where
  fmap f = \case
    Done i -> Done (f i)
    mi :<| ri -> fmap (fmap f) mi :<| fmap f ri

instance (Functor f) => Applicative (ReduceT f) where
  pure = Done
  (<*>) = ap

instance (Functor f) => Monad (ReduceT f) where
  ma >>= fa = case ma of
    Done i -> fa i
    mi :<| ri -> ((>>= fa) <$> mi) :<| (ri >>= fa)

-- | Change the underlying monad using a natural transformation.
liftR :: (Functor f) => (forall a. f a -> g a) -> ReduceT f i -> ReduceT g i
liftR nat = \case
  Done i -> Done i
  lhs :<| rhs -> nat (liftR nat <$> lhs) :<| liftR nat rhs

-- | Extract the input from the reducer.
extract :: ReduceT f i -> i
extract = \case
  Done i -> i
  _ :<| rhs -> extract rhs

-- | Reduce an input using a monad.
reduce :: (MonadPlus m) => (i -> m ()) -> ReduceT m i -> m i
reduce fn rt = case rt of
  Done i -> fn i $> i
  lhs :<| rhs -> do
    (lhs >>= reduce fn) `mplus` reduce fn rhs

infixr 3 <|

-- Combinators
(<|) :: (Applicative f) => ReduceT f i -> ReduceT f i -> ReduceT f i
f <| b = pure f :<| b

-- | Split the world on a fact. False it does not happen, and True it does happen.
given :: (Applicative f) => ReduceT f Bool
given = pure False <| pure True

-- | A reducer is something that takes an inputs and returns a reduction tree.
type Reducer m i = i -> ReduceT m i

-- | A reducer should extract itself
-- @
--  extract . red = id
-- @
lawReduceId :: (Eq i) => Reducer m i -> i -> Bool
lawReduceId red i = extract (red i) == i

rList :: (Applicative m) => Reducer m [a]
rList = \case
  [] -> Done []
  a : as -> rList as <| (a :) <$> rList as

rBinaryList :: (Applicative m) => Reducer m [a]
rBinaryList = \case
  [] -> Done []
  as -> Done [] <| go as
  where
    go = \case
      [] -> error "unexpected"
      [a] -> Done [a]
      as -> go l <| liftM2 (<>) (go f) (Done [] <| go l)
        where
          (f, l) = splitAt (length as `div` 2) as