{-# LANGUAGE GADTs             #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell   #-}
{-# LANGUAGE TupleSections     #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Shape
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Shape
  where

import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Type

import Language.Haskell.TH.Extra
import Prelude                                                      hiding ( zip )

import GHC.Base                                                     ( quotInt, remInt )


-- | Shape and index representations as nested pairs
--
data ShapeR sh where
  ShapeRz    :: ShapeR ()
  ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int)

-- | Nicely format a shape as a string
--
showShape :: ShapeR sh -> sh -> String
showShape :: forall sh. ShapeR sh -> sh -> String
showShape ShapeR sh
shr = (Int -> String -> String) -> String -> [Int] -> String
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Int
sh String
str -> String
str String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" :. " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
sh) String
"Z" ([Int] -> String) -> (sh -> [Int]) -> sh -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr

-- Synonyms for common shape types
--
type DIM0 = ()
type DIM1 = ((), Int)
type DIM2 = (((), Int), Int)
type DIM3 = ((((), Int), Int), Int)

dim0 :: ShapeR DIM0
dim0 :: ShapeR DIM0
dim0 = ShapeR DIM0
ShapeRz

dim1 :: ShapeR DIM1
dim1 :: ShapeR DIM1
dim1 = ShapeR DIM0 -> ShapeR DIM1
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM0
dim0

dim2 :: ShapeR DIM2
dim2 :: ShapeR DIM2
dim2 = ShapeR DIM1 -> ShapeR DIM2
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM1
dim1

dim3 :: ShapeR DIM3
dim3 :: ShapeR DIM3
dim3 = ShapeR DIM2 -> ShapeR DIM3
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM2
dim2

-- | Number of dimensions of a /shape/ or /index/ (>= 0)
--
rank :: ShapeR sh -> Int
rank :: forall sh. ShapeR sh -> Int
rank ShapeR sh
ShapeRz          = Int
0
rank (ShapeRsnoc ShapeR sh
shr) = ShapeR sh -> Int
forall sh. ShapeR sh -> Int
rank ShapeR sh
shr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1

-- | Total number of elements in an array of the given shape
--
size :: ShapeR sh -> sh -> Int
size :: forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
ShapeRz () = Int
1
size (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz)
  | Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0   = Int
0
  | Bool
otherwise = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr sh
sh Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz

-- | The empty shape
--
empty :: ShapeR sh -> sh
empty :: forall sh. ShapeR sh -> sh
empty ShapeR sh
ShapeRz          = ()
empty (ShapeRsnoc ShapeR sh
shr) = (ShapeR sh -> sh
forall sh. ShapeR sh -> sh
empty ShapeR sh
shr, Int
0)

