{-# LANGUAGE DeriveFunctor        #-}
{-# LANGUAGE DerivingStrategies   #-}
{-# LANGUAGE UndecidableInstances #-}

module MyFreeMonad ( ArithM
                   , ArithF (..)
                   , addA
                   , subA
                   , mulA
                   , divA
                   , evalArith
                   , example
                   , exampleDo
                   ) where

import           Control.Monad.Free (Free (..), liftF)

-- | The functor for the arithmetic language.
data ArithF x = Add Int x | Sub Int x | Mul Int x | Div Int x deriving (Int -> ArithF x -> ShowS
[ArithF x] -> ShowS
ArithF x -> String
(Int -> ArithF x -> ShowS)
-> (ArithF x -> String) -> ([ArithF x] -> ShowS) -> Show (ArithF x)
forall x. Show x => Int -> ArithF x -> ShowS
forall x. Show x => [ArithF x] -> ShowS
forall x. Show x => ArithF x -> String
forall a.
(Int -> a -> ShowS) -> (a -> String) -> ([a] -> ShowS) -> Show a
$cshowsPrec :: forall x. Show x => Int -> ArithF x -> ShowS
showsPrec :: Int -> ArithF x -> ShowS
$cshow :: forall x. Show x => ArithF x -> String
show :: ArithF x -> String
$cshowList :: forall x. Show x => [ArithF x] -> ShowS
showList :: [ArithF x] -> ShowS
Show, (forall a b. (a -> b) -> ArithF a -> ArithF b)
-> (forall a b. a -> ArithF b -> ArithF a) -> Functor ArithF
forall a b. a -> ArithF b -> ArithF a
forall a b. (a -> b) -> ArithF a -> ArithF b
forall (f :: * -> *).
(forall a b. (a -> b) -> f a -> f b)
-> (forall a b. a -> f b -> f a) -> Functor f
$cfmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
fmap :: forall a b. (a -> b) -> ArithF a -> ArithF b
$c<$ :: forall a b. a -> ArithF b -> ArithF a
<$ :: forall a b. a -> ArithF b -> ArithF a
Functor)

-- | The free monad for the arithmetic language.
type ArithM = Free ArithF

-- | Evaluate an arithmetic expression.
evalArith :: Free ArithF Int -> Int
evalArith :: Free ArithF Int -> Int
evalArith (Free (Add Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x
evalArith (Free (Sub Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
- Int
x
evalArith (Free (Mul Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
* Int
x
evalArith (Free (Div Int
x Free ArithF Int
n)) = Free ArithF Int -> Int
evalArith Free ArithF Int
n Int -> Int -> Int
forall a. Integral a => a -> a -> a
`div` Int
x
evalArith (Pure Int
x)         = Int
x

addA :: Int -> ArithM ()
addA :: Int -> ArithM ()
addA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Add Int
x ())

subA :: Int -> ArithM ()
subA :: Int -> ArithM ()
subA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Sub Int
x ())

mulA :: Int -> ArithM ()
mulA :: Int -> ArithM ()
mulA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Mul Int
x ())

divA :: Int -> ArithM ()
divA :: Int -> ArithM ()
divA Int
x = ArithF () -> ArithM ()
forall (f :: * -> *) (m :: * -> *) a.
(Functor f, MonadFree f m) =>
f a -> m a
liftF (Int -> () -> ArithF ()
forall x. Int -> x -> ArithF x
Div Int
x ())

-- @evalArith (example 0) == 5@         # ((((0+10)*2)-10)/2) == 5
example :: Int -> ArithM Int
example :: Int -> Free ArithF Int
example Int
n =
    Int -> ArithM ()
divA Int
2
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
subA Int
10
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
mulA Int
2
    ArithM () -> ArithM () -> ArithM ()
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> ArithM ()
addA Int
10
    ArithM () -> Free ArithF Int -> Free ArithF Int
forall a b. Free ArithF a -> Free ArithF b -> Free ArithF b
forall (m :: * -> *) a b. Monad m => m a -> m b -> m b
>> Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n

-- @evalArith (exampleDo 1) == 6@       # ((((1+10)*2)-10)/2) == 6
exampleDo :: Int -> ArithM Int
exampleDo :: Int -> Free ArithF Int
exampleDo Int
n = do
  Int -> ArithM ()
divA Int
2
  Int -> ArithM ()
subA Int
10
  Int -> ArithM ()
mulA Int
2
  Int -> ArithM ()
addA Int
10
  Int -> Free ArithF Int
forall a. a -> Free ArithF a
forall (m :: * -> *) a. Monad m => a -> m a
return Int
n