{-# LANGUAGE GADTs #-}

{-|
Module      : Defunc
Description : A simple example of defunctionalisation .
Copyright   : © Frank Jung, 2024
License     : GPL-3.0-only

== Defunctionalization

A small example to show how to defunctionalise lambda functions.

== References

- [Compiling higher order functions with GADTs](https://injuly.in/blog/defunct/)
- [Lightweight higher-kinded polymorphism](https://www.cl.cam.ac.uk/~jdy22/papers/lightweight-higher-kinded-polymorphism.pdf)

-}
module Defunc
  (
  -- * Types
  Arrow (..)
  -- * Functions
  , fold
  , sum
  , add
  , apply
  , fold'
  , sum'
  , add'
  ) where

import           Prelude hiding (sum)

-- |
-- == Motivation
-- The motivating example is the following functions.

-- | Fold a list using recursion.
fold :: (a -> b -> b) -> b -> [a] -> b
fold :: forall a b. (a -> b -> b) -> b -> [a] -> b
fold a -> b -> b
_ b
z []     = b
z
fold a -> b -> b
f b
z (a
x:[a]
xs) = a -> b -> b
f a
x ((a -> b -> b) -> b -> [a] -> b
forall a b. (a -> b -> b) -> b -> [a] -> b
fold a -> b -> b
f b
z [a]
xs)

-- | Sum using fold.
sum :: [Int] -> Int
sum :: [Int] -> Int
sum = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall a b. (a -> b -> b) -> b -> [a] -> b
fold Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0

-- | Add one to each element using fold.
add :: Int -> [Int] -> [Int]
add :: Int -> [Int] -> [Int]
add Int
n = (Int -> [Int] -> [Int]) -> [Int] -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> b
fold (\Int
x [Int]
xs -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
xs) []

-- |
-- == Defunctionalisation
-- Defunctionalisation of lambda expressions from the motivating example.

-- | Arrow data type with two function constructors representing the lambda
-- expressions from our motivating example.
data Arrow p r where
  FPlus :: Arrow (Int, Int) Int
  FPlusCons :: Int -> Arrow (Int, [Int]) [Int]

-- | Apply the Arrow to the input.
apply :: Arrow p r -> p -> r
apply :: forall p r. Arrow p r -> p -> r
apply Arrow p r
FPlus (r
x, r
y)          = r
x r -> r -> r
forall a. Num a => a -> a -> a
+ r
y
apply (FPlusCons Int
n) (Int
x, [Int]
xs) = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x)Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
xs

-- | Fold a list using the Arrow.
fold' :: Arrow (a, b) b -> b -> [a] -> b
fold' :: forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (a, b) b
_ b
z []     = b
z
fold' Arrow (a, b) b
f b
z (a
x:[a]
xs) = Arrow (a, b) b -> (a, b) -> b
forall p r. Arrow p r -> p -> r
apply Arrow (a, b) b
f (a
x, Arrow (a, b) b -> b -> [a] -> b
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (a, b) b
f b
z [a]
xs)

-- | Sum using fold'.
sum' :: [Int] -> Int
sum' :: [Int] -> Int
sum' = Arrow (Int, Int) Int -> Int -> [Int] -> Int
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (Int, Int) Int
FPlus Int
0

-- | Add n to each element using fold'.
add' :: Int -> [Int] -> [Int]
add' :: Int -> [Int] -> [Int]
add' Int
n = Arrow (Int, [Int]) [Int] -> [Int] -> [Int] -> [Int]
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' (Int -> Arrow (Int, [Int]) [Int]
FPlusCons Int
n) []