-- | Yield the intersection of two shapes
--
intersect :: ShapeR sh -> sh -> sh -> sh
intersect :: forall sh. ShapeR sh -> sh -> sh -> sh
intersect = (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
forall a. Ord a => a -> a -> a
min

-- | Yield the union of two shapes
--
union :: ShapeR sh -> sh -> sh -> sh
union :: forall sh. ShapeR sh -> sh -> sh -> sh
union = (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
forall a. Ord a => a -> a -> a
max

zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip :: forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
_ ShapeR sh
ShapeRz          ()      ()      = ()
zip Int -> Int -> Int
f (ShapeRsnoc ShapeR sh
shr) (sh
as, Int
a) (sh
bs, Int
b) = ((Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
f ShapeR sh
shr sh
as sh
bs, Int -> Int -> Int
f Int
a Int
b)

eq :: ShapeR sh -> sh -> sh -> Bool
eq :: forall sh. ShapeR sh -> sh -> sh -> Bool
eq ShapeR sh
ShapeRz          ()      ()        = Bool
True
eq (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
i) (sh
sh', Int
i') = Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i' Bool -> Bool -> Bool
&& ShapeR sh -> sh -> sh -> Bool
forall sh. ShapeR sh -> sh -> sh -> Bool
eq ShapeR sh
shr sh
sh sh
sh'


-- | Map a multi-dimensional index into one in a linear, row-major
-- representation of the array (first argument is the /shape/, second
-- argument is the /index/).
--
toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex :: forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeR sh
ShapeRz () () = Int
0
toIndex (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) (sh
ix, Int
i)
  = Int -> Int -> Int -> Int
forall a. HasCallStack => Int -> Int -> a -> a
indexCheck Int
i Int
sz
  (Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> sh -> sh -> Int
forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeR sh
shr sh
sh sh
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i

-- | Inverse of 'toIndex'
--
fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex :: forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeR sh
ShapeRz () Int
_ = ()
fromIndex (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) Int
i
  = (ShapeR sh -> sh -> Int -> sh
forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeR sh
shr sh
sh (Int
i Int -> Int -> Int
`quotInt` Int
sz), Int
r)
  -- If we assume that the index is in range, there is no point in computing
  -- the remainder for the highest dimension since i < sz must hold.
  --
  where
    r :: Int
r = case ShapeR sh
shr of -- Check if rank of shr is 0
      ShapeR sh
ShapeRz -> Int -> Int -> Int -> Int
forall a. HasCallStack => Int -> Int -> a -> a
indexCheck Int
i Int
sz Int
i
      ShapeR sh
_       -> Int
i Int -> Int -> Int
`remInt` Int
sz

-- | Iterate through the entire shape, applying the function in the second
-- argument; third argument combines results and fourth is an initial value
-- that is combined with the results; the index space is traversed in
-- row-major order
--
iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter :: forall sh a.
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter ShapeR sh
ShapeRz          ()       sh -> a
f a -> a -> a
_ a
_ = sh -> a
f ()
iter (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) sh -> a
f a -> a -> a
c a
z = ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
forall sh a.
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter ShapeR sh
shr sh
sh (\sh
ix -> (sh, Int) -> a -> a
iter' (sh
ix,Int
0) a
z) a -> a -> a
c a
z
  where
    iter' :: (sh, Int) -> a -> a
iter' (sh
ix,Int
i) a
r | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz   = a
r
                   | Bool
otherwise = (sh, Int) -> a -> a
iter' (sh
ix,Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
r a -> a -> a
`c` sh -> a
f (sh
ix,Int
i))

-- | Variant of 'iter' without an initial value
--
iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 :: forall sh a.
HasCallStack =>
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 ShapeR sh
ShapeRz          ()       sh -> a
f a -> a -> a
_ = sh -> a
f ()
iter1 (ShapeRsnoc ShapeR sh
_  ) (sh
_,  Int
0)  sh -> a
_ a -> a -> a
_ = Format a a -> a
forall r a. HasCallStack => Format r a -> a
boundsError Format a a
"empty iteration space"
iter1 (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) sh -> a
f a -> a -> a
c = ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
forall sh a.
HasCallStack =>
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 ShapeR sh
shr sh
sh (\sh
ix -> (sh, Int) -> a
iter1' (sh
ix,Int
0)) a -> a -> a
c
  where
    iter1' :: (sh, Int) -> a
iter1' (sh
ix,Int
i) | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
szInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = sh -> a
f (sh
ix,Int
i)
                  | Bool
otherwise = sh -> a
f (sh
ix,Int
i) a -> a -> a
`c` (sh, Int) -> a
iter1' (sh
ix,Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)

-- Operations to facilitate conversion with IArray

-- | Convert a minpoint-maxpoint index into a shape
--
rangeToShape :: ShapeR sh -> (sh, sh) -> sh
rangeToShape :: forall sh. ShapeR sh -> (sh, sh) -> sh
rangeToShape ShapeR sh
ShapeRz          ((), ())                 = ()
rangeToShape (ShapeRsnoc ShapeR sh
shr) ((sh
sh1, Int
sz1), (sh
sh2, Int
sz2)) = (ShapeR sh -> (sh, sh) -> sh
forall sh. ShapeR sh -> (sh, sh) -> sh
rangeToShape ShapeR sh
shr (sh
sh1, sh
sh2), Int
sz2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sz1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)

