{-# LANGUAGE BlockArguments #-}
{-# LANGUAGE ConstraintKinds #-}
{-# LANGUAGE FlexibleContexts #-}
{-# LANGUAGE FlexibleInstances #-}
{-# LANGUAGE GADTs #-}
{-# LANGUAGE LambdaCase #-}
{-# LANGUAGE RankNTypes #-}
{-# LANGUAGE RecordWildCards #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TypeApplications #-}
{-# LANGUAGE TypeFamilies #-}
{-# LANGUAGE NoMonomorphismRestriction #-}

module ReduceC where

import Control.Monad.Reader
import Control.Monad.Reduce
import Control.Monad.Trans.Maybe (MaybeT (runMaybeT))
import Data.Data
import Data.Foldable
import Data.Functor
import qualified Data.Map.Strict as Map
import qualified Data.Set as Set
import Data.Vector.Internal.Check (HasCallStack)
import qualified Language.C as C
import qualified Language.C.Data.Ident as C

data Context = Context
  { keywords :: !(Set.Set Keyword)
  , typeDefs :: !(Map.Map C.Ident [C.CDeclarationSpecifier C.NodeInfo])
  }

data Keyword
  = KeepMain
  | DoNoops
  | NoSemantics
  deriving (Show, Read, Enum, Eq, Ord)

type CM m = (MonadReduce (String, C.Position) m, MonadReader Context m, MonadFail m)

defaultReduceC :: (CReducible a, MonadReduce (String, C.Position) m) => a -> m (Maybe a)
defaultReduceC a = runMaybeT (runReaderT (reduceC a) defaultContext)

addTypeDefs :: [C.Ident] -> [C.CDeclarationSpecifier C.NodeInfo] -> Context -> Context
addTypeDefs ids cs Context{..} =
  Context
    { typeDefs =
        foldl' (\a i -> Map.insert i cs a) typeDefs ids
    , ..
    }

defaultContext :: Context
defaultContext =
  Context
    { keywords = Set.fromList [KeepMain]
    , typeDefs = Map.empty
    }

keyword :: (MonadReader Context m) => Keyword -> m Bool
keyword s = asks (Set.member s . keywords)

class CReducible a where
  reduceC :: (CM m) => a -> m a

instance CReducible C.CTranslUnit where
  reduceC (C.CTranslUnit es ni) = do
    es' <- foldr reduceCExternalDeclaration (pure []) es
    pure $ C.CTranslUnit es' ni
   where
    reduceCExternalDeclaration r cont = do
      shouldKeepMain <- keyword KeepMain
      case r of
        C.CFDefExt fun
          | shouldKeepMain && maybe False (("main" ==) . C.identToString) (functionName fun) -> do
              r' <- C.CFDefExt <$> reduceC fun
              (r' :) <$> cont
          | otherwise ->
              split ("remove function " <> maybe "" C.identToString (functionName fun), C.posOf r) cont do
                r' <- C.CFDefExt <$> reduceC fun
                (r' :) <$> cont
        C.CDeclExt result ->
          case result of
            -- A typedef
            C.CDecl (C.CStorageSpec (C.CTypedef n) : rst) decl _ -> do
              let [ids] = identifiers decl
              split
                ("inline typedef " <> C.identToString ids, C.posOf r)
                (local (addTypeDefs [ids] rst) cont)
                ((r :) <$> local (addTypeDefs [ids] [C.CTypeSpec (C.CTypeDef ids n)]) cont)
            a -> don'tHandle a
        _r -> don'tHandle r

prettyIdent :: C.Identifier C.NodeInfo -> [Char]
prettyIdent (C.Ident s _ a) = s ++ " at " ++ show (C.posOfNode a)

instance CReducible C.CFunDef where
  reduceC r = do
    C.CFunDef spc dec cdecls smt ni <- inlineTypeDefs r
    smt' <- reduceC smt
    pure $ C.CFunDef spc dec cdecls smt' ni

reduceCCompoundBlockItem
  :: (CM m)
  => C.CCompoundBlockItem C.NodeInfo
  -> m [C.CCompoundBlockItem C.NodeInfo]
  -> m [C.CCompoundBlockItem C.NodeInfo]
reduceCCompoundBlockItem r cont = case r of
  C.CBlockStmt smt -> do
    split ("remove statement", C.posOf r) cont do
      smt' <- reduceC smt
      (C.CBlockStmt smt' :) <$> cont
  C.CBlockDecl decl -> do
    case decl of
      C.CDecl{} -> do
        (r :) <$> cont
      d -> don'tHandle d
  a -> don'tHandle a

instance CReducible (C.CStatement C.NodeInfo) where
  reduceC smt = case smt of
    C.CCompound is cbi ni -> do
      cbi' <- foldr reduceCCompoundBlockItem (pure []) cbi
      pure $ C.CCompound is cbi' ni
    C.CWhile e s dow ni -> do
      e' <- reduceCExprOrZero e
      s' <- reduceC s
      pure $ C.CWhile e' s' dow ni
    C.CExpr me ni -> do
      case me of
        Just e ->
          splitOn DoNoops ("change to noop", C.posOf smt) (pure $ C.CExpr Nothing ni) do
            e' <- reduceC e
            pure $ C.CExpr (Just e') ni
        Nothing ->
          pure $ C.CExpr Nothing ni
    C.CReturn me ni ->
      case me of
        Just e -> do
          e' <- reduceCExprOrZero e
          pure $ C.CReturn (Just e') ni
        Nothing ->
          pure $ C.CReturn Nothing ni
    a -> don'tHandle a

splitIf :: (MonadReduce l m) => Bool -> l -> m a -> m a -> m a
splitIf True s a b = split s a b
splitIf False _ _ b = b

splitOn :: (MonadReduce l m, MonadReader Context m) => Keyword -> l -> m a -> m a -> m a
splitOn k s a b = do
  con <- keyword k
  splitIf con s a b

--     C.CCompound is cbi ni -> do
--       cbi' <- collect (reduce @C.CCompoundBlockItem) cbi
--       pure $ C.CCompound is cbi' ni
--     C.CExpr e ni -> do
--       e' <- optional do
--         e' <- liftMaybe e
--         reduce @C.CExpression e'
--       pure $ C.CExpr e' ni
--     C.CIf e s els ni -> do
--       s' <- reduce s
--       e' <- optional do
--         reduce @C.CExpression e
--       els' <- optional do
--         els' <- liftMaybe els
--         given >> reduce els'
--       case (e', els') of
--         (Nothing, Nothing) -> pure s'
--         (Just e'', Nothing) -> pure $ C.CIf e'' s' Nothing ni
--         (Nothing, Just x) -> pure $ C.CIf zeroExp s' (Just x) ni
--         (Just e'', Just x) -> pure $ C.CIf e'' s' (Just x) ni
--     C.CFor e1 e2 e3 s ni -> do
--       reduce s <| do
--         e1' <- reduce @C.CForInit e1
--         e2' <- optional $ liftMaybe e2 >>= reduce @C.CExpression
--         e3' <- optional $ liftMaybe e3 >>= reduce @C.CExpression
--         s' <- reduce s
--         pure $ C.CFor e1' e2' e3' s' ni
--     C.CReturn e ni -> do
--       e' <- traverse (fmap orZero reduce) e
--       pure $ C.CReturn e' ni
--     C.CBreak ni -> pure (C.CBreak ni)
--     C.CCont ni -> pure (C.CCont ni)
--     C.CLabel i s [] ni -> do
--       -- todo fix attrs
--       s' <- reduce s
--       withFallback s' do
--         givenThat (Val.is i)
--         pure $ C.CLabel i s' [] ni
--     C.CGoto i ni ->
--       withFallback (C.CExpr Nothing ni) do
--         givenThat (Val.is i)
--         pure $ C.CGoto i ni
--     C.CWhile e s dow ni -> do
--       e' <- orZero (reduce @C.CExpression e)
--       s' <- reduce s
--       pure $ C.CWhile e' s' dow ni

zeroExpr :: C.CExpression C.NodeInfo
zeroExpr = C.CConst (C.CIntConst (C.cInteger 0) C.undefNode)

reduceCExprOrZero :: (CM m) => C.CExpr -> m C.CExpr
reduceCExprOrZero expr =
  splitOn NoSemantics ("replace by zero", C.posOf expr) (pure zeroExpr) do
    reduceC expr

instance CReducible C.CExpr where
  reduceC expr = case expr of
    C.CBinary o elhs erhs ni ->
      splitOn NoSemantics ("reduce to left", C.posOf elhs) (reduceC elhs) do
        splitOn NoSemantics ("reduce to right", C.posOf erhs) (reduceC erhs) do
          elhs' <- reduceC elhs
          erhs' <- reduceC erhs
          pure $ C.CBinary o elhs' erhs' ni
    C.CVar i ni -> do
      pure $ C.CVar i ni
    C.CConst x -> do
      pure $ C.CConst x
    C.CUnary o elhs ni -> do
      elhs' <- reduceC elhs
      splitOn NoSemantics ("reduce to operant", C.posOf expr) (pure elhs') do
        pure $ C.CUnary o elhs' ni
    a -> error (show a)

--     C.CCall e es ni -> do
--       e' <- reduce e
--       es' <- traverse (fmap orZero reduce) es
--       pure $ C.CCall e' es' ni
--     C.CCond ec et ef ni -> do
--       ec' <- reduce ec
--       ef' <- reduce ef
--       et' <- optional do
--         et' <- liftMaybe et
--         reduce et'
--       pure $ C.CCond ec' et' ef' ni
--     C.CBinary o elhs erhs ni -> onBothExpr elhs erhs \lhs rhs ->
--       pure $ C.CBinary o lhs rhs ni
--     C.CUnary o elhs ni -> do
--       lhs <- reduce elhs
--       pure $ C.CUnary o lhs ni
--     C.CConst c -> do
--       -- TODO fix
--       pure $ C.CConst c
--     C.CCast cd e ni -> do
--       -- TODO fix
--       cd' <- reduce @C.CDeclaration cd
--       e' <- reduce e
--       pure $ C.CCast cd' e' ni
--     C.CAssign op e1 e2 ni -> onBothExpr e1 e2 \e1' e2' ->
--       pure $ C.CAssign op e1' e2' ni
--     C.CIndex e1 e2 ni -> do
--       e1' <- reduce e1
--       e2' <- orZero (reduce e2)
--       pure $ C.CIndex e1' e2' ni
--     C.CMember e i b ni -> do
--       givenThat (Val.is i)
--       e' <- reduce e
--       pure $ C.CMember e' i b ni
--     C.CComma items ni -> do
--       C.CComma <$> collectNonEmpty' reduce items <*> pure ni
--     e -> error (show e)
--    where
--     onBothExpr elhs erhs = onBoth (reduce elhs) (reduce erhs)

inlineTypeDefs :: forall d m. (Data d, MonadFail m, MonadReader Context m) => d -> m d
inlineTypeDefs r = do
  case eqT @d @[C.CDeclarationSpecifier C.NodeInfo] of
    Just Refl -> do
      res' :: [[C.CDeclarationSpecifier C.NodeInfo]] <- forM r \case
        C.CTypeSpec (C.CTypeDef idx _) -> do
          res <- asks (Map.lookup idx . typeDefs)
          case res of
            Just args -> pure args
            Nothing -> fail ("could not find typedef:" <> show idx)
        a -> pure [a]
      pure (fold res')
    Nothing ->
      gmapM inlineTypeDefs r

-- instance CReducible C.CExtDecl where
--  reduceC (C.CFunDef spc dec cdecls smt ni) = do
--    pure $ C.CFunDef spc dec cdecls smt ni

identifiers :: forall a. (Data a) => a -> [C.Ident]
identifiers d = case cast d of
  Just l -> [l]
  Nothing -> concat $ gmapQ identifiers d

functionName :: C.CFunctionDef C.NodeInfo -> Maybe C.Ident
functionName = \case
  C.CFunDef _ (C.CDeclr ix _ _ _ _) _ _ _ -> ix

isMain :: C.CFunctionDef C.NodeInfo -> Bool
isMain (C.CFunDef _ (C.CDeclr (Just i) _ _ _ _) _ _ _) =
  C.identToString i == "main"
isMain _ow = False

don'tHandle :: (HasCallStack, Functor f, Show (f ())) => f C.NodeInfo -> b
don'tHandle f = error (show (f $> ()))

-- instance CReducible C.CDeclaration where
--   reduce = \case
--     C.CDecl spc@(C.CStorageSpec (C.CTypedef _) : rst) decl ni -> do
--       decl' <-
--         decl & collectNonEmpty' \case
--           C.CDeclarationItem d Nothing Nothing -> do
--             let (x, _) = cDeclaratorIdentifiers d
--             case x of
--               Just x' ->
--                 splitOn
--                   (Val.is x')
--                   ( do
--                       modify (Map.insert x' (Type rst))
--                       mzero
--                   )
--                   (pure $ C.CDeclarationItem d Nothing Nothing)
--               Nothing ->
--                 pure $ C.CDeclarationItem d Nothing Nothing
--           a -> error (show a)
--       pure (C.CDecl spc decl' ni)
--     C.CDecl spc@[C.CTypeSpec (C.CTypeDef i ni')] decl ni -> do
--       x <- gets (Map.lookup i)
--       case x of
--         Just (Type rst) -> do
--           decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers rst) decl
--           pure $ C.CDecl rst decl' ni
--         Nothing -> do
--           decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers spc) decl
--           pure $ C.CDecl spc decl' ni
--     C.CDecl spc decl ni -> do
--       decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers spc) decl
--       pure $ C.CDecl spc decl' ni
--     a -> error (show a)
--    where
--     reduceCDeclarationItem rq' = \case
--       C.CDeclarationItem d i e -> do
--         let (fn, reqs) = cDeclaratorIdentifiers d
--         case fn of
--           Just fn' ->
--             conditionalGivenThat (rq' <> reqs) (Val.is fn')
--           Nothing ->
--             mapM_ (givenThat . Val.is) (rq' <> reqs)
--
--         i' <- optional do
--           liftMaybe i >>= reduce @C.CInitializer
--         e' <- optional do
--           liftMaybe e >>= reduce @C.CExpression
--
--         pure (C.CDeclarationItem d i' e')
--       a -> error (show a)

-- import Control.Monad.Reduce
--
-- import qualified Data.Valuation as Val
--
-- import Control.Applicative
-- import Control.Monad.State
-- import Control.Monad.Trans.Maybe
-- import Data.Data
-- import Data.Function
-- import Data.Functor
-- import qualified Data.Map.Strict as Map
-- import Data.Maybe (catMaybes)
-- import qualified Language.C as C

-- type Lab = C.Ident
--
-- data LabInfo
--   = Type [C.CDeclarationSpecifier C.NodeInfo]
--
-- type CState = Map.Map Lab LabInfo
--
-- reduceC :: (MonadReduce Lab m, MonadState CState m) => C.CTranslUnit -> m C.CTranslUnit
-- reduceC (C.CTranslUnit es ni) = do
--   es' <- collect reduceCExternalDeclaration es
--   pure $ C.CTranslUnit es' ni
--  where
--   reduceCExternalDeclaration = \case
--     C.CFDefExt fun -> do
--       C.CFDefExt <$> reduce @C.CFunctionDef fun
--     C.CDeclExt decl ->
--       C.CDeclExt <$> reduce @C.CDeclaration decl
--     a -> error (show a)
--
-- identifiers :: forall a. (Data a) => a -> [Lab]
-- identifiers d = case cast d of
--   Just l -> [l]
--   Nothing -> concat $ gmapQ identifiers d
--
-- type Reducer m a = a -> m a
--
-- class CReducible c where
--   reduce :: (MonadReducePlus Lab m, MonadState CState m) => Reducer m (c C.NodeInfo)
--
-- cDeclaratorIdentifiers :: C.CDeclarator C.NodeInfo -> (Maybe Lab, [Lab])
-- cDeclaratorIdentifiers (C.CDeclr mi dd _ la _) =
--   (mi, identifiers dd <> identifiers la)
--
-- instance CReducible C.CFunctionDef where
--   reduce (C.CFunDef spc dec cdecls smt ni) = do
--     let (fn, ids) = cDeclaratorIdentifiers dec
--     let requirements = identifiers spc <> identifiers cdecls <> ids
--     case fn of
--       Just fn' ->
--         conditionalGivenThat requirements (Val.is fn')
--       Nothing ->
--         mapM_ (givenThat . Val.is) requirements
--     smt' <- reduce @C.CStatement smt
--     pure $ C.CFunDef spc dec cdecls smt' ni
--
-- instance CReducible C.CDeclaration where
--   reduce = \case
--     C.CDecl spc@(C.CStorageSpec (C.CTypedef _) : rst) decl ni -> do
--       decl' <-
--         decl & collectNonEmpty' \case
--           C.CDeclarationItem d Nothing Nothing -> do
--             let (x, _) = cDeclaratorIdentifiers d
--             case x of
--               Just x' ->
--                 splitOn
--                   (Val.is x')
--                   ( do
--                       modify (Map.insert x' (Type rst))
--                       mzero
--                   )
--                   (pure $ C.CDeclarationItem d Nothing Nothing)
--               Nothing ->
--                 pure $ C.CDeclarationItem d Nothing Nothing
--           a -> error (show a)
--       pure (C.CDecl spc decl' ni)
--     C.CDecl spc@[C.CTypeSpec (C.CTypeDef i ni')] decl ni -> do
--       x <- gets (Map.lookup i)
--       case x of
--         Just (Type rst) -> do
--           decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers rst) decl
--           pure $ C.CDecl rst decl' ni
--         Nothing -> do
--           decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers spc) decl
--           pure $ C.CDecl spc decl' ni
--     C.CDecl spc decl ni -> do
--       decl' <- collectNonEmpty' (reduceCDeclarationItem $ identifiers spc) decl
--       pure $ C.CDecl spc decl' ni
--     a -> error (show a)
--    where
--     reduceCDeclarationItem rq' = \case
--       C.CDeclarationItem d i e -> do
--         let (fn, reqs) = cDeclaratorIdentifiers d
--         case fn of
--           Just fn' ->
--             conditionalGivenThat (rq' <> reqs) (Val.is fn')
--           Nothing ->
--             mapM_ (givenThat . Val.is) (rq' <> reqs)
--
--         i' <- optional do
--           liftMaybe i >>= reduce @C.CInitializer
--         e' <- optional do
--           liftMaybe e >>= reduce @C.CExpression
--
--         pure (C.CDeclarationItem d i' e')
--       a -> error (show a)
--
-- instance CReducible C.CInitializer where
--   reduce = \case
--     C.CInitExpr e ni -> reduce @C.CExpression e <&> \e' -> C.CInitExpr e' ni
--     C.CInitList (C.CInitializerList items) ni -> do
--       collectNonEmpty' rmCInitializerListItem items <&> \items' ->
--         C.CInitList (C.CInitializerList items') ni
--    where
--     rmCInitializerListItem (pds, is) = do
--       pds' <- collect rmCPartDesignator pds
--       is' <- reduce is
--       pure (pds', is')
--
--     rmCPartDesignator = \case
--       a -> error (show a)
--
-- instance CReducible C.CStatement where
--   reduce = \case
--     C.CCompound is cbi ni -> do
--       cbi' <- collect (reduce @C.CCompoundBlockItem) cbi
--       pure $ C.CCompound is cbi' ni
--     C.CExpr e ni -> do
--       e' <- optional do
--         e' <- liftMaybe e
--         reduce @C.CExpression e'
--       pure $ C.CExpr e' ni
--     C.CIf e s els ni -> do
--       s' <- reduce s
--       e' <- optional do
--         reduce @C.CExpression e
--       els' <- optional do
--         els' <- liftMaybe els
--         given >> reduce els'
--       case (e', els') of
--         (Nothing, Nothing) -> pure s'
--         (Just e'', Nothing) -> pure $ C.CIf e'' s' Nothing ni
--         (Nothing, Just x) -> pure $ C.CIf zeroExp s' (Just x) ni
--         (Just e'', Just x) -> pure $ C.CIf e'' s' (Just x) ni
--     C.CFor e1 e2 e3 s ni -> do
--       reduce s <| do
--         e1' <- reduce @C.CForInit e1
--         e2' <- optional $ liftMaybe e2 >>= reduce @C.CExpression
--         e3' <- optional $ liftMaybe e3 >>= reduce @C.CExpression
--         s' <- reduce s
--         pure $ C.CFor e1' e2' e3' s' ni
--     C.CReturn e ni -> do
--       e' <- traverse (fmap orZero reduce) e
--       pure $ C.CReturn e' ni
--     C.CBreak ni -> pure (C.CBreak ni)
--     C.CCont ni -> pure (C.CCont ni)
--     C.CLabel i s [] ni -> do
--       -- todo fix attrs
--       s' <- reduce s
--       withFallback s' do
--         givenThat (Val.is i)
--         pure $ C.CLabel i s' [] ni
--     C.CGoto i ni ->
--       withFallback (C.CExpr Nothing ni) do
--         givenThat (Val.is i)
--         pure $ C.CGoto i ni
--     C.CWhile e s dow ni -> do
--       e' <- orZero (reduce @C.CExpression e)
--       s' <- reduce s
--       pure $ C.CWhile e' s' dow ni
--     a -> error (show a)
--
-- instance CReducible C.CForInit where
--   reduce = \case
--     C.CForDecl decl -> withFallback (C.CForInitializing Nothing) do
--       C.CForDecl <$> reduce @C.CDeclaration decl
--     C.CForInitializing n -> do
--       C.CForInitializing <$> optional do
--         n' <- liftMaybe n
--         reduce @C.CExpression n'
--
--
-- zeroExp :: C.CExpression C.NodeInfo
-- zeroExp = C.CConst (C.CIntConst (C.cInteger 0) C.undefNode)
--
-- withFallback :: (Alternative m) => a -> m a -> m a
-- withFallback a ma = ma <|> pure a
--
-- orZero :: (Alternative m) => m (C.CExpression C.NodeInfo) -> m (C.CExpression C.NodeInfo)
-- orZero = withFallback zeroExp
--
-- instance CReducible C.CCompoundBlockItem where
--   reduce = \case
--     C.CBlockStmt s ->
--       C.CBlockStmt <$> do
--         given >> reduce @C.CStatement s
--     C.CBlockDecl d ->
--       C.CBlockDecl <$> do
--         reduce @C.CDeclaration d
--     a -> error (show a)