{-# LANGUAGE FlexibleInstances   #-}
{-# LANGUAGE GADTs               #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# LANGUAGE TypeApplications    #-}
{-# LANGUAGE TypeFamilies        #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Representation.Slice
-- Copyright   : [2008..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Representation.Slice
  where

import Data.Array.Accelerate.Representation.Shape

import Language.Haskell.TH.Extra


-- | Class of slice representations (which are nested pairs)
--
class Slice sl where
  type SliceShape    sl      -- the projected slice
  type CoSliceShape  sl      -- the complement of the slice
  type FullShape     sl      -- the combined dimension
  sliceIndex :: SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)

instance Slice () where
  type SliceShape    () = ()
  type CoSliceShape  () = ()
  type FullShape     () = ()
  sliceIndex :: SliceIndex () (SliceShape ()) (CoSliceShape ()) (FullShape ())
sliceIndex = SliceIndex () () () ()
SliceIndex () (SliceShape ()) (CoSliceShape ()) (FullShape ())
SliceNil

instance Slice sl => Slice (sl, ()) where
  type SliceShape   (sl, ()) = (SliceShape  sl, Int)
  type CoSliceShape (sl, ()) = CoSliceShape sl
  type FullShape    (sl, ()) = (FullShape   sl, Int)
  sliceIndex :: SliceIndex
  (sl, ())
  (SliceShape (sl, ()))
  (CoSliceShape (sl, ()))
  (FullShape (sl, ()))
sliceIndex = SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
-> SliceIndex
     (sl, ()) (SliceShape sl, Int) (CoSliceShape sl) (FullShape sl, Int)
forall ix co co dim.
SliceIndex ix co co dim
-> SliceIndex (ix, ()) (co, Int) co (dim, Int)
SliceAll (forall sl.
Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
sliceIndex @sl)

instance Slice sl => Slice (sl, Int) where
  type SliceShape   (sl, Int) = SliceShape sl
  type CoSliceShape (sl, Int) = (CoSliceShape sl, Int)
  type FullShape    (sl, Int) = (FullShape    sl, Int)
  sliceIndex :: SliceIndex
  (sl, Int)
  (SliceShape (sl, Int))
  (CoSliceShape (sl, Int))
  (FullShape (sl, Int))
sliceIndex = SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
-> SliceIndex
     (sl, Int)
     (SliceShape sl)
     (CoSliceShape sl, Int)
     (FullShape sl, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
SliceFixed (forall sl.
Slice sl =>
SliceIndex sl (SliceShape sl) (CoSliceShape sl) (FullShape sl)
sliceIndex @sl)

-- |Generalised array index, which may index only in a subset of the dimensions
-- of a shape.
--
data SliceIndex ix slice coSlice sliceDim where
  SliceNil   :: SliceIndex () () () ()
  SliceAll   :: SliceIndex ix slice co dim -> SliceIndex (ix, ()) (slice, Int) co       (dim, Int)
  SliceFixed :: SliceIndex ix slice co dim -> SliceIndex (ix, Int) slice      (co, Int) (dim, Int)

instance Show (SliceIndex ix slice coSlice sliceDim) where
  show :: SliceIndex ix slice coSlice sliceDim -> String
show SliceIndex ix slice coSlice sliceDim
SliceNil          = String
"SliceNil"
  show (SliceAll SliceIndex ix slice coSlice dim
rest)   = String
"SliceAll (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ SliceIndex ix slice coSlice dim -> String
forall a. Show a => a -> String
show SliceIndex ix slice coSlice dim
rest String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"
  show (SliceFixed SliceIndex ix slice co dim
rest) = String
"SliceFixed (" String -> ShowS
forall a. [a] -> [a] -> [a]
++ SliceIndex ix slice co dim -> String
forall a. Show a => a -> String
show SliceIndex ix slice co dim
rest String -> ShowS
forall a. [a] -> [a] -> [a]
++ String
")"

-- | Project the shape of a slice from the full shape.
--
sliceShape :: SliceIndex slix sl co dim -> dim -> sl
sliceShape :: forall slix sl co dim. SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex slix sl co dim
SliceNil          ()      = ()
sliceShape (SliceAll   SliceIndex ix slice co dim
slix) (dim
sh, Int
n) = (SliceIndex ix slice co dim -> dim -> slice
forall slix sl co dim. SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex ix slice co dim
slix dim
sh, Int
n)
sliceShape (SliceFixed SliceIndex ix sl co dim
slix) (dim
sh, Int
_) = SliceIndex ix sl co dim -> dim -> sl
forall slix sl co dim. SliceIndex slix sl co dim -> dim -> sl
sliceShape SliceIndex ix sl co dim
slix dim
sh

-- | Project the full shape of the slice
--
sliceDomain :: SliceIndex slix sl co dim -> slix -> sl -> dim
sliceDomain :: forall slix sl co dim.
SliceIndex slix sl co dim -> slix -> sl -> dim
sliceDomain SliceIndex slix sl co dim
SliceNil          ()        ()       = ()
sliceDomain (SliceAll SliceIndex ix slice co dim
slix)   (ix
slx, ()) (slice
sl, Int
sz) = (SliceIndex ix slice co dim -> ix -> slice -> dim
forall slix sl co dim.
SliceIndex slix sl co dim -> slix -> sl -> dim
sliceDomain SliceIndex ix slice co dim
slix ix
slx slice
sl, Int
sz)
sliceDomain (SliceFixed SliceIndex ix sl co dim
slix) (ix
slx, Int
sz) sl
sl       = (SliceIndex ix sl co dim -> ix -> sl -> dim
forall slix sl co dim.
SliceIndex slix sl co dim -> slix -> sl -> dim
sliceDomain SliceIndex ix sl co dim
slix ix
slx sl
sl, Int
sz)

sliceShapeR :: SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR :: forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex slix sl co dim
SliceNil        = ShapeR sl
ShapeR ()
ShapeRz
sliceShapeR (SliceAll SliceIndex ix slice co dim
sl)   = ShapeR slice -> ShapeR (slice, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc (ShapeR slice -> ShapeR (slice, Int))
-> ShapeR slice -> ShapeR (slice, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix slice co dim -> ShapeR slice
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex ix slice co dim
sl
sliceShapeR (SliceFixed SliceIndex ix sl co dim
sl) = SliceIndex ix sl co dim -> ShapeR sl
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR sl
sliceShapeR SliceIndex ix sl co dim
sl

sliceDomainR :: SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR :: forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex slix sl co dim
SliceNil        = ShapeR dim
ShapeR ()
ShapeRz
sliceDomainR (SliceAll SliceIndex ix slice co dim
sl)   = ShapeR dim -> ShapeR (dim, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc (ShapeR dim -> ShapeR (dim, Int))
-> ShapeR dim -> ShapeR (dim, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix slice co dim -> ShapeR dim
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex ix slice co dim
sl
sliceDomainR (SliceFixed SliceIndex ix sl co dim
sl) = ShapeR dim -> ShapeR (dim, Int)
forall sh1. ShapeR sh1 -> ShapeR (sh1, Int)
ShapeRsnoc (ShapeR dim -> ShapeR (dim, Int))
-> ShapeR dim -> ShapeR (dim, Int)
forall a b. (a -> b) -> a -> b
$ SliceIndex ix sl co dim -> ShapeR dim
forall slix sl co dim. SliceIndex slix sl co dim -> ShapeR dim
sliceDomainR SliceIndex ix sl co dim
sl

-- | Enumerate all slices within a given bound. The innermost dimension changes
-- most rapidly.
--
-- See 'Data.Array.Accelerate.Sugar.Slice.enumSlices' for an example.
--
enumSlices
    :: forall slix co sl dim.
       SliceIndex slix sl co dim
    -> dim
    -> [slix]
enumSlices :: forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex slix sl co dim
SliceNil        ()       = [()]
enumSlices (SliceAll   SliceIndex ix slice co dim
sl) (dim
sh, Int
_)  = [ (ix
sh', ()) | ix
sh' <- SliceIndex ix slice co dim -> dim -> [ix]
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex ix slice co dim
sl dim
sh]
enumSlices (SliceFixed SliceIndex ix sl co dim
sl) (dim
sh, Int
n)  = [ (ix
sh', Int
i)  | ix
sh' <- SliceIndex ix sl co dim -> dim -> [ix]
forall slix co sl dim. SliceIndex slix sl co dim -> dim -> [slix]
enumSlices SliceIndex ix sl co dim
sl dim
sh, Int
i <- [Int
0..Int
nInt -> Int -> Int
forall a. Num a => a -> a -> a
-Int
1]]

rnfSliceIndex :: SliceIndex ix slice co sh -> ()
rnfSliceIndex :: forall ix slice co sh. SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co sh
SliceNil        = ()
rnfSliceIndex (SliceAll SliceIndex ix slice co dim
sh)   = SliceIndex ix slice co dim -> ()
forall ix slice co sh. SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co dim
sh
rnfSliceIndex (SliceFixed SliceIndex ix slice co dim
sh) = SliceIndex ix slice co dim -> ()
forall ix slice co sh. SliceIndex ix slice co sh -> ()
rnfSliceIndex SliceIndex ix slice co dim
sh

liftSliceIndex :: SliceIndex ix slice co sh -> CodeQ (SliceIndex ix slice co sh)
liftSliceIndex :: forall ix slice co sh.
SliceIndex ix slice co sh -> CodeQ (SliceIndex ix slice co sh)
liftSliceIndex SliceIndex ix slice co sh
SliceNil          = [|| SliceIndex () () () ()
SliceNil ||]
liftSliceIndex (SliceAll SliceIndex ix slice co dim
rest)   = [|| SliceIndex ix slice co dim
-> SliceIndex (ix, ()) (slice, Int) co (dim, Int)
forall ix co co dim.
SliceIndex ix co co dim
-> SliceIndex (ix, ()) (co, Int) co (dim, Int)
SliceAll $$(SliceIndex ix slice co dim -> CodeQ (SliceIndex ix slice co dim)
forall ix slice co sh.
SliceIndex ix slice co sh -> CodeQ (SliceIndex ix slice co sh)
liftSliceIndex SliceIndex ix slice co dim
rest) ||]
liftSliceIndex (SliceFixed SliceIndex ix slice co dim
rest) = [|| SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
forall ix slice co dim.
SliceIndex ix slice co dim
-> SliceIndex (ix, Int) slice (co, Int) (dim, Int)
SliceFixed $$(SliceIndex ix slice co dim -> CodeQ (SliceIndex ix slice co dim)
forall ix slice co sh.
SliceIndex ix slice co sh -> CodeQ (SliceIndex ix slice co sh)
liftSliceIndex SliceIndex ix slice co dim
rest) ||]