{-# LANGUAGE BangPatterns        #-}
{-# LANGUAGE ScopedTypeVariables #-}
{-# LANGUAGE TemplateHaskell     #-}
{-# OPTIONS_HADDOCK hide #-}
-- |
-- Module      : Data.Array.Accelerate.Array.Unique
-- Copyright   : [2016..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--

module Data.Array.Accelerate.Array.Unique
  where

import Data.Array.Accelerate.Lifetime

import Control.Applicative
import Control.Concurrent.Unique
import Control.DeepSeq
import Data.Word
import Foreign.ForeignPtr
import Foreign.ForeignPtr.Unsafe
import Foreign.Marshal.Array
import Foreign.Ptr
import Foreign.Storable
import Language.Haskell.TH.Extra
import System.IO.Unsafe
import Prelude


-- | A uniquely identifiable array.
--
-- For the purposes of memory management, we use arrays as keys in a table. For
-- this reason we need a way to uniquely identify each array we create. We do
-- this by attaching a unique identifier to each array.
--
-- Note: [Unique array strictness]
--
-- The actual array data is in many cases unnecessary. For discrete memory
-- backends such as for GPUs, we require the unique identifier to track the data
-- in the remote memory space, but the data will in most cases never be copied
-- back to the host. Thus, the array payload field is only lazily allocated, and
-- we should be careful not to make this field overly strict.
--
data UniqueArray e = UniqueArray
    { forall e. UniqueArray e -> Unique
uniqueArrayId   :: {-# UNPACK #-} !Unique
    , forall e. UniqueArray e -> Lifetime (ForeignPtr e)
uniqueArrayData :: {-# UNPACK #-} !(Lifetime (ForeignPtr e))
    }

instance NFData (UniqueArray e) where
  rnf :: UniqueArray e -> ()
rnf = UniqueArray e -> ()
forall e. UniqueArray e -> ()
rnfUniqueArray

-- | Create a new UniqueArray
--
{-# INLINE newUniqueArray #-}
newUniqueArray :: ForeignPtr e -> IO (UniqueArray e)
newUniqueArray :: forall e. ForeignPtr e -> IO (UniqueArray e)
newUniqueArray ForeignPtr e
fp = Unique -> Lifetime (ForeignPtr e) -> UniqueArray e
forall e. Unique -> Lifetime (ForeignPtr e) -> UniqueArray e
UniqueArray (Unique -> Lifetime (ForeignPtr e) -> UniqueArray e)
-> IO Unique -> IO (Lifetime (ForeignPtr e) -> UniqueArray e)
forall (f :: * -> *) a b. Functor f => (a -> b) -> f a -> f b
<$> IO Unique
newUnique IO (Lifetime (ForeignPtr e) -> UniqueArray e)
-> IO (Lifetime (ForeignPtr e)) -> IO (UniqueArray e)
forall a b. IO (a -> b) -> IO a -> IO b
forall (f :: * -> *) a b. Applicative f => f (a -> b) -> f a -> f b
<*> ForeignPtr e -> IO (Lifetime (ForeignPtr e))
forall a. a -> IO (Lifetime a)
newLifetime ForeignPtr e
fp

-- | Access the pointer backing the unique array.
--
-- The array data is kept alive at least during the whole action, even if it is
-- not directly used inside. Note that it is not safe to return the pointer from
-- the action and use it after the action completes. All uses of the pointer
-- should be inside the bracketed function.
--
{-# INLINE withUniqueArrayPtr #-}
withUniqueArrayPtr :: UniqueArray a -> (Ptr a -> IO b) -> IO b
withUniqueArrayPtr :: forall a b. UniqueArray a -> (Ptr a -> IO b) -> IO b
withUniqueArrayPtr UniqueArray a
ua Ptr a -> IO b
go =
  Lifetime (ForeignPtr a) -> (ForeignPtr a -> IO b) -> IO b
forall a b. Lifetime a -> (a -> IO b) -> IO b
withLifetime (UniqueArray a -> Lifetime (ForeignPtr a)
forall e. UniqueArray e -> Lifetime (ForeignPtr e)
uniqueArrayData UniqueArray a
ua) ((ForeignPtr a -> IO b) -> IO b) -> (ForeignPtr a -> IO b) -> IO b
forall a b. (a -> b) -> a -> b
$ \ForeignPtr a
fp -> ForeignPtr a -> (Ptr a -> IO b) -> IO b
forall a b. ForeignPtr a -> (Ptr a -> IO b) -> IO b
withForeignPtr ForeignPtr a
fp Ptr a -> IO b
go

-- | Returns the element of an immutable array at the specified index. This
-- does no bounds checking.
--
{-# INLINE unsafeIndexArray #-}
unsafeIndexArray :: Storable e => UniqueArray e -> Int -> e
unsafeIndexArray :: forall e. Storable e => UniqueArray e -> Int -> e
unsafeIndexArray !UniqueArray e
ua !Int
i =
  IO e -> e
forall a. IO a -> a
unsafePerformIO (IO e -> e) -> IO e -> e
forall a b. (a -> b) -> a -> b
$! UniqueArray e -> Int -> IO e
forall e. Storable e => UniqueArray e -> Int -> IO e
unsafeReadArray UniqueArray e
ua Int
i

-- | Read an element from a mutable array at the given index. This does no
-- bounds checking.
--
{-# INLINE unsafeReadArray #-}
unsafeReadArray :: Storable e => UniqueArray e -> Int -> IO e
unsafeReadArray :: forall e. Storable e => UniqueArray e -> Int -> IO e
unsafeReadArray !UniqueArray e
ua !Int
i =
  UniqueArray e -> (Ptr e -> IO e) -> IO e
forall a b. UniqueArray a -> (Ptr a -> IO b) -> IO b
withUniqueArrayPtr UniqueArray e
ua ((Ptr e -> IO e) -> IO e) -> (Ptr e -> IO e) -> IO e
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr -> Ptr e -> Int -> IO e
forall a. Storable a => Ptr a -> Int -> IO a
peekElemOff Ptr e
ptr Int
i

-- | Write an element into a mutable array at the given index. This does no
-- bounds checking.
--
{-# INLINE unsafeWriteArray #-}
unsafeWriteArray :: Storable e => UniqueArray e -> Int -> e -> IO ()
unsafeWriteArray :: forall e. Storable e => UniqueArray e -> Int -> e -> IO ()
unsafeWriteArray !UniqueArray e
ua !Int
i !e
e =
  UniqueArray e -> (Ptr e -> IO ()) -> IO ()
forall a b. UniqueArray a -> (Ptr a -> IO b) -> IO b
withUniqueArrayPtr UniqueArray e
ua ((Ptr e -> IO ()) -> IO ()) -> (Ptr e -> IO ()) -> IO ()
forall a b. (a -> b) -> a -> b
$ \Ptr e
ptr -> Ptr e -> Int -> e -> IO ()
forall a. Storable a => Ptr a -> Int -> a -> IO ()
pokeElemOff Ptr e
ptr Int
i e
e


-- | Extract the pointer backing the unique array.
--
-- This is potentially unsafe, as if the argument is the last occurrence of this
-- unique array then the finalisers will be run, potentially invalidating the
-- plain pointer just obtained.
--
-- See also: 'unsafeGetValue', 'unsafeForeignPtrToPtr'.
--
{-# INLINE unsafeUniqueArrayPtr #-}
unsafeUniqueArrayPtr :: UniqueArray a -> Ptr a
unsafeUniqueArrayPtr :: forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr = ForeignPtr a -> Ptr a
forall a. ForeignPtr a -> Ptr a
unsafeForeignPtrToPtr (ForeignPtr a -> Ptr a)
-> (UniqueArray a -> ForeignPtr a) -> UniqueArray a -> Ptr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. Lifetime (ForeignPtr a) -> ForeignPtr a
forall a. Lifetime a -> a
unsafeGetValue (Lifetime (ForeignPtr a) -> ForeignPtr a)
-> (UniqueArray a -> Lifetime (ForeignPtr a))
-> UniqueArray a
-> ForeignPtr a
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UniqueArray a -> Lifetime (ForeignPtr a)
forall e. UniqueArray e -> Lifetime (ForeignPtr e)
uniqueArrayData


-- | Ensure that the unique array is alive at the given place in a sequence of
-- IO actions. Note that this does not force the actual array payload.
--
-- See: [Unique array strictness]
--
{-# INLINE touchUniqueArray #-}
touchUniqueArray :: UniqueArray a -> IO ()
touchUniqueArray :: forall a. UniqueArray a -> IO ()
touchUniqueArray = Lifetime (ForeignPtr a) -> IO ()
forall a. Lifetime a -> IO ()
touchLifetime (Lifetime (ForeignPtr a) -> IO ())
-> (UniqueArray a -> Lifetime (ForeignPtr a))
-> UniqueArray a
-> IO ()
forall b c a. (b -> c) -> (a -> b) -> a -> c
. UniqueArray a -> Lifetime (ForeignPtr a)
forall e. UniqueArray e -> Lifetime (ForeignPtr e)
uniqueArrayData


rnfUniqueArray :: UniqueArray a -> ()
rnfUniqueArray :: forall e. UniqueArray e -> ()
rnfUniqueArray (UniqueArray Unique
_ Lifetime (ForeignPtr a)
ad) = Lifetime (ForeignPtr a) -> ForeignPtr a
forall a. Lifetime a -> a
unsafeGetValue Lifetime (ForeignPtr a)
ad ForeignPtr a -> () -> ()
forall a b. a -> b -> b
`seq` ()

-- TODO: Make sure that the data is correctly aligned...
--
liftUniqueArray :: forall a. Storable a => Int -> UniqueArray a -> CodeQ (UniqueArray a)
liftUniqueArray :: forall a.
Storable a =>
Int -> UniqueArray a -> CodeQ (UniqueArray a)
liftUniqueArray Int
sz UniqueArray a
ua = Q Exp -> Code Q (UniqueArray a)
forall a (m :: * -> *). Quote m => m Exp -> Code m a
unsafeCodeCoerce (Q Exp -> Code Q (UniqueArray a))
-> Q Exp -> Code Q (UniqueArray a)
forall a b. (a -> b) -> a -> b
$ do
  [Word8]
bytes <- IO [Word8] -> Q [Word8]
forall a. IO a -> Q a
runIO (IO [Word8] -> Q [Word8]) -> IO [Word8] -> Q [Word8]
forall a b. (a -> b) -> a -> b
$ Int -> Ptr Word8 -> IO [Word8]
forall a. Storable a => Int -> Ptr a -> IO [a]
peekArray (a -> Int
forall a. Storable a => a -> Int
sizeOf (a
forall a. HasCallStack => a
undefined::a) Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
sz) (Ptr a -> Ptr Word8
forall a b. Ptr a -> Ptr b
castPtr (UniqueArray a -> Ptr a
forall a. UniqueArray a -> Ptr a
unsafeUniqueArrayPtr UniqueArray a
ua) :: Ptr Word8)
  [| unsafePerformIO $ do
       fp  <- newForeignPtr_ (Ptr $(Lit -> Q Exp
forall (m :: * -> *). Quote m => Lit -> m Exp
litE ([Word8] -> Lit
StringPrimL [Word8]
bytes)))
       ua' <- newUniqueArray (castForeignPtr fp)
       return ua'
   |]