{-# LANGUAGE RebindableSyntax #-}
-- |
-- Module      : Data.Array.Accelerate.Control.Monad
-- Copyright   : [2018..2020] The Accelerate Team
-- License     : BSD3
--
-- Maintainer  : Trevor L. McDonell <trevor.mcdonell@gmail.com>
-- Stability   : experimental
-- Portability : non-portable (GHC extensions)
--
-- A monad sequences actions over a parametrised type.
--
-- This is essentially the same as the standard Haskell 'Control.Monad' class,
-- lifted to Accelerate 'Exp' terms.
--
-- @since 1.4.0.0
--

module Data.Array.Accelerate.Control.Monad (

  -- * Monad class
  Monad(..),

  -- * Functions
  -- ** Basic functions
  (=<<), (>>),
  (>=>), (<=<),

  -- ** Conditional execution of monadic expressions
  when, unless,

  -- ** Monadic lifting operations
  liftM, liftM2, liftM3, liftM4, liftM5,

) where

import Data.Array.Accelerate.Data.Functor
import Data.Array.Accelerate.Language

import Data.Array.Accelerate.Sugar.Elt
import Data.Array.Accelerate.Smart

import Prelude                                                      ( Bool, flip )


-- | The 'Monad' class is used for scalar types which can be sequenced.
-- Instances of 'Monad' should satisfy the following laws:
--
-- [Left identity]  @'return' a '>>=' k  =  k a@
-- [Right identity] @m '>>=' 'return'  =  m@
-- [Associativity]  @m '>>=' (\\x -> k x '>>=' h)  =  (m '>>=' k) '>>=' h@
--
-- Furthermore, the 'Monad' and 'Functor' operations should relate as follows:
-- * @'fmap' f xs  =  xs '>>=' 'return' . f@
--
class Functor m => Monad m where
  -- | Sequentially compose two actions, passing any value produced
  -- by the first as an argument to the second.
  --
  -- \'@as '>>=' bs@\' can be understood as the @do@ expression
  --
  -- @
  -- do a <- as
  --    bs a
  -- @
  --
  infixl 1 >>=
  (>>=) :: (Elt a, Elt b, Elt (m a), Elt (m b))
        => Exp (m a)
        -> (Exp a -> Exp (m b))
        -> Exp (m b)

  -- | Inject a value into the monadic type
  --
  return :: (Elt a, Elt (m a)) => Exp a -> Exp (m a)


-- | Same as '>>=', but with the arguments interchanged
--
infixr 1 =<<
(=<<) :: (Monad m, Elt a, Elt b, Elt (m a), Elt (m b))
      => (Exp a -> Exp (m b))
      -> Exp (m a)
      -> Exp (m b)
=<< :: forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
(Exp a -> Exp (m b)) -> Exp (m a) -> Exp (m b)
(=<<) = (Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b))
-> (Exp a -> Exp (m b)) -> Exp (m a) -> Exp (m b)
forall a b c. (a -> b -> c) -> b -> a -> c
flip Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
forall a b.
(Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
(>>=)

-- | Sequentially compose two actions, discarding any value produced by the
-- first, like sequencing operators (such as the semicolon) in imperative
-- languages.
--
-- \'@as '>>' bs@\' can be understood as the @do@ expression
--
-- @
-- do as
--    bs
-- @
--
infixl 1 >>
(>>) :: (Monad m, Elt a, Elt b, Elt (m a), Elt (m b))
     => Exp (m a)
     -> Exp (m b)
     -> Exp (m b)
Exp (m a)
m >> :: forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> Exp (m b) -> Exp (m b)
>> Exp (m b)
k = Exp (m a)
m Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
forall a b.
(Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
>>= \Exp a
_ -> Exp (m b)
k


-- | Left-to-right composition of Kleisli arrows.
--
-- \'@(bs '>=>' cs) a@\' can be understood as the @do@ expression
--
-- @
-- do b <- bs a
--    cs b
-- @
--
infixr 1 >=>
(>=>) :: (Monad m, Elt a, Elt b, Elt c, Elt (m b), Elt (m c))
      => (Exp a -> Exp (m b))
      -> (Exp b -> Exp (m c))
      -> (Exp a -> Exp (m c))
Exp a -> Exp (m b)
f >=> :: forall (m :: * -> *) a b c.
(Monad m, Elt a, Elt b, Elt c, Elt (m b), Elt (m c)) =>
(Exp a -> Exp (m b)) -> (Exp b -> Exp (m c)) -> Exp a -> Exp (m c)
>=> Exp b -> Exp (m c)
g = \Exp a
x -> Exp a -> Exp (m b)
f Exp a
x Exp (m b) -> (Exp b -> Exp (m c)) -> Exp (m c)
forall a b.
(Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
Exp (m a) -> (Exp a -> Exp (m b)) -> Exp (m b)
>>= Exp b -> Exp (m c)
g

-- | Right-to-left composition of Kleisli arrows. @('>=>')@, with the arguments
-- flipped.
--
-- Note how this operator resembles function composition @('.')@:
--
-- > (.)   ::            (b ->   c) -> (a ->   b) -> a ->   c
-- > (<=<) :: Monad m => (b -> m c) -> (a -> m b) -> a -> m c
--
infixr 1 <=<
(<=<) :: (Monad m, Elt a, Elt b, Elt c, Elt (m b), Elt (m c))
      => (Exp b -> Exp (m c))
      -> (Exp a -> Exp (m b))
      -> (Exp a -> Exp (m c))
<=< :: forall (m :: * -> *) a b c.
(Monad m, Elt a, Elt b, Elt c, Elt (m b), Elt (m c)) =>
(Exp b -> Exp (m c)) -> (Exp a -> Exp (m b)) -> Exp a -> Exp (m c)
(<=<) = ((Exp a -> Exp (m b))
 -> (Exp b -> Exp (m c)) -> Exp a -> Exp (m c))
-> (Exp b -> Exp (m c))
-> (Exp a -> Exp (m b))
-> Exp a
-> Exp (m c)
forall a b c. (a -> b -> c) -> b -> a -> c
flip (Exp a -> Exp (m b)) -> (Exp b -> Exp (m c)) -> Exp a -> Exp (m c)
forall (m :: * -> *) a b c.
(Monad m, Elt a, Elt b, Elt c, Elt (m b), Elt (m c)) =>
(Exp a -> Exp (m b)) -> (Exp b -> Exp (m c)) -> Exp a -> Exp (m c)
(>=>)


-- | Conditional execution of a monadic expression
--
when :: (Monad m, Elt (m ())) => Exp Bool -> Exp (m ()) -> Exp (m ())
when :: forall (m :: * -> *).
(Monad m, Elt (m ())) =>
Exp Bool -> Exp (m ()) -> Exp (m ())
when Exp Bool
p Exp (m ())
s = Exp Bool -> Exp (m ()) -> Exp (m ()) -> Exp (m ())
forall t. Elt t => Exp Bool -> Exp t -> Exp t -> Exp t
cond Exp Bool
p Exp (m ())
s (Exp () -> Exp (m ())
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (() -> Exp ()
forall e. (HasCallStack, Elt e) => e -> Exp e
constant ()))

-- | The reverse of 'when'
--
unless :: (Monad m, Elt (m ())) => Exp Bool -> Exp (m ()) -> Exp (m ())
unless :: forall (m :: * -> *).
(Monad m, Elt (m ())) =>
Exp Bool -> Exp (m ()) -> Exp (m ())
unless Exp Bool
p Exp (m ())
s = Exp Bool -> Exp (m ()) -> Exp (m ()) -> Exp (m ())
forall t. Elt t => Exp Bool -> Exp t -> Exp t -> Exp t
cond Exp Bool
p (Exp () -> Exp (m ())
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (() -> Exp ()
forall e. (HasCallStack, Elt e) => e -> Exp e
constant ())) Exp (m ())
s

-- | Promote a function to a monad
--
liftM :: (Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) => (Exp a -> Exp b) -> Exp (m a) -> Exp (m b)
liftM :: forall (m :: * -> *) a b.
(Monad m, Elt a, Elt b, Elt (m a), Elt (m b)) =>
(Exp a -> Exp b) -> Exp (m a) -> Exp (m b)
liftM Exp a -> Exp b
f Exp (m a)
m1 = do
  Exp a
x1 <- Exp (m a)
m1
  Exp b -> Exp (m b)
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (Exp a -> Exp b
f Exp a
x1)

-- | Promote a function to a monad, scanning the monadic arguments from
-- left to right.
--
liftM2 :: (Monad m, Elt a, Elt b, Elt c, Elt (m a), Elt (m b), Elt (m c))
       => (Exp a -> Exp b -> Exp c)
       -> Exp (m a)
       -> Exp (m b)
       -> Exp (m c)
liftM2 :: forall (m :: * -> *) a b c.
(Monad m, Elt a, Elt b, Elt c, Elt (m a), Elt (m b), Elt (m c)) =>
(Exp a -> Exp b -> Exp c) -> Exp (m a) -> Exp (m b) -> Exp (m c)
liftM2 Exp a -> Exp b -> Exp c
f Exp (m a)
m1 Exp (m b)
m2 = do
  Exp a
x1 <- Exp (m a)
m1
  Exp b
x2 <- Exp (m b)
m2
  Exp c -> Exp (m c)
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (Exp a -> Exp b -> Exp c
f Exp a
x1 Exp b
x2)

-- | Promote a function to a monad, scanning the monadic arguments from
-- left to right (cf. 'liftM2')
--
liftM3 :: (Monad m, Elt a, Elt b, Elt c, Elt d, Elt (m a), Elt (m b), Elt (m c), Elt (m d))
       => (Exp a -> Exp b -> Exp c -> Exp d)
       -> Exp (m a)
       -> Exp (m b)
       -> Exp (m c)
       -> Exp (m d)
liftM3 :: forall (m :: * -> *) a b c d.
(Monad m, Elt a, Elt b, Elt c, Elt d, Elt (m a), Elt (m b),
 Elt (m c), Elt (m d)) =>
(Exp a -> Exp b -> Exp c -> Exp d)
-> Exp (m a) -> Exp (m b) -> Exp (m c) -> Exp (m d)
liftM3 Exp a -> Exp b -> Exp c -> Exp d
f Exp (m a)
m1 Exp (m b)
m2 Exp (m c)
m3 = do
  Exp a
x1 <- Exp (m a)
m1
  Exp b
x2 <- Exp (m b)
m2
  Exp c
x3 <- Exp (m c)
m3
  Exp d -> Exp (m d)
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (Exp a -> Exp b -> Exp c -> Exp d
f Exp a
x1 Exp b
x2 Exp c
x3)

-- | Promote a function to a monad, scanning the monadic arguments from
-- left to right (cf. 'liftM2')
--
liftM4 :: (Monad m, Elt a, Elt b, Elt c, Elt d, Elt e, Elt (m a), Elt (m b), Elt (m c), Elt (m d), Elt (m e))
       => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e)
       -> Exp (m a)
       -> Exp (m b)
       -> Exp (m c)
       -> Exp (m d)
       -> Exp (m e)
liftM4 :: forall (m :: * -> *) a b c d e.
(Monad m, Elt a, Elt b, Elt c, Elt d, Elt e, Elt (m a), Elt (m b),
 Elt (m c), Elt (m d), Elt (m e)) =>
(Exp a -> Exp b -> Exp c -> Exp d -> Exp e)
-> Exp (m a) -> Exp (m b) -> Exp (m c) -> Exp (m d) -> Exp (m e)
liftM4 Exp a -> Exp b -> Exp c -> Exp d -> Exp e
f Exp (m a)
m1 Exp (m b)
m2 Exp (m c)
m3 Exp (m d)
m4 = do
  Exp a
x1 <- Exp (m a)
m1
  Exp b
x2 <- Exp (m b)
m2
  Exp c
x3 <- Exp (m c)
m3
  Exp d
x4 <- Exp (m d)
m4
  Exp e -> Exp (m e)
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (Exp a -> Exp b -> Exp c -> Exp d -> Exp e
f Exp a
x1 Exp b
x2 Exp c
x3 Exp d
x4)

-- | Promote a function to a monad, scanning the monadic arguments from
-- left to right (cf. 'liftM2')
--
liftM5 :: ( Monad m
          , Elt a, Elt b, Elt c, Elt d, Elt e, Elt f
          , Elt (m a), Elt (m b), Elt (m c), Elt (m d), Elt (m e), Elt (m f)
          )
       => (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f)
       -> Exp (m a)
       -> Exp (m b)
       -> Exp (m c)
       -> Exp (m d)
       -> Exp (m e)
       -> Exp (m f)
liftM5 :: forall (m :: * -> *) a b c d e f.
(Monad m, Elt a, Elt b, Elt c, Elt d, Elt e, Elt f, Elt (m a),
 Elt (m b), Elt (m c), Elt (m d), Elt (m e), Elt (m f)) =>
(Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f)
-> Exp (m a)
-> Exp (m b)
-> Exp (m c)
-> Exp (m d)
-> Exp (m e)
-> Exp (m f)
liftM5 Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f
f Exp (m a)
m1 Exp (m b)
m2 Exp (m c)
m3 Exp (m d)
m4 Exp (m e)
m5 = do
  Exp a
x1 <- Exp (m a)
m1
  Exp b
x2 <- Exp (m b)
m2
  Exp c
x3 <- Exp (m c)
m3
  Exp d
x4 <- Exp (m d)
m4
  Exp e
x5 <- Exp (m e)
m5
  Exp f -> Exp (m f)
forall a. (Elt a, Elt (m a)) => Exp a -> Exp (m a)
forall (m :: * -> *) a.
(Monad m, Elt a, Elt (m a)) =>
Exp a -> Exp (m a)
return (Exp a -> Exp b -> Exp c -> Exp d -> Exp e -> Exp f
f Exp a
x1 Exp b
x2 Exp c
x3 Exp d
x4 Exp e
x5)