{-# LANGUAGE GADTs #-}
{-# LANGUAGE OverloadedStrings #-}
{-# LANGUAGE TemplateHaskell #-}
{-# LANGUAGE TupleSections #-}
{-# OPTIONS_HADDOCK hide #-}
module Data.Array.Accelerate.Representation.Shape
where
import Data.Array.Accelerate.Error
import Data.Array.Accelerate.Type
import Data.Array.Accelerate.Representation.Type
import Language.Haskell.TH.Extra
import Prelude hiding ( zip )
import GHC.Base ( quotInt, remInt )
data ShapeR sh where
ShapeRz :: ShapeR ()
ShapeRsnoc :: ShapeR sh -> ShapeR (sh, Int)
showShape :: ShapeR sh -> sh -> String
showShape :: forall sh. ShapeR sh -> sh -> String
showShape ShapeR sh
shr = (Int -> String -> String) -> String -> [Int] -> String
forall a b. (a -> b -> b) -> b -> [a] -> b
forall (t :: * -> *) a b.
Foldable t =>
(a -> b -> b) -> b -> t a -> b
foldr (\Int
sh String
str -> String
str String -> String -> String
forall a. [a] -> [a] -> [a]
++ String
" :. " String -> String -> String
forall a. [a] -> [a] -> [a]
++ Int -> String
forall a. Show a => a -> String
show Int
sh) String
"Z" ([Int] -> String) -> (sh -> [Int]) -> sh -> String
forall b c a. (b -> c) -> (a -> b) -> a -> c
. ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr
type DIM0 = ()
type DIM1 = ((), Int)
type DIM2 = (((), Int), Int)
type DIM3 = ((((), Int), Int), Int)
dim0 :: ShapeR DIM0
dim0 :: ShapeR DIM0
dim0 = ShapeR DIM0
ShapeRz
dim1 :: ShapeR DIM1
dim1 :: ShapeR DIM1
dim1 = ShapeR DIM0 -> ShapeR DIM1
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM0
dim0
dim2 :: ShapeR DIM2
dim2 :: ShapeR DIM2
dim2 = ShapeR DIM1 -> ShapeR DIM2
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM1
dim1
dim3 :: ShapeR DIM3
dim3 :: ShapeR DIM3
dim3 = ShapeR DIM2 -> ShapeR DIM3
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc ShapeR DIM2
dim2
rank :: ShapeR sh -> Int
rank :: forall sh. ShapeR sh -> Int
rank ShapeR sh
ShapeRz = Int
0
rank (ShapeRsnoc ShapeR sh
shr) = ShapeR sh -> Int
forall sh. ShapeR sh -> Int
rank ShapeR sh
shr Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1
size :: ShapeR sh -> sh -> Int
size :: forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
ShapeRz () = Int
1
size (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz)
| Int
sz Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
<= Int
0 = Int
0
| Bool
otherwise = ShapeR sh -> sh -> Int
forall sh. ShapeR sh -> sh -> Int
size ShapeR sh
shr sh
sh Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz
empty :: ShapeR sh -> sh
empty :: forall sh. ShapeR sh -> sh
empty ShapeR sh
ShapeRz = ()
empty (ShapeRsnoc ShapeR sh
shr) = (ShapeR sh -> sh
forall sh. ShapeR sh -> sh
empty ShapeR sh
shr, Int
0)
intersect :: ShapeR sh -> sh -> sh -> sh
intersect :: forall sh. ShapeR sh -> sh -> sh -> sh
intersect = (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
forall a. Ord a => a -> a -> a
min
union :: ShapeR sh -> sh -> sh -> sh
union :: forall sh. ShapeR sh -> sh -> sh -> sh
union = (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
forall a. Ord a => a -> a -> a
max
zip :: (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip :: forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
_ ShapeR sh
ShapeRz () () = ()
zip Int -> Int -> Int
f (ShapeRsnoc ShapeR sh
shr) (sh
as, Int
a) (sh
bs, Int
b) = ((Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
forall sh. (Int -> Int -> Int) -> ShapeR sh -> sh -> sh -> sh
zip Int -> Int -> Int
f ShapeR sh
shr sh
as sh
bs, Int -> Int -> Int
f Int
a Int
b)
eq :: ShapeR sh -> sh -> sh -> Bool
eq :: forall sh. ShapeR sh -> sh -> sh -> Bool
eq ShapeR sh
ShapeRz () () = Bool
True
eq (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
i) (sh
sh', Int
i') = Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
i' Bool -> Bool -> Bool
&& ShapeR sh -> sh -> sh -> Bool
forall sh. ShapeR sh -> sh -> sh -> Bool
eq ShapeR sh
shr sh
sh sh
sh'
toIndex :: HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex :: forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeR sh
ShapeRz () () = Int
0
toIndex (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) (sh
ix, Int
i)
= Int -> Int -> Int -> Int
forall a. HasCallStack => Int -> Int -> a -> a
indexCheck Int
i Int
sz
(Int -> Int) -> Int -> Int
forall a b. (a -> b) -> a -> b
$ ShapeR sh -> sh -> sh -> Int
forall sh. HasCallStack => ShapeR sh -> sh -> sh -> Int
toIndex ShapeR sh
shr sh
sh sh
ix Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
i
fromIndex :: HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex :: forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeR sh
ShapeRz () Int
_ = ()
fromIndex (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) Int
i
= (ShapeR sh -> sh -> Int -> sh
forall sh. HasCallStack => ShapeR sh -> sh -> Int -> sh
fromIndex ShapeR sh
shr sh
sh (Int
i Int -> Int -> Int
`quotInt` Int
sz), Int
r)
where
r :: Int
r = case ShapeR sh
shr of
ShapeR sh
ShapeRz -> Int -> Int -> Int -> Int
forall a. HasCallStack => Int -> Int -> a -> a
indexCheck Int
i Int
sz Int
i
ShapeR sh
_ -> Int
i Int -> Int -> Int
`remInt` Int
sz
iter :: ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter :: forall sh a.
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter ShapeR sh
ShapeRz () sh -> a
f a -> a -> a
_ a
_ = sh -> a
f ()
iter (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) sh -> a
f a -> a -> a
c a
z = ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
forall sh a.
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a -> a
iter ShapeR sh
shr sh
sh (\sh
ix -> (sh, Int) -> a -> a
iter' (sh
ix,Int
0) a
z) a -> a -> a
c a
z
where
iter' :: (sh, Int) -> a -> a
iter' (sh
ix,Int
i) a
r | Int
i Int -> Int -> Bool
forall a. Ord a => a -> a -> Bool
>= Int
sz = a
r
| Bool
otherwise = (sh, Int) -> a -> a
iter' (sh
ix,Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1) (a
r a -> a -> a
`c` sh -> a
f (sh
ix,Int
i))
iter1 :: HasCallStack => ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 :: forall sh a.
HasCallStack =>
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 ShapeR sh
ShapeRz () sh -> a
f a -> a -> a
_ = sh -> a
f ()
iter1 (ShapeRsnoc ShapeR sh
_ ) (sh
_, Int
0) sh -> a
_ a -> a -> a
_ = Format a a -> a
forall r a. HasCallStack => Format r a -> a
boundsError Format a a
"empty iteration space"
iter1 (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) sh -> a
f a -> a -> a
c = ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
forall sh a.
HasCallStack =>
ShapeR sh -> sh -> (sh -> a) -> (a -> a -> a) -> a
iter1 ShapeR sh
shr sh
sh (\sh
ix -> (sh, Int) -> a
iter1' (sh
ix,Int
0)) a -> a -> a
c
where
iter1' :: (sh, Int) -> a
iter1' (sh
ix,Int
i) | Int
i Int -> Int -> Bool
forall a. Eq a => a -> a -> Bool
== Int
szInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1 = sh -> a
f (sh
ix,Int
i)
| Bool
otherwise = sh -> a
f (sh
ix,Int
i) a -> a -> a
`c` (sh, Int) -> a
iter1' (sh
ix,Int
iInt -> Int -> Int
forall a. Num a => a -> a -> a
+Int
1)
rangeToShape :: ShapeR sh -> (sh, sh) -> sh
rangeToShape :: forall sh. ShapeR sh -> (sh, sh) -> sh
rangeToShape ShapeR sh
ShapeRz ((), ()) = ()
rangeToShape (ShapeRsnoc ShapeR sh
shr) ((sh
sh1, Int
sz1), (sh
sh2, Int
sz2)) = (ShapeR sh -> (sh, sh) -> sh
forall sh. ShapeR sh -> (sh, sh) -> sh
rangeToShape ShapeR sh
shr (sh
sh1, sh
sh2), Int
sz2 Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
sz1 Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
1)
shapeToRange :: ShapeR sh -> sh -> (sh, sh)
shapeToRange :: forall sh. ShapeR sh -> sh -> (sh, sh)
shapeToRange ShapeR sh
ShapeRz () = ((), ())
shapeToRange (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
sz) = let (sh
low, sh
high) = ShapeR sh -> sh -> (sh, sh)
forall sh. ShapeR sh -> sh -> (sh, sh)
shapeToRange ShapeR sh
shr sh
sh in ((sh
low, Int
0), (sh
high, Int
sz Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
1))
shapeToList :: ShapeR sh -> sh -> [Int]
shapeToList :: forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
ShapeRz () = []
shapeToList (ShapeRsnoc ShapeR sh
shr) (sh
sh,Int
sz) = Int
sz Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: ShapeR sh -> sh -> [Int]
forall sh. ShapeR sh -> sh -> [Int]
shapeToList ShapeR sh
shr sh
sh
listToShape :: HasCallStack => ShapeR sh -> [Int] -> sh
listToShape :: forall sh. HasCallStack => ShapeR sh -> [Int] -> sh
listToShape ShapeR sh
shr [Int]
ds =
case ShapeR sh -> [Int] -> Maybe sh
forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
shr [Int]
ds of
Just sh
sh -> sh
sh
Maybe sh
Nothing -> String -> sh
forall a. HasCallStack => String -> a
error String
"listToShape: unable to convert list to a shape at the specified type"
listToShape' :: ShapeR sh -> [Int] -> Maybe sh
listToShape' :: forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
ShapeRz [] = sh -> Maybe sh
forall a. a -> Maybe a
Just ()
listToShape' (ShapeRsnoc ShapeR sh
shr) (Int
x:[Int]
xs) = (, Int
x) (sh -> sh) -> Maybe sh -> Maybe sh
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> ShapeR sh -> [Int] -> Maybe sh
forall sh. ShapeR sh -> [Int] -> Maybe sh
listToShape' ShapeR sh
shr [Int]
xs
listToShape' ShapeR sh
_ [Int]
_ = Maybe sh
forall a. Maybe a
Nothing
shapeType :: ShapeR sh -> TypeR sh
shapeType :: forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
ShapeRz = TupR ScalarType sh
TupR ScalarType DIM0
forall (s :: * -> *). TupR s DIM0
TupRunit
shapeType (ShapeRsnoc ShapeR sh
shr) =
ShapeR sh -> TypeR sh
forall sh. ShapeR sh -> TypeR sh
shapeType ShapeR sh
shr
TypeR sh -> TupR ScalarType Int -> TupR ScalarType (sh, Int)
forall (s :: * -> *) a1 b. TupR s a1 -> TupR s b -> TupR s (a1, b)
`TupRpair`
ScalarType Int -> TupR ScalarType Int
forall (s :: * -> *) a. s a -> TupR s a
TupRsingle (SingleType Int -> ScalarType Int
forall a. SingleType a -> ScalarType a
SingleScalarType (NumType Int -> SingleType Int
forall a. NumType a -> SingleType a
NumSingleType (IntegralType Int -> NumType Int
forall a. IntegralType a -> NumType a
IntegralNumType IntegralType Int
TypeInt)))
rnfShape :: ShapeR sh -> sh -> ()
rnfShape :: forall sh. ShapeR sh -> sh -> DIM0
rnfShape ShapeR sh
ShapeRz () = ()
rnfShape (ShapeRsnoc ShapeR sh
shr) (sh
sh, Int
s) = Int
s Int -> DIM0 -> DIM0
forall a b. a -> b -> b
`seq` ShapeR sh -> sh -> DIM0
forall sh. ShapeR sh -> sh -> DIM0
rnfShape ShapeR sh
shr sh
sh
rnfShapeR :: ShapeR sh -> ()
rnfShapeR :: forall sh. ShapeR sh -> DIM0
rnfShapeR ShapeR sh
ShapeRz = ()
rnfShapeR (ShapeRsnoc ShapeR sh
shr) = ShapeR sh -> DIM0
forall sh. ShapeR sh -> DIM0
rnfShapeR ShapeR sh
shr
liftShapeR :: ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR :: forall sh. ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR ShapeR sh
ShapeRz = [|| ShapeR DIM0
ShapeRz ||]
liftShapeR (ShapeRsnoc ShapeR sh
sh) = [|| ShapeR sh -> ShapeR (sh, Int)
forall sh. ShapeR sh -> ShapeR (sh, Int)
ShapeRsnoc $$(ShapeR sh -> CodeQ (ShapeR sh)
forall sh. ShapeR sh -> CodeQ (ShapeR sh)
liftShapeR ShapeR sh
sh) ||]