{-# LANGUAGE CPP               #-}
{-# LANGUAGE OverloadedStrings #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Trafo
-- Copyright   : [2012..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Trafo (

  -- * HOAS -> de Bruijn conversion
  -- ** Array computations
  convertAcc, convertAccWith,

  -- ** Array functions
  Afunction, ArraysFunctionR,
  convertAfun, convertAfunWith,

  -- ** Sequence computations
  -- convertSeq, convertSeqWith,

  -- ** Scalar expressions
  Function, EltFunctionR,
  convertExp, convertFun,

) where

import Data.Array.Accelerate.Sugar.Array                            ( ArraysR )
import Data.Array.Accelerate.Sugar.Elt                              ( EltR )
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Trafo.Config
import Data.Array.Accelerate.Trafo.Delayed
import Data.Array.Accelerate.Trafo.Sharing                          ( Afunction, ArraysFunctionR, Function, EltFunctionR )
import qualified Data.Array.Accelerate.AST                          as AST
import qualified Data.Array.Accelerate.Trafo.Fusion                 as Fusion
import qualified Data.Array.Accelerate.Trafo.LetSplit               as LetSplit
import qualified Data.Array.Accelerate.Trafo.Simplify               as Rewrite
import qualified Data.Array.Accelerate.Trafo.Sharing                as Sharing
-- import qualified Data.Array.Accelerate.Trafo.Vectorise              as Vectorise

import Control.DeepSeq
import Data.Text.Lazy.Builder

#ifdef ACCELERATE_DEBUG
import Formatting
import System.IO.Unsafe
import Data.Array.Accelerate.Debug.Internal.Flags                   hiding ( when )
import Data.Array.Accelerate.Debug.Internal.Timed
#endif


-- HOAS -> de Bruijn conversion
-- ----------------------------

-- | Convert a closed array expression to de Bruijn form while also
--   incorporating sharing observation and array fusion.
--
convertAcc :: Acc arrs -> DelayedAcc (ArraysR arrs)
convertAcc :: forall arrs. Acc arrs -> DelayedAcc (ArraysR arrs)
convertAcc = Config -> Acc arrs -> DelayedOpenAcc () (ArraysR arrs)
forall arrs. Config -> Acc arrs -> DelayedAcc (ArraysR arrs)
convertAccWith Config
defaultOptions

convertAccWith :: Config -> Acc arrs -> DelayedAcc (ArraysR arrs)
convertAccWith :: forall arrs. Config -> Acc arrs -> DelayedAcc (ArraysR arrs)
convertAccWith Config
config
  = Builder
-> (Acc (ArraysR arrs) -> DelayedAcc (ArraysR arrs))
-> Acc (ArraysR arrs)
-> DelayedAcc (ArraysR arrs)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"array-fusion"           (Config -> Acc (ArraysR arrs) -> DelayedAcc (ArraysR arrs)
forall arrs. HasCallStack => Config -> Acc arrs -> DelayedAcc arrs
Fusion.convertAccWith Config
config)
  (Acc (ArraysR arrs) -> DelayedAcc (ArraysR arrs))
-> (Acc arrs -> Acc (ArraysR arrs))
-> Acc arrs
-> DelayedAcc (ArraysR arrs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder
-> (Acc (ArraysR arrs) -> Acc (ArraysR arrs))
-> Acc (ArraysR arrs)
-> Acc (ArraysR arrs)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"array-split-lets"       Acc (ArraysR arrs) -> Acc (ArraysR arrs)
forall aenv a. OpenAcc aenv a -> OpenAcc aenv a
LetSplit.convertAcc
  -- phase "vectorise-sequences"    Vectorise.vectoriseSeqAcc `when` vectoriseSequences
  (Acc (ArraysR arrs) -> Acc (ArraysR arrs))
-> (Acc arrs -> Acc (ArraysR arrs))
-> Acc arrs
-> Acc (ArraysR arrs)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder
-> (Acc arrs -> Acc (ArraysR arrs))
-> Acc arrs
-> Acc (ArraysR arrs)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"sharing-recovery"       (Config -> Acc arrs -> Acc (ArraysR arrs)
forall arrs.
HasCallStack =>
Config -> Acc arrs -> Acc (ArraysR arrs)
Sharing.convertAccWith Config
config)


-- | Convert a unary function over array computations, incorporating sharing
--   observation and array fusion
--
convertAfun :: Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun :: forall f. Afunction f => f -> DelayedAfun (ArraysFunctionR f)
convertAfun = Config -> f -> PreOpenAfun DelayedOpenAcc () (ArraysFunctionR f)
forall f.
Afunction f =>
Config -> f -> DelayedAfun (ArraysFunctionR f)
convertAfunWith Config
defaultOptions

convertAfunWith :: Afunction f => Config -> f -> DelayedAfun (ArraysFunctionR f)
convertAfunWith :: forall f.
Afunction f =>
Config -> f -> DelayedAfun (ArraysFunctionR f)
convertAfunWith Config
config
  = Builder
-> (Afun (ArraysFunctionR f) -> DelayedAfun (ArraysFunctionR f))
-> Afun (ArraysFunctionR f)
-> DelayedAfun (ArraysFunctionR f)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"array-fusion"           (Config
-> Afun (ArraysFunctionR f) -> DelayedAfun (ArraysFunctionR f)
forall f. HasCallStack => Config -> Afun f -> DelayedAfun f
Fusion.convertAfunWith Config
config)
  (Afun (ArraysFunctionR f) -> DelayedAfun (ArraysFunctionR f))
-> (f -> Afun (ArraysFunctionR f))
-> f
-> DelayedAfun (ArraysFunctionR f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder
-> (Afun (ArraysFunctionR f) -> Afun (ArraysFunctionR f))
-> Afun (ArraysFunctionR f)
-> Afun (ArraysFunctionR f)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"array-split-lets"       Afun (ArraysFunctionR f) -> Afun (ArraysFunctionR f)
forall aenv f.
PreOpenAfun OpenAcc aenv f -> PreOpenAfun OpenAcc aenv f
LetSplit.convertAfun
  -- phase "vectorise-sequences"    Vectorise.vectoriseSeqAfun  `when` vectoriseSequences
  (Afun (ArraysFunctionR f) -> Afun (ArraysFunctionR f))
-> (f -> Afun (ArraysFunctionR f)) -> f -> Afun (ArraysFunctionR f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder
-> (f -> Afun (ArraysFunctionR f)) -> f -> Afun (ArraysFunctionR f)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"sharing-recovery"       (Config -> f -> Afun (ArraysFunctionR f)
forall f.
(HasCallStack, Afunction f) =>
Config -> f -> Afun (ArraysFunctionR f)
Sharing.convertAfunWith Config
config)


-- | Convert a closed scalar expression, incorporating sharing observation and
--   optimisation.
--
convertExp :: Exp e -> AST.Exp () (EltR e)
convertExp :: forall e. Exp e -> Exp () (EltR e)
convertExp
  = Builder
-> (Exp () (EltR e) -> Exp () (EltR e))
-> Exp () (EltR e)
-> Exp () (EltR e)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"exp-simplify"     Exp () (EltR e) -> Exp () (EltR e)
forall aenv t. HasCallStack => Exp aenv t -> Exp aenv t
Rewrite.simplifyExp
  (Exp () (EltR e) -> Exp () (EltR e))
-> (Exp e -> Exp () (EltR e)) -> Exp e -> Exp () (EltR e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder -> (Exp e -> Exp () (EltR e)) -> Exp e -> Exp () (EltR e)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"sharing-recovery" Exp e -> Exp () (EltR e)
forall e. HasCallStack => Exp e -> Exp () (EltR e)
Sharing.convertExp


-- | Convert closed scalar functions, incorporating sharing observation and
--   optimisation.
--
convertFun :: Function f => f -> AST.Fun () (EltFunctionR f)
convertFun :: forall f. Function f => f -> Fun () (EltFunctionR f)
convertFun
  = Builder
-> (Fun () (EltFunctionR f) -> Fun () (EltFunctionR f))
-> Fun () (EltFunctionR f)
-> Fun () (EltFunctionR f)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"exp-simplify"     Fun () (EltFunctionR f) -> Fun () (EltFunctionR f)
forall aenv f. HasCallStack => Fun aenv f -> Fun aenv f
Rewrite.simplifyFun
  (Fun () (EltFunctionR f) -> Fun () (EltFunctionR f))
-> (f -> Fun () (EltFunctionR f)) -> f -> Fun () (EltFunctionR f)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Builder
-> (f -> Fun () (EltFunctionR f)) -> f -> Fun () (EltFunctionR f)
forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
"sharing-recovery" f -> Fun () (EltFunctionR f)
forall f.
(HasCallStack, Function f) =>
f -> Fun () (EltFunctionR f)
Sharing.convertFun

{--
-- | Convert a closed sequence computation, incorporating sharing observation and
--   optimisation.
--
convertSeq :: Typeable s => Seq s -> DelayedSeq s
convertSeq = convertSeqWith phases

convertSeqWith :: Typeable s => Phase -> Seq s -> DelayedSeq s
convertSeqWith Phase{..} s
  = phase "array-fusion"           (Fusion.convertSeq enableAccFusion)
  -- $ phase "vectorise-sequences"    Vectorise.vectoriseSeq     `when` vectoriseSequences
  -- $ phase "rewrite-segment-offset" Rewrite.convertSegmentsSeq `when` convertOffsetOfSegment
  $ phase "sharing-recovery"       (Sharing.convertSeq recoverAccSharing recoverExpSharing recoverSeqSharing floatOutAccFromExp)
  $ s
--}


-- when :: (a -> a) -> Bool -> a -> a
-- when f True  = f
-- when _ False = id

-- Debugging
-- ---------

-- Execute a phase of the compiler and (possibly) print some timing/gc
-- statistics.
--
phase :: NFData b => Builder -> (a -> b) -> a -> b
#ifdef ACCELERATE_DEBUG
phase n f x = unsafePerformIO $ do
  enabled <- getFlag dump_phases
  if enabled
    then timed dump_phases (now ("phase " <> n <> ": ") % elapsed) (return $!! f x)
    else return (f x)
#else
phase :: forall b a. NFData b => Builder -> (a -> b) -> a -> b
phase Builder
_ a -> b
f = a -> b
f
#endif