{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE GADTs                 #-}
{-# LANGUAGE LambdaCase            #-}
{-# LANGUAGE OverloadedStrings     #-}
{-# LANGUAGE QuantifiedConstraints #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE StandaloneDeriving    #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Type
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Type
  where

import Data.Array.Accelerate.Type
import Data.Primitive.Vec

import Formatting
import Language.Haskell.TH.Extra


-- | Both arrays (Acc) and expressions (Exp) are represented as nested
-- pairs consisting of:
--
--   * unit (void)
--
--   * pairs: representing compound values (i.e. tuples) where each component
--     will be stored in a separate array.
--
--   * single array / scalar types
--     in case of expressions: values which go in registers. These may be single value
--     types such as int and float, or SIMD vectors of single value types such
--     as <4 * float>. We do not allow vectors-of-vectors.
--
data TupR s a where
  TupRunit   ::                         TupR s ()
  TupRsingle :: s a                  -> TupR s a
  TupRpair   :: TupR s a -> TupR s b -> TupR s (a, b)

deriving instance (forall a. Show (s a)) => Show (TupR s t)

formatTypeR :: Format r (TypeR a -> r)
formatTypeR :: forall r a. Format r (TypeR a -> r)
formatTypeR = (TypeR a -> Builder) -> Format r (TypeR a -> r)
forall a r. (a -> Builder) -> Format r (a -> r)
later ((TypeR a -> Builder) -> Format r (TypeR a -> r))
-> (TypeR a -> Builder) -> Format r (TypeR a -> r)
forall a b. (a -> b) -> a -> b
$ \case
  TypeR a
TupRunit     -> Builder
"()"
  TupRsingle ScalarType a
t -> Format Builder (ScalarType a -> Builder) -> ScalarType a -> Builder
forall a. Format Builder a -> a
bformat Format Builder (ScalarType a -> Builder)
forall r a. Format r (ScalarType a -> r)
formatScalarType ScalarType a
t
  TupRpair TupR ScalarType a
a TypeR b
b -> Format Builder (TupR ScalarType a -> TypeR b -> Builder)
-> TupR ScalarType a -> TypeR b -> Builder
forall a. Format Builder a -> a
bformat (Format Builder (TupR ScalarType a -> TypeR b -> Builder)
-> Format Builder (TupR ScalarType a -> TypeR b -> Builder)
forall r a. Format r a -> Format r a
parenthesised (Format
  (TypeR b -> Builder) (TupR ScalarType a -> TypeR b -> Builder)
forall r a. Format r (TypeR a -> r)
formatTypeR Format
  (TypeR b -> Builder) (TupR ScalarType a -> TypeR b -> Builder)
-> Format Builder (TypeR b -> Builder)
-> Format Builder (TupR ScalarType a -> TypeR b -> Builder)
forall r a r'. Format r a -> Format r' r -> Format r' a
% Format (TypeR b -> Builder) (TypeR b -> Builder)
"," Format (TypeR b -> Builder) (TypeR b -> Builder)
-> Format Builder (TypeR b -> Builder)
-> Format Builder (TypeR b -> Builder)
forall r a r'. Format r a -> Format r' r -> Format r' a
% Format Builder (TypeR b -> Builder)
forall r a. Format r (TypeR a -> r)
formatTypeR)) TupR ScalarType a
a TypeR b
b

type TypeR = TupR ScalarType

rnfTupR :: (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR :: forall (s :: * -> *) a. (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR forall b. s b -> ()
_ TupR s a
TupRunit       = ()
rnfTupR forall b. s b -> ()
f (TupRsingle s a
s) = s a -> ()
forall b. s b -> ()
f s a
s
rnfTupR forall b. s b -> ()
f (TupRpair TupR s a
a TupR s b
b) = (forall b. s b -> ()) -> TupR s a -> ()
forall (s :: * -> *) a. (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR s b -> ()
forall b. s b -> ()
f TupR s a
a () -> () -> ()
forall a b. a -> b -> b
`seq` (forall b. s b -> ()) -> TupR s b -> ()
forall (s :: * -> *) a. (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR s b -> ()
forall b. s b -> ()
f TupR s b
b

rnfTypeR :: TypeR t -> ()
rnfTypeR :: forall t. TypeR t -> ()
rnfTypeR = (forall b. ScalarType b -> ()) -> TupR ScalarType t -> ()
forall (s :: * -> *) a. (forall b. s b -> ()) -> TupR s a -> ()
rnfTupR ScalarType b -> ()
forall b. ScalarType b -> ()
rnfScalarType

liftTupR :: (forall b. s b -> CodeQ (s b)) -> TupR s a -> CodeQ (TupR s a)
liftTupR :: forall (s :: * -> *) a.
(forall b. s b -> CodeQ (s b)) -> TupR s a -> CodeQ (TupR s a)
liftTupR forall b. s b -> CodeQ (s b)
_ TupR s a
TupRunit       = [|| TupR s ()
forall (s :: * -> *). TupR s ()
TupRunit ||]
liftTupR forall b. s b -> CodeQ (s b)
f (TupRsingle s a
s) = [|| s a -> TupR s a
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle $$(s a -> CodeQ (s a)
forall b. s b -> CodeQ (s b)
f s a
s) ||]
liftTupR forall b. s b -> CodeQ (s b)
f (TupRpair TupR s a
a TupR s b
b) = [|| TupR s a -> TupR s b -> TupR s (a, b)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
TupRpair $$((forall b. s b -> CodeQ (s b)) -> TupR s a -> CodeQ (TupR s a)
forall (s :: * -> *) a.
(forall b. s b -> CodeQ (s b)) -> TupR s a -> CodeQ (TupR s a)
liftTupR s b -> CodeQ (s b)
forall b. s b -> CodeQ (s b)
f TupR s a
a) $$((forall b. s b -> CodeQ (s b)) -> TupR s b -> CodeQ (TupR s b)
forall (s :: * -> *) a.
(forall b. s b -> CodeQ (s b)) -> TupR s a -> CodeQ (TupR s a)
liftTupR s b -> CodeQ (s b)
forall b. s b -> CodeQ (s b)
f TupR s b
b) ||]

liftTypeR :: TypeR t -> CodeQ (TypeR t)
liftTypeR :: forall t. TypeR t -> CodeQ (TypeR t)
liftTypeR TupR ScalarType t
TupRunit         = [|| TupR s ()
forall (s :: * -> *). TupR s ()
TupRunit ||]
liftTypeR (TupRsingle ScalarType t
t)   = [|| s a -> TupR s a
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle $$(ScalarType t -> CodeQ (ScalarType t)
forall t. ScalarType t -> CodeQ (ScalarType t)
liftScalarType ScalarType t
t) ||]
liftTypeR (TupRpair TupR ScalarType a
ta TupR ScalarType b
tb) = [|| TupR s a -> TupR s b -> TupR s (a, b)
forall (s :: * -> *) a b. TupR s a -> TupR s b -> TupR s (a, b)
TupRpair $$(TupR ScalarType a -> CodeQ (TupR ScalarType a)
forall t. TypeR t -> CodeQ (TypeR t)
liftTypeR TupR ScalarType a
ta) $$(TupR ScalarType b -> CodeQ (TupR ScalarType b)
forall t. TypeR t -> CodeQ (TypeR t)
liftTypeR TupR ScalarType b
tb) ||]

liftTypeQ :: TypeR t -> TypeQ
liftTypeQ :: forall t. TypeR t -> TypeQ
liftTypeQ = TypeR t -> TypeQ
forall t. TypeR t -> TypeQ
tuple
  where
    tuple :: TypeR t -> TypeQ
    tuple :: forall t. TypeR t -> TypeQ
tuple TupR ScalarType t
TupRunit         = [t| () |]
    tuple (TupRpair TupR ScalarType a
t1 TupR ScalarType b
t2) = [t| ($(TupR ScalarType a -> TypeQ
forall t. TypeR t -> TypeQ
tuple TupR ScalarType a
t1), $(TupR ScalarType b -> TypeQ
forall t. TypeR t -> TypeQ
tuple TupR ScalarType b
t2)) |]
    tuple (TupRsingle ScalarType t
t)   = ScalarType t -> TypeQ
forall t. ScalarType t -> TypeQ
scalar ScalarType t
t

    scalar :: ScalarType t -> TypeQ
    scalar :: forall t. ScalarType t -> TypeQ
scalar (SingleScalarType SingleType t
t) = SingleType t -> TypeQ
forall t. SingleType t -> TypeQ
single SingleType t
t
    scalar (VectorScalarType VectorType (Vec n a1)
t) = VectorType (Vec n a1) -> TypeQ
forall (n :: Nat) a. VectorType (Vec n a) -> TypeQ
vector VectorType (Vec n a1)
t

    vector :: VectorType (Vec n a) -> TypeQ
    vector :: forall (n :: Nat) a. VectorType (Vec n a) -> TypeQ
vector (VectorType Int
n SingleType a1
t) = [t| Vec $(Q TyLit -> TypeQ
forall (m :: * -> *). Quote m => m TyLit -> m Type
litT (Integer -> Q TyLit
forall (m :: * -> *). Quote m => Integer -> m TyLit
numTyLit (Int -> Integer
forall a. Integral a => a -> Integer
toInteger Int
n))) $(SingleType a1 -> TypeQ
forall t. SingleType t -> TypeQ
single SingleType a1
t) |]

    single :: SingleType t -> TypeQ
    single :: forall t. SingleType t -> TypeQ
single (NumSingleType NumType t
t) = NumType t -> TypeQ
forall t. NumType t -> TypeQ
num NumType t
t

    num :: NumType t -> TypeQ
    num :: forall t. NumType t -> TypeQ
num (IntegralNumType IntegralType t
t) = IntegralType t -> TypeQ
forall t. IntegralType t -> TypeQ
integral IntegralType t
t
    num (FloatingNumType FloatingType t
t) = FloatingType t -> TypeQ
forall t. FloatingType t -> TypeQ
floating FloatingType t
t

    integral :: IntegralType t -> TypeQ
    integral :: forall t. IntegralType t -> TypeQ
integral IntegralType t
TypeInt    = [t| Int |]
    integral IntegralType t
TypeInt8   = [t| Int8 |]
    integral IntegralType t
TypeInt16  = [t| Int16 |]
    integral IntegralType t
TypeInt32  = [t| Int32 |]
    integral IntegralType t
TypeInt64  = [t| Int64 |]
    integral IntegralType t
TypeWord   = [t| Word |]
    integral IntegralType t
TypeWord8  = [t| Word8 |]
    integral IntegralType t
TypeWord16 = [t| Word16 |]
    integral IntegralType t
TypeWord32 = [t| Word32 |]
    integral IntegralType t
TypeWord64 = [t| Word64 |]

    floating :: FloatingType t -> TypeQ
    floating :: forall t. FloatingType t -> TypeQ
floating FloatingType t
TypeHalf   = [t| Half |]
    floating FloatingType t
TypeFloat  = [t| Float |]
    floating FloatingType t
TypeDouble = [t| Double |]

runQ $
  let
      mkT :: Int -> Q Dec
      mkT n =
        let xs  = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
            ts  = map varT xs
            rhs = foldl (\a b -> [t| ($a, $b) |]) [t| () |] ts
         in
         tySynD (mkName ("Tup" ++ show n)) (map plainTV xs) rhs
  in
  mapM mkT [2..16]