{-# LANGUAGE GADTs #-}

{-|
Module      : Defunc
Description : A simple example of defunctionalisation.
Copyright   : © Frank Jung, 2024, 2026
License     : GPL-3.0-only

== Defunctionalization

Defunctionalisation is a technique that replaces higher-order function
values with a first-order representation. Instead of passing lambdas or
closures directly, you encode the possible operations as constructors of an
algebraic data type and interpret them with an application function.

The code below shows this transformation for a small list fold example.
The higher-order @fold@ accepts a combining function, while @fold'@ accepts
an @Arrow@ value whose behaviour is defined by @apply@.

=== Why Use It?

* __Explicit representation:__ The set of supported operations becomes a
closed data type that can be inspected and manipulated.

* __Captured values become data:__ The @n@ in @FPlusCons n@ is stored as a
constructor field rather than hidden inside a closure.

* __Type safety with GADTs:__ The @Arrow@ GADT records the input and output
types for each encoded operation.

* __A first-order core:__ Higher-order control flow is replaced by a small
interpreter, which can simplify analysis or later compilation passes.

=== Key Principle

The main tradeoff is that first-class functions are replaced with a closed
set of function tags plus an interpreter. In this example, @(+)@ and
@\x xs -> x + n : xs@ become the constructors @FPlus@ and @FPlusCons@,
and @apply@ performs the work that the original function values did.

== References

- [Compiling higher order functions with GADTs](https://injuly.in/blog/defunct/)
- [Lightweight higher-kinded polymorphism](https://www.cl.cam.ac.uk/~jdy22/papers/lightweight-higher-kinded-polymorphism.pdf)

-}
module Defunc
  (
  -- * Types
  Arrow (..)
  -- * Functions
  , fold
  , sum
  , add
  , apply
  , fold'
  , sum'
  , add'
  ) where

import           Prelude hiding (sum)

-- |
-- == Motivation
-- The motivating example is the following functions.

-- | Fold a list using recursion.
fold :: (a -> b -> b) -> b -> [a] -> b
fold :: forall a b. (a -> b -> b) -> b -> [a] -> b
fold a -> b -> b
_ b
z []     = b
z
fold a -> b -> b
f b
z (a
x:[a]
xs) = a -> b -> b
f a
x ((a -> b -> b) -> b -> [a] -> b
forall a b. (a -> b -> b) -> b -> [a] -> b
fold a -> b -> b
f b
z [a]
xs)

-- | Sum using fold.
sum :: [Int] -> Int
sum :: [Int] -> Int
sum = (Int -> Int -> Int) -> Int -> [Int] -> Int
forall a b. (a -> b -> b) -> b -> [a] -> b
fold Int -> Int -> Int
forall a. Num a => a -> a -> a
(+) Int
0

-- | Add one to each element using fold.
add :: Int -> [Int] -> [Int]
add :: Int -> [Int] -> [Int]
add Int
n = (Int -> [Int] -> [Int]) -> [Int] -> [Int] -> [Int]
forall a b. (a -> b -> b) -> b -> [a] -> b
fold (\Int
x [Int]
xs -> Int
x Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
n Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
: [Int]
xs) []

-- |
-- == Defunctionalisation
-- Defunctionalisation of lambda expressions from the motivating example.

-- | Arrow data type with two function constructors representing the lambda
-- expressions from our motivating example.
data Arrow p r where
  FPlus :: Arrow (Int, Int) Int
  FPlusCons :: Int -> Arrow (Int, [Int]) [Int]

-- | Apply the Arrow to the input.
apply :: Arrow p r -> p -> r
apply :: forall p r. Arrow p r -> p -> r
apply Arrow p r
FPlus (r
x, r
y)          = r
x r -> r -> r
forall a. Num a => a -> a -> a
+ r
y
apply (FPlusCons Int
n) (Int
x, [Int]
xs) = (Int
n Int -> Int -> Int
forall a. Num a => a -> a -> a
+ Int
x)Int -> [Int] -> [Int]
forall a. a -> [a] -> [a]
:[Int]
xs

-- | Fold a list using the Arrow.
fold' :: Arrow (a, b) b -> b -> [a] -> b
fold' :: forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (a, b) b
_ b
z []     = b
z
fold' Arrow (a, b) b
f b
z (a
x:[a]
xs) = Arrow (a, b) b -> (a, b) -> b
forall p r. Arrow p r -> p -> r
apply Arrow (a, b) b
f (a
x, Arrow (a, b) b -> b -> [a] -> b
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (a, b) b
f b
z [a]
xs)

-- | Sum using fold'.
sum' :: [Int] -> Int
sum' :: [Int] -> Int
sum' = Arrow (Int, Int) Int -> Int -> [Int] -> Int
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' Arrow (Int, Int) Int
FPlus Int
0

-- | Add n to each element using fold'.
add' :: Int -> [Int] -> [Int]
add' :: Int -> [Int] -> [Int]
add' Int
n = Arrow (Int, [Int]) [Int] -> [Int] -> [Int] -> [Int]
forall a b. Arrow (a, b) b -> b -> [a] -> b
fold' (Int -> Arrow (Int, [Int]) [Int]
FPlusCons Int
n) []