{-# LANGUAGE CPP                   #-}
{-# LANGUAGE ConstraintKinds       #-}
{-# LANGUAGE FlexibleContexts      #-}
{-# LANGUAGE FlexibleInstances     #-}
{-# LANGUAGE MultiParamTypeClasses #-}
{-# LANGUAGE RankNTypes            #-}
{-# LANGUAGE ScopedTypeVariables   #-}
{-# LANGUAGE TemplateHaskell       #-}
{-# LANGUAGE TypeApplications      #-}
{-# LANGUAGE TypeFamilies          #-}
{-# LANGUAGE TypeOperators         #-}
#if __GLASGOW_HASKELL__ >= 806
{-# LANGUAGE UndecidableInstances  #-}
#endif
-- |
-- Module      : Data.Array.Accelerate.Lift
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- Lifting and lowering surface expressions through constructors.
--

module Data.Array.Accelerate.Lift (

  -- * Lifting and unlifting
  Lift(..), Unlift(..),

  lift1, lift2, lift3,
  ilift1, ilift2, ilift3,

) where

import Data.Array.Accelerate.AST.Idx
import Data.Array.Accelerate.Pattern
import Data.Array.Accelerate.Smart
import Data.Array.Accelerate.Sugar.Array
import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Sugar.Shape
import Data.Array.Accelerate.Type

import Language.Haskell.TH.Extra                                    hiding ( Exp )


-- | Lift a unary function into 'Exp'.
--
lift1 :: (Unlift Exp a, Lift Exp b)
      => (a -> b)
      -> Exp (Plain a)
      -> Exp (Plain b)
lift1 :: forall a b.
(Unlift Exp a, Lift Exp b) =>
(a -> b) -> Exp (Plain a) -> Exp (Plain b)
lift1 a -> b
f = b -> Exp (Plain b)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (b -> Exp (Plain b))
-> (Exp (Plain a) -> b) -> Exp (Plain a) -> Exp (Plain b)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. a -> b
f (a -> b) -> (Exp (Plain a) -> a) -> Exp (Plain a) -> b
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Exp (Plain a) -> a
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift

-- | Lift a binary function into 'Exp'.
--
lift2 :: (Unlift Exp a, Unlift Exp b, Lift Exp c)
      => (a -> b -> c)
      -> Exp (Plain a)
      -> Exp (Plain b)
      -> Exp (Plain c)
lift2 :: forall a b c.
(Unlift Exp a, Unlift Exp b, Lift Exp c) =>
(a -> b -> c) -> Exp (Plain a) -> Exp (Plain b) -> Exp (Plain c)
lift2 a -> b -> c
f Exp (Plain a)
x Exp (Plain b)
y = c -> Exp (Plain c)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (c -> Exp (Plain c)) -> c -> Exp (Plain c)
forall a b. (a -> b) -> a -> b
$ a -> b -> c
f (Exp (Plain a) -> a
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain a)
x) (Exp (Plain b) -> b
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain b)
y)

-- | Lift a ternary function into 'Exp'.
--
lift3 :: (Unlift Exp a, Unlift Exp b, Unlift Exp c, Lift Exp d)
      => (a -> b -> c -> d)
      -> Exp (Plain a)
      -> Exp (Plain b)
      -> Exp (Plain c)
      -> Exp (Plain d)
lift3 :: forall a b c d.
(Unlift Exp a, Unlift Exp b, Unlift Exp c, Lift Exp d) =>
(a -> b -> c -> d)
-> Exp (Plain a) -> Exp (Plain b) -> Exp (Plain c) -> Exp (Plain d)
lift3 a -> b -> c -> d
f Exp (Plain a)
x Exp (Plain b)
y Exp (Plain c)
z = d -> Exp (Plain d)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift (d -> Exp (Plain d)) -> d -> Exp (Plain d)
forall a b. (a -> b) -> a -> b
$ a -> b -> c -> d
f (Exp (Plain a) -> a
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain a)
x) (Exp (Plain b) -> b
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain b)
y) (Exp (Plain c) -> c
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain c)
z)

-- | Lift a unary function to a computation over rank-1 indices.
--
ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1
ilift1 :: (Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1
ilift1 Exp Int -> Exp Int
f = ((Z :. Exp Int) -> Z :. Exp Int)
-> Exp (Plain (Z :. Exp Int)) -> Exp (Plain (Z :. Exp Int))
forall a b.
(Unlift Exp a, Lift Exp b) =>
(a -> b) -> Exp (Plain a) -> Exp (Plain b)
lift1 (\(Z
Z:.Exp Int
i) -> Z
Z Z -> Exp Int -> Z :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int -> Exp Int
f Exp Int
i)

-- | Lift a binary function to a computation over rank-1 indices.
--
ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift2 :: (Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift2 Exp Int -> Exp Int -> Exp Int
f = ((Z :. Exp Int) -> (Z :. Exp Int) -> Z :. Exp Int)
-> Exp (Plain (Z :. Exp Int))
-> Exp (Plain (Z :. Exp Int))
-> Exp (Plain (Z :. Exp Int))
forall a b c.
(Unlift Exp a, Unlift Exp b, Lift Exp c) =>
(a -> b -> c) -> Exp (Plain a) -> Exp (Plain b) -> Exp (Plain c)
lift2 (\(Z
Z:.Exp Int
i) (Z
Z:.Exp Int
j) -> Z
Z Z -> Exp Int -> Z :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int -> Exp Int -> Exp Int
f Exp Int
i Exp Int
j)

-- | Lift a ternary function to a computation over rank-1 indices.
--
ilift3 :: (Exp Int -> Exp Int -> Exp Int -> Exp Int) -> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift3 :: (Exp Int -> Exp Int -> Exp Int -> Exp Int)
-> Exp DIM1 -> Exp DIM1 -> Exp DIM1 -> Exp DIM1
ilift3 Exp Int -> Exp Int -> Exp Int -> Exp Int
f = ((Z :. Exp Int)
 -> (Z :. Exp Int) -> (Z :. Exp Int) -> Z :. Exp Int)
-> Exp (Plain (Z :. Exp Int))
-> Exp (Plain (Z :. Exp Int))
-> Exp (Plain (Z :. Exp Int))
-> Exp (Plain (Z :. Exp Int))
forall a b c d.
(Unlift Exp a, Unlift Exp b, Unlift Exp c, Lift Exp d) =>
(a -> b -> c -> d)
-> Exp (Plain a) -> Exp (Plain b) -> Exp (Plain c) -> Exp (Plain d)
lift3 (\(Z
Z:.Exp Int
i) (Z
Z:.Exp Int
j) (Z
Z:.Exp Int
k) -> Z
Z Z -> Exp Int -> Z :. Exp Int
forall tail head. tail -> head -> tail :. head
:. Exp Int -> Exp Int -> Exp Int -> Exp Int
f Exp Int
i Exp Int
j Exp Int
k)


-- | The class of types @e@ which can be lifted into @c@.
--
class Lift c e where
  -- | An associated-type (i.e. a type-level function) that strips all
  --   instances of surface type constructors @c@ from the input type @e@.
  --
  --   For example, the tuple types @(Exp Int, Int)@ and @(Int, Exp
  --   Int)@ have the same \"Plain\" representation.  That is, the
  --   following type equality holds:
  --
  --    @Plain (Exp Int, Int) ~ (Int,Int) ~ Plain (Int, Exp Int)@
  --
  type Plain e

  -- | Lift the given value into a surface type 'c' --- either 'Exp' for scalar
  -- expressions or 'Acc' for array computations. The value may already contain
  -- subexpressions in 'c'.
  --
  lift :: e -> c (Plain e)

-- | A limited subset of types which can be lifted, can also be unlifted.
class Lift c e => Unlift c e where

  -- | Unlift the outermost constructor through the surface type. This is only
  -- possible if the constructor is fully determined by its type - i.e., it is a
  -- singleton.
  --
  unlift :: c (Plain e) -> e


-- Identity instances
-- ------------------

instance Lift Exp (Exp e) where
  type Plain (Exp e) = e
  lift :: Exp e -> Exp (Plain (Exp e))
lift = Exp e -> Exp e
Exp e -> Exp (Plain (Exp e))
forall a. a -> a
id

instance Unlift Exp (Exp e) where
  unlift :: Exp (Plain (Exp e)) -> Exp e
unlift = Exp e -> Exp e
Exp (Plain (Exp e)) -> Exp e
forall a. a -> a
id

instance Lift Acc (Acc a) where
  type Plain (Acc a) = a
  lift :: Acc a -> Acc (Plain (Acc a))
lift = Acc a -> Acc a
Acc a -> Acc (Plain (Acc a))
forall a. a -> a
id

instance Unlift Acc (Acc a) where
  unlift :: Acc (Plain (Acc a)) -> Acc a
unlift = Acc a -> Acc a
Acc (Plain (Acc a)) -> Acc a
forall a. a -> a
id

-- instance Lift Seq (Seq a) where
--   type Plain (Seq a) = a
--   lift = id

-- instance Unlift Seq (Seq a) where
--   unlift = id


-- Instances for indices
-- ---------------------

instance Lift Exp Z where
  type Plain Z = Z
  lift :: Z -> Exp (Plain Z)
lift Z
_ = Exp Z
Exp (Plain Z)
Z_

instance Unlift Exp Z where
  unlift :: Exp (Plain Z) -> Z
unlift Exp (Plain Z)
_ = Z
Z

instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Int) where
  type Plain (ix :. Int) = Plain ix :. Int
  lift :: (ix :. Int) -> Exp (Plain (ix :. Int))
lift (ix
ix :. Int
i) = ix -> Exp (Plain ix)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ix
ix Exp (Plain ix) -> Exp Int -> Exp (Plain ix :. Int)
forall a b. (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
::. Int -> Exp (Plain Int)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift Int
i

instance (Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. All) where
  type Plain (ix :. All) = Plain ix :. All
  lift :: (ix :. All) -> Exp (Plain (ix :. All))
lift (ix
ix :. All
i) = ix -> Exp (Plain ix)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ix
ix Exp (Plain ix) -> Exp All -> Exp (Plain ix :. All)
forall a b. (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
::. All -> Exp All
forall e. (HasCallStack, Elt e) => e -> Exp e
constant All
i

instance (Elt e, Elt (Plain ix), Lift Exp ix) => Lift Exp (ix :. Exp e) where
  type Plain (ix :. Exp e) = Plain ix :. e
  lift :: (ix :. Exp e) -> Exp (Plain (ix :. Exp e))
lift (ix
ix :. Exp e
i) = ix -> Exp (Plain ix)
forall (c :: * -> *) e. Lift c e => e -> c (Plain e)
lift ix
ix Exp (Plain ix) -> Exp e -> Exp (Plain ix :. e)
forall a b. (Elt a, Elt b) => Exp a -> Exp b -> Exp (a :. b)
::. Exp e
i

instance {-# OVERLAPPABLE #-} (Elt e, Elt (Plain ix), Unlift Exp ix) => Unlift Exp (ix :. Exp e) where
  unlift :: Exp (Plain (ix :. Exp e)) -> ix :. Exp e
unlift (Exp (Plain ix)
ix ::. Exp e
i) = Exp (Plain ix) -> ix
forall (c :: * -> *) e. Unlift c e => c (Plain e) -> e
unlift Exp (Plain ix)
ix ix -> Exp e -> ix :. Exp e
forall tail head. tail -> head -> tail :. head
:. Exp e
i

instance {-# OVERLAPPABLE #-} (Elt e, Elt ix) => Unlift Exp (Exp ix :. Exp e) where
  unlift :: Exp (Plain (Exp ix :. Exp e)) -> Exp ix :. Exp e
unlift (Exp ix
ix ::. Exp e
i) = Exp ix
ix Exp ix -> Exp e -> Exp ix :. Exp e
forall tail head. tail -> head -> tail :. head
:. Exp e
i

instance (Shape sh, Elt (Any sh)) => Lift Exp (Any sh) where
  type Plain (Any sh) = Any sh
  lift :: Any sh -> Exp (Plain (Any sh))
lift Any sh
Any = Any sh -> Exp (Any sh)
forall e. (HasCallStack, Elt e) => e -> Exp e
constant Any sh
forall sh. Any sh
Any

-- Instances for numeric types
-- ---------------------------

{-# INLINE expConst #-}
expConst :: forall e. Elt e => IsScalar (EltR e) => e -> Exp e
expConst :: forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst = SmartExp (EltR e) -> Exp e
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (EltR e) -> Exp e)
-> (e -> SmartExp (EltR e)) -> e -> Exp e
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PreSmartExp SmartAcc SmartExp (EltR e) -> SmartExp (EltR e)
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (EltR e) -> SmartExp (EltR e))
-> (e -> PreSmartExp SmartAcc SmartExp (EltR e))
-> e
-> SmartExp (EltR e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ScalarType (EltR e)
-> EltR e -> PreSmartExp SmartAcc SmartExp (EltR e)
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const (forall a. IsScalar a => ScalarType a
scalarType @(EltR e)) (EltR e -> PreSmartExp SmartAcc SmartExp (EltR e))
-> (e -> EltR e) -> e -> PreSmartExp SmartAcc SmartExp (EltR e)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. e -> EltR e
forall a. Elt a => a -> EltR a
fromElt

instance Lift Exp Int where
  type Plain Int = Int
  lift :: Int -> Exp (Plain Int)
lift = Int -> Exp Int
Int -> Exp (Plain Int)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Int8 where
  type Plain Int8 = Int8
  lift :: Int8 -> Exp (Plain Int8)
lift = Int8 -> Exp Int8
Int8 -> Exp (Plain Int8)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Int16 where
  type Plain Int16 = Int16
  lift :: Int16 -> Exp (Plain Int16)
lift = Int16 -> Exp Int16
Int16 -> Exp (Plain Int16)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Int32 where
  type Plain Int32 = Int32
  lift :: Int32 -> Exp (Plain Int32)
lift = Int32 -> Exp Int32
Int32 -> Exp (Plain Int32)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Int64 where
  type Plain Int64 = Int64
  lift :: Int64 -> Exp (Plain Int64)
lift = Int64 -> Exp Int64
Int64 -> Exp (Plain Int64)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Word where
  type Plain Word = Word
  lift :: Word -> Exp (Plain Word)
lift = Word -> Exp Word
Word -> Exp (Plain Word)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Word8 where
  type Plain Word8 = Word8
  lift :: Word8 -> Exp (Plain Word8)
lift = Word8 -> Exp Word8
Word8 -> Exp (Plain Word8)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Word16 where
  type Plain Word16 = Word16
  lift :: Word16 -> Exp (Plain Word16)
lift = Word16 -> Exp Word16
Word16 -> Exp (Plain Word16)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Word32 where
  type Plain Word32 = Word32
  lift :: Word32 -> Exp (Plain Word32)
lift = Word32 -> Exp Word32
Word32 -> Exp (Plain Word32)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Word64 where
  type Plain Word64 = Word64
  lift :: Word64 -> Exp (Plain Word64)
lift = Word64 -> Exp Word64
Word64 -> Exp (Plain Word64)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CShort where
  type Plain CShort = CShort
  lift :: CShort -> Exp (Plain CShort)
lift = CShort -> Exp CShort
CShort -> Exp (Plain CShort)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CUShort where
  type Plain CUShort = CUShort
  lift :: CUShort -> Exp (Plain CUShort)
lift = CUShort -> Exp CUShort
CUShort -> Exp (Plain CUShort)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CInt where
  type Plain CInt = CInt
  lift :: CInt -> Exp (Plain CInt)
lift = CInt -> Exp CInt
CInt -> Exp (Plain CInt)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CUInt where
  type Plain CUInt = CUInt
  lift :: CUInt -> Exp (Plain CUInt)
lift = CUInt -> Exp CUInt
CUInt -> Exp (Plain CUInt)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CLong where
  type Plain CLong = CLong
  lift :: CLong -> Exp (Plain CLong)
lift = CLong -> Exp CLong
CLong -> Exp (Plain CLong)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CULong where
  type Plain CULong = CULong
  lift :: CULong -> Exp (Plain CULong)
lift = CULong -> Exp CULong
CULong -> Exp (Plain CULong)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CLLong where
  type Plain CLLong = CLLong
  lift :: CLLong -> Exp (Plain CLLong)
lift = CLLong -> Exp CLLong
CLLong -> Exp (Plain CLLong)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CULLong where
  type Plain CULLong = CULLong
  lift :: CULLong -> Exp (Plain CULLong)
lift = CULLong -> Exp CULLong
CULLong -> Exp (Plain CULLong)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Half where
  type Plain Half = Half
  lift :: Half -> Exp (Plain Half)
lift = Half -> Exp Half
Half -> Exp (Plain Half)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Float where
  type Plain Float = Float
  lift :: Float -> Exp (Plain Float)
lift = Float -> Exp Float
Float -> Exp (Plain Float)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Double where
  type Plain Double = Double
  lift :: Double -> Exp (Plain Double)
lift = Double -> Exp Double
Double -> Exp (Plain Double)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CFloat where
  type Plain CFloat = CFloat
  lift :: CFloat -> Exp (Plain CFloat)
lift = CFloat -> Exp CFloat
CFloat -> Exp (Plain CFloat)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CDouble where
  type Plain CDouble = CDouble
  lift :: CDouble -> Exp (Plain CDouble)
lift = CDouble -> Exp CDouble
CDouble -> Exp (Plain CDouble)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp Bool where
  type Plain Bool = Bool
  lift :: Bool -> Exp (Plain Bool)
lift Bool
True  = SmartExp (Word8, ()) -> Exp (Plain Bool)
SmartExp (EltR Bool) -> Exp Bool
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (Word8, ()) -> Exp (Plain Bool))
-> (PreSmartExp SmartAcc SmartExp (Word8, ())
    -> SmartExp (Word8, ()))
-> PreSmartExp SmartAcc SmartExp (Word8, ())
-> Exp (Plain Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PreSmartExp SmartAcc SmartExp (Word8, ()) -> SmartExp (Word8, ())
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (Word8, ()) -> Exp (Plain Bool))
-> PreSmartExp SmartAcc SmartExp (Word8, ()) -> Exp (Plain Bool)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (ScalarType Word8 -> Word8 -> PreSmartExp SmartAcc SmartExp Word8
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType Word8
forall a. IsScalar a => ScalarType a
scalarType Word8
1) SmartExp Word8
-> SmartExp () -> PreSmartExp SmartAcc SmartExp (Word8, ())
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
`Pair` PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil
  lift Bool
False = SmartExp (Word8, ()) -> Exp (Plain Bool)
SmartExp (EltR Bool) -> Exp Bool
forall t. SmartExp (EltR t) -> Exp t
Exp (SmartExp (Word8, ()) -> Exp (Plain Bool))
-> (PreSmartExp SmartAcc SmartExp (Word8, ())
    -> SmartExp (Word8, ()))
-> PreSmartExp SmartAcc SmartExp (Word8, ())
-> Exp (Plain Bool)
forall b c a. (b -> c) -> (a -> b) -> a -> c
. PreSmartExp SmartAcc SmartExp (Word8, ()) -> SmartExp (Word8, ())
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (PreSmartExp SmartAcc SmartExp (Word8, ()) -> Exp (Plain Bool))
-> PreSmartExp SmartAcc SmartExp (Word8, ()) -> Exp (Plain Bool)
forall a b. (a -> b) -> a -> b
$ PreSmartExp SmartAcc SmartExp Word8 -> SmartExp Word8
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp (ScalarType Word8 -> Word8 -> PreSmartExp SmartAcc SmartExp Word8
forall t (acc :: * -> *) (exp :: * -> *).
ScalarType t -> t -> PreSmartExp acc exp t
Const ScalarType Word8
forall a. IsScalar a => ScalarType a
scalarType Word8
0) SmartExp Word8
-> SmartExp () -> PreSmartExp SmartAcc SmartExp (Word8, ())
forall (exp :: * -> *) t1 t2 (acc :: * -> *).
exp t1 -> exp t2 -> PreSmartExp acc exp (t1, t2)
`Pair` PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil

instance Lift Exp Char where
  type Plain Char = Char
  lift :: Char -> Exp (Plain Char)
lift = Char -> Exp Char
Char -> Exp (Plain Char)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CChar where
  type Plain CChar = CChar
  lift :: CChar -> Exp (Plain CChar)
lift = CChar -> Exp CChar
CChar -> Exp (Plain CChar)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CSChar where
  type Plain CSChar = CSChar
  lift :: CSChar -> Exp (Plain CSChar)
lift = CSChar -> Exp CSChar
CSChar -> Exp (Plain CSChar)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

instance Lift Exp CUChar where
  type Plain CUChar = CUChar
  lift :: CUChar -> Exp (Plain CUChar)
lift = CUChar -> Exp CUChar
CUChar -> Exp (Plain CUChar)
forall e. (Elt e, IsScalar (EltR e)) => e -> Exp e
expConst

-- Instances for tuples
-- --------------------

instance Lift Exp () where
  type Plain () = ()
  lift :: () -> Exp (Plain ())
lift ()
_ = SmartExp (EltR ()) -> Exp ()
forall t. SmartExp (EltR t) -> Exp t
Exp (PreSmartExp SmartAcc SmartExp () -> SmartExp ()
forall t. PreSmartExp SmartAcc SmartExp t -> SmartExp t
SmartExp PreSmartExp SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartExp acc exp ()
Nil)

instance Unlift Exp () where
  unlift :: Exp (Plain ()) -> ()
unlift Exp (Plain ())
_ = ()

instance Lift Acc () where
  type Plain () = ()
  lift :: () -> Acc (Plain ())
lift ()
_ = SmartAcc (ArraysR ()) -> Acc ()
forall a. SmartAcc (ArraysR a) -> Acc a
Acc (PreSmartAcc SmartAcc SmartExp () -> SmartAcc ()
forall a. PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
SmartAcc PreSmartAcc SmartAcc SmartExp ()
forall (acc :: * -> *) (exp :: * -> *). PreSmartAcc acc exp ()
Anil)

instance Unlift Acc () where
  unlift :: Acc (Plain ()) -> ()
unlift Acc (Plain ())
_ = ()

instance (Shape sh, Elt e) => Lift Acc (Array sh e) where
  type Plain (Array sh e) = Array sh e
  lift :: Array sh e -> Acc (Plain (Array sh e))
lift (Array Array (EltR sh) (EltR e)
arr) = SmartAcc (ArraysR (Plain (Array sh e))) -> Acc (Plain (Array sh e))
forall a. SmartAcc (ArraysR a) -> Acc a
Acc (SmartAcc (ArraysR (Plain (Array sh e)))
 -> Acc (Plain (Array sh e)))
-> SmartAcc (ArraysR (Plain (Array sh e)))
-> Acc (Plain (Array sh e))
forall a b. (a -> b) -> a -> b
$ PreSmartAcc SmartAcc SmartExp (ArraysR (Plain (Array sh e)))
-> SmartAcc (ArraysR (Plain (Array sh e)))
forall a. PreSmartAcc SmartAcc SmartExp a -> SmartAcc a
SmartAcc (PreSmartAcc SmartAcc SmartExp (ArraysR (Plain (Array sh e)))
 -> SmartAcc (ArraysR (Plain (Array sh e))))
-> PreSmartAcc SmartAcc SmartExp (ArraysR (Plain (Array sh e)))
-> SmartAcc (ArraysR (Plain (Array sh e)))
forall a b. (a -> b) -> a -> b
$ ArrayR (Array (EltR sh) (EltR e))
-> Array (EltR sh) (EltR e)
-> PreSmartAcc SmartAcc SmartExp (Array (EltR sh) (EltR e))
forall sh e (acc :: * -> *) (exp :: * -> *).
ArrayR (Array sh e)
-> Array sh e -> PreSmartAcc acc exp (Array sh e)
Use (forall sh e. (Shape sh, Elt e) => ArrayR (Array (EltR sh) (EltR e))
arrayR @sh @e) Array (EltR sh) (EltR e)
arr

-- Lift and Unlift instances for tuples
--
runQ $ do
    let
        mkInstances :: Name -> TypeQ -> ExpQ -> ExpQ -> ExpQ -> ExpQ -> Int -> Q [Dec]
        mkInstances con cst smart prj nil pair n = do
          let
              xs      = [ mkName ('x' : show i) | i <- [0 .. n-1] ]
              ts      = map varT xs
              res1    = tupT ts
              res2    = tupT (map (conT con `appT`) ts)
              plain   = tupT (map (\t -> [t| Plain $t |]) ts)
              ctx1    = tupT (map (\t -> [t| Lift $(conT con) $t |]) ts)
              ctx2    = tupT (map (\t -> [t| $cst (Plain $t) |]) ts)
              ctx3    = tupT (map (appT cst) ts)
              --
              get x 0 = [| $(conE con) ($smart ($prj PairIdxRight $x)) |]
              get x i = get [| $smart ($prj PairIdxLeft $x) |] (i-1)
          --
          _x <- newName "_x"
          [d| instance ($ctx1, $ctx2) => Lift $(conT con) $res1 where
                type Plain $res1 = $plain
                lift $(tupP (map varP xs)) =
                  $(conE con)
                  $(foldl (\vs v -> do _v <- newName "_v"
                                       [| let $(conP con [varP _v]) = lift $(varE v)
                                           in $smart ($pair $vs $(varE _v)) |]) [| $smart $nil |] xs)

              instance $ctx3 => Unlift $(conT con) $res2 where
                unlift $(conP con [varP _x]) =
                  $(tupE (map (get (varE _x)) [(n-1), (n-2) .. 0]))
            |]

        mkAccInstances = mkInstances (mkName "Acc") [t| Arrays |] [| SmartAcc |] [| Aprj |] [| Anil |] [| Apair |]
        mkExpInstances = mkInstances (mkName "Exp") [t| Elt    |] [| SmartExp |] [| Prj  |] [| Nil  |] [| Pair  |]
    --
    as <- mapM mkAccInstances [2..16]
    es <- mapM mkExpInstances [2..16]
    return $ concat (as ++ es)