{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE DeriveFunctor #-}
{-# LANGUAGE DerivingVia #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE KindSignatures #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE PackageImports #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeSynonymInstances #-}
{-# LANGUAGE ViewPatterns #-}

{- |
Module: Control.RTree
-}
module Control.RTree where

import Control.Applicative (Alternative ((<|>)))
import Data.Coerce (coerce)
import Data.Functor.Classes
import qualified Data.List.NonEmpty as NE
import Data.Void

import qualified Data.List as L

import Control.Monad.Reader
import "free" Control.Monad.Free.Church

-- | The base functor for the reduction tree.
data RTreeF l f
  = Split (Maybe l) f f
  deriving (Show, Eq, Functor)

instance (Show l) => Show1 (RTreeF l) where
  liftShowsPrec = undefined

newtype RTree l i = RTree {rtreeFree :: F (RTreeF l) i}
  deriving (Functor, Applicative, Monad) via (F (RTreeF l))

instance MonadFree (RTreeF l) (RTree l) where
  wrap x = RTree (wrap (fmap rtreeFree x))

infixr 3 <|
infixl 3 |>

{-# INLINE (<|) #-}
(<|) :: (MonadFree (RTreeF l) r) => r i -> r i -> r i
r1 <| r2 = wrap (Split Nothing r1 r2)

{-# INLINE splitOn #-}
splitOn :: (MonadFree (RTreeF l) r) => l -> r i -> r i -> r i
splitOn l r1 r2 = wrap (Split (Just l) r1 r2)

{-# INLINE split #-}
split :: (MonadFree (RTreeF l) r) => Maybe l -> r i -> r i -> r i
split l r1 r2 = wrap (Split l r1 r2)

{-# INLINE (|>) #-}
(|>) :: (MonadFree (RTreeF l) r) => r i -> r i -> r i
r1 |> r2 = r2 <| r1

{-# INLINE foldR #-}
foldR :: (RTreeF l a -> a) -> RTree l a -> a
foldR fn = coerce $ iter fn

foldRM :: (Monad m) => (RTreeF l (m a) -> m a) -> RTree l a -> m a
foldRM fn = coerce $ iterM fn

-- | Extract the input from the reducer.
extract :: RTree l i -> i
extract = foldR \(Split _ _ e) -> e

-- | Remove all labels from a RTree by expanding all choices.
flatten :: forall i l. (Eq l) => RTree l i -> Maybe (RTree Void i)
flatten t = foldR go (fmap (const . Just . pure) t) []
 where
  go (Split ml lhs rhs) lst =
    case ml of
      Just l -> case l `L.lookup` lst of
        Nothing -> do
          join' (lhs $ (l, False) : lst) (rhs $ (l, True) : lst)
        Just True ->
          join' (lhs lst) (rhs lst)
        Just False ->
          Nothing
      Nothing -> join' (lhs lst) (rhs lst)

  join' mlhs mrhs = do
    case (mlhs, mrhs) of
      (Just lhs', Just rhs') -> pure (lhs' <| rhs')
      _ -> mlhs <|> mrhs

-- | Reduce an input using a monad.
reduce
  :: forall m i
   . (Alternative m)
  => (i -> m ())
  -> RTree Void i
  -> m i
reduce p t = do
  let (mi, i') = foldR go $ fmap (\i -> (pure i, i)) t
  p i' *> mi
 where
  go :: RTreeF l (m i, i) -> (m i, i)
  go (Split _ (lhs, le) (rhs, re)) =
    ((p le *> lhs) <|> rhs, re)
{-# INLINE reduce #-}

data RTree' l i
  = RTree' (RTreeF l (RTree' l i))
  | Done i

extract' :: RTree' l i -> i
extract' = \case
  RTree' (Split _ _ v) -> extract' v
  Done v -> v

instance Functor (RTree' l) where
  fmap f (Done i) = Done (f i)
  fmap f (RTree' r) = RTree' (fmap (fmap f) r)

instance Applicative (RTree' l) where
  pure = Done
  (<*>) = ap

instance Monad (RTree' l) where
  ma >>= f = case ma of
    Done i -> f i
    RTree' r ->
      RTree'
        (fmap (>>= f) r)

instance MonadFree (RTreeF l) (RTree' l) where
  wrap = RTree'
  {-# INLINE wrap #-}

-- | Reduce an input using a monad.
reduce'
  :: forall m l i
   . (Alternative m)
  => (i -> m ())
  -> RTree' l i
  -> m i
reduce' p = checkgo
 where
  go = \case
    (Done i) -> pure i
    (RTree' (Split _ lhs rhs)) ->
      (checkgo lhs <|> go rhs)
  checkgo rt = p (extract' rt) *> go rt

-- newtype I l i = I ([(l, Bool)] -> RTreeI l i)
--
-- data RTreeI l i
--   = RTreeI (RTreeF l (I l i))
--   | DoneI !i

-- -- This is not a great defintions, as the i does not depend on
-- -- the current i, but instead on the final I.
-- data RTreeIO j i = RTreeIO ((j -> IO Bool) -> IO i) j
--
-- extractIO :: RTreeIO j i -> j
-- extractIO (RTreeIO _ i) = i

-- instance Functor (RTreeIO j) where
--   fmap f (RTreeIO mf i) = RTreeIO (\h -> f <$> mf (h . f)) (f i)
--
-- instance Applicative (RTreeIO j) where
--   pure i = RTreeIO (\_ -> pure i) i
--   (<*>) = ap
--
-- -- RTreeIO f fi <*> RTreeIO a ai = RTreeIO (f <*> a) (fi ai)
--
-- instance Monad (RTreeIO j) where
--   RTreeIO (ma :: ((a -> IO Bool) -> IO a)) a >>= (f :: (a -> RTreeIO b)) =
--     RTreeIO undefined (extractIO $ f a)
--
-- instance MonadFree (RTreeF Void) (RTreeIO j) where
--   wrap (Split Nothing (RTreeIO lhs le) (RTreeIO rhs re)) =
--     RTreeIO
--       ( \p ->
--           p le >>= \case
--             True -> lhs p
--             False -> rhs p
--       )
--       re
--   wrap (Split (Just x) _ _) = absurd x

-- reduceIO
--   :: forall i
--    . (i -> IO Bool)
--   -> RTreeIO j i
--   -> IO (Maybe i)
-- reduceIO p (RTreeIO rt i) = runMaybeT do
--   let (mi, i') = foldR go $ fmap (\i -> (pure i, i)) t
--   p i' *> mi
--  where
--   go :: RTreeF l (IO i, i) -> (IO i, i)
--   go (Split _ (lhs, le) (rhs, re)) =
--     ((p le *> lhs) <|> rhs, re)

-- | Split the world on a fact. False it does not happen, and True it does happen.
given :: RTree Void Bool
given = pure False <| pure True

{- | A reducer should extract itself
@
 extract . red = id
@
-}
lawReduceId :: (Eq i) => (i -> RTree l i) -> i -> Bool
lawReduceId red i = extract (red i) == i

-- | Reducing a list one element at a time.
rList :: [a] -> RTree l [a]
rList = \case
  [] -> pure []
  a : as -> rList as <| (a :) <$> rList as

{- | Binary reduction on the list assumming suffixes all contain eachother:
@[] < [c] < [b, c] < [a,b,c]@
-}
rSuffixList :: [a] -> RTree l [a]
rSuffixList as = do
  res <- exponentialSearch (NE.tails as)
  case res of
    [] -> pure []
    a : as' -> (a :) <$> rSuffixList as'

{- | Given a progression of inputs that are progressively larger, pick the smallest using
binary search.
-}
binarySearch :: NE.NonEmpty i -> RTree l i
binarySearch = \case
  a NE.:| [] -> pure a
  d -> binarySearch l <| binarySearch f
   where
    (NE.fromList -> f, NE.fromList -> l) = NE.splitAt (NE.length d `div` 2) d

{- | Given a progression of inputs that are progressively larger, pick the smallest using
binary search.
-}
exponentialSearch :: NE.NonEmpty i -> RTree l i
exponentialSearch = go 1
 where
  go n = \case
    d
      | n >= NE.length d -> binarySearch d
      | otherwise -> go (n * 2) l <| binarySearch f
     where
      (NE.fromList -> f, NE.fromList -> l) = NE.splitAt n d

nonEmptyOr :: String -> [a] -> NE.NonEmpty a
nonEmptyOr msg ls = case NE.nonEmpty ls of
  Just a -> a
  Nothing -> error msg

-- | Given a list of orderd options,  the
linearSearch :: NE.NonEmpty i -> RTree l i
linearSearch = foldr1 (<|) . fmap pure

-- | Given a list of orderd options,  the
linearSearch' :: [i] -> RTree l (Maybe i)
linearSearch' is = linearSearch (NE.fromList $ fmap Just is ++ [Nothing])

-- | Given
ddmin :: [i] -> RTree l [i]
ddmin = \case
  [] -> pure []
  [a] -> pure [a]
  as -> go 2 as
 where
  go n lst
    | n' <= 0 = pure lst
    | otherwise = do
        r <- linearSearch' (partitions n' lst ++ composites n' lst)
        case r of
          Nothing -> go (n * 2) lst <| pure lst -- (for efficiency :D)
          Just lst' -> ddmin lst'
   where
    n' = length lst `div` n
  partitions n lst =
    case lst of
      [] -> []
      _ -> let (h, r) = splitAt n lst in h : partitions n r
  composites n lst =
    case lst of
      [] -> []
      _ -> let (h, r) = splitAt n lst in r : fmap (h ++) (composites n r)