{-# LANGUAGE DeriveFunctor        #-}
{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE UndecidableInstances #-}

{-|

Module      : MyFreeMonad
Description : A simple arithmetic language implemented using a free monad.

Free monads in Haskell are a powerful abstraction that allows for the
creation of monadic structures without imposing additional constraints
beyond those required by the monad definition. They are "free" in the sense
that they are unrestricted, meaning they do not add any extra laws or
structure beyond what is necessary for a monad.

A free monad satisfies all the Monad laws, but does not do any computation.
It just builds up a nested series of contexts. The user who creates such a
free monadic value is responsible for doing something with those nested
contexts, so that the meaning of such a composition can be deferred until
after the monadic value has been created.

Example usage:

>>> evalArith (example 0)
5

-}
module MyFreeMonad ( ArithM
                   , ArithF (..)
                   , addA
                   , subA
                   , mulA
                   , divA
                   , evalArith
                   , example
                   , exampleDo
                   ) where

import           Control.Monad.Free (Free (..), liftF)

-- | The functor for the arithmetic language, defining the supported operations.
data ArithF x
  = Add Int x -- ^ Addition operation
  | Sub Int x -- ^ Subtraction operation
  | Mul Int x -- ^ Multiplication operation
  | Div Int x -- ^ Division operation
  deriving (Int -> ArithF x -> ShowS
[ArithF x] -> ShowS
ArithF x -> String
(Int -> ArithF x -> ShowS)
-> (ArithF x -> String) -> ([ArithF x] -> ShowS) -> Show (ArithF x)
forall x. Show x => Int -> ArithF x -> ShowS
forall x. Show x => [ArithF x] -> ShowS
forall x. Show x => ArithF x -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall x. Show x => Int -> ArithF x -> ShowS
showsPrec :: Int -> ArithF x -> ShowS
$cshow :: forall x. Show x => ArithF x -> String
show :: ArithF x -> String
$cshowList :: forall x. Show x => [ArithF x] -> ShowS
showList :: [ArithF x] -> ShowS
Show, (forall a b. (a -> b) -> ArithF a -> ArithF b)
-> (forall a b. a -> ArithF b -> ArithF a) -> Functor ArithF
forall a b. a -> ArithF b -> ArithF a
forall a b. (a -> b) -> ArithF a -> ArithF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
fmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
$c<$ :: forall a b. a -> ArithF b -> ArithF a
<$ :: forall a b. a -> ArithF b -> ArithF a
Functor)

-- | The free monad for the arithmetic language, built over the 'ArithF' functor.
type ArithM = Free ArithF

-- | Evaluate an arithmetic expression.
--
-- >>> evalArith (Pure 10)
-- 10
--
-- >>> evalArith (example 1)
-- 6
evalArith :: Free ArithF Int -> Int
evalArith :: Free ArithF Int -> Int
evalArith (Free (Add Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x
evalArith (Free (Sub Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x
evalArith (Free (Mul Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x
evalArith (Free (Div Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
x
evalArith (Pure Int
x)         = Int
x

-- | Lift an addition operation into the 'ArithM' monad.
addA :: Int -> ArithM ()
addA :: Int -> ArithM ()
addA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Add Int
x ())

-- | Lift a subtraction operation into the 'ArithM' monad.
subA :: Int -> ArithM ()
subA :: Int -> ArithM ()
subA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Sub Int
x ())

-- | Lift a multiplication operation into the 'ArithM' monad.
mulA :: Int -> ArithM ()
mulA :: Int -> ArithM ()
mulA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Mul Int
x ())

-- | Lift a division operation into the 'ArithM' monad.
divA :: Int -> ArithM ()
divA :: Int -> ArithM ()
divA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Div Int
x ())

-- | An example arithmetic computation.
--
-- >>> evalArith (example 0)
-- 5
example :: Int -> ArithM Int
example :: Int -> Free ArithF Int
example Int
n =
    Int -> ArithM ()
divA Int
2
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
subA Int
10
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
mulA Int
2
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
addA Int
10
    ArithM () -> Free ArithF Int -> Free ArithF Int
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n

-- | An example arithmetic computation using do-notation.
--
-- >>> evalArith (exampleDo 1)
-- 6
exampleDo :: Int -> ArithM Int
exampleDo :: Int -> Free ArithF Int
exampleDo Int
n = do
  Int -> ArithM ()
divA Int
2
  Int -> ArithM ()
subA Int
10
  Int -> ArithM ()
mulA Int
2
  Int -> ArithM ()
addA Int
10
  Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n