-- | Converse of 'rangeToShape'
--
shapeToRange :: ShapeR sh -> sh -> (sh, sh)
shapeToRange :: forall sh. ShapeR sh -> sh -> (sh, sh)
shapeToRange ShapeR sh
ShapeRz          ()       = ((), ())
shapeToRange (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) = let (sh
low, sh
high) = ShapeR sh -> sh -> (sh, sh)
forall sh. ShapeR sh -> sh -> (sh, sh)
shapeToRange ShapeR sh
shr sh
sh in ((sh
low, Int
0), (sh
high, Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))

-- | Convert a shape or index into its list of dimensions
--
shapeToList :: ShapeR sh -> sh -> [Int]
shapeToList :: forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
ShapeRz          ()      = []
shapeToList (ShapeRsnoc ShapeR sh
shr) (sh
sh,Int
sz) = Int
sz Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr sh
sh

-- | Convert a list of dimensions into a shape
--
listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh
listToShape :: forall sh. HasCallStack => ShapeR sh -> [Int] -> sh
listToShape ShapeR sh
shr [Int]
ds =
  case ShapeR sh -> [Int] -> Maybe sh
forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
shr [Int]
ds of
    Just sh
sh -> sh
sh
    Maybe sh
Nothing -> String -> sh
forall a. HasCallStack => String -> a
error String
"listToShape: unable to convert list to a shape at the specified type"

-- | Attempt to convert a list of dimensions into a shape
--
listToShape' :: ShapeR sh -> [Int] -> Maybe sh
listToShape' :: forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
ShapeRz          []     = sh -> Maybe sh
forall a. a -> Maybe a
Just ()
listToShape' (ShapeRsnoc ShapeR sh
shr) (Int
x:[Int]
xs) = (, Int
x) (sh -> sh) -> Maybe sh -> Maybe sh
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeR sh -> [Int] -> Maybe sh
forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
shr [Int]
xs
listToShape' ShapeR sh
_                [Int]
_      = Maybe sh
forall a. Maybe a
Nothing

shapeType :: ShapeR sh -> TypeR sh
shapeType :: forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
ShapeRz          = TupR ScalarType sh
TupR ScalarType DIM0
forall (s :: * -> *). TupR s DIM0
TupRunit
shapeType (ShapeRsnoc ShapeR sh
shr) =
  ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr
  TypeR sh -> TupR ScalarType Int -> TupR ScalarType (sh, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair`
  ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType Int -> ScalarType Int
forall a. SingleType a -> ScalarType a
SingleScalarType (NumType Int -> SingleType Int
forall a. NumType a -> SingleType a
NumSingleType (IntegralType Int -> NumType Int
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType Int
TypeInt)))

rnfShape :: ShapeR sh -> sh -> ()
rnfShape :: forall sh. ShapeR sh -> sh -> DIM0
rnfShape ShapeR sh
ShapeRz          ()      = ()
rnfShape (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
s) = Int
s Int -> DIM0 -> DIM0
forall a b. a -> b -> b
`seq` ShapeR sh -> sh -> DIM0
forall sh. ShapeR sh -> sh -> DIM0
rnfShape ShapeR sh
shr sh
sh

rnfShapeR :: ShapeR sh -> ()
rnfShapeR :: forall sh. ShapeR sh -> DIM0
rnfShapeR ShapeR sh
ShapeRz          = ()
rnfShapeR (ShapeRsnoc ShapeR sh
shr) = ShapeR sh -> DIM0
forall sh. ShapeR sh -> DIM0
rnfShapeR ShapeR sh
shr

liftShapeR :: ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR :: forall sh. ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR ShapeR sh
ShapeRz         = [|| ShapeR DIM0
ShapeRz ||]
liftShapeR (ShapeRsnoc ShapeR sh
sh) = [|| ShapeR sh -> ShapeR (sh, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc $$(ShapeR sh -> CodeQ (ShapeR sh)
forall sh. ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR ShapeR sh
sh) ||]