Posts

Playing with Probability in Scala 3

Given a fair dice, how many rolls are needed, on average, to observe all the sides? From this simple question, this article explores basic probability theory and statistics in Scala 3, the upcoming major release of the Scala programming language. It shows the new features, their use and benefits.

Read more →

Les GADTs Par l'Exemple

Découvrez les ADT généralisés en vous amusant.

Read more →

GADTs By Use Cases

Discover Generalized ADT through lots of practical examples.

Read more →

Proving Primality with GADTs

Let's do a bit of logic in Scala's type system.

Read more →

Recursion Schemes: the high-school introduction

Presentation of recursion schemes from simple examples without the complex vocabulary in the way.

Read more →

Demystifying GADTs

Introduction to catamorphisms on Algebraic Data Types

Read more →

Let's meet the charming fold family

Introduction to catamorphisms on Algebraic Data Types

Read more →

How to make game in the browser thanks to ScalaJS

ScalaIO.2018 Workshop

Read more →

JSON to XML: the probably a tiny bit over engineered way

Conversion from/to JSON and XML using advanced concepts

Read more →

F-Algebra talk at ScalaIO 2017: Modéliser astucieusement vos données

ScalaIO.2017 Talk

Read more →

Subsections of Posts

Playing with Probability in Scala 3

Here is a simple experiment: take a fair dice (or a coin) and roll it until every side of the dice has been observed at least once. Write down the number of rolls and repeat the operation several times. On average, how many rolls does it take to observe every side of the dice at least once? This is precisely the question we are about to answer together.

Solving Maths puzzle is fun. Solving them using the shiniest features of Scala 3 is even amazingly funnier! If you do not have Scala 3 installed yet:

You should now be able to run the Scala 3 REPL via the command:

[shell prompt]$ dotr -new-syntax -Yexplicit-nulls -Yerased-terms -indent -version
Starting dotty REPL...
Dotty compiler version 0.25.0-RC2 -- Copyright 2002-2020, LAMP/EPFL
scala>

An Important Note: Please refrain from starting your favourite IDE or starting a new project. It will only make experimenting harder and painful. All you need is the Scala 3 REPL, any basic text editor and knowing how to copy-paste on your system.

Understanding the Problem

Let us start by modelling a dice. The sides of a dice will be numbered starting from 1. We consider a coin as a 2-sided dice whose sides are 1 and 2. The sides of an usual 6-sided dice are 1, 2, 3, 4, 5 and 6.

final case class Dice(sides: Int) {
  def roll(): Int =
    scala.util.Random.nextInt(sides) + 1
}

val d2   = Dice(2)
val d6   = Dice(6)
val d10  = Dice(10)
val d20  = Dice(20)
val d100 = Dice(100)

d2 models a coin, d6 models a usual 6-sided dice, etc. The method roll , as its name suggest, simulates rolling the dice. On each invocation it gives a random side (its number). The first question to answer is: is such a dice fair ? Remember that a fair dice is one for which every side is equally likely to be observed. For a coin, it means getting 1 is as likely as getting 2. To check empirically that a dice is fair, or at least not loaded, we will roll it many times and count how often we observe its sides:

def (dice: Dice) frequencies(rolls: Long): Map[Int, Double] =
  // Stores how many times we observed each side
  val arr = Array.fill(dice.sides)(0)
  for i <- 1L to rolls do
    arr(dice.roll() - 1) += 1

  // Transforms counters into ratio
  val probability: IndexedSeq[(Int, Double)] =
    for i <- 1 to dice.sides yield
      i -> arr(i-1).toDouble / rolls
  probability.toMap
scala> d2.frequencies(1000000000L)
val res0: Map[Int, Double] = Map(1 -> 0.499985517, 2 -> 0.500014483)

scala> d6.frequencies(1000000000L)
val res1: Map[Int, Double] = HashMap(
  1 -> 0.166669596,
  2 -> 0.166660131,
  3 -> 0.166664591,
  4 -> 0.166654524
  5 -> 0.166665811,
  6 -> 0.166685347)

This extension method can be called like any method of Dice. As you can see, the frequencies are very close to each other. In addition, the more rolls we perform, the closer they are. We can conclude that these dice are fair enough. We are finally ready for our big adventure: finding the so desired average! We call an experiment the action of rolling the dice until every side has been observed at least once and the length of the experiment its number of rolls. The method rollUntilAllSeen simulates an experiment and return its length.

def (dice: Dice) rollUntilAllSeen(): Int =
  var rolls     = 0
  val seen      = Array.fill(dice.sides)(false)
  var remaining = dice.sides

  while remaining > 0 do
    val outcome = dice.roll()
    rolls += 1
    if !seen(outcome-1) then
      seen(outcome-1) = true
      remaining -= 1
  rolls
scala> d6.rollUntilAllSeen()
val res2: Int = 12

scala> d6.rollUntilAllSeen()
val res3: Int = 15

scala> d6.rollUntilAllSeen()
val res4: Int = 8

scala> d6.rollUntilAllSeen()
val res5: Int = 9

Based on the four experiments above, we get the impression that the average should be close to 11, but four experiments are not a lot to get an accurate estimation of the real average. Fortunately, the more experiments we run, the closer we get to it. We need to compute the average over a large number of experiments. We will actually be a bit smarter. Instead of limiting ourselves to computing the average, we will count, for every observed length, the number of its experiments. It will give us how often a length is observed, i.e. its frequency.

final case class Histogram[A](values: Map[A, BigInt]) {

  def average(using Accumulable[A], Ratio[A]): A =
    var acc   : A      = summon[Accumulable[A]].zero
    var total : BigInt = BigInt(0)
    for (a,count) <- values do
      acc += a*count
      total += count
    acc / total

  def frequencies(using Ratio[A]): Map[A, BigDecimal] =
    val total = values.values.sum
    values.map { case (a, count) => a -> BigDecimal(count) / BigDecimal(total) }

  def map[B](f: A => B): Histogram[B] =
    var histoB = scala.collection.mutable.Map.empty[B, BigInt]
    for (a, count) <- values do
      val b = f(a)
      histoB.update(b, histoB.getOrElse(b, BigInt(0)) + count)
    new Histogram[B](histoB.toMap)
}

object Histogram {
  def apply[A](iterations: Long)(value : => A): Histogram[A] =
    var histo = scala.collection.mutable.Map.empty[A, BigInt]
    for i <- 1L to iterations do
      val a = value
      histo.update(a, histo.getOrElse(a, BigInt(0)) + 1)
    new Histogram(histo.toMap)
}

The class Histogram[A] is essentially a key-value store where the value is the number of times the key has been observed, also known as its multiplicity. You may also wonder how Scala can accept adding two values of type A and multiplying/dividing a value of type A by a BigInt in average. It works thanks to the magic of Type Classes in Scala 3. Accumulable and Ratio are two type classes defined by:

trait Accumulable[A] {
  def zero: A
  def (x:A)+(y:A): A
  def (x:A)*(m: BigInt): A
}

trait Ratio[A] {
  def (x:A) / (m: BigInt): A
}

Note that, unlike Scala 2, no weird implicit conversion is required to support infix syntax for +, * and /. These methods are just defined as extension methods.

scala> val h6 = Histogram(100000000L)(d6.rollUntilAllSeen())
val h6: Histogram[Int] = Histogram(HashMap(73 -> 202, 69 -> 385, 88 -> 17,
  10 -> 8298014, 56 -> 4403, 42 -> 56557, 24 -> 1462064, 37 -> 140975, ...))

scala> h6.average
1 |h6.average
  |          ^
  |no implicit argument of type Accumulable[Int] was found for parameter x$1 of method average in class Histogram

If your first reaction is to implement an instance of Accumulable for Int, ask yourself how you could be confident that the computed values are correct when adding two positive numbers can result into a negative one:

scala> val x = 1990000000
val x: Int = 1990000000

scala> x + x
val res0: Int = -314967296

I am well aware that most use cases using Int is perfectly fine, because they never have numbers big enough to reach this limit. After all, 10 digits ought to be enough for anybody, right? In the next sections, you will see that we will reach this limit very often! Writing an instance of Accumulable for Int is a catastrophic idea. Instead we will write instances for BigInt and BigDecimal.

given Accumulable[BigInt] {
  def zero: BigInt = BigInt(0)
  def (x:BigInt)+(y:BigInt): BigInt = x+y
  def (x:BigInt)*(m:BigInt): BigInt = x*m
}

given Accumulable[BigDecimal] with Ratio[BigDecimal] {
  def zero: BigDecimal = BigDecimal(0)
  def (x:BigDecimal)+(y:BigDecimal): BigDecimal = x+y
  def (x:BigDecimal)*(m:BigInt): BigDecimal = x*BigDecimal(m)
  def (x:BigDecimal)/(m:BigInt): BigDecimal = x/BigDecimal(m)
}

Now we can get out much awaited average:

scala> h6.map(BigDecimal(_)).average
val res0: BigDecimal = 14.69830127

scala> Histogram(10000000L)(BigDecimal(d6.rollUntilAllSeen())).average
val res1: BigDecimal = 14.6949955

As you can see, the average is never far from 14.69. Knowing the average is nice but it does not tell us much about how the length is distributed among the experiments. This is precisely the reason why we kept counters! To visualize this data, we can export the histogram as a CSV file.

def (histogram: Histogram[Int])toCSV(fileName: String): Unit =
    import java.io._
    val pw = new PrintWriter(new File(fileName))
    pw.println("value,count")
    for length <- 0 to histogram.values.keySet.max do
      pw.println(s"$length,${histogram.values.getOrElse(length, BigInt(0))}")
    pw.close()
scala> h6.toCSV("d6.csv")

Opening the file d6.csv with LibreOffice and plotting the data as a XY Chart using the value column as X and count as Y gives this chart:

d6 distribution length/count d6 distribution length/count

As you can see, after length 15, there is a huge decrease in the number of experiments. And after length 50, the number of experiment is almost neglectable. The situation is similar for other dice. For example, here is the curve for d100:

d100 distribution length/count d100 distribution length/count

By running enough experiment, we can get an pretty close estimation of the average. But an experiment is by nature random, every measure we perform is very likely to give a (close but still) different estimation of the average. We need a more reliable way to approximate the average.

Modelling the Problem

To get a more reliable approximation of the average, or the exact value, we can not rely on random experiments. We need to use maths! Remember that an experiment is a sequence of dice rolls such that, as soon as every side of the dice have been observed at least once, the sequence is over. Given a dice, we will call a sequence of sides valid when it follows these rules.

Using a 3-sided dice:

  • The sequence 2→2→1→2 is invalid because the side 3 has not been observed.
  • The sequence 2→2→1→2→3→3 is invalid because the sequence needs to stop as soon as every side has been observed so the last roll is not required.
  • The sequence 2→2→1→2→3 is valid: every side has been observed and it was not possible to stop earlier.

Note that the validity depends on the dice used! The sequence 2→2→1→2→3 is valid for a 3-sided dice but invalid for a 4-sided dice. To compute the average, we will: (1) enumerate all valid sequences (up to a certain length), then (2) sum their length and finally (3) divide by the number of values in the sum.

def enumerate(sides: Int, depth: Int): LazyList[List[Int]] =
  def aux(revSeq: List[Int]): LazyList[List[Int]] =
    if revSeq.length > depth
    then LazyList.empty
    else if revSeq.toSet.size == sides
          then LazyList(revSeq.reverse)
          else LazyList.range(1, sides+1).flatMap { next => aux(next :: revSeq) }

  aux(List.empty)

def average(sides: Int, depth: Int): Double =
  val validSequences = enumerate(sides, depth)
  validSequences.map(_.length.toDouble).sum / validSequences.size

For a 3-sided dice, the list of all valid sequences up to length 5 is:

scala> enumerate(3,5).map(_.mkString("→")).toList
val res26: List[String] = List(
  11123, 11132, 11213, 11223, 1123,   11312, 1132, 11332,
  12113, 12123, 1213,   12213, 12223, 1223,   123,
  13112, 1312,   13132, 132,     13312, 1332,   13332,
  21113, 21123, 2113,   21213, 21223, 2123,   213,
  22113, 22123, 2213,   22213, 22231, 2231,   22321, 22331,
  231,     2321,   23221, 23231, 2331,   23321, 23331,
  31112, 3112,   31132, 312,     31312, 3132,   31332,
  321,     3221,   32221, 32231, 3231,   32321, 32331,
  33112, 3312,   33132, 3321,   33221, 33231, 33312, 33321
  )

That’s awesome! We just have to average all the lengths:

scala> average(3,5)
val res27: Double = 4.545454545454546

scala> average(3,10)
val res32: Double = 9.071713147410359

scala> average(3,14)
val res36: Double = 13.00953778429934

scala> average(3,16)
val res39: Double = 15.003205911089399

Apparently, computing the average for sequences up to length 16 does not converge yet. Unfortunately our implementation is to slow for large sequences. The number of valid sequences grows exponentially over length. We need a much faster algorithm.

def aggregate[A:Accumulable](sides: Int, depth: Int)(f: Int => A): A =
  var current : Array[BigInt] = Array.fill(sides+1)(BigInt(0))
  var next    : Array[BigInt] = Array.fill(sides+1)(BigInt(0))
  var agg     : A             = summon[Accumulable[A]].zero
  var length  : Int           = 0

  current(0) = BigInt(1) // The empty sequence is the unique sequence where 0 sides have been seen

  while length <= depth do
    agg += f(length) * current(sides)

    for seen <- 0 to sides - 1 do
      next(seen)     += current(seen) * seen
      next(seen + 1) += current(seen) * (sides - seen)

    length += 1
    current = next
    next    = Array.fill(sides+1)(BigInt(0))

  agg

This is a generic aggregation function that, given a sides-sided dice, gives to every valid sequence a value that depends only on its length (via f) and aggregate them over all valid sequences up to a certain length called the depth. We can use it to compute the average for sequences up to length 100000).

scala> val depth = 100000
val depth: Int = 100000

scala> val sides = 3
val sides: Int = 3

scala> val sumOfLengths = aggregate(sides,depth)(length => BigInt(length))
val sumOfLengths: BigInt = A 30109 digits positive integer

scala> val numberOfSeqs = aggregate(sides,depth)(length => BigInt(1))
val numberOfSeqs: BigInt =  A 30104 digits positive integer

scala> val avegrageUpTo100000 = (BigDecimal(sumOfLengths)/numberOfSeqs).toDouble
val avegrageUpTo100000: Double = 99999.0

The average does not seem to converge. Have a look at previous estimations of the averages for depths 5, 10, 14 and 16. The average seem very close to depth - 1. It seem to indicate that, on average, you need to roll a 3-sided dice an infinity of times to obverse every side at least once. It means that, regardless the number of rolls you perform, it is almost certain that you will never see at least once side. Let’s confirm that using the methods of the previous section:

scala> val d3 = Dice(3)
val d3: Dice = Dice(3)

scala> val h3 = Histogram(100000000L)(d3.rollUntilAllSeen())
val h3: Histogram[Int] = Histogram(HashMap(...))

scala> h3.map(BigDecimal(_)).average
val res6: BigDecimal = 5.5003517

The experience shows that, on average, 5.5 rolls are enough to see every side of a 3-sided dice. The only possible conclusion is that our modeling is very wrong. The problem is we consider every sequence to be equally likely. But the sequence 1→2→3 is much likelier to happen than 1→2→2→1→2→1→2→1→1→2→1→3. We can plot the h3 histogram to check that the longer a sequence is, the less likely it is to happen:

d3 distribution length/count d3 distribution length/count

Our real big mistake is our mathematical modeling does not model the real problem. This is a very important rule in modeling: models much match closely the things they are supposed to model.

Understanding the Experiment

To get a mathematical model that closely match the experience, we need to have a deeper understanding of the problem. When we perform 10000000 experiments, we get as many valid sequences of sides. But taking the hypothesis that theses sequences are all distinct is wrong. An experiment is a random process, you may get the same sequence several times. We need to take into account how often a sequence is likely to happen.

Given a \(n\)-sided fair dice, by definition of fairness, every time we roll the dice, for any side, there is exactly \(\frac{1}{n}\) chance to observe this side. Each roll being independent from other ones, for every sequence of \(m\) rolls, there is \((\frac{1}{n})^m\) chance to observe this sequence.

Do not jump on the conclusion that the probability of a valid sequence of length \(m\) in our problem is \((\frac{1}{n})^m\) yet! If we change the problem slightly by requiring that every side is observed exactly once (ignoring sequences where one side is observed several times). Then for a coin there is only 2 valid sequences 1→2 and 2→1, each equally likely so they both have probability \(\frac{1}{2}\), not \(\frac{1}{4}\). The good way to proceed is finding a probability space that models correctly the problem.

Remember that the dice used has been fixed as a \(n\)-sided fair dice. The first step in defining a probability space is defining the outcomes. Outcomes are the results of statistical experiments. It our problem, outcomes are the valid sequences of dice rolls. Then we need to define the set of events. Events are the things whose likelihood we want to measure! For example: what is the probability that a valid sequence starts with 1, or what is the probability that a valid sequence is a palindrome (i.e. the reverse of the sequence is the sequence itself),etc. It feels natural, in our situation, to consider as event, any set of valid sequences. Last but not least, we need the probability function. Its purpose is to give, for any event, the likelihood of this event. Informally, a probability function must satisfy 3 properties:

  1. the likelihood of any event must be positive or null but never negative!
  2. the likelihood of distinct events is the sum of the likelihood of every event.
  3. the likelihood of the set of all valid sequences must be 1.

This is where things get complicated. We can decide to give to any valid sequence of size \(m\) the the probability (\((\frac{1}{n})^m\)), but we need to prove that this function satisfies all the conditions above to be a probability function. In addition, the set of valid sequences is not that trivial to work with (at least for me!). Fortunately working in this probability space is not mandatory. We can work in more comfortable probability space as long as we are able to transpose results into this one.

Remember that the dice being used is a \(n\)-sided fair dice. Let us start by some definitions:

  • Let \(\mathcal{C}_n=\{1,\dots,n\}\) be the set of the dice’s sides.

  • The set of countably infinite sequences of sides is written \(\Omega\).

  • The set of finite sequences of sides is written \(\mathcal{F}\).

  • For any finite sequence of sides \(f \in \mathcal{F}\), its length is written \(|f|\).

  • For any sequence \(s \in \mathcal{F} \cup \Omega\) of sides (finite or infinite), let \(\diamond s\) be the set of sides observed in \(s\) and \(\sharp s\) be the number of distinct sides observed in \(s\), i.e. \(\sharp s = |\diamond s|\).

  • For any sequence \(s \in \mathcal{F} \cup \Omega\) of sides (finite or infinite), and any \(i \in \{1,\dots,|s]\}\), let \(s_{[i]}\) be the sides observed at the \(i\)-th roll of \(s\), i.e. \(s=(s_{[1]},\dots,s_{[|s|]})\).

  • For any \(f \in \mathcal{F}\) and \(f' \in \mathcal{F}\cup\Omega\), where \(f = (s_1,\dots,s_i)\) and \(f' = (s'_1,\dots)\), we write \(f \cdot f' \in \mathcal{F}\cup\Omega\) the concatenation of \(f\) and \(f'\), i.e. the sequence \((s_1,\dots,s_i,s'_1,\dots)\). Furthermore, for any set of prefix \(F \subset \mathcal{F}\), and any set set of (finite or infinite) sequences \(S\subset \mathcal{F}\cup\Omega\), we write \(F\cdot S = \{f\cdot s\mid f\in F,\ s\in S\}\) the set of sequences made of concatenations of \(F\) and \(S\).

For the new probability space, we can take as outcomes \(\Omega\), the set of all infinite (but countable) sequence of sides. Given a finite sequence of sides \(f \in \mathcal{F}\) (possibly empty), the set of all outcomes (infinite sequences) that start with \(f\) is called a prefix event and written \(\mathcal{E}(f)\). The finite sequence \(f \in \mathcal{F}\) is called a prefix. Note that the set of all outcomes, \(\Omega\), is an event because it is the prefix event of the empty sequence \(\epsilon\). The set of all prefix events is written \(\mathcal{E}\). We will take as events the the σ-field \(\sigma(\mathcal{E})\) generated from prefix events, i.e. the smallest σ-field containing prefix events that is closed under complement, countable unions and intersections. It means that any countable union or intersection of events is an event and the complement of any event is an event. Let \(F\subset \mathcal{F}\) be a finite or countable set of prefixes, we write \(\mathcal{E}(F)\) the event \(\bigcup_{f\in F} \mathcal{E}(f)\).

The class of sets \(\mathcal{R} = \mathcal{E} \cup \{\emptyset\} \) is a semiring of sets. It comes from two facts. Let \(f_1, f_2\in \mathcal{F}\) be two prefixes. Either \(\mathcal{E}(f_1)\) and \(\mathcal{E}(f_2)\) are disjoint, or one is contained in the other. It proves that \(\mathcal{R}\) is table by finite intersection. If \(\mathcal{E}(f_2) \subset \mathcal{E}(f_1)\) then there exists \(f_3\in \mathcal{F}\) such that \(f_2 = f_1 \cdot f_3\) and \(\mathcal{E}(f) = \bigcup_{f_4\in \mathcal{F}, |f_4|=|f_3|} \mathcal{E}(f \cdot f_4)\). It proves that \(\mathcal{E}(f_1) \setminus \mathcal{E}(f_2)\) can be written as a finite union of disjoint element of \(\mathcal{R}\).

Instead of defining the probability function \(p\) directly over \(\sigma(\mathcal{E})\), \(p\) is defined over \(\mathcal{R}\) first and then extended to \(\sigma(\mathcal{E})\) using Carathéodory’s extension theorem. \(p\) is defined on \(\mathcal{R}\) by

$$\begin{aligned} p(\emptyset) & = 0 \\ \forall f \in \mathcal{F}\quad p(\mathcal{E}(f)) & = (\frac{1}{n})^{|f|} \end{aligned}$$

\(p\) is additive and σ-subadditive because the only way for the union of two disjoint elements of \(\mathcal{R}\) to be in \(\mathcal{R}\) is if one of the two is the empty set. It is also σ-finite because \(p(\Omega) = p(\mathcal{E}(\epsilon)) = (\frac{1}{n})^{0} = 1\). The function \(p\) can then be uniquely extended into a probability function over \(\sigma(\mathcal{E})\).

Note that:

$$\begin{aligned} \forall f_1,f_2\in \mathcal{F},\quad p(\mathcal{E}(f_1\cdot f_2)) & = p(\mathcal{E}(f_1)) \times p(\mathcal{E}(f_2)) \\ \forall f\in \mathcal{F},\quad p(\mathcal{E}(f)) & = \prod_{i=1}^{|f|}p(\mathcal{E}(f_{[i]})) \end{aligned}$$

The probability space \((\Omega,\sigma(\mathcal{E}),p)\) has the very pleasing property that the probability of a prefix event \(\mathcal{E}(f)\) is exactly the probability of getting the sequence \(f\) with \(|f|\) dice rolls. There is a problem though: there are outcomes for which not every side has been observed at least once. The infinite sequence \((1,\dots)\) is such an outcome.

Let \(M\) the subset of outcomes such that at least one side has not been observed, i.e. \( M = \{ \omega \in \Omega \mid \sharp \omega < n \} \). We want to know how likely \(M\) is. For any \(i \in \mathbb{N} \), let \(M_i\) the set of outcomes such at least one side has not been observed up to the \(i\)-th roll, i.e. \(M_i = \mathcal{E}(\{f \in \mathcal{F}\mid |f|=i,\ \sharp f < n\})\).

For any \(i\in \mathbb{N}\), \(M_i\) is an event because it is a finite union of events. Not observing at least one side with an infinity of rolls is equivalent to not observing this side for every \(i\)-th roll, so \(M = \bigcap_{i\in\mathbb{N}} M_i\). From the last equation, we can conclude that \(M\) is an event because it is a countable intersection of events. Furthermore, given \(i\) dice rolls, the probability of not observing a given side is \((\frac{n-1}{n})^i\) so \(p(M_i) \le n \times (\frac{n-1}{n})^i\). Note that \(M_{i+1} \subset M_i\), so \(\bigcap_{j=0}^i M_j = M_i\). We can conclude that the probability, given an infinity of dice rolls, of never observing one side (anyone) is \(0\):

$$ p(M) = p(\bigcap_{i\in\mathbb{N}} M_i) = \lim_{i\rightarrow\infty} p(M_i) \le \lim_{i\rightarrow\infty} n \times \Bigl(\frac{n-1}{n}\Bigr)^i = 0 $$

Note that it does not mean these outcome are impossible. In theory, if you flip a coin endlessly, it is possible to always get head (resp. tail), but this is incredibly unlikely. Let \(\Omega' = \overline{M}\) be set of outcomes such that every side has been observed at least once. \(\Omega'\) is the complement of \(M\), written \(\overline{M}\). Its probability is then \(p(\Omega') = 1\). So for any event \(E\in\sigma(\mathcal{E})\), \(p(E) = p(E\cap\Omega') + p(E\cap M)\), but \(p(E\cap M) \le p(M) = 0\), so

$$\forall E\in\sigma(\mathcal{E})\quad p(E) = p(E\cap\Omega')$$

Informally, it means we can assume that, in every outcome, every side of the dice are observed at least once. More precisely, we take as probability space, the restriction of \((\Omega,\sigma(\mathcal{E}),p)\) to \(\Omega'\), written \((\Omega',\sigma(\mathcal{E})|_{\Omega'},p)\).

How does the problem translates into this probability space? Remember that an outcome \(\omega \in \Omega'\) is an infinite sequence \((s_1,s_2,\dots)\) of sides \(s_i \in \mathcal{C}_n\) such that every every side is observed at some point. For any \(m \in \{0,\dots,n\}\) we define the random variable \(X_m\) as the function, from \(\Omega'\) to \(\mathbb{R}\), that maps any outcome \(\omega \in \Omega' = (s_1,s_2,\dots)\) to the first \(i\) such that \(m\) side has been observed at least once.

$$\forall \omega = (s_i)_{i\ge 1}\in \Omega', \quad X_m(\omega) = \inf_i \bigl\{ i \mid \sharp (s_1,\dots,s_i) = m \bigr\}$$

Note that \(X_n(\omega)\) is the number of rolls needed to observe every side of the chosen dice at least once. The average we are looking for is actually the expected value of \(X_n\). But for the expected value to defined, \(X_m\) has to be a measurable function from \((\Omega',\sigma(\mathcal{E})|_{\Omega'})\) to \((\mathbb{R},\mathcal{B}(\mathbb{R}))\). Let \(F_{m,l}\) be the set of prefixes \(f=(s_1,\dots,s_l)\) of length \(l\) such that \(l\) is the first roll for which exactly \(m\) sides have been observed at least once, i.e. \(l = \inf_i\{i \mid \sharp (s_1,\dots,s_i) = m \}\). Then

$$ X_m = \sum_{l\in\mathbb{N}} \sum_{f \in F_{m,l}} l \times \mathbb{1}_{\mathcal{E}(f)\cap \Omega'} $$

and so \(X_m\) is indeed measurable. The expected value is:

$$\begin{aligned} \mathbb{E}(X_m) & = \sum_{l\in\mathbb{N}} l\times p(X_m^{-1}(l)) \\ & = \sum_{l\in\mathbb{N}} \sum_{f \in F_{m,l}} l \times p(\mathcal{E}(f)) \\ & = \sum_{l\in\mathbb{N}} \sum_{f \in F_{m,l}} l \times \Bigl(\frac{1}{n}\Bigr)^l \\ & = \sum_{l\in\mathbb{N}} |F_{m,l}| \times l \times \Bigl(\frac{1}{n}\Bigr)^l \end{aligned}$$

The mistake we did in the last section is now clear. We computed \(\sum_{l\in\mathbb{N}} |F_{m,l}| \times l\) instead of \(\sum_{l\in\mathbb{N}} |F_{m,l}| \times l \times (\frac{1}{n})^l\). We just have to fix the function used in the aggregation to compute the right value:

def forOneValidPrefix(sides: Int)(length: Int): BigDecimal =
  (BigDecimal(1) / sides).pow(length) * length
scala> val depth = 100000
val depth: Int = 100000

scala> val sides = 3
val sides: Int = 3

scala> val expectedValueFor100000 = aggregate(sides,depth)(forOneValidPrefix(sides))
val expectedValueFor100000: BigDecimal = 5.499999999999999999999999999999999

This time the computed value match what we observed with random experiments, a value around \(5.5\). It also match the average we got for a 6-sided fair dice:

scala> val depth = 10000
val depth: Int = 10000

scala> val sides = 6
val sides: Int = 6

scala> val expectedValueFor100000 = aggregate(sides,depth)(forOneValidPrefix(sides))
val expectedValueFor100000: BigDecimal = 14.70000000000000000000000000000001

We can actually go a bit deeper by observing that for any \(m \in \{1,\dots,n\}\), \(\mathbb{E}(X_m) = \mathbb{E}(X_{m-1}) + \mathbb{E}(X_m - X_{m-1})\):

$$\begin{aligned} \mathbb{E}(X_m) & = \sum_{i=0}^{m-1} \mathbb{E}(X_{i+1} - X_i) \\ \mathbb{E}(X_{i+1}-X_i) & = \sum_{d\in\mathbb{N}^*} d \times p((X_{i+1}-X_i)^{-1}(d)) \\ p((X_{i+1}-X_i)^{-1}(d)) & = \sum_{k\in\mathbb{N}} p(\mathcal{E}(F_{i+1,k+d}) \cap \mathcal{E}(F_{i,k})) \\p(\mathcal{E}(F_{i+1,k+d})\cap \mathcal{E}(F_{i,k})) & = p(\mathcal{E}(\{f \cdot f' \cdot c \mid f\in F_{i,k},\quad |f'|=d-1,\ \diamond f' \subset \diamond f,\quad c\in \mathcal{C}_n\setminus\diamond f \})) \\& = \sum_{f\in F_{i,k}} \Biggl( \sum_{|f'|=d-1,\ \diamond f' \subset \diamond f} \Biggl( \sum_{c\in \mathcal{C}_n\setminus \diamond f} p(\mathcal{E}(f \cdot f' \cdot c)) \Biggr)\Biggr) \\& = \sum_{f\in F_{i,k}} \Biggl( p(\mathcal{E}(f)) \times \sum_{|f'|=d-1,\ \diamond f' \subset \diamond f} \Biggl( p(\mathcal{E}(f')) \times \sum_{c\in \mathcal{C}_n\setminus \diamond f} p(\mathcal{E}(c)) \Biggr)\Biggr) \\& = \sum_{f\in F_{i,k}} \Biggl( p(\mathcal{E}(f)) \times |\{f'\in \mathcal{F} \mid |f'|=d-1,\ \diamond f' \subset \diamond f\}| \times \Bigl(\frac{1}{n}\Bigr)^{d-1} \times \frac{n-i}{n} \Biggr) \\& = \Bigl(\frac{i}{n}\Bigr)^{d-1} \times \frac{n-i}{n} \times p(\mathcal{E}(F_{i,k})) \end{aligned}$$

So \(p(\mathcal{E}(F_{i+1,k+d}) | \mathcal{E}(F_{i,k})) = (\frac{i}{n})^{d-1}\times \frac{n-1}{n} \). So \(p(\mathcal{E}(F_{i+1,k+d}) | \mathcal{E}(F_{i,k})) = (\frac{i}{n})^{d-1}\times \frac{n-1}{n} \). So \(\mathbb{E}(X_{i+1} - X_i) = \sum_{d\in\mathbb{N}^*} d \times \frac{n-1}{n}\times (\frac{i}{n})^{d-1} \).

We recognize a Geometric Distribution whose probability of success is \(p' = \frac{n-i}{n}\). Its expected value is known to be \(\frac{1}{p'} = \frac{n}{n-i}\). It can be computed by

$$\begin{aligned} \mathbb{E}(X_{i+1} - X_i) & = \sum_{d\in\mathbb{N}^*} d \times \frac{n-i}{n}\times \Bigl(\frac{i}{n}\Bigr)^{d-1} \\& = \frac{n-i}{n} \times \sum_{d\in\mathbb{N}^*} d \times \Bigl(\frac{i}{n}\Bigr)^{d-1} \\& = \frac{n-i}{n} \times \sum_{d\in\mathbb{N}^*} d \times \Bigl(\frac{i}{n}\Bigr)^{d-1} \end{aligned}$$

But

$$\begin{aligned} \sum_{d\in\mathbb{N}^\star} d \times x^{d-1} & = \sum_{d\in\mathbb{N}^\star} (x^d)^\prime \\& = \Bigl(\sum_{d\in\mathbb{N}^\star} x^d\Bigr)^\prime \\& = \Bigl(\frac{x}{1 - x}\Bigr)^\prime \\& = \frac{1}{(1 - x)^2} \end{aligned}$$

So

$$\begin{aligned} \mathbb{E}(X_{i+1} - X_i) & = \frac{n-i}{n} \times \frac{1}{(1 - \frac{i}{n})^2} \\& = \frac{n-i}{n} \times \Bigl(\frac{n}{n - i}\Bigr)^2 \\& = \frac{n}{n - i} \end{aligned}$$

Finally we can give the formula for the expected value and check that it gives the expected values:

$$ \mathbb{E}(X_n) = \sum_{i=0}^{n-1} \frac{n}{n-i} $$
def expectedValue(sides: Int): Double =
  (1 to sides).map(sides.toDouble / _).sum
scala> expectedValue(6)
val res0: Double = 14.7

scala> expectedValue(3)
val res1: Double = 5.5

Understanding the Probability

Still given a \(n\)-sided fair dice \(n>0\). Let \(C \subset \mathcal{C}_n\) be a subset of the sides of the dice and \(l \in \mathbb{N}\) a non-negative integer. The event of all outcomes whose prefixes of length \(l\) do not contain any side in \(C\) is written

$$M_{n,l,C} = \mathcal{E}(\{ f\in \mathcal{F} \mid |f|=l,\quad \diamond f \cap C = \emptyset \}$$

Note that for any subsets \(C_1\) and \(C_2\) of \(\mathcal{C}_n\) and \(l\in\mathbb{N}\), the property \(M_{n,l,C_1}\cap M_{n,l,C_1} = M_{n,l,C_1\cup C_2}\) holds and \(p(M_{n,l,C}) = (\frac{n - |C|}{n})^l = (1 - \frac{|C|}{n})^l\). Let \(A_{n,l}\) be the event of all outcomes whose prefixes of size \(l\) contain every side of the dice at least once:

$$\begin{aligned} A_{n,l} & = \mathcal{E}(\{ f\in \mathcal{F} \mid |f|=l,\quad \diamond f = C_n \} \\ & = \bigcap_{c\in \mathcal{C}_n} \overline{M_{n,l,\{c\}}} \\ & = \overline{\bigcup_{c\in C} M_{n,l,\{c\}}} \end{aligned}$$$$\begin{aligned} p(A_{n,l}) & = 1 - p(\bigcup_{c \in C_n} M_{n,l,\{c\}}) \\ & = 1 - \Biggl[\sum_{C \subset C_n,C\neq\emptyset} -(-1)^{|C|} \times p(\bigcap_{c\in C} M_{n,l,\{c\}})\Biggr] \\ & = 1 - \Biggl[\sum_{C \subset C_n,C\neq\emptyset} -(-1)^{|C|} \times p(M_{n,C,l})\Biggr] \\ & = 1 + \Biggl[\sum_{C \subset C_n,C\neq\emptyset} (-1)^{|C|} \times \biggl(\frac{1-|C|}{n}\biggr)^{l}\Biggr] \\ & = 1 + \sum_{k=1}^{n} \binom{n}{k} \times (-1)^k \times \biggl(1 - \frac{k}{n}\biggr)^{l} \end{aligned}$$

We can generalize this result: given a \(n\)-sided fair dice and \(C \subset \mathcal{C}_n\) a subset of the sides of the dice. Let \(A_{n,l,C}\) be the set of outcomes whose prefixes of length \(l\) do contain only sides of \(C\) and every side of \(C\):

$$\begin{aligned} A_{n,l,C} & = \mathcal{E}(\{f\in \mathcal{F} \mid |f|=n, \quad \diamond f = C\}) \\p(A_{n,l,C}\mid M_{n,l,\overline{C}}) & = p(A_{l,|C|}) \\\ & = 1 + \sum_{k=1}^{|C|} \binom{|C|}{k} \times (-1)^k \times \biggl(1 - \frac{k}{|C|}\biggr)^{l} \end{aligned}$$

We found the probability of observing, in \(l\) dice rolls or less, a subset of all the sides. But \(A_{n,l,C}\) is not in general \(X_n^{-1}(\{l\})\) because outcomes in \(X_n^{-1}(\{l\})\) reach the last observed side at roll \(l\) while outcomes in \(A_{n,l}\) may have observed every side much before the \(l\)-th roll. But we can relate the two. For any side \(c \in \mathcal{C}_n\) and any non-negative integer \(i\in\mathbb{N}\), let \(R_{n,i,c}\) the event of observing the side \(c\) at roll \(i\). Let \(l\in\mathbb{N}^*\):

$$\begin{aligned} X_n^{-1}(\{l\}) & = \bigcup_{c=1}^n A_{n,l-1,C_n\setminus\{c\}} \cap M_{n,l-1,\{c\}} \cap R_{n,l,c} \\p(X_n^{-1}(\{l\})) & = \sum_{c=1}^n p(A_{n,l-1,C_n\setminus\{c\}} \cap M_{n,l-1,\{c\}} \cap R_{n,l,c}) \\ & = \sum_{c=1}^n p(A_{n,l-1,C_n\setminus\{c\}} \mid M_{n,l-1,\{c\}}) \times p(M_{n,l-1,\{c\}}) \times \frac{1}{n} \\ & = \sum_{c=1}^n p(A_{n-1,l-1}) \times \biggl(\frac{n-1}{n}\biggr)^{l-1} \times \frac{1}{n} \\ & = \biggl(\frac{n-1}{n}\biggr)^{l-1} \times p(A_{n-1,l-1}) \\ & = \biggl(\frac{n - 1}{n}\biggr)^{l-1} \times \Biggl[ 1 + \sum_{k=1}^{n-1} \binom{n-1}{k} \times (-1)^k \times \biggl(\frac{n - 1 - k}{n-1}\biggr)^{l-1} \Biggr] \\ & = \biggl(1 - \frac{1}{n}\biggr)^{l-1} + \sum_{k=1}^{n-1} \binom{n-1}{k} \times (-1)^k \times \biggl(1 - \frac{k+1}{n}\biggr)^{l-1} \end{aligned}$$

We need to be careful when translating this formula to avoid computation approximations to lead to very wrong answers:

def probabilityForLength(sides: Int, length: Int): BigDecimal =
  def cnk(k: Int, f: Int): BigInt =
    def dfact(n: Int, k: Int, res: BigInt = BigInt(1L)): BigInt =
      if k <= 0
      then res
      else dfact(n-1, k-1, res*n)

    dfact(f,k) / dfact(k,k-1)

  if sides == 0 || length == 0 then
    if sides == 0 && length == 0 then
      BigDecimal(1)
    else
      BigDecimal(0)
  else
    ((BigDecimal(sides - 1)/sides).pow(length - 1) +
      BigDecimal(
        (1 to (sides-1)).map { (k: Int) =>
          ( cnk(k,sides-1)
            * (if k % 2 == 0 then 1 else -1)
            * BigInt(sides - 1 - k).pow(length-1)
          )
        }.sum
      ) / BigDecimal(sides).pow(length - 1)
    )

We can check that the probability of observing all the sides of a 120-sided dice in less than 120 rolls is indeed 0:

scala> probabilityForLength(120, 13)
val res0: BigDecimal = 4E-34

scala> probabilityForLength(120, 42)
val res1: BigDecimal = 8.751991852311394833964673845157515E-34

scala> probabilityForLength(120, 99)
val res2: BigDecimal = 1.468941911574859178966522092677385E-33

scala> probabilityForLength(120, 119)
val res3: BigDecimal = 1.529470021201154499656736868919480E-33

Note that the very small numbers we get for 42, 99 and 119 rolls instead of 0 are due approximations in computing with so big and small numbers. To get a better idea of how the probability behaves we can, as usual, export it as a CSV file. Let us start by defining the probability for every length using a stream:

scala> val probaD120 =
          LazyList
            .iterate(0)(_ + 1)
            .map(length => length -> probabilityForLength(120, length))
val probaD120: LazyList[(Int, BigDecimal)] = LazyList(<not computed>)

And write this stream up to some length into a CSV file:

def (self: LazyList[(Int, BigDecimal)]) plot(depth: Int, file: String): Unit =
  import java.io._
  val pw = new PrintWriter(new File(file))
  pw.println("length;probability")
  for (length, proba) <- self.take(depth) do
    pw.printf("%d;%f\n", length, proba.toDouble)
  pw.close
scala> probaD120.plot(1500, "/tmp/d120.csv")

d120 density length/probability d120 density length/probability

There is a lot of things we can do thanks to this probability like asking how many rolls we need to observe all the sides of dice 9 times out of 10. The probability of observing all the sides of the dice in exactly or less than \(l\) rolls is given by

$$p(X_n \le l) = \sum_{l'=0}^{l} p(X_n = l')$$

All we need to do is transform the stream probaD120:

def [A,B,C](ll: LazyList[A])foldP(z: B)(f: (B,A) => (B,C)): LazyList[C] =
  LazyList.unfold((ll,z)) { case (l0, z0) =>
      if l0.isEmpty then
        None
      else
        val (z1,c0) = f(z0,l0.head)
        Some((c0, (l0.tail,z1)))
  }

def (self: LazyList[(Int, BigDecimal)]) cumul: LazyList[(Int, BigDecimal)] =
  self.foldP(BigDecimal(0)) { case (s, (l,p)) => (s+p, l -> (s+p)) }
scala> val distribD120 = probaD120.cumul
val distribD120: LazyList[(Int, BigDecimal)] = LazyList(<not computed>)

scala> distribD120.dropWhile(_._2 < 0.9).head._1
val res5: Int = 842

There is 90% chance that we observe all sides of the 120-sided fair dice with 842 rolls. Once again, we can get a better idea of how the distribution behaves by plotting it

scala> distribD120.plot(1500, "/tmp/d120_cumul.csv")

d120 distribution length/probability d120 distribution length/probability

Conclusion

The initial question was simple:

how many rolls are needed, on average, with a \(n\) sided fair dice, to observe all of its sides.

But the answer was not! We have seen how to check, empirically, that the Scala Random.nextInt function correctly simulates a fair dice. From there we run (many) experiments to get an approximation of the answer. We experienced how easy but disastrous it can be to build a model disconnected from reality. We learned that building a valid model requires a deep understanding of the problem and the experiments. We had to put a lot of care into the construction of the model to be sure it is a valid formalization of the problem. The maths were not easy, but they were right. And in the end, maths lead us to a very simple formula. Was all this formal brutality useful? Yes, it was. The simple formula gives the exact answer and is by far the most efficient implementation. Going deeper we found how to answer more questions like the chance we have to observe all sides in \(l\) rolls or less. We even were able to get a precise idea of how the probability behaves by plotting it.

All along this journey, we used many of the new Scala 3 features among which extensions methods and Type Classes. We saw how easy they were to use and the great benefits they offer. Extensions methods let us add methods to objects without any boilerplate, Type Classes let us write generic function, etc.

I hope you enjoyed this journey as much as I loved writing it. Have a look at all the Scala 3 features. Many of them are not covered here but are truly amazing (Polymorphic Functions, Dependent Function Types, Match Types, Intersection and Union Types, etc). The list is pretty large.

Les GADTs Par l'Exemple

Soyez les bienvenu·e·s! Cette session a le dessein de vous présenter un outil de programmation très puissant. Alors que la plupart des introductions sur le sujet commencent par une présentation de ses fondements théoriques d’une manière très formelle, nous avons choisi de vous le présenter à travers de courts exemples et des cas d’utilisation concrets.

Cet atelier est composé de trois parties. La dernière présente trois des cas d’utilisation des plus utiles. Ils forment les usages majeurs en pratique. Mais ne vous y aventurez pas sans préparation! Cette partie est la dernière pour une bonne raison: elle s’appuie massivement sur les leçons des parties précédentes. Commencez par Premier Contact, elle vous exposera, via les plus simples exemples, les idées clefs. Son but est d’ouvrir votre esprit à des manières d’utiliser les types et données que vous n’avez vraisemblablement jamais soupçonnées. Arpentez ensuite Cas d’utilisation simples et utiles: relations sur les types, pour un premier défi devant un usage pratique. Après cela seulement vous serez prêt·e pour Cas d’Utilisation Plus Avancés.

Assurez vous de lire LISEZ-MOI, cette section contient de précieuses astuces pour faciliter votre parcours.

Remerciements

Nous tenons à remercier Laure Juglaret pour ses nombreuses relectures, ses précieuses remarques et corrections.

LISEZ-MOI

Durant toute cette présentation, nous considérerons que:

  • null n’existe pas!
  • La réflexion au runtime n’existe pas! (c.-à-d. isInstanceOf, getClass, etc)

Cette présentation considère que ces fonctionnalités n’existent pas du tout!.

Leur utilisation n’amènera jamais à une réponse correcte aux questions..

Pour faire cet atelier vous devez disposez du nécessaire pour écrire, compiler et exécuter rapidement du code Scala. Le meilleur moyen est d’ouvrir une session interactive (R.E.P.L.). Si vous avez Scala d’installé sur votre système, vous pouvez facilement en démarrer une via la ligne de commande en exécutant le programme scala:

system-command-line# scala
Welcome to Scala 2.13.1 (OpenJDK 64-Bit Server VM, Java 1.8.0_222).
Type in expressions for evaluation. Or try :help.

scala>

Pour rappel, dans une session interactive (R.E.P.L.), la commande :paste permet de copier du code dans la session et la commande :reset de repartir d’un environnement vierge.

Si vous n’avez pas Scala d’installé, vous pouvez utiliser le site https://scastie.scala-lang.org/ .

Échauffements

Cette section est un bref rappel de quelques définitions et propriétés sur les types et les valeurs.

Valeurs et Types?

Les valeurs sont les données concrètes que vos programmes manipulent comme l’entier 5, le booléen true, la chaîne "Hello World!", la fonction (x: Double) => x / 7.5, la liste List(1,2,3), etc. Il est souvent pratique de classer les valeurs en groupes. Ces groupes sont appelés des types. Par exemple:

  • Int est le groupe des valeurs entières, c.-à-d. les valeurs telles que 1, -7, 19, etc.
  • Boolean est le groupe contenant exactement les valeurs true et false (ni plus, ni moins!).
  • String est le groupe dont les valeurs sont "Hello World!", "", "J' ❤️ les GADTs", etc.
  • Double => Double est le groupe dont les valeurs sont les fonctions prenant en argument n’importe quel Double et renvoyant également un double Double.

Pour indiquer que la valeur v appartient au type (c.-à-d. groupe de valeurs) T, la notation est v : T. En Scala, tester si une valeur v appartient au type T est très simple: il suffit de taper v : T dans la session interactive (REPL):

scala> 5 : Int
res7: Int = 5

Si Scala l’accepte, alors v appartient bien au type T. Si Scala râle, ce n’est probablement pas le cas:

scala> 5 : String
       ^
       error: type mismatch;
        found   : Int(5)
        required: String

Combien de types?

Créons maintenant quelques types et quelques unes de leurs valeurs (quand cela est possible!).

class UnType
  • Question 1: Combien de types la ligne class UnType définit-elle?

    Solution (cliquer pour dévoiler)

    Comme son nom le suggère, la ligne class UnType définit seulement un type, nommé UnType.

Passons maintenant à:

class UnTypePourChaque[A]
  • Question 2: Combien de types la ligne class UnTypePourChaque[A] définit-elle?

    Solution (cliquer pour dévoiler)

    Comme son nom le suggère, chaque type concret A donne lieu à un type distinct UnTypePourChaque[A].

    Par exemple, une liste d’entiers n’est ni une liste de booléens, ni une liste de chaîne de caractères, ni une liste de fonctions, ni … En effet les types List[Int], List[Boolean], List[Int => Int], etc sont tous des types distincts.

    la ligne class UnTypePourChaque[A] définit un type distinct pour chaque type concret A. Il y a une infinité de types concrets A, donc une infinité de de types distincts UnTypePourChaque[A].

  • Question 3: Donnez une valeur qui appartient à la fois aux types UnTypePourChaque[Int] et UnTypePourChaque[Boolean].

    Pour rappel, null n’existe pas!

    Solution (cliquer pour dévoiler)

    C’est en fait impossible. Chaque type concret A donne lieu à un type distinct UnTypePourChaque[A] qui n’a aucune valeur en commun avec les autres types de la forme UnTypePourChaque[B] avec B ≠ A.

Combien de valeurs?

En considérant le type suivant:

final abstract class PasDeValeurPourCeType
  • Question 1: Donnez une valeur appartenant au type PasDeValeurPourCeType? Combien de valeurs appartiennent au type PasDeValeurPourCeType?

    Astuce (cliquer pour dévoiler)
    • Qu’est ce qu’une classe final? En quoi est-ce qu’elle diffère d’une classe normale (non finale)?
    • Qu’est ce qu’une classe abstract? En quoi est-ce qu’elle diffère d’une classe concrète?
    Solution (cliquer pour dévoiler)

    La classe PasDeValeurPourCeType est déclarée comme abstract. Cela signifie qu’il est interdit de créer des instances directes de cette classe:

    scala> new PasDeValeurPourCeType
           ^
           error: class PasDeValeurPourCeType is abstract; cannot be instantiated

    La seule manière de créer une instance d’une classe abstraite est de créer une une sous-classe concrète. Mais le mot clef final interdit la création de telles sous-classes:

    scala> class SousClasseConcrete extends PasDeValeurPourCeType
                                            ^
            error: illegal inheritance from final class PasDeValeurPourCeType

    Il n’existe aucun moyen de créer une instance pour PasDeValeurPourCeType.

Prenons un autre exemple:

sealed trait ExactementUneValeur
case object LaSeuleValeur extends ExactementUneValeur
  • Question 2: Donnez une valeur appartenant au type ExactementUneValeur?

    Solution (cliquer pour dévoiler)

    Par définition, LaSeuleValeur est une valeur du type ExactementUneValeur.

  • Question 3: Combien de valeurs appartiennent à ExactementUneValeur?

    Solution (cliquer pour dévoiler)

    Comme ci-dessus, ExactementUneValeur, étant un trait, est abstrait. Étant sealed, l’étendre en dehors de son fichier source est interdit. Donc LaSeuleValeur est la seule valeur du type ExactementUneValeur.

Premier Contact

Cette partie présente les idées clefs. Il y a en fait seulement deux idées! Vous trouverez ici des exemples épurés illustrant chacune de ces deux idées.

Cas d’Utilisation: Preuve d’une propriété

Définissons un simple sealed trait:

sealed trait ATrait[A]
case object AValue extends ATrait[Char]
  • Question 1: Donnez une valeur du type ATrait[Char].

    Solution (cliquer pour dévoiler)

    Par définition, AValue est une valeur du type ATrait[Char].

  • Question 2: Donnez une valeur du type ATrait[Double].

    Solution (cliquer pour dévoiler)

    Il n’existe aucun moyen d’obtenir une instance du type ATrait[Double]. Il n’existe en fait aucun moyen d’obtenir une instance de ATrait[B] pour B ≠ Char parce que la seule valeur possible est AValue qui est de type ATrait[Char].

  • Question 3: Que pouvez vous conclure sur le type A si vous avez une valeur ev de type ATrait[A] (c.-à-d. ev: ATrait[A])?

    Solution (cliquer pour dévoiler)

    La seule valeur possible est AValue, donc ev == AValue. De plus AValue est de type ATrait[Char] donc A = Char.

  • Question 4: Dans la session interactive (REPL), entrez le code suivant:

    def f[A](x: A, ev: ATrait[A]): Char =
      x
  • Question 5: Essayez maintenant en utilisant un filtrage par motif (pattern matching) sur ev: ATrait[A]

    def f[A](x: A, ev: ATrait[A]): Char =
      ev match {
        case AValue => x
      }

    Le filtrage par motif (pattern-matching) est il exhaustif?

    Solution (cliquer pour dévoiler)

    Le filtrage par motif est exhaustif parce la seule et unique valeur possible pour ev est en fait AValue. De plus AValue est de type ATrait[Char] ce qui signifie que ev : ATrait[Char] parce que ev == AValue. Donc A = Char et x : Char.

  • Question 6: Appelez f avec x = 'w' : Char.

    Solution (cliquer pour dévoiler)
    scala> f[Char]('w', AValue)
    res0: Char = w
  • Question 7: Appelez f avec x = 5.2 : Double.

    Solution (cliquer pour dévoiler)

    C’est impossible parce que cela demenderait de fournir une valeur ev : ATrait[Double], ce qui n’existe pas!

    scala> f[Double](5, AValue)
                        ^
              error: type mismatch;
                found   : AValue.type
                required: ATrait[Double]
Remarque pour les personnes à l'aise en Scala (cliquer pour dévoiler)

En utilisant toutes les chouettes fonctionnalités syntaxiques de Scala, la version satisfaisante en production du code ci-dessus est:

sealed trait IsChar[A]
object IsChar {
  implicit case object Evidence extends IsChar[Char]

  def apply[A](implicit evidence: IsChar[A]): IsChar[A] =
    evidence
}

def f[A: IsChar](x: A): Char =
  IsChar[A] match {
    case IsChar.Evidence => x
  }

Cas d’Utilisation: La seule chose que je sais, est qu’il existe

Que feriez vous si vous vouliez que votre application tienne un journal d’évènements (c.-à-d. un log), mais que vous vouliez être sur qu’elle ne dépende d’aucun détail d’implémentation de la méthode de journalisation (c.-à-d. du logger), de telle manière que vous puissiez changer son implémentation sans risquer de casser votre application?

En considérant le type suivant, UnknownLogger, des méthodes de journalisation:

sealed trait UnknownLogger
final case class LogWith[X](logs : X, appendMessage: (X, String) => X) extends UnknownLogger

La première méthode (c.-à-d. *logger) que nous créons stocke les messages dans une String:

val loggerStr : UnknownLogger =
  LogWith[String]("", (logs: String, message: String) => logs ++ message)

La seconde méthode les stocke dans une List[String]:

val loggerList : UnknownLogger =
  LogWith[List[String]](Nil, (logs: List[String], message: String) => message :: logs)

La troisième méthode de journalisation imprime directement les messages sur la sortie standard:

val loggerStdout : UnknownLogger =
  LogWith[Unit]((), (logs: Unit, message: String) => println(message))

Notez que ces trois méthodes de journalisation ont toutes le même type (c.-à-d. UnknownLogger) mais qu’elles stockent les messages en utilisant différents types X (String, List[String] et Unit).

  • Question 1: Soit v une valeur de type UnknownLogger. Clairement v doit être une instance de la classe LogWith[X] pour un certain X. Que pouvez vous dire sur le type X? Pouvez-vous deviner quel type concret est X?

    Pour rappel, il est interdit d’utiliser la réflexion au runtime! (c.-à-d. isInstanceOf, getClass, etc)

    Solution (cliquer pour dévoiler)

    Nous ne savons presque rien sur X. La seule chose que nous avons est qu’il existe au moins une valeur (v.logs) de type X. À part cela, X peut être n’importe quel type.

    Ne pas savoir quel type concret est X est très utile pour garantir que le code qui utilisera v : UnknownLogger ne dépendra pas de la nature de X. Si ce code savait que X était String par exemple, il pourrait exécuter des opérations que nous voulons interdir comme inverser la liste, ne retenir que les n premiers caractères, etc. En cachant la nature deX, nous forçons notre application à ne pas dépendre du type concret derrièreX mais de n’utiliser que la fonnction fournie v.appendMessage. Ainsi changer l’implémentation réelle de la méthode de journalisation ne cassera aucun code.

  • Question 2: Écrivez la fonction def log(message: String, v: UnknownLogger): UnknownLogger qui utilise v.appendMessage pour ajouter le message au journal v.logs et retourne un nouvel UnknownLogger contenant le nouveau journal (c.-à-d. le nouveau log).

    Pour rappel, en Scala, le motif (c.-à-d. pattern) case ac : AClass[t] => est possible dans les expressions de type match/case comme alternative au motif case AClass(v) => :

    final case class AClass[A](value : A)
    
    def f[A](v: AClass[A]): A =
      v match {
        case ac : AClass[t] =>
          // La variable `t` is une variable de type
          // Le type `t` est `A`
          val r : t = ac.value
          r
      }

    Son principal avantage est d’introduire la variable de type t. Les variables de type se comportent comme des variables de motif classiques (c.-à-d. pattern variables) à l’exception prés qu’elles représentent des types et non des valeurs. Avoir t sous la main nous permet d’aider le compilateur en donnant explicitement certains types (comme ci-dessus, expliciter que r est de type t).

    Solution (cliquer pour dévoiler)
    def log(message: String, v: UnknownLogger): UnknownLogger =
      v match {
        case vx : LogWith[x] => LogWith[x](vx.appendMessage(vx.logs, message), vx.appendMessage)
      }
  • Question 3: Exécutez log("Hello World", loggerStr) et log("Hello World", loggerList) et log("Hello World", loggerStdout)

    Solution (cliquer pour dévoiler)
    scala> log("Hello World", loggerStr)
    res0: UnknownLogger = LogWith(Hello World,$$Lambda$988/1455466014@421ead7e)
    
    scala> log("Hello World", loggerList)
    res1: UnknownLogger = LogWith(List(Hello World),$$Lambda$989/1705282731@655621fd)
    
    scala> log("Hello World", loggerStdout)
    Hello World
    res2: UnknownLogger = LogWith((),$$Lambda$990/1835105031@340c57e0)
Remarque pour les personnes à l'aise en Scala (cliquer pour dévoiler)

Une fois encore, en utilisant toutes les chouettes fonctionnalités syntaxiques de Scala, la version satisfaisante en production du code ci-dessus est:

sealed trait UnknownLogger {
  type LogsType
  val logs : LogsType
  def appendMessage(presentLogs: LogsType, message: String): LogsType
}

object UnknownLogger {

  final case class LogWith[X](logs : X, appendMessage_ : (X, String) => X) extends UnknownLogger {
    type LogsType = X
    def appendMessage(presentLogs: LogsType, message: String): LogsType =
      appendMessage_(presentLogs, message)
  }

  def apply[X](logs: X, appendMessage: (X, String) => X): UnknownLogger =
    LogWith[X](logs, appendMessage)
}

Conclusion Intermédiaire

Les GADTs ne sont en fait que ceci: de simples sealed trait avec quelques case object (possiblement aucun) et quelques final case class (également possiblement aucune!). Dans les parties suivantes, nous explorerons quelques un des cas d’utilisation majeurs des GADTs.

Cas d’utilisation simples et utiles: relations sur les types

Une faculté simple mais très utile des GADTs est l’expression de relations sur les types telles que:

  • Le type A est-il égal au type B?
  • Le type A est-il un sous-type de B?

Notez bien que, par définition, tout type A est sous-type de lui-même (c.-à-d. A <: A), tout comme tout entier x est également inférieur ou égal à lui-même x ≤ x.

Cas d’Utilisation: Témoin d’Égalité entre Types

sealed trait EqT[A,B]
final case class Evidence[X]() extends EqT[X,X]
  • Question 1: Donnez une valeur de type EqT[Int, Int]

    Solution (cliquer pour dévoiler)
    scala> Evidence[Int]() : EqT[Int, Int]
    res0: EqT[Int,Int] = Evidence()
  • Question 2: Donnez une valeur de type EqT[String, Int]

    Solution (cliquer pour dévoiler)

    La classe Evidence est l’unique sous-classe conctrète du trait EqT et il est impossible d’en créer une autre parce que EqT est sealed. Donc une valeur v : EqT[A,B] ne peut être qu’une instance de Evidence[X] pour un certain type X, qui elle-même est de type EqT[X,X]. Ainsi il n’y a aucun moyen d’obtenir une valeur de type EqT[String, Int]

  • Question 3: Soient A et B deux types (inconnus). Si je vous donne une valeur de type EqT[A,B], que pouvez-vous en déduire sur A et B?

    Solution (cliquer pour dévoiler)

    Si je vous donne une valeur v : EqT[A,B], alors vous savez que v est une instance de Evidence[X] pour un certain type X (inconnu). En effet la classe Evidence est la seule et unique sous-classe concrète du sealed trait EqT. En fait, Evidence[X] est un sous-type de EqT[X,X]. Donc v : EqT[X,X]. Les types EqT[A,B] et EqT[X,X] n’ont aucune valeur en commun si A ≠ X ou B ≠ X, donc A = X et B = X. Et donc A = B. CQFD.

Remarque pour les personnes à l'aise en Scala (cliquer pour dévoiler)

En production, il est pratique de définir EqT de la manière suivante, qui est bien entendu équivalente:

sealed trait EqT[A,B]
object EqT {
  final case class Evidence[X]() extends EqT[X,X]

  implicit def evidence[X] : EqT[X,X] = Evidence[X]()

  def apply[A,B](implicit ev: EqT[A,B]): ev.type = ev
}

Passer d’un type égal à l’autre

Si A et B sont en fait le même type, alors List[A] est également le même type que List[B], Option[A] est également le même type que Option[B], etc. De manière générale, pour n’importe quel F[_], F[A] est également le même type que F[B].

  • Question 4: Écrivez la fonction def coerce[F[_],A,B](eqT: EqT[A,B])(fa: F[A]): F[B].

    Solution (cliquer pour dévoiler)
    def coerce[F[_],A,B](eqT: EqT[A,B])(fa: F[A]): F[B] =
      eqT match {
        case _ : Evidence[x] => fa
      }

La bibliothèque standard de Scala définit déjà une classe, nommée =:=[A,B] (son nom est bel et bien =:=), représentant l’égalité entre types. Je vous recommande vivement de jeter un œil à sa documentation (cliquez ici). Fort heureusement, pour plus de lisibilité, Scala nous permet d’écrire A =:= B le type =:=[A,B].

Étant donné deux types A et B, avoir une instance (c.-à-d. objet) de A =:= B prouve que A et B sont en réalité le même type, tout comme pour EqT[A,B]. Pour rappel, A =:= B n’est que du sucre syntaxique pour désigner le type =:=[A,B].

Des instances de A =:= B peuvent êtres crées en utilisant la fonction (<:<).refl[X]: X =:= X (cliquer pour voir la documentation). Le “symbole” <:< est en effet un nom d’objet valide.

  • Question 5: En utilisant la fonction coerce ci-dessus, écrivez la fonction def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B.

    Astuce (cliquer pour dévoiler)
    def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B = {
      /* Trouver une définition pour:
            - le constructeur de type `F`
            - la valeur `fa : F[A]`
      */
      type F[X] = ???
      val fa : F[A] = ???
    
      /* Telles que cet appel: */
      coerce[F,A,B](eqT)(fa) // soit de type `F[B]`
    }
    Solution (cliquer pour dévoiler)
    def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B = {
      type F[X] = A =:= X
      val fa: F[A] = (<:<).refl[A]
      coerce[F,A,B](eqT)(fa)
    }
  • Question 6: En utilisant la méthode substituteCo[F[_]](ff: F[A]): F[B] des objets de la classe A =:= B, dont la documentation est ici, écrivez la fonction def fromScalaEq[A,B](scala: A =:= B): EqT[A,B].

    Astuce (cliquer pour dévoiler)
    def fromScalaEq[A,B](scala: A =:= B): EqT[A,B] = {
      /* Trouver une définition pour:
            - le constructeur de type `F`
            - la valeur `fa : F[A]`
      */
      type F[X] = ???
      val fa : F[A] = ???
    
      /* Telles que cet appel: */
      scala.substituteCo[F](fa) // soit de type `F[B]`
    }
    Solution (cliquer pour dévoiler)
    def fromScalaEq[A,B](scala: A =:= B): EqT[A,B] = {
      type F[X] = EqT[A,X]
      val fa: F[A] = Evidence[A]()
      scala.substituteCo[F](fa)
    }

Cas d’Utilisation: Témoin de Sous-Typage

Dans cette section, nous voulons créer les types SubTypeOf[A,B] dont les valeurs prouvent que le type A est un sous-type de B (c.-à-d. A <: B). Une classe similaire, mais différente, est déjà définie dans la bibliothèque standard de Scala. Il s’agit de la classe <:<[A,B], qui est le plus souvent écrite A <:< B. Sa documentation est ici. Cette section étant dédiée à l’implémentation d’une variante de cette classe, veuillez ne pas utiliser <:<[A,B] pour implémenter SubTypeOf.

  • Question 1: En utilisant uniquement des bornes supérieures (c.-à-d. A <: B) ou bornes inférieures (c.-à-d. A >: B) et aucune annotation de variance (c.-à-d. [+A] et [-A]), créez le trait SubTypeOf[A,B] (et tout ce qui est nécessaire) tel que:

    Il existe une valeur de type SubType[A,B] si et seulement si A est un sous-type de B (c.-à-d. A <: B).

    Pour rappel, par définition, un type A est un sous-type de lui-même (c.-à-d. A <: A).

    Pour rappel, n’utilisez pas la classe <:<[A,B].

    Solution (cliquer pour dévoiler)
    sealed trait SubTypeOf[A,B]
    final case class SubTypeEvidence[A <: B, B]() extends SubTypeOf[A,B]
    Remarque pour les personnes à l'aise en Scala (cliquer pour dévoiler)

    En production, il est pratique de définir SubTypeOf de la manière équivalente suivante:

    sealed trait SubTypeOf[A,B]
    object SubTypeOf {
      final case class Evidence[A <: B, B]() extends SubTypeOf[A,B]
    
      implicit def evidence[A <: B, B]: SubTypeOf[A,B] = Evidence[A,B]()
    
      def apply[A,B](implicit ev: SubTypeOf[A,B]): ev.type = ev
    }

Cas d’Utilisation: Éviter les messages d’erreur de scalac à propos des bornes non respectées

Dans cet exemple, nous voulons modéliser le régime alimentaire de certains animaux. Commençons par définir le type Food (c.-à-d. nourriture) et quelques-uns de ces sous-types:

trait Food
class Vegetable extends Food
class Fruit extends Food

et maintenant la classe représentant les animaux mangeant de la nourriture de type A (c.-à-d. Vegetable, Fruit, etc):

class AnimalEating[A <: Food]

val elephant : AnimalEating[Vegetable] =
  new AnimalEating[Vegetable]

Définissons une fonction comme il en existe tant en Programmation Fonctionnelle et passons lui elephant comme argument:

def dummy[F[_],A](fa: F[A]): String = "Ok!"
scala> dummy[List, Int](List(1,2,3))
res0: String = Ok!

scala> dummy[Option, Boolean](Some(true))
res1: String = Ok!

scala> dummy[AnimalEating, Vegetable](elephant)
            ^
       error: kinds of the type arguments (AnimalEating,Vegetable)
       do not conform to the expected kinds of the type parameters (type F,type A).

       AnimalEating's type parameters do not match type F's expected parameters:
       type A's bounds <: Food are stricter than
       type _'s declared bounds >: Nothing <: Any
  • Question 1: Pourquoi scalac se plaint il?

    Solution (cliquer pour dévoiler)

    La fonction dummy requiert que son argument F, qui est un constructeur de type comme le sont List, Option, Future, etc, accepte n’importe quel type en argument afin qu’il soit toujours possible d’écrire F[A] pour n’importe quel type A. Hors AnimalEating impose que son argument soit un sous-type de Food. Donc AnimalEating ne peut être utilisé comme argument F de dummy.

Le problème est que, en définissant class AnimalEating[A <: Food], nous avons imposé à A d’être un sous-type de Food. Donc Scala, tout comme Java, nous interdit de donner à AnimalEating, en tant qu’argument A, autre chose qu’un sous-type de Food (en incluant Food lui-même):

scala> type T1 = AnimalEating[Int]
                 ^
       error: type arguments [Int] do not conform
       to class AnimalEating's type parameter bounds [A <: Food]

scala> type T2 = AnimalEating[Food]
defined type alias T2

Nous sommes face à un dilemme: afin d’utiliser la fonction dummy, que nous tenons beaucoup à utiliser parce c’est une fonction très utile, il nous faut supprimer la contrainte A <: Food de la définition class AnimalEating[A <: Food]. Mais nous tenons également au fait que les animaux ne mangent que de la nourriture (Food) et pas des entiers, ni des booléens et encore moins des chaînes de caractères!

  • Question 2: Comment pouvez vous adapter la définition de AnimalEating telle que:

    • Il soit possible d’appeler dummy avec comme argument elephant! Nous voulons:

      scala> dummy[AnimalEating, Vegetable](elephant)
      res0: String = Ok!
    • Si A n’est pas un sous-type de Food (Food lui-même inclus), alors il doit être impossible de créer une instance de AnimalEating[A].

    • La classe AnimalEating doit rester une classe ouverte (c.-à-d. non sealed ou final)! Il doit toujours être possible pour n’importe qui, n’importe quand, de créer librement des sous-classes de AnimalEating. Bien évidemment, ces sous-classes doivent respecter les deux contraintes ci-dessus.

      Astuce (cliquer pour dévoiler)

      En Scala, Nothing est un type ne contenant aucune valeur. Pouvez vous créer une valeur de type (Nothing, Int)? Pourquoi?

      Solution (cliquer pour dévoiler)

      Si, afin de créer une instance de AnimalEating[A], nous forçons chaque méthode créant des valeurs à prendre un paramètre supplémentaire de type SubTypeOf[A, Food], alors il sera uniquement possible de créer une instance de AnimalEating[A] quand A sera un sous-type de Food:

      class AnimalEating[A](ev : SubTypeOf[A, Food])
      
      val elephant : AnimalEating[Vegetable] =
        new AnimalEating[Vegetable](SubTypeEvidence[Vegetable, Food])

      Pour créer une valeur de type AnimalEating[A], nous avons besoin d’appeler le constructeur d’AnimalEating. Pour appeler ce constructeur, il nous faut fournir ev : SubTypeOf[A, Food].

      Il nous est désormais possible d’appeler la fonction dummy sur elephant:

      scala> dummy[AnimalEating, Vegetable](elephant)
      res0: String = Ok!

      En pratique, en utilisant des implicites, le compilateur peut fournir de lui-même le paramètre ev : SubTypeOf[A, Food].

      Notez qu’il est désormais possible d’écrire le type AnimalEating[Int] mais vous ne pourrez jamais créer une valeur de ce type.

Cas d’Utilisation: Fournir les bonnes données au bon diagramme

Ce cas d’utilisation traite des méthodes pour garantir, à la compilation, que seulement les valeurs du bon type peuvent être données à une fonction donnée. L’exemple choisi est celui de la conception d’une bibliothèque de graphiques. Afin de simplifier l’exemple, nous considèrerons que notre bibliothèque n’implémente que deux types de graphique: des camemberts (c.-à-d. pie charts) et des graphiques dit XY (c.-à-d. XY charts). Cela s’écrit en Scala via l’énumération:

sealed trait ChartType
case object PieChart extends ChartType
case object XYChart extends ChartType

Bien évidemment les camemberts (Pie) et graphiques XY s’appuient sur des jeux de données de nature différente. Encore une fois, pour simplifier, nous considèrerons que les deux types de données sont PieData pour les camemberts et XYData pour les graphiques XY:

class PieData
class XYData

Un camembert (PieChart) n’affiche que des données PieData, alors qu’un graphique XY (XYChart) n’affiche que des données XYData. Voici, grandement simplifiée, la fonction d’affichage draw:

def draw[A](chartType: ChartType)(data: A): Unit =
  chartType match {
    case PieChart =>
      val pieData = data.asInstanceOf[PieData]
      // Faire des trucs pour tracer les données pieData
      ()
    case XYChart =>
      val xyData = data.asInstanceOf[XYData]
      // Faire des trucs pour tracer les données xyData
      ()
  }

Cette fonction repose sur l’hypothèse que l’utilisateur·rice n’appellera la fonction draw que sur le bon type de données. Quand chartType vaut PieChart, la fonction présuppose, via data.asInstanceOf[PieData] que data est en fait du type PieData. Et quand chartType vaut XYChart, elle présuppose que data est en fait de type XYData.

Le problème est que ces suppositions reposent sur l’idée que les utilisateurs·rices et/ou développeurs·euses s’assureront toujours que ces hypothèses soient bien respectées. Mais rien n’empêche quelqu’un·e d’appeler draw sur un camembert (PieChart) avec des données de type XYData (ou le contraire), faisant planter le système misérablement en production!

scala> draw(PieChart)(new XYData)
java.lang.ClassCastException: XYData cannot be cast to PieData
  at .draw(<pastie>:11)
  ... 28 elided

En tant que développeurs·euses, nous savons que les erreurs, ça arrive! Nous voulons un moyen d’empêcher ces bogues ennuyeux de survenir en production! Nous voulons imposer, à la compilation, que seulement deux scenarii soit possibles:

  • Quand draw est appelée avec chartType == PieChart: l’argument data doit être de type PieData
  • Quand draw est appelée avec chartType == XYChart: l’argument data doit être de type XYData.

Pour rappel, ces deux contraintes doivent être vérifiées à la compilation!

  • Question 1: Adaptez les définitions de ChartType, PieChart, XYChart et draw telles que:

    • Tout scenario différent des deux ci-dessus fera échouer la compilation sur une erreur de type.

    • ChartType doit toujours être un sealed trait. Mais il est autorisé à prendre des paramètres de type (c.-à-d. generics).

    • PieChart et XYChar doivent toujours être des case object et ils doivent toujours étendre ChartType.

    • Les déclarations de ChartType, PieChart et XYChar ne doivent pas avoir de corps du tout (c.-à-d. il ne doit pas y avoir d’accolades { ... } dans leurs déclarations);

    Astuce (cliquer pour dévoiler)

    Le code ressemble à ceci:

    sealed trait ChartType[/*METTRE LES GENERICS ICI*/]
    case object PieChart extends ChartType[/*Il y a quelque chose à écrire ici*/]
    case object XYChart extends ChartType[/*Il y a quelque chose à écrire ici aussi*/]
    
    def draw[A](chartType: ChartType[/*Ecrire quelque chose ici*/])(data: A): Unit =
     chartType match {
        case PieChart =>
          val pieData : PieData = data
          ()
        case XYChart =>
          val xyData: XYData = data
          ()
        }
    Solution (cliquer pour dévoiler)
    sealed trait ChartType[A]
    case object PieChart extends ChartType[PieData]
    case object XYChart extends ChartType[XYData]
    
    def draw[A](chartType: ChartType[A])(data: A): Unit =
     chartType match {
        case PieChart =>
          val pieData : PieData = data
          ()
        case XYChart =>
          val xyData: XYData = data
          ()
        }

Vous pouvez maintenant dormir sur vos deux oreilles avec l’assurance que votre code en production ne plantera pas à cause d’une entrée non conforme à cet endroit 😉

Cas d’Utilisation Plus Avancés

Maintenant que vous avez vu ce que sont les GADTs et comment les utiliser dans la vie de tous les jours, vous êtes prêt·e pour les cas d’utilisations plus conséquents ci-dessous. Il y en a trois. Chacun illustre une manière différente d’utiliser la puissance des GADTs. Le premier traite de l’expression d’effets, ce qui est très largement utilisé dans chaque monade IO populaire ou effets algébriques. Ne vous inquiétez pas de ne pas savoir ce que sont ces derniers, cette section l’expliquera. Le second s’attache à montrer comment garantir des propriétés dans le système de types. Ce point est illustré à travers l’exemple de l’accommodation des techniques issues de la programmation fonctionnelle aux contraintes issues des bases de données. Le troisième offre une manière plus simple de travailler avec des implicites.

Cas d’Utilisation: Les Effets!

Ce qui est appelé un effet est parfois juste une interface déclarant quelques fonctions dépourvues d’implémentation. Par exemple nous pouvons définir le trait ci-dessous. Notez qu’aucune de ces fonctions n’a d’implémentation.

trait ExampleEffectSig {
  def echo[A](value: A): A
  def randomInt : Int
  def ignore[A](value: A): Unit
}

Les implémentations de ces interfaces (traits) sont données ailleurs, et il peut en avoir beaucoup! Cela est utile quand il est désirable de changer facilement d’implémentation:

object ExampleEffectImpl extends ExampleEffectSig {
  def echo[A](value: A): A = value
  def randomInt : Int = scala.util.Random.nextInt()
  def ignore[A](value: A): Unit = ()
}

Une manière équivalente de définir ExampleEffectSig est via un sealed trait muni de quelques final case class (peut-être aucune!) et/ou quelques case object (peut-être aucun!):

sealed trait ExampleEffect[A]
final case class  Echo[A](value: A) extends ExampleEffect[A]
final case object RandomInt extends ExampleEffect[Int]
final case class  Ignore[A](value: A) extends ExampleEffect[Unit]

De nouveau, nous avons des déclarations ne fournissant aucune implémentation! De nouveau, leurs implémentations peuvent être fournies ailleurs et il peut en avoir beaucoup:

def runExampleEffect[A](effect: ExampleEffect[A]): A =
  effect match {
    case Echo(value) => value
    case RandomInt   => scala.util.Random.nextInt()
    case Ignore(_)   => ()
  }

Prenons un effet plus réaliste ainsi qu’une de ses implémentations possibles:

trait EffectSig {
  def currentTimeMillis: Long
  def printLn(msg: String): Unit
  def mesure[X,A](fun: X => A, arg: X): A
}

object EffectImpl extends EffectSig {
  def currentTimeMillis: Long =
    System.currentTimeMillis()

  def printLn(msg: String): Unit =
    println(msg)

  def mesure[X,A](fun: X => A, arg: X): A = {
    val t0 = System.currentTimeMillis()
    val r  = fun(arg)
    val t1 = System.currentTimeMillis()
    println(s"Took ${t1 - t0} milli-seconds")
    r
  }
}
  • Question 1: Tout comme ExampleEffect est l’équivalent de ExampleEffectSig via la définition d’un sealed trait muni de quelques final case class et case object, écrivez l’équivalent de EffectSig de la même manière. Appelez ce trait Effect.

    Solution (cliquer pour dévoiler)
    sealed trait Effect[A]
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
  • Question 2: Écrivez la fonction def run[A](effect: Effect[A]): A qui reproduit l’implémentation de EffectImpl tout comme runExampleEffect reproduit celle de ExampleEffectImpl.

    Solution (cliquer pour dévoiler)
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
      }

Le type Effect[A] déclare des effets intéressants (CurrentTimeMillis, PrintLn et Mesure) mais pour être réellement utile, il doit être possible de chaîner ces effets! Pour ce faire, nous voulons pouvoir disposer des deux fonctions suivantes:

  • def pure[A](value: A): Effect[A]
  • def flatMap[X,A](fx: Effect[X], f: X => Effect[A]): Effect[A]

De nouveau, nous ne nous intéressons pas à leurs implémentations. Tout ce que nous voulons, pour le moment, est déclarer ces deux opérations de la même manière que nous avons déclaré CurrentTimeMillis, PrintLn et Mesure.

  • Question 3: Ajoutez deux final case classes, Pure et FlatMap, à Effect[A] déclarant ces deux opérations.

    Solution (cliquer pour dévoiler)
    sealed trait Effect[A]
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
    final case class  Pure[A](value: A) extends Effect[A]
    final case class  FlatMap[X,A](fx: Effect[X], f: X => Effect[A]) extends Effect[A]
  • Question 4: Adaptez la fonction run pour gérer ces deux nouveaux cas.

    Solution (cliquer pour dévoiler)
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
    
        case Pure(a) =>
          a
    
        case FlatMap(fx, f) =>
          val x  = run(fx)
          val fa : Effect[A] = f(x)
          run(fa)
      }
  • Question 5: Ajoutez les deux méthodes suivantes au trait Effect[A] pour obtenir:

    sealed trait Effect[A] {
      final def flatMap[B](f: A => Effect[B]): Effect[B] = FlatMap(this, f)
      final def map[B](f: A => B): Effect[B] = flatMap[B]((a:A) => Pure(f(a)))
    }

    Et exécutez le code suivant pour voir s’il fonctionne:

    val effect1: Effect[Unit] =
      for {
        t0 <- CurrentTimeMillis
        _  <- PrintLn(s"The current time is $t0")
      } yield ()
    
    run(effect1)
    Solution (cliquer pour dévoiler)
    sealed trait Effect[A] {
      final def flatMap[B](f: A => Effect[B]): Effect[B] = FlatMap(this, f)
      final def map[B](f: A => B): Effect[B] = flatMap[B]((a:A) => Pure(f(a)))
    }
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
    final case class  Pure[A](value: A) extends Effect[A]
    final case class  FlatMap[X,A](fx: Effect[X], f: X => Effect[A]) extends Effect[A]
    
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
    
        case Pure(a) =>
          a
    
        case FlatMap(fx, f) =>
          val x  = run(fx)
          val fa : Effect[A] = f(x)
          run(fa)
      }
    
    val effect1: Effect[Unit] =
      for {
        t0 <- CurrentTimeMillis
        _  <- PrintLn(s"The current time is $t0")
      } yield ()

    En exécutant run(effect1) on obtient:

    scala> run(effect1)
    The current time is 1569773175010

Félicitations! Vous venez d’écrire votre première monade IO! Il y a de nombreux noms scientifiques au sealed trait Effect[A]: vous pouvez l’appeler un effet algébrique, une monade libre, une IO, etc. Mais au bout du compte, ce n’est qu’un simple et banal sealed trait pour lequel nous avons défini quelques final case class et case object afin de représenter les fonctions dont nous voulions disposer sans fournir leurs implémentations (CurrentTimeMillis, PrintLn, Mesure, Pure et FlatMap). Vous pouvez les appeler des méthodes virtuelles si vous voulez. Ce qui importe réellement est d’avoir isolé la définition de ces fonctions de leurs implémentations. Rappelez vous qu’un trait est juste une interface après tout.

Cas d’Utilisation: S’assurer que les types sont pris en charge par la Base De Données.

Les bases de données sont formidables. Nous pouvons y stocker des tables, des documents, des paires clef/valeur, des graphes, etc. Mais, pour n’importe quelle base de données, il y a malheureusement seulement un nombre limité de types pris en charge. Prenez la base de données que vous voulez, je suis sûr de pouvoir trouver des types qu’elle ne prend pas en charge.

Dans cette section, nous allons nous intéresser au cas des structures des données et du code qui ne marche pas pour tout les types, mais seulement certains! Ce cas d’usage ne se limite pas aux bases de données mais concerne chaque interface de programmation qui ne supporte qu’un nombre limité de types (la vaste majorité des interfaces de programmation). Comment s’assurer du respect de ces contraintes? Comment adapter les techniques que nous aimons afin qu’elles travaillent sous ces contraintes? Voilà ce dont il s’agit dans cette section.

Nous considérerons une base de données fictive qui ne prend en charge que les types suivants:

  1. String
  2. Double
  3. (A,B)A et B sont également des types pris en charge par la base de données.

Cela signifie que les valeurs stockées dans la base de données (dans des tables, des paires clef/valeur, etc) doivent respecter les règles ci-dessus. Elle peut stocker "Hello World" parce que c’est une String, qui est est un type pris en charge par la base de données en vertu de la règle 1. Pour les mêmes raisons, elle peut stocker 5.2 parce que c’est un Double, mais elle ne peut pas stocker l’entier 5 parce que c’est unInt. Elle peut stocker ("Hello World", 5.2) grâce à la règle 3 ainsi que (("Hello World", 5.2) , 8.9), de nouveau grâce à la règle 3.

  • Question 1: Définissez le type DBType[A] tel que:

    Il existe une valeur de type DBType[A] si et seulement si A est un type pris en charge par la base de données.

    Solution (cliquer pour dévoiler)

    La version simple est:

    sealed trait DBType[A]
    final case object DBString extends DBType[String]
    final case object DBDouble extends DBType[Double]
    final case class  DBPair[A,B](first: DBType[A], second: DBType[B]) extends DBType[(A,B)]
    Remarque pour les personnes à l'aise en Scala (cliquer pour dévoiler)

    En utilisant toutes les chouettes fonctionnalités syntaxiques de Scala, la version satisfaisante en production du code ci-dessus est:

    sealed trait DBType[A]
    object DBType {
      final case object DBString extends DBType[String]
      final case object DBDouble extends DBType[Double]
      final case class DBPair[A,B](first: DBType[A], second: DBType[B]) extends DBType[(A,B)]
    
      implicit val dbString : DBType[String] =
        DBString
    
      implicit val dbDouble : DBType[Double] =
        DBDouble
    
      implicit def dbPair[A,B](implicit first: DBType[A], second: DBType[B]): DBType[(A,B)] =
        DBPair(first, second)
    
      def apply[A](implicit ev: DBType[A]): ev.type = ev
    }

En utilisant DBType, nous pouvons coupler une valeur de type A avec une valeur de type DBType[A], fournissant ainsi la preuve que le type A est pris en charge par la base de données:

final case class DBValue[A](value: A)(implicit val dbType: DBType[A])

Notez que le paramètre dbType n’a nullement besoin d’être implicite! Ce qui compte est que pour créer une valeur de type DBValue[A], nous devons fournir une valeur de type DBType[A] ce qui force A à être un type pris en charge par la base de données.

Un foncteur est, de manière informelle et approximative, un constructeur de typeF, comme List, Option, DBValue, etc, pour lequel il est possible de fournir une instance du trait:

trait Functor[F[_]] {
  def map[A,B](fa: F[A])(f: A => B): F[B]
}

map(fa)(f) applique la fonction f à chaque valeur de type A contenue dans fa. Par exemple:

implicit object OptionFunctor extends Functor[Option] {
  def map[A,B](fa: Option[A])(f: A => B): Option[B] =
    fa match {
      case Some(a) => Some(f(a))
      case None => None
    }
}
  • Question 2: Écrivez une instance de Functor[DBValue].

    Solution (cliquer pour dévoiler)

    C’est en fait impossible! Si nous tentions de compiler le code suivant:

    object DBValueFunctor extends Functor[DBValue] {
      def map[A,B](fa: DBValue[A])(f: A => B): DBValue[B] =
        DBValue[B](f(fa.value))
    }

    Scala râlerait: could not find implicit value for parameter dbType: DBType[B]. En effet, les booléens ne sont pas un type pris en charge par la base de données: ils ne sont ni des chaînes de caractères, ni des nombres flottants, ni des paires de types pris en charge.

    Supposons que nous puissions définir une instance de Funcor pour DBValue (c.-à-d. que nous puissions définir une fonction map pour DBValue), alors nous pourrions écrire:

    val dbValueString  : DBValue[String]  = DBValue("A")(DBString)
    val dbValueBoolean : DBValue[Boolean] = dbValueString.map(_ => true)
    val dbTypeBoooean  : DBType[Boolean]  = dbValueBoolean.dbType

    Nous obtiendrions une valeur (dbTypeBoooean) de type DBType[Boolean] ce qui signifirait que le type Boolean est pris en charge par la base de données. Mais il ne l’est pas! Hors par définition:

    Il existe une valeur de type DBType[A] si et seulement si A est un type pris en charge par la base de donnée.

    Donc il est impossible d’obtenir une valeur de type DBType[Boolean] et donc il est impossible d’écrire une fonction map pout DBValue. Ainsi il n’y a aucun moyen de définir une instance de Functor pour DBValue. CQDF.

Un Foncteur Généralisé est très similaire à un Functor classique, à la différence près que la fonction map ne doit pas obligatoirement être applicable à n’importe quels types A et B mais peut n’être applicable qu’à certains types A et B particuliers:

trait GenFunctor[P[_],F[_]] {
  def map[A,B](fa: F[A])(f: A => B)(implicit evA: P[A], evB: P[B]): F[B]
}

Par exemple, Set (plus précisément TreeSet) n’est pas un foncteur! En effet il n’y a aucun moyen d’écrire une fonction map qui fonctionne pour n’importe quel type B (parce qu’il est nécessaire d’avoir une relation d’ordre sur B). Mais si l’on restreint map aux seuls types B disposant d’une relation d’ordre, alors il devient possible d’écrire:

import scala.collection.immutable._
object TreeSetFunctor extends GenFunctor[Ordering, TreeSet] {
  def map[A,B](fa: TreeSet[A])(f: A => B)(implicit evA: Ordering[A], evB: Ordering[B]): TreeSet[B] =
    TreeSet.empty[B](evB) ++ fa.toSeq.map(f)
}
  • Question 3: Écrivez une instance de GenFunctor[DBType, DBValue].

    Solution (cliquer pour dévoiler)
    object DBValueGenFunctor extends GenFunctor[DBType, DBValue] {
      def map[A,B](fa: DBValue[A])(f: A => B)(implicit evA: DBType[A], evB: DBType[B]): DBValue[B] =
        DBValue[B](f(fa.value))(evB)
    }

Ce que nous avons fait ici avec Functor peut être fait avec de nombreuses structures de données et techniques de programmation. Il est souvent possible de restreindre la plage des types sur lesquels la structure de donnée ou la classe de types (type class) peut opérer en ajoutant un paramètre supplémentaire comme ev : DBType[A] aux constructeurs et méthodes.

Cas d’Utilisation: Simplifier les Implicites

Ce cas d’utilisation est l’un des plus intéressants, mais malheureusement, pas l’un des plus simples. Il montre comment il est possible d’utiliser les GADTs pour simplifier la création de valeurs implicites.

Des listes de valeurs dont les éléments peuvent être de types différents sont appelées listes hétérogènes. Elles sont généralement définies en Scala presque comme les listes classiques:

final case class HNil() // La liste vide
final case class HCons[Head,Tail](head: Head, tail: Tail) // L'operation: `head :: tail`

val empty : HNil =
  HNil()

val oneTrueToto : HCons[Int, HCons[Boolean, HCons[String, HNil]]] =
  HCons(1, HCons(true, HCons("toto", HNil())))

val falseTrueFive: HCons[Boolean, HCons[Boolean, HCons[Int, HNil]]] =
  HCons(false, HCons(true, HCons(5, HNil())))

Comme vous pouvez le voir, il n’y a rien de vraiment spécial à propos de ces listes. Nous voulons définir des relations d’ordre sur les listes hétérogènes. Une relation d’ordre est une façon de comparer deux valeurs (du même type!): elles peuvent êtres égales ou l’une peut être strictement plus petite que l’autre. Une relation d’ordre sur le type A peut se définir en Scala comme une instance de Order[A] défini comme suit:

trait Order[A] {
  // vrai si et seulement si a1 < a2
  def lesserThan(a1: A, a2: A): Boolean

  /* a1 et a2 sont égales si et seulement si
     aucune d'entre elles n'est strictement plus petite que l'autre
  */
  final def areEqual(a1: A, a2: A): Boolean = !lesserThan(a1, a2) && !lesserThan(a2, a1)

  // a1 > a2 si et seulement si a2 < a1
  final def greaterThan(a1: A, a2: A): Boolean = lesserThan(a2, a1)

  final def lesserThanOrEqual(a1: A, a2: A): Boolean = !lesserThan(a2, a1)

  final def greaterThanOrEqual(a1: A, a2: A): Boolean = !lesserThan(a1, a2)
}

object Order {
  def apply[A](implicit ev: Order[A]): ev.type = ev

  def make[A](lg_ : (A,A) => Boolean): Order[A] =
    new Order[A] {
      def lesserThan(a1: A, a2: A): Boolean = lg_(a1,a2)
    }
}

implicit val orderInt    = Order.make[Int](_ < _)
implicit val orderString = Order.make[String](_ < _)

Pour rappel, nous ne comparerons que des listes de même type:

  • Les listes de type HNil seront uniquement comparées à d’autres listes de type HNil.
  • Les listes de type HCons[H,T] seront uniquement comparées à d’autres listes de type HCons[H,T].

Comparer des listes de type HNil est trivial parce qu’il n’y a qu’une seule et unique valeur de type HNil (la liste vide HNil()). Mais il existe de nombreuses façon de comparer des listes de type HCons[H,T]. Voici deux relations d’ordre possibles (il en existe de nombreuses autres!):

  • L’ordre lexicographique (c.-à-d. l’ordre du dictionnaire: de la gauche vers la droite)

    HCons(h1,t1) < HCons(h2,t2) si et seulement si h1 < h2 ou (h1 == h2 et t1 < t2 par l’ordre lexicographique).

    sealed trait Lex[A] {
      val order : Order[A]
    }
    
    object Lex {
      def apply[A](implicit ev: Lex[A]): ev.type = ev
    
      implicit val lexHNil: Lex[HNil] =
        new Lex[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def lexHCons[Head,Tail](implicit
          orderHead: Order[Head],
          lexTail: Lex[Tail]
        ): Lex[HCons[Head, Tail]] =
        new Lex[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = lexTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.lesserThan(t1,t2))
          }
        }
    }
  • L’ordre lexicographique inversé qui est la version à l’envers de l’ordre lexicographique (c.-à-d. de droite à gauche)

    HCons(h1,t1) < HCons(h2,t2) si et seulement si (t1 < t2 par ordre lexicographique inversé) ou (t1 == t2 et h1 < h2).

    sealed trait RevLex[A] {
      val order : Order[A]
    }
    
    object RevLex {
      def apply[A](implicit ev: RevLex[A]): ev.type = ev
    
      implicit val revLexHNil: RevLex[HNil] =
        new RevLex[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def revLexHCons[Head,Tail](implicit
          orderHead: Order[Head],
          revLexTail: RevLex[Tail]
        ): RevLex[HCons[Head, Tail]] =
        new RevLex[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = revLexTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderTail.lesserThan(t1,t2) || (orderTail.areEqual(t1,t2) && orderHead.lesserThan(h1,h2))
          }
        }
    }

Comme dit plus haut, il est possible de définir davantage de relations d’ordre:

  • Question 1: L’ordre Alternate est défini par:

    HCons(h1,t1) < HCons(h2,t2) si et seulement si h1 < h2 ou (h1 == h2 et t1 > t2 par ordre Alternate).

    En suivant la méthoe employée pour Lex and RevLex, implémentez l’ordre Alternate.

    Solution (cliquer pour dévoiler)
    sealed trait Alternate[A] {
      val order : Order[A]
    }
    
    object Alternate {
      def apply[A](implicit ev: Alternate[A]): ev.type = ev
    
      implicit val alternateHNil: Alternate[HNil] =
        new Alternate[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def alternateHCons[Head,Tail](implicit
          orderHead: Order[Head],
          alternateTail: Alternate[Tail]
        ): Alternate[HCons[Head, Tail]] =
        new Alternate[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = alternateTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.greaterThan(t1,t2))
          }
        }
    }

Il existe de nombreuses manières de définir une relation d’ordre valide sur les listes hétérogènes! Créer une classe de type (type class) comme Lex, RevLex et Alternate pour chaque relation d’ordre voulue est fatigant et propice aux erreurs. Nous pouvons faire bien mieux … avec un GADT 😉

sealed trait HListOrder[A]
object HListOrder {
  final case object HNilOrder extends HListOrder[HNil]

  final case class HConsOrder[Head,Tail](
      orderHead: Order[Head],
      hlistOrderTail: HListOrder[Tail]
    ) extends HListOrder[HCons[Head,Tail]]

  // Définitions des Implicites

  implicit val hnilOrder : HListOrder[HNil] =
    HNilOrder

  implicit def hconsOrder[Head,Tail](implicit
      orderHead: Order[Head],
      hlistOrderTail: HListOrder[Tail]
    ): HListOrder[HCons[Head,Tail]] =
    HConsOrder(orderHead, hlistOrderTail)

  def apply[A](implicit ev: HListOrder[A]): ev.type = ev
}

Il est à noter que la définition de ces implicites est du pur boilerplate. Leur seule raison d’être est de passer leurs arguments au constructeur correspondant (c.-à-d. final case class ou case object): hnilOrder à HListOrder (O arguments) et hconsOrder à HConsOrder (2 arguments).

  • Question 2: Écrivez une fonction def lex[A](implicit v : HListOrder[A]): Order[A] qui retourne l’ordre lexicographique à partir d’une valeur de type HListOrder[A].

    Solution (cliquer pour dévoiler)
    def lex[A](implicit v : HListOrder[A]): Order[A] =
      v match {
        case HListOrder.HNilOrder =>
          Order.make[HNil]((_,_) => false)
    
        case hc : HListOrder.HConsOrder[head,tail] =>
          val orderHead: Order[head] = hc.orderHead
          val orderTail: Order[tail] = lex(hc.hlistOrderTail)
    
          Order.make[HCons[head, tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.lesserThan(t1,t2))
          }
      }
  • Question 3: Écrivez une fonction def revLex[A](implicit v : HListOrder[A]): Order[A] qui retourne l’ordre lexicographique inversé à partir d’une valeur de type HListOrder[A].

    Solution (cliquer pour dévoiler)
    def revLex[A](implicit v : HListOrder[A]): Order[A] =
      v match {
        case HListOrder.HNilOrder =>
          Order.make[HNil]((_,_) => false)
    
        case hc : HListOrder.HConsOrder[head,tail] =>
          val orderHead: Order[head] = hc.orderHead
          val orderTail: Order[tail] = revLex(hc.hlistOrderTail)
    
          Order.make[HCons[head, tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderTail.lesserThan(t1,t2) || (orderTail.areEqual(t1,t2) && orderHead.lesserThan(h1,h2))
          }
      }

Cette approche a de nombreux avantages. Alors que l’approche initiale devait effectuer une recherche d’implicites pour chaque relation d’ordre, l’approche par GADT n’a besoin de faire cette recherche qu’une seule fois! Sachant que la résolution d’implicites est une opération gourmande, la réduire signifie des temps de compilation plus courts. Lire le code des fonctions lex et revLex est également plus simple que comprendre comment la résolution d’implicites fonctionne pour les traits Lex et RevLex. De plus, ce ne sont que des fonctions, vous pouvez y utiliser tout ce que vous pouvez programmer afin de construire les instances de Order[A].

Conclusion

Pas si trivial, n’est-ce pas? 😉 En fait, une grande part de la complexité à laquelle vous venez de faire face vient du triste fait que les techniques de raisonnements sur les types et valeurs ne sont presque jamais enseignées dans les cours de programmation. Ce que vous trouvez simple maintenant (API Web, Streaming, Bases De Données, etc) terrifierait probablement la/le jeune programmeuse·eur que vous étiez à votre premier “Hello World!”. Vous n’avez probablement pas appris tout ce que vous savez en programmation en trois heures, donc n’attendez pas des techniques de raisonnement sur des programmes d’êtres magiquement plus simples.

Cet atelier avait pour but de vous inspirer, d’ouvrir votre esprit à ce nouvel univers de possibilités. Si vous trouvez ces cas d’utilisation intéressants, alors prenez le temps de comprendre les techniques.

Amusez vous bien et prenez bien soin de vous ❤️

GADTs By Use Cases

This workshop will be presented at ScalaIO 2019, in Lyon (France), on Tuesday the 29th of October at 9am.

Welcome. This session will introduce you to a very powerful tool in programming. Whereas most introduction start by presenting its theoretical foundations in a very formal way, we chose to present it via short examples and practical use cases.

This workshop is made of three parts. The last one presents three of the most valuable use cases. They are the real big powerful use cases. But do not go there unprepared! This is the last part for a reason: they rely massively on lessons you will learn in the previous parts. Start by First Contact, it will show you, via the simplest examples, the core ideas. Its goal is to open your mind to ways of using types and data you may have never imagined possible. Then go to Easy Useful Use Cases: Relations on Types, for the first real-life challenge. Then, you are ready for More Advanced Use Cases.

Be sure to read README, it contains precious tips to ease your journey.

Acknowledgements

We would like to thank Laure Juglaret for having reviewed this presentation many times, for her precious remarks and corrections.

README

In this presentation we will assume that:

  • null does not exists!
  • runtime-reflection does not exist! (i.e. isInstanceOf, getClass, etc)

This presentation considers that these features do not exists at all.

Using these features will never lead to a valid answer.

This presentation expects you to have access to something where you can easily write, compile and run Scala code. The best way to do so is opening a R.E.P.L. session. If you have Scala installed on your system, you can easily start one from the command-line by executing scala:

system-command-line# scala
Welcome to Scala 2.13.1 (OpenJDK 64-Bit Server VM, Java 1.8.0_222).
Type in expressions for evaluation. Or try :help.

scala>

Remember that the R.E.P.L command :paste allows to paste code and :reset cleans the environment.

If you don’t have Scala installed, you can use the online-REPL https://scastie.scala-lang.org/ .

Stretching

This section is a brief reminder of some definitions and properties about values and types.

Values and Types?

Values are actual piece of data your program manipulates like the integer 5, the boolean true, the string "Hello World!", the function (x: Double) => x / 7.5, the list List(1,2,3), etc. It is convenient to classify values into groups. These groups are called types. For example:

  • Int is the group of integer values, i.e. values like 1, -7, 19, etc.
  • Boolean is the group containing exactly the values true and false (no more, no less!).
  • String is the group whose values are "Hello World!", "", "I ❤️ GADTs", etc.
  • Double => Double is the group whose values are functions taking any Double as argument and returning some Double.

To indicate that the value v belongs to the type (i.e. group of values) T, we write v : T. In Scala, testing if a value v belongs to a type T is very simple: just type v : T in the REPL:

scala> 5 : Int
res7: Int = 5

If Scala accepts it, then v belongs to T. If Scala complains, it most probably does not :

scala> 5 : String
       ^
       error: type mismatch;
        found   : Int(5)
        required: String

How many types?

Let’s now create some types and some of their values (when possible!).

class OneType
  • Question 1: How many types does the line class OneType defines?

    Solution (click to expand)

    As the name suggests, class OneType defines only one type which is named OneType.

Let’s now consider:

class OneTypeForEvery[A]
  • Question 2: How many types does the line class OneTypeForEvery[A] defines?

    Solution (click to expand)

    As the name suggests, every concrete type A gives rise to a distinct type OneTypeForEvery[A].

    For example, a list of integers is neither a list of booleans, nor a list of strings, nor a list of functions, nor … It means the types List[Int], List[Boolean], List[Int => Int], etc are all distinct types.

    The line class OneTypeForEvery[A] defines a distinct type for every concrete type A. There is an infnity of concrete types A, so an infinity of distinct types OneTypeForEvery[A].

  • Question 3: Give a value that belongs to both OneTypeForEvery[Int] and OneTypeForEvery[Boolean].

    Remember that null does not exist!

    Solution (click to expand)

    This is actually impossible. Every concrete type A give rise a distinct type OneTypeForEvery[A] that have no values in common with others types OneTypeForEvery[B] for B ≠ A.

How many values?

Considering the following type:

final abstract class NoValueForThisType
  • Question 1: Give a value belonging to the type NoValueForThisType? How many values belong to NoValueForThisType?

    Hint (click to expand)
    • What is a final class? How does it differ from a normal non-final class?
    • What is an abstract class? How does it differ from a concrete class?
    Solution (click to expand)

    The class NoValueForThisType is declared abstract. It is then forbidden to create a direct instance of this class:

    scala> new NoValueForThisType
            ^
            error: class NoValueForThisType is abstract; cannot be instantiated

    The only way to create an instance of an abstract class is creating a concrete sub-class. But the keyword final forbids creating such sub-classes:

    scala> class ConcreteSubClass extends NoValueForThisType
                                            ^
            error: illegal inheritance from final class NoValueForThisType

    There is no way to create an instance of NoValueForThisType.

Let’s take another example:

sealed trait ExactlyOneValue
case object TheOnlyValue extends ExactlyOneValue
  • Question 2: Give a value belonging to the type ExactlyOneValue?

    Solution (click to expand)

    By definition, TheOnlyValue is a value of type ExactlyOneValue.

  • Question 3: How many values belong to ExactlyOneValue?

    Solution (click to expand)

    Just like above, ExactlyOneValue, being a trait, is abstract. Being sealed, extending it outside of its defining file is forbidden. So TheOnlyValue is the only value of type ExactlyOneValue.

First Contact

This part presents the core ideas. There are actually only two ideas! What you will find here are stripped down examples to illustrate each of these ideas.

Use Case: Evidence of some property

Let’s define a simple sealed trait:

sealed trait ATrait[A]
case object AValue extends ATrait[Char]
  • Question 1: Give a value of type ATrait[Char].

    Solution (click to expand)

    By definition, AValue is a value of type ATrait[Char].

  • Question 2: Give a value of type ATrait[Double].

    Solution (click to expand)

    There is no way to have an instance of type ATrait[Double]. There is actually no way to have an instance of type ATrait[B] for B ≠ Char, because the only possible value is AValue which is of type ATrait[Char].

  • Question 3: What can you conclude about a type A if you have a value ev of type ATrait[A] (i.e. ev: ATrait[A])?

    Solution (click to expand)

    The only possible value is AValue, so ev == AValue. Furthermore AValue is of type ATrait[Char] so A = Char.

  • Question 4: In the REPL, enter the following code:

    def f[A](x: A, ev: ATrait[A]): Char =
      x
  • Question 5: Now try pattern matching on ev: ATrait[A]

    def f[A](x: A, ev: ATrait[A]): Char =
      ev match {
        case AValue => x
      }

    Is the pattern matching exhaustive?

    Solution (click to expand)

    The pattern matching is exhaustive because the only possible actual value for ev is AValue. Furthermore AValue is of type ATrait[Char] which means ev : ATrait[Char] because ev == AValue. So A = Char and x : Char.

  • Question 6: Call f with x = 'w' : Char.

    Solution (click to expand)
    scala> f[Char]('w', AValue)
    res0: Char = w
  • Question 7: Call f with x = 5.2 : Double.

    Solution (click to expand)

    This is impossible because it would require to give a value ev : ATrait[Double], which does not exist!

    scala> f[Double](5, AValue)
                        ^
              error: type mismatch;
                found   : AValue.type
                required: ATrait[Double]
Remark for people fluent in Scala (click to expand)

Using all the nice features of Scala, the production-ready version of the code above is:

sealed trait IsChar[A]
object IsChar {
  implicit case object Evidence extends IsChar[Char]

  def apply[A](implicit evidence: IsChar[A]): IsChar[A] =
    evidence
}

def f[A: IsChar](x: A): Char =
  IsChar[A] match {
    case IsChar.Evidence => x
  }

Use Case: The only thing I know is it exists

What would you do if you wanted your codebase to log messages, but you want to be sure your codebase do not rely on any implementation details of the logger, so that you can change its implementation without risking breaking the codebase?

Take the following logger type UnknownLogger:

sealed trait UnknownLogger
final case class LogWith[X](logs : X, appendMessage: (X, String) => X) extends UnknownLogger

The first logger we will create stores the logs in a String:

val loggerStr : UnknownLogger =
  LogWith[String]("", (logs: String, message: String) => logs ++ message)

The second logger stores them in a List[String]:

val loggerList : UnknownLogger =
  LogWith[List[String]](Nil, (logs: List[String], message: String) => message :: logs)

The third logger directly prints the messages to the standard output:

val loggerStdout : UnknownLogger =
  LogWith[Unit]((), (logs: Unit, message: String) => println(message))

Note that these three loggers all have the same type (i.e. UnknownLogger) but they store the logs using different types X (String, List[String] and Unit).

  • Question 1: Let v be a value of type UnknownLogger. Clearly v has to be an instance of the class LogWith[X] for some type X. What can you say about the type X? Can you figure out which type X actually is?

    Remember that we refuse to use runtime-reflection! (i.e. isInstanceOf, getClass, etc)

    Solution (click to expand)

    We know almost nothing about X. The only thing we know is there exists at least one value (v.logs) of type X. Appart from that, X can be any type.

    Not knowing which type is actually X is very useful to guarantee that the code that will use v : UnknownLogger will never rely on the nature of X. If the code knew X was String for example, it could perform some operations we want to forbid like reversing the list, taking only the nth first characters, etc. By hiding the real type X, we force our codebase to not depend on what X is but to use the provided v.appendMessage. So changing the real implementation of the logger won’t break any code.

  • Question 2: Write the function def log(message: String, v: UnknownLogger): UnknownLogger that uses v.appendMessage to append message to v.logs and returns a new UnknownLogger containing the new logs.

    Remember that in Scala, the pattern case ac : AClass[t] => (see below) is allowed in match/case in replacement of the pattern case AClass(v) =>:

    final case class AClass[A](value : A)
    
    def f[A](v: AClass[A]): A =
      v match {
        case ac : AClass[t] =>
          // The variable `t` is a type variable
          // The type `t` is equal to `A`
          val r : t = ac.value
          r
      }

    Its main benefit is introducing the type variable t. Type variables work like normal pattern variables except that they represent types instead of values. Having t enable us to help the compiler by giving explicit types (like just above, saying r is of type t).

    Solution (click to expand)
    def log(message: String, v: UnknownLogger): UnknownLogger =
      v match {
        case vx : LogWith[x] => LogWith[x](vx.appendMessage(vx.logs, message), vx.appendMessage)
      }
  • Question 3: Execute log("Hello World", loggerStr) and log("Hello World", loggerList) and log("Hello World", loggerStdout)

    Solution (click to expand)
    scala> log("Hello World", loggerStr)
    res0: UnknownLogger = LogWith(Hello World,$$Lambda$988/1455466014@421ead7e)
    
    scala> log("Hello World", loggerList)
    res1: UnknownLogger = LogWith(List(Hello World),$$Lambda$989/1705282731@655621fd)
    
    scala> log("Hello World", loggerStdout)
    Hello World
    res2: UnknownLogger = LogWith((),$$Lambda$990/1835105031@340c57e0)
Remark for people fluent in Scala (click to expand)

Once again, using all the nice features of Scala, the production-ready version of the code above is:

sealed trait UnknownLogger {
  type LogsType
  val logs : LogsType
  def appendMessage(presentLogs: LogsType, message: String): LogsType
}

object UnknownLogger {

  final case class LogWith[X](logs : X, appendMessage_ : (X, String) => X) extends UnknownLogger {
    type LogsType = X
    def appendMessage(presentLogs: LogsType, message: String): LogsType =
      appendMessage_(presentLogs, message)
  }

  def apply[X](logs: X, appendMessage: (X, String) => X): UnknownLogger =
    LogWith[X](logs, appendMessage)
}

Intermediary Conclusion

GADTs are actually only this: simple sealed trait with some case object (possibly none) and some final case class (possible none too!). In the following parts we will explore some major use cases of GADTs

Easy Useful Use Cases: Relations on Types

One easy, but very useful, benefit of GADTs is expressing relations about types such that:

  • Is type A equal to type B?
  • Is type A a sub-type of B?

Note that, by definition, a type A is a sub-type of itself (i.e. A <: A), very much like an integer x is also lesser-than-or-equal to itself x ≤ x.

Use Case: Witnessing Type Equality

sealed trait EqT[A,B]
final case class Evidence[X]() extends EqT[X,X]
  • Question 1: Give a value of type EqT[Int, Int]

    Solution (click to expand)
    scala> Evidence[Int]() : EqT[Int, Int]
    res0: EqT[Int,Int] = Evidence()
  • Question 2: Give a value of type EqT[String, Int]

    Solution (click to expand)

    The class Evidence is the only concrete sub-class of trait EqT and we cannot create another one because EqT is sealed. So any value v : EqT[A,B] has to be an instance of Evidence[X] for some type X, which is of type EqT[X,X]. Thus there is no way to get a value of type EqT[String, Int].

  • Question 3: Given two (unknown) types A and B. What can you conclude if I give you a value of type EqT[A,B]?

    Solution (click to expand)

    If I give you a value v : EqT[A,B], then you know that v is an instance of Evidence[X] for some (unknown) type X because the class Evidence is the only concrete sub-class of the sealed trait EqT. Actually Evidence[X] is a sub-type of EqT[X,X]. Thus v : EqT[X,X]. Types EqT[A,B] and EqT[X,X] have no value in common if A ≠ X or B ≠ X, so A = X and B = X. Thus A = B.

Remark for people fluent in Scala (click to expand)

In production, it is convenient to write the following equivalent code:

sealed trait EqT[A,B]
object EqT {
  final case class Evidence[X]() extends EqT[X,X]

  implicit def evidence[X] : EqT[X,X] = Evidence[X]()

  def apply[A,B](implicit ev: EqT[A,B]): ev.type = ev
}

Switching between equal types

If A and B are actually the same type, then List[A] is also the same type as List[B], Option[A] is also the same type as Option[B], etc. More generally, for any F[_], F[A] is also the same type as F[B].

  • Question 4: Write the function def coerce[F[_],A,B](eqT: EqT[A,B])(fa: F[A]): F[B].

    Solution (click to expand)
    def coerce[F[_],A,B](eqT: EqT[A,B])(fa: F[A]): F[B] =
      eqT match {
        case _ : Evidence[x] => fa
      }

The Scala standard library already defines a class, named =:=[A,B] (yes, its name is really =:=), representing type equality. You’re strongly encouraged to have a quick look at its documentation (click here). Thankfully, Scala enables to write A =:= B instead of =:=[A,B].

Given two types A and B, having an instance (i.e. object) of A =:= B means that A and B are actually the same type, just like with EqT[A,B]. Remember that A =:= B is just syntactic sugar for =:=[A,B].

Instances of A =:= B can be created by the function (<:<).refl[X]: X =:= X (click for docs). The “symbol” <:< is indeed a valid name for an object.

  • Question 5: Using the function coerce above, write the function def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B.

    Hint (click to expand)
    def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B = {
      /* Find a definition for:
            - the type constructor `F`
            - the value `fa : F[A]`
      */
      type F[X] = ???
      val fa : F[A] = ???
    
      /* Such that this call: */
      coerce[F,A,B](eqT)(fa) // is of type `F[B]`
    }
    Solution (click to expand)
    def toScalaEq[A,B](eqT: EqT[A,B]): A =:= B = {
      type F[X] = A =:= X
      val fa: F[A] = (<:<).refl[A]
      coerce[F,A,B](eqT)(fa)
    }
  • Question 6: Using the method substituteCo[F[_]](ff: F[A]): F[B] of objects of class A =:= B, whose documentation is here, write the function def fromScalaEq[A,B](scala: A =:= B): EqT[A,B].

    Hint (click to expand)
    def fromScalaEq[A,B](scala: A =:= B): EqT[A,B] = {
      /* Find a definition for:
            - the type constructor `F`
            - the value `fa : F[A]`
      */
      type F[X] = ???
      val fa : F[A] = ???
    
      /* Such that this call: */
      scala.substituteCo[F](fa) // is of type `F[B]`
    }
    Solution (click to expand)
    def fromScalaEq[A,B](scala: A =:= B): EqT[A,B] = {
      type F[X] = EqT[A,X]
      val fa: F[A] = Evidence[A]()
      scala.substituteCo[F](fa)
    }

Use Case: Witnessing Sub Typing

In this section, we want to create types SubTypeOf[A,B] whose values prove that the type A is a sub-type of B (i.e. A <: B). A similar but different class already exists in the Scala library. It is named <:<[A,B], which is often written A <:< B. Its documentation is here. Because this section is about implementing a variant of this class, please do not use <:<[A,B] to implement SubTypeOf.

  • Question 1: Using only upper bounds (i.e. A <: B) or lower bounds (i.e. A >: B) and no variance annotation (i.e. [+A] and [-A]), create the trait SubTypeOf[A,B] (and all that is necessary) such that:

    There exists a value of type SubType[A,B] if and only if A is a sub-type of B (i.e. A <: B).

    Remember that, by definition, a type A is a sub-type of itself (i.e. A <: A).

    Remember that you should not use the class <:<[A,B].

    Solution (click to expand)
    sealed trait SubTypeOf[A,B]
    final case class SubTypeEvidence[A <: B, B]() extends SubTypeOf[A,B]
    Remark for people fluent in Scala (click to expand)

    In production, it is convenient to write the following equivalent code:

    sealed trait SubTypeOf[A,B]
    object SubTypeOf {
      final case class Evidence[A <: B, B]() extends SubTypeOf[A,B]
    
      implicit def evidence[A <: B, B]: SubTypeOf[A,B] = Evidence[A,B]()
    
      def apply[A,B](implicit ev: SubTypeOf[A,B]): ev.type = ev
    }

Use Case: Avoiding annoying scalac error messages about bounds not respected

In this example, we want to model the diet of some animals. We start by defining the Food type and some of its subtypes:

trait Food
class Vegetable extends Food
class Fruit extends Food

and then the class representing animals eating food of type A (i.e. Vegetable, Fruit, etc):

class AnimalEating[A <: Food]

val elephant : AnimalEating[Vegetable] =
  new AnimalEating[Vegetable]

Let’s define a function like there are so many in Functional Programming, and apply it to elephant:

def dummy[F[_],A](fa: F[A]): String = "Ok!"
scala> dummy[List, Int](List(1,2,3))
res0: String = Ok!

scala> dummy[Option, Boolean](Some(true))
res1: String = Ok!

scala> dummy[AnimalEating, Vegetable](elephant)
            ^
       error: kinds of the type arguments (AnimalEating,Vegetable)
       do not conform to the expected kinds of the type parameters (type F,type A).

       AnimalEating's type parameters do not match type F's expected parameters:
       type A's bounds <: Food are stricter than
       type _'s declared bounds >: Nothing <: Any
  • Question 1: Why does scalac complains?

    Solution (click to expand)

    The function dummy requires its argument F, which is a type constructor like List, Option, Future, etc, to accept any type as argument so that we can write F[A] for any type A. On the contrary, AnimalEating requires its argument to be a sub-type of Food. Thus AnimalEating can not be used as dummy’s argument F.

The problem is that, when we defined class AnimalEating[A <: Food], we gave the restriction that A <: Food. So Scala, like Java, forbids us to give AnimalEating anything but a sub-type of Food (including Food itself):

scala> type T1 = AnimalEating[Int]
                 ^
       error: type arguments [Int] do not conform
       to class AnimalEating's type parameter bounds [A <: Food]

scala> type T2 = AnimalEating[Food]
defined type alias T2

We face a dilemma: to use the function dummy, that we really want to use because it’s a very nice function, we need to remove the constraint A <: Food from the definition class AnimalEating[A <: Food]. But we still want to say that animals eat food, not integers, boolean or strings!

  • Question 2: How can you adapt the definition of AnimalEating so that:

    • We can call dummy on elephant! We want:

      scala> dummy[AnimalEating, Vegetable](elephant)
      res0: String = Ok!
    • If A is not a sub-type of Food (including Food itself), then it is impossible to create an instance of AnimalEating[A].

    • The class AnimalEating must remain an open class (i.e. neither sealed nor final)! It should be possible for anyone, anywhen, to create, freely, sub-classes of AnimalEating. Obviously those sub-classes must satisfy the constraints above.

      Hint (click to expand)

      In Scala, Nothing is a type having no value. Can you create a value of type (Nothing, Int)? Why?

      Solution (click to expand)

      If, to create an instance of AnimalEating[A], we force every creation method to take an extra paramerer of type SubTypeOf[A, Food], then it will only be possible to create an instance of AnimalEating[A] when A is a sub-type of Food:

      class AnimalEating[A](ev : SubTypeOf[A, Food])
      
      val elephant : AnimalEating[Vegetable] =
        new AnimalEating[Vegetable](SubTypeEvidence[Vegetable, Food])

      To create a value of type AnimalEating[A], we need to call AnimalEating’s constructor. And to call this constructor, we need to provide ev : SubTypeOf[A, Food].

      Now we can apply the dummy function on elephant:

      scala> dummy[AnimalEating, Vegetable](elephant)
      res0: String = Ok!

      In practice, using implicits, we let the compiler fill the parameter ev : SubTypeOf[A, Food] itself.

      Note that you can now express the type AnimalEating[Int] but you won’t be able to create a value of this type.

Use Case: Give the right data to the right chart

This use case is about enforcing at compile-time that only values of the right type can be given to a function. In this example, we consider the design of a chart library. For simplicity’s sake, we will assume that our library only supports two kinds of charts: Pie charts and XY charts. This is written in Scala via the enumeration:

sealed trait ChartType
case object PieChart extends ChartType
case object XYChart extends ChartType

Obviously Pie and XY charts rely on different kinds of data. Once again for simplicity’s sake, we will assume that our two kinds of data are:

class PieData
class XYData

A pie chart (PieChart) plots only PieData, whereas an XY chart (XYChart) plots only XYData. Here is our drawing function draw:

def draw[A](chartType: ChartType)(data: A): Unit =
  chartType match {
    case PieChart =>
      val pieData = data.asInstanceOf[PieData]
      // Do some stuff to draw pieData
      ()
    case XYChart =>
      val xyData = data.asInstanceOf[XYData]
      // Do some stuff to draw xyData
      ()
  }

It assumes that the user will only call draw on the right data. When chartType is PieChart, the function assumes, via data.asInstanceOf[PieData] that data is actually of type PieData. And when chartType is XYChart, it assumes that data is actually of type XYData.

The problem is that these assumptions rely on the hypothesis that users and developers will always make sure they are calling draw on the right data type. But nothing stops someone to call draw on a PieChart with XYData (or the opposite), crashing the system miserably at runtime!

scala> draw(PieChart)(new XYData)
java.lang.ClassCastException: XYData cannot be cast to PieData
  at .draw(<pastie>:11)
  ... 28 elided

As developers, we know mistakes do happen! We want a way to prevent theses annoying bugs to happen in production! We want to enforce at compile-time that only these two scenarii are possible:

  • When draw is called with chartType == PieChart: the argument data can only be of type PieData
  • When draw is called with chartType == XYChart: the argument data can only be of type XYData.

Remember that these two constraints have to be enforced at compile-time!

  • Question 1: Adapt the definition of ChartType, PieChart, XYChart and draw such that:

    • Any scenario other than the two above will make the compilation fail on a type error.

    • ChartType must still be a sealed trait. It is now allowed to have type parameters (i.e. generics).

    • PieChart and XYChar must still be two case object. They should still extends ChartType.

    • ChartType, PieChart and XYChar declarations must have no body at all (i.e. there should be no brackets { ... } in their declaration).

    Hint (click to expand)

    The code looks like this:

    sealed trait ChartType[/*PUT SOME GENERICS HERE*/]
    case object PieChart extends ChartType[/*There is something to write here*/]
    case object XYChart extends ChartType[/*There is something to write here too*/]
    
    def draw[A](chartType: ChartType[/*Write something here*/])(data: A): Unit =
     chartType match {
        case PieChart =>
          val pieData : PieData = data
          ()
        case XYChart =>
          val xyData: XYData = data
          ()
        }
    Solution (click to expand)
    sealed trait ChartType[A]
    case object PieChart extends ChartType[PieData]
    case object XYChart extends ChartType[XYData]
    
    def draw[A](chartType: ChartType[A])(data: A): Unit =
     chartType match {
        case PieChart =>
          val pieData : PieData = data
          ()
        case XYChart =>
          val xyData: XYData = data
          ()
        }

You can now sleep well knowing your production will not crash because of some bad inputs here 😉

More Advanced Use Cases

Now that you have seen what GADTs are about and how to use them in real-life, you are ready for the big use cases below. There are three of them. Each one illustrates one different way to use the power of GADTs. The first one is about expressing effects, which is widely used in every popular IO monads or algebraic effects. Do not worry if you do not know what they are, the section will clarifies it. The second one is about enforcing properties. This point is illustrated by the real-life use-case of enabling functional programming techniques support constructions that only work for a limited set of types (in the example, types supported by our fictional database). The third one is about simplifying the creation of implicits.

Use Case: Effects!

What we call an effect is sometimes just an interface declaring some functions with no implementation. For example we can define the trait below. Note that none of its functions has an implementation.

trait ExampleEffectSig {
  def echo[A](value: A): A
  def randomInt : Int
  def ignore[A](value: A): Unit
}

Implementations of these interfaces are given elsewhere and there can be many of them! This is useful to switch between implementations easily:

object ExampleEffectImpl extends ExampleEffectSig {
  def echo[A](value: A): A = value
  def randomInt : Int = scala.util.Random.nextInt()
  def ignore[A](value: A): Unit = ()
}

Another equivalent way to define ExampleEffectSig is via a sealed trait with some final case class (possibly none!) and/or somecase object (possibly none too!):

sealed trait ExampleEffect[A]
final case class  Echo[A](value: A) extends ExampleEffect[A]
final case object RandomInt extends ExampleEffect[Int]
final case class  Ignore[A](value: A) extends ExampleEffect[Unit]

Once again this is a declaration with no implementation! Once again implementations can be written elsewhere and there can also be many of them:

def runExampleEffect[A](effect: ExampleEffect[A]): A =
  effect match {
    case Echo(value) => value
    case RandomInt   => scala.util.Random.nextInt()
    case Ignore(_)   => ()
  }

Let’s consider a more realistic effect and one possible implementation:

trait EffectSig {
  def currentTimeMillis: Long
  def printLn(msg: String): Unit
  def mesure[X,A](fun: X => A, arg: X): A
}

object EffectImpl extends EffectSig {
  def currentTimeMillis: Long =
    System.currentTimeMillis()

  def printLn(msg: String): Unit =
    println(msg)

  def mesure[X,A](fun: X => A, arg: X): A = {
    val t0 = System.currentTimeMillis()
    val r  = fun(arg)
    val t1 = System.currentTimeMillis()
    println(s"Took ${t1 - t0} milli-seconds")
    r
  }
}
  • Question 1: ExampleEffect is the equivalent of ExampleEffectSig, but using a sealed trait with some final case class and case object. Write the equivalent of EffectSig in the same way. Call this trait Effect.

    Solution (click to expand)
    sealed trait Effect[A]
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
  • Question 2: Write the function def run[A](effect: Effect[A]): A that mimics the implementation of EffectImpl just like runExampleEffect mimics ExampleEffectImpl.

    Solution (click to expand)
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
      }

The type Effect[A] declares interesting effects (CurrentTimeMillis, PrintLn and Mesure) but to be really useful, we need to be able to chain effects! To do so, we want to have these two functions:

  • def pure[A](value: A): Effect[A]
  • def flatMap[X,A](fx: Effect[X], f: X => Effect[A]): Effect[A]

Once again we do not care yet about the implementation. Presently all we want is declaring these two operations, just like we declared CurrentTimeMillis, PrintLn and Mesure.

  • Question 3: Add two final case classes, Pure and FlatMap, to Effect[A] declaring these operations.

    Solution (click to expand)
    sealed trait Effect[A]
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
    final case class  Pure[A](value: A) extends Effect[A]
    final case class  FlatMap[X,A](fx: Effect[X], f: X => Effect[A]) extends Effect[A]
  • Question 4: Adapt the function run to handle these two new cases.

    Solution (click to expand)
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
    
        case Pure(a) =>
          a
    
        case FlatMap(fx, f) =>
          val x  = run(fx)
          val fa : Effect[A] = f(x)
          run(fa)
      }
  • Question 5: Add the two following methods to trait Effect[A] to get:

    sealed trait Effect[A] {
      final def flatMap[B](f: A => Effect[B]): Effect[B] = FlatMap(this, f)
      final def map[B](f: A => B): Effect[B] = flatMap[B]((a:A) => Pure(f(a)))
    }

    And run the follwing code to see if it works:

    val effect1: Effect[Unit] =
      for {
        t0 <- CurrentTimeMillis
        _  <- PrintLn(s"The current time is $t0")
      } yield ()
    
    run(effect1)
    Solution (click to expand)
    sealed trait Effect[A] {
      final def flatMap[B](f: A => Effect[B]): Effect[B] = FlatMap(this, f)
      final def map[B](f: A => B): Effect[B] = flatMap[B]((a:A) => Pure(f(a)))
    }
    final case object CurrentTimeMillis extends Effect[Long]
    final case class  PrintLn(msg: String) extends Effect[Unit]
    final case class  Mesure[X,A](fun: X => A, arg: X) extends Effect[A]
    final case class  Pure[A](value: A) extends Effect[A]
    final case class  FlatMap[X,A](fx: Effect[X], f: X => Effect[A]) extends Effect[A]
    
    def run[A](effect: Effect[A]): A =
      effect match {
        case CurrentTimeMillis =>
          System.currentTimeMillis()
    
        case PrintLn(msg) =>
          println(msg)
    
        case Mesure(fun, arg) =>
          val t0 = System.currentTimeMillis()
          val r  = fun(arg)
          val t1 = System.currentTimeMillis()
          println(s"Took ${t1 - t0} milli-seconds")
          r
    
        case Pure(a) =>
          a
    
        case FlatMap(fx, f) =>
          val x  = run(fx)
          val fa : Effect[A] = f(x)
          run(fa)
      }
    
    val effect1: Effect[Unit] =
      for {
        t0 <- CurrentTimeMillis
        _  <- PrintLn(s"The current time is $t0")
      } yield ()

    When we run run(effect1):

    scala> run(effect1)
    The current time is 1569773175010

Congratulations! You just wrote your first IO monad! There is a lot of scientific words to name the sealed trait Effect[A]: you can call it an algebraic effect, a free monad, an IO, etc. But in the end, it is just a plain simple sealed trait with some final case class and case object that represent the functions we wanted to have, without providing their implementation (CurrentTimeMillis, PrintLn, Mesure, Pure and FlatMap). You can call them virtual methods if you like. What really matters is that we isolated the declaration of the functions from their implementation. Remember that a trait is just a interface after all.

Use Case: Ensuring types are supported by the Database

Databases are great. We can store tables, documents, key/values pairs, graphs, etc. But for any database, there is unfortunately only a limited set of supported types. Take a database you like, I am sure I can find some types it does not support.

In this section we consider the use case of data structures and code that do not work for (values of) any type but only for some! This problem is not limited to databases but concerns any API that only supports a limited set of types (the vast majority of APIs). How to enforce this constraint? How to adapt the patterns we like to work under this constraint? This is all this section is about.

We consider a fictional database that only supports the following types:

  1. String
  2. Double
  3. (A,B) where A and B are also types supported by the database.

It means that the values stored by the database (in tables, key/value pairs, etc) must follow the rules above. It can store "Hello World" because it is a String which is a supported type by rule 1. Likewise, it can store 5.2 because it is a Double, but it can not store 5 : Int because it is an Int. It can store ("Hello World", 5.2) thanks to rule 3 and also (("Hello World", 5.2) , 8.9) once again by rule 3.

  • Question 1: Define the type DBType[A] such that:

    There exists a value of type DBType[A] if and only if A is a type supported by the database.

    Solution (click to expand)

    Transposing the rules above in code, we get:

    sealed trait DBType[A]
    final case object DBString extends DBType[String]
    final case object DBDouble extends DBType[Double]
    final case class  DBPair[A,B](first: DBType[A], second: DBType[B]) extends DBType[(A,B)]
    Remark for people fluent in Scala (click to expand)

    Using all the nice features of Scala, the production-ready version of the code above is:

    sealed trait DBType[A]
    object DBType {
      final case object DBString extends DBType[String]
      final case object DBDouble extends DBType[Double]
      final case class DBPair[A,B](first: DBType[A], second: DBType[B]) extends DBType[(A,B)]
    
      implicit val dbString : DBType[String] =
        DBString
    
      implicit val dbDouble : DBType[Double] =
        DBDouble
    
      implicit def dbPair[A,B](implicit first: DBType[A], second: DBType[B]): DBType[(A,B)] =
        DBPair(first, second)
    
      def apply[A](implicit ev: DBType[A]): ev.type = ev
    }

Using DBType, we can pair a value of type A with a value of type DBType[A], which provides an evidence that the type A is supported by the database:

final case class DBValue[A](value: A)(implicit val dbType: DBType[A])

Note that the parameter dbType does not need to be implicit! All that matters is that to create a value of type DBValue[A], we need to provide a value of type DBType[A] which forces A to be a supported type.

A functor is, approximately, a type constructor F like List, Option, DBValue, … for which you can write an instance of the trait

trait Functor[F[_]] {
  def map[A,B](fa: F[A])(f: A => B): F[B]
}

where map(fa)(f) applies the function f to any value of type A contained in fa. For example:

implicit object OptionFunctor extends Functor[Option] {
  def map[A,B](fa: Option[A])(f: A => B): Option[B] =
    fa match {
      case Some(a) => Some(f(a))
      case None => None
    }
}
  • Question 2: Write an instance of Functor[DBValue].

    Solution (click to expand)

    We actually can not! If we try to compile the following code:

    object DBValueFunctor extends Functor[DBValue] {
      def map[A,B](fa: DBValue[A])(f: A => B): DBValue[B] =
        DBValue[B](f(fa.value))
    }

    Scala complains: could not find implicit value for parameter dbType: DBType[B]. Indeed, booleans are not a supported type by the database: they are neither strings, nor doubles, not pairs. But if we could write a Functor instance for DBValue (i.e. if we could write a map function for DBValue), then we could write:

    val dbValueString  : DBValue[String]  = DBValue("A")(DBString)
    val dbValueBoolean : DBValue[Boolean] = dbValueString.map(_ => true)
    val dbTypeBoolean  : DBType[Boolean]  = dbValueBoolean.dbType

    We would get a value (dbTypeBoolean) of type DBType[Boolean], which would mean that the type Boolean is supported by the database. But it is not! Furthermore, by definition:

    There exists a value of type DBType[A] if and only if A is a type supported by the database.

    So it is impossible to have a value of type DBType[Boolean] and thus it is impossible to write a function map for DBValue. So there is no way to write a Functor instance for DBValue.

A Generalized Functor is very much like a regular Functor. But whereas the map function of functors have to work for every types A and B, the map function of generalized functor can be narrowed to only operate on a limited set of types A and B:

trait GenFunctor[P[_],F[_]] {
  def map[A,B](fa: F[A])(f: A => B)(implicit evA: P[A], evB: P[B]): F[B]
}

For example Set (more precisely TreeSet) is not a functor! Indeed there is no way to write a function map that works for any type B (because B need to have an ordering). But if we narrow map to the only types B having an ordering, we can write it.

import scala.collection.immutable._
object TreeSetFunctor extends GenFunctor[Ordering, TreeSet] {
  def map[A,B](fa: TreeSet[A])(f: A => B)(implicit evA: Ordering[A], evB: Ordering[B]): TreeSet[B] =
    TreeSet.empty[B](evB) ++ fa.toSeq.map(f)
}
  • Question 3: Write an instance of GenFunctor[DBType, DBValue].

    Solution (click to expand)
    object DBValueGenFunctor extends GenFunctor[DBType, DBValue] {
      def map[A,B](fa: DBValue[A])(f: A => B)(implicit evA: DBType[A], evB: DBType[B]): DBValue[B] =
        DBValue[B](f(fa.value))(evB)
    }

What we have done to Functor can be done for many data-structures and patterns. We can often limit the types on which a data-structure or a type-class can operate by adding an extra parameter like ev : DBType[A] to constructors and methods.

Use Case: Simplifying Implicits

This use case is one the most interesting but unfortunately, not one of the easiest. It illustrates how it is possible to use GADTs to simplify the creation of implicit values. In this example we consider lists of values whose items can be of different types. Theses lists are called heterogeneous lists. They are usually defined in Scala almost like normal lists:

final case class HNil() // The empty list
final case class HCons[Head,Tail](head: Head, tail: Tail) // The `head :: tail` operation

val empty : HNil =
  HNil()

val oneTrueToto : HCons[Int, HCons[Boolean, HCons[String, HNil]]] =
  HCons(1, HCons(true, HCons("toto", HNil())))

val falseTrueFive: HCons[Boolean, HCons[Boolean, HCons[Int, HNil]]] =
  HCons(false, HCons(true, HCons(5, HNil())))

As you can see, there is nothing special about it. We want to define orderings on heterogeneous lists. An ordering is a way to compare two values (of the same type!): they can be equal or one may be lesser than the other. In Scala we can define the trait Order:

trait Order[A] {
  // true if and only if a1 < a2
  def lesserThan(a1: A, a2: A): Boolean

  // a1 and a2 are equal if and only if none of them is lesser than the other.
  final def areEqual(a1: A, a2: A): Boolean = !lesserThan(a1, a2) && !lesserThan(a2, a1)

  // a1 > a2 if and only if a2 < a1
  final def greaterThan(a1: A, a2: A): Boolean = lesserThan(a2, a1)

  final def lesserThanOrEqual(a1: A, a2: A): Boolean = !lesserThan(a2, a1)

  final def greaterThanOrEqual(a1: A, a2: A): Boolean = !lesserThan(a1, a2)
}

object Order {
  def apply[A](implicit ev: Order[A]): ev.type = ev

  def make[A](lg_ : (A,A) => Boolean): Order[A] =
    new Order[A] {
      def lesserThan(a1: A, a2: A): Boolean = lg_(a1,a2)
    }
}

implicit val orderInt    = Order.make[Int](_ < _)
implicit val orderString = Order.make[String](_ < _)

Remember that we will only compare lists of the same type:

  • Lists of type HNil will only be compared to lists of type HNil.
  • Lists of type HCons[H,T] will only be compared to lists of type HCons[H,T].

Comparing lists of type HNil is trivial because there is only one value of type HNil (the empty list HNil()). But there are many ways of comparing lists of type HCons[H,T]. Here are two possible orderings (there exists many more!):

  • The lexicographic ordering (i.e. dictionary order: from left to right)

    HCons(h1,t1) < HCons(h2,t2) if and only if h1 < h2 or (h1 == h2 and t1 < t2 by lexicographic ordering).

    sealed trait Lex[A] {
      val order : Order[A]
    }
    
    object Lex {
      def apply[A](implicit ev: Lex[A]): ev.type = ev
    
      implicit val lexHNil: Lex[HNil] =
        new Lex[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def lexHCons[Head,Tail](implicit
          orderHead: Order[Head],
          lexTail: Lex[Tail]
        ): Lex[HCons[Head, Tail]] =
        new Lex[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = lexTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.lesserThan(t1,t2))
          }
        }
    }
  • The reverse-lexicographic ordering, which is the reverse version of the lexicographic ordering, (i.e. from right to left)

    HCons(h1,t1) < HCons(h2,t2) if and only if t1 < t2 by reverse-lexicographic ordering or (t1 == t2 and h1 < h2).

    sealed trait RevLex[A] {
      val order : Order[A]
    }
    
    object RevLex {
      def apply[A](implicit ev: RevLex[A]): ev.type = ev
    
      implicit val revLexHNil: RevLex[HNil] =
        new RevLex[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def revLexHCons[Head,Tail](implicit
          orderHead: Order[Head],
          revLexTail: RevLex[Tail]
        ): RevLex[HCons[Head, Tail]] =
        new RevLex[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = revLexTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderTail.lesserThan(t1,t2) || (orderTail.areEqual(t1,t2) && orderHead.lesserThan(h1,h2))
          }
        }
    }

As said above, it is possible to define more orderings:

  • Question 1: The Alternate ordering is defined by:

    HCons(h1,t1) < HCons(h2,t2) if and only if h1 < h2 or (h1 == h2 and t1 > t2 by alternate ordering).

    Just like what was done for Lex and RevLex, implement the Alternate ordering.

    Solution (click to expand)
    sealed trait Alternate[A] {
      val order : Order[A]
    }
    
    object Alternate {
      def apply[A](implicit ev: Alternate[A]): ev.type = ev
    
      implicit val alternateHNil: Alternate[HNil] =
        new Alternate[HNil] {
          val order = Order.make[HNil]((_,_) => false)
        }
    
      implicit def alternateHCons[Head,Tail](implicit
          orderHead: Order[Head],
          alternateTail: Alternate[Tail]
        ): Alternate[HCons[Head, Tail]] =
        new Alternate[HCons[Head, Tail]] {
          val orderTail: Order[Tail] = alternateTail.order
    
          val order = Order.make[HCons[Head, Tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.greaterThan(t1,t2))
          }
        }
    }

There are lots of ways to define a valid ordering on heterogeneous lists! Defining type classes like Lex, RevLex and Alternate for every ordering we want to implement is clunky and messy. We can do much better than that … with a GADT 😉

sealed trait HListOrder[A]
object HListOrder {
  final case object HNilOrder extends HListOrder[HNil]

  final case class HConsOrder[Head,Tail](
      orderHead: Order[Head],
      hlistOrderTail: HListOrder[Tail]
    ) extends HListOrder[HCons[Head,Tail]]

  // Implicit definitions

  implicit val hnilOrder : HListOrder[HNil] =
    HNilOrder

  implicit def hconsOrder[Head,Tail](implicit
      orderHead: Order[Head],
      hlistOrderTail: HListOrder[Tail]
    ): HListOrder[HCons[Head,Tail]] =
    HConsOrder(orderHead, hlistOrderTail)

  def apply[A](implicit ev: HListOrder[A]): ev.type = ev
}

Note that these implicit definitions are boilerplate. Their only purpose is passing arguments to their corresponding constructor (i.e. final case class or case object): hnilOrder to HNilOrder (O arguments) and hconsOrder to HConsOrder (2 arguments).

  • Question 2: Write the function def lex[A](implicit v : HListOrder[A]): Order[A] that computes the lexicographic ordering from a value of type HListOrder[A].

    Solution (click to expand)
    def lex[A](implicit v : HListOrder[A]): Order[A] =
      v match {
        case HListOrder.HNilOrder =>
          Order.make[HNil]((_,_) => false)
    
        case hc : HListOrder.HConsOrder[head,tail] =>
          val orderHead: Order[head] = hc.orderHead
          val orderTail: Order[tail] = lex(hc.hlistOrderTail)
    
          Order.make[HCons[head, tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderHead.lesserThan(h1,h2) || (orderHead.areEqual(h1,h2) && orderTail.lesserThan(t1,t2))
          }
      }
  • Question 3: Write the function def revLex[A](implicit v : HListOrder[A]): Order[A] that computes the reverse-lexicographic ordering from a value of type HListOrder[A].

    Solution (click to expand)
    def revLex[A](implicit v : HListOrder[A]): Order[A] =
      v match {
        case HListOrder.HNilOrder =>
          Order.make[HNil]((_,_) => false)
    
        case hc : HListOrder.HConsOrder[head,tail] =>
          val orderHead: Order[head] = hc.orderHead
          val orderTail: Order[tail] = revLex(hc.hlistOrderTail)
    
          Order.make[HCons[head, tail]] {
            case (HCons(h1,t1), HCons(h2,t2)) =>
              orderTail.lesserThan(t1,t2) || (orderTail.areEqual(t1,t2) && orderHead.lesserThan(h1,h2))
          }
      }

This approach has several benefits. Whereas the initial approach had to find one implicit by ordering, the GADT approach only have to find one! Considering implicit resolution is a costly operation, reducing it means faster compilation times. Reading the code of functions lex and revLex is easier than understanding how implicit resolution works for traits Lex and RevLex. Furthermore, they are just functions, you can use all you can code in building the instance Order[A].

Conclusion

Not so trivial, isn’t it? 😉 Actually a fair amount of the complexity you may have experienced comes to the fact that reasoning about values and types is almost never taught in programming courses. What you consider simple now (web APIs, Streaming, Databases, etc) would probably terrifies your younger self when you were introduced for the first time to “Hello World!”. You probably did not learn all you know in programming in three hours so do not expect reasoning about programs to be magically easier.

This workshop aimed at inspiring you, opening your mind to this all new set of possibilities. If you found the use case interesting, then take the time to understand the techniques.

Have fun and take care ❤️

Proving Primality with GADTs

Today we will explore the Curry–Howard correspondence. Our mission is writing, in Scala’s type system, the property on natural number of being prime. Wikipedia defines it by:

A natural number (1, 2, 3, 4, 5, 6, etc.) is called a prime number (or a prime) if it is greater than 1 and cannot be written as a product of two natural numbers that are both smaller than it.

An equivalent way to put it is:

A natural number is prime if it is greater than 1 and cannot be divided by any natural number greater than 1 but smaller than it.

These definitions are equivalent as, by definition, any natural number n is divisible by k if and only if it can be written n = k × p for some natural number p.

Writing a program whose execution checks whether a number is prime is easy. But we are not interested in executing programs, only compiling them! We want the compiler to verify that a number is indeed prime. At that point, you may wonder how it is even possible to use the compiler to “prove” something about numbers. That’s exactly the point of the Curry–Howard correspondence

The Challenge

You can write any positive integer in the input bow below:

  • Please write 3 in the input box above. A button shall appear letting you download Prime3.scala. Run it via

    scala Prime3.scala

    The file should compile and run flawlessly outputting ForAllRec(NotDivRec(NotDivBase(SIsPositive(),LTBase()),AddPlus1(AddZero())),ForAllBase()). Look into Prime3.scala, you should see a value prime3: Prime[_3] defined. The main method simply outputs this value.

  • Now, write 4 in the input box. Download and run Prime4.scala via

    scala Prime4.scala

    The file should compile but execution should failed with the exception scala.NotImplementedError: an implementation is missing. Look into Prime4.scala, the value prime4: Prime[_4] is defined by ???.

  • Read Prime4.scala carefully, starting from the beginning, and try to write a valid definition for val prime4: Prime[_4]. Remember to follow very scrupulously the rules stated in the first comment of Prime4.scala.

    • DO NOT ALTER, IN ANY WAY, THE DEFINITION OF ANY TYPE IN THE FILE
    • DO NOT ADD SUB CLASSES/OBJECTS TO TYPES IN THE FILE
    • DO NOT USE NULL IN ANY WAY
    • ONLY USE THE GIVEN CASE OBJECTS AND CASE CLASSES IN THE FILE
    • THE GOAL IS TO PRODUCE A val prime4: Prime[_4], NOT A def prime4: Prime[_4], NOT A lazy val prime4: Prime[_4]!
    • YOUR CODE SHOULD TYPE-CHECK AND RUN PRINTING THE VALUE prime4

Try to find valid values of type Prime[_N] when is not a prime number.

What the hell is going on ???

To encode properties over natural number, we need to start by encoding natural numbers. To do so, we associate to any natural number a type. Natural numbers can be constructed by starting from the first one, 0, and creating new ones by adding 1 successively. For example, 1 = 0 + 1, 2 = 0 + 1 + 1, 3 = 0 + 1 + 1 + 1, etc. Our encoding will mimic this construction, starting from a base type encoding 0 and an operation "+ 1". The number 0 is encoded as the type _0 defined as the final abstract class _0 and the operation "+ 1" as the type function final abstract class S[N <: Nat] which for every type N encoding a natural number n gives the type S[N] encoding the natural number n + 1. The type Nat can simply be defined as the alias type Nat = _0 | S[?] because a natural number is either 0 or obtained by some + 1 operations.

type Nat = _0 | S[?]
final abstract class _0
final abstract class S[N <: Nat]

type _1 = S[_0]
type _2 = S[_1] // S[S[_0]]
type _3 = S[_2] // S[S[S[_0]]]
type _4 = S[_3] // S[S[S[S[_0]]]]
type _5 = S[_4] // S[S[S[S[S[_0]]]]]
...

The next step is to define the type Prime[N] such that:

There exists a valid value of type Prime[N] if and only if N is (the type associated to) a prime number.

Proving that a Natural Number is Prime

Let n be a natural number and N its associated type (for example n=3 and N = _3). Then:

n is prime if and only if for all natural number m such that 2 ≤ m < n, then m does not divide n.

The type ForAll[X, N] encodes this exact property. There exists a value of type ForAll[X,N] if and only if both:

  • X ≤ N
  • For all M such that X ≤ M < N, M do not divide N

Actually the type Prime[N] is an alias for ForAll[_2, N]. We need to encode two more properties:

  • For I and J two natural numbers, the property that I is less than or equal to J (I ≤ J). It is encoded as the type LessThan[I, J].
  • For I and J two natural numbers, the property that I does not divide J. It is encoded as the type NotDiv[I, J].

Read the file PrimeN.scala carefully, each step is described and explained in much details.

Conclusion

Why asking you to build a value that not exists? Because the main point is not knowing whether there exists a value of type Prime[_4] but understanding why such a value (following all the rules!) cannot exists!

It is widely known and accepted in programming culture that every type has values. After all, types exists only to qualify values, right? And instantiating a type T is as simple as calling new! There is one huge problem with this claim: it is completely wrong!. The idea that a types can have no value, often called empty type or uninhabited type, is the cornerstone of a lot of techniques including logic, maths, programming with rich types, formal systems like Coq, etc.

This example is indeed both complicated and complex. It is neither a regular usage of GADTs nor something meant for production! It’s perfectly ok being confused about it or not understanding what is going on. As I said, it is an complicated and complex example!! But when you manage to understand it, consider you master the subject.

Recursion Schemes: the high-school introduction

I gave a talk on Thursday the 28th of march 2019 at the 96th Paris Scala User Group session on about this. The slides are here.

Recursion schemes are said to be a tough subject. Articles and presentations often flood the audience with lots of names such as Algebra, CoAlgebra, catamorphisms, anamorhpisms, hylomorphism, etc. Is knowing all these concepts required to understand recursion schemes? I have good news for you: it isn’t! All you need, to see what recursion schemes are and why there are useful, can be presented with just a single basic function, often taught as an introduction to programming: factorial. I’m glad to welcome you to to the high-school introduction to recursion scheme 😉.

Learning Recursion Schemes

Before diving into the subject, let’s take a moment to contextualize. Recursion-schemes, like most of advanced functional programming techniques, is almost never taught in programming courses or books. It means there is a strong chance the subject, and the concepts it relies upon, is totally new to you. I want you to remember you haven’t learnt programming in one day, and you probably did not start learning programming by implementing a distributed steaming application over a spark cluster from scratch. Like most of us, you probably started by coding some sort of Hello, World!. Let’s face it, real business application are a lot more complex than this. Do you imagine what a first introduction to programming would be, if instead of asking people to write a simple Hello, World!, we would ask them to write a real state-of-the-art large-scale business application that meets all the requirements we expect in production nowadays? Learning takes time! Start with toy examples that are indeed far from real-word cases but enables you to grow your understanding, one step at a time.

The examples below are indeed toy examples. When i develop with recursion scheme, like any specialist in any field, i use specialist techniques and vocabulary (you know, the usual vocabulary from category and type theory). But if you’re reading this, it probably means you’re not a recursion-scheme specialist yet. Using complex words flatters our ego, which is very enjoyable, but developing a deep understanding of these notions is far better! So let’s put our ego aside for a moment and accept to start with the basics.

In Recursion Schemes, there is Recursion

First of all, let me present you the famous factorial function. It is defined on non-negative numbers n as the product of all numbers between 1 and n included:

$$fact(n) = 1 \times 2 \times 3 \times \cdots \times n$$

To ease the presentation we will take Int as the type of non-negative integers. Obviously in production code negative values should be handled appropriately but for simplicity’s sake, we will define fact in Scala and in Haskell as

def fact(n: Int): Int =
  if (n == 0) 1
  else {
    val r = fact(n-1)
    n * r
  }
fact :: Int -> Int
fact 0 = 1
fact n = let r = fact (n - 1)
         in n * r

Factorial is written here as a recursive function. As you probably know, it can also be written as an iterative one (using a for or while loop) but the subject of this article is Recursion Schemes, not Iterative Schemes, so let’s use recursion. This function computes fact(2) as follows:

  • fact(2) = 2 * fact(1) so it needs to compute fact(1)
  • fact(1) = 1 * fact(0) so it needs to compute fact(0)
  • fact(0) = 1
  • now that the result of fact(0) is known, it can replace the call of fact(0) by its result which gives fact(1) = 1 * fact(0) = 1 * 1 = 1
  • now that the result of fact(1) is known, it can replace the call of fact(1) by its result which gives fact(2) = 2 * fact(1) = 2 * 1 = 2.

Look at how fact(n) is calling it-self: if n = 0 then it doesn’t call itself, otherwise it calls itself on n - 1. Let’s split this definition in two parts: the first one contains all the code relevant to how fact is calling itself but only it, the second one is made of the rest. There is no clear rule for what is relevant and what is not. Different splits may work, they will just give rise to different schemes, which is not a problem at all. You just need to find one that fits your needs.

For fact, the key element to note is it is not calling itself when n = 0 but otherwise calls itself with n - 1. The constant returned in the n = 0 case and the operation done in the other one have no impact on how fact recurses. So i choose to split it by taking all code not relevant to recursion out of its body:

/* Part irrelevant to recursion:
 * The real definitions of these variables
 * have no impact on how fact is calling itself
 */
val baseCase: Int = 1
def recCase(n: Int, r: Int): Int = n * r

/* Recursion-only part:
 * The only implementation details it contains
 * are about how fact it calling itself
 */
def fact(n: Int): Int =
  if (n == 0) baseCase
  else {
    val r = fact(n-1)
    recCase(n, r)
  }
{-
 Part irrelevant to recursion:
 The real definitions of these variables
 Have no impact on how fact is calling itself
-}
baseCase :: Int
baseCase  = 1

recCase :: Int -> Int -> Int
recCase n r = n * r
 
{-
 Recursion-only part:
 The only implementation details it contains
 are about how fact it calling itself
 -}
fact :: Int -> Int
fact 0 = baseCase
fact n = let r = fact (n-1)
         in recCase n r

Let me present you another function, also defined on non-negative numbers n, but that computes this time the sum of all numbers between 1 and n included:

$$sum(n) = 1 + 2 + 3 + \cdots + n$$
def sum(n: Int): Int =
  if (n == 0) 0
  else {
    val r = sum(n-1)
    n + r
  }
sum :: Int -> Int
sum 0 = 0
sum n = let r = sum (n - 1)
        in n + r

We can apply the same technique to sum: splitting the definition into two parts, one containing all but only recursion-relevant code, and the other the rest. It gives:

/* Part irrelevant to recursion:
 * The real definitions of these variables
 * have no impact on how sum is recurs
 */
val baseCase: Int = 0
def recCase(n: Int, r: Int): Int = n + r

/* Recursion-only part:
 * The only implementation details it contains
 * are about how fact it recurs
 */
def sum(n: Int): Int =
  if (n == 0) baseCase
  else {
    val r = sum(n-1)
    recCase(n, r)
  }
{-
 Part irrelevant to recursion:
 The real definitions of these variables
 Have no impact on how fact is calling itself
-}
baseCase :: Int
baseCase = 0

recCase :: Int -> Int -> Int
recCase n r = n + r
 
{-
 Recursion-only part:
 The only implementation details it contains
 are about how fact it calling itself
 -}
sum :: Int -> Int
sum 0 = baseCase
sum n = let r = sum (n-1)
        in recCase n r

Do you see how similar the recursion-relevant parts of sum and fact are? They are actually identical! It means fact and sum have the same recursion structure. The recursion-irrelevant part differ: the constant baseCase which is 1 in fact but 0 in sum and operation recCase which is n * r in fact but n + r in sum. Note that if we replace, in each case, occurrences of baseRec and recCase by their definition, we get back the original functions. Look at the common recursive-relevant part:

def commonRecursiveRelevantPart(n: Int): Int =
  if (n == 0) baseCase
  else {
    val r = commonRecursiveRelevantPart(n-1)
    recCase(n, r)
  }
commonRecursiveRelevantPart :: Int -> Int
commonRecursiveRelevantPart 0 = baseCase
commonRecursiveRelevantPart n = let r = commonRecursiveRelevantPart (n-1)
                                in recCase n r

Obviously, for this code to be correct, baseCase and recCase have to be defined. Let’s fix this by taking them as arguments:

def scheme(baseCase: Int, recCase: (Int, Int) => Int): Int => Int = {
  def commonRecursiveRelevantPart(n: Int): Int =
    if (n == 0) baseCase
    else {
      val r = commonRecursiveRelevantPart(n-1)
      recCase(n, r)
    }
  
  commonRecursiveRelevantPart
}
scheme :: Int -> (Int -> Int -> Int) -> Int -> Int
scheme baseCase recCase = commonRecursiveRelevantPart
  where
    commonRecursiveRelevantPart :: Int -> Int
    commonRecursiveRelevantPart 0 = baseCase
    commonRecursiveRelevantPart n = let r = commonRecursiveRelevantPart (n-1)
                                    in recCase n r

It is then trivial to define both fact and sum by feeding scheme with corresponding definitions for baseCase and recCase:

def fact: Int => Int = scheme(1, (n: Int, r:Int) => n * r)
def sum : Int => Int = scheme(0, (n: Int, r:Int) => n + r)
fact :: Int -> Int
fact = scheme 1 (*)

sum :: Int -> Int
sum = scheme 0 (+)

We can now give a first answer to how recursion schemes can be useful. They enable to to write less code which is both easier and safer. But there is more! Recursive calls, like any function calls, consume the stack. If there are too many recursive calls (i.e. when n is to big), there is a risk of stack overflow. Some languages like Scala are smart enough to avoid, in some cases, this problem by transforming tail-recursive functions into iterative loops. Unfortunately not all recursive functions are tail-recursive. Writing recursive functions as iterative loops is not the solution either since it is intricate and error-prone. Fortunately it is enough to only write the recursion scheme once:

def scheme(baseCase: Int, recCase: (Int, Int) => Int)(n: Int): Int = {
  var res = baseCase
  var i: Int = 1
  while (i <= n) {
    res = recCase(i, res)
    i += 1
  }
  res
}
scheme :: Int -> (Int -> Int -> Int) -> Int -> Int
scheme baseCase recCase n = aux baseCase 1
  where
    aux res i = if i <= n
                then aux (recCase i res) (i + 1)
                else res

Note that the scheme is usually simpler to write as it only focuses on recursion, not business logic. Furthermore one scheme may fit many functions thus reducing the complexity and bugs in writing business functions. Remember that fact and sum are purposely trivial. They are just toy example to introduce the subject. In practice you will use much more complex recursive functions. Once you’ve understood this example, you’ll be able to scale this technique to any recursive one, however complex it is.

Scaling up!

To be sure we have a good understanding of the techniqe, let’s apply it to the fibonacci function we all love. It is defined on non-negative integers by

$$fib(0) = 1$$$$fib(1) = 1$$$$fib(n+2) = fib(n+1) + fib(n)$$
def fib(n: Int): Int =
  n match {
    case 0 => 1
    case 1 => 1
    case n =>
      val r1 = fib(n-1)
      val r2 = fib(n-2)
      r1 + r2
  }
fib :: Int -> Int
fib 0 = 1
fib 1 = 1
fib n = r1 + r2
  where
    r1 = fib (n - 1)
    r2 = fib (n - 2)

The function fib does not call itself when n is 0 or 1 but calls itself twice, on n-1 and n-2 otherwise. So we can, like fact and sum, split fib into two pieces: one containing only recursion-relevant code and the other one the rest. Once again the split is done by taking recursion-irrelevant code out of the function’s body. Remember they are many ways to split it up. This one is just one of many sensible way of doing so:

/* Part irrelevant to recursion:
 * The real definitions of these variables
 * have no impact on how fact is calling itself
 */
val baseCase0: Int = 1
val baseCase1: Int = 1
def recCase(r1: Int, r2: Int): Int = r1 + r2

/* Recursion-only part:
 * The only implementation details it contains
 * are about how fib it calling itself
 */
def fib(n: Int): Int =
  n match {
    case 0 => baseCase0
    case 1 => baseCase1
    case n =>
      val r1 = fib(n-1)
      val r2 = fib(n-2)
      recCase(r1, r2)
  }
{-
 Part irrelevant to recursion:
 The real definitions of these variables
 Have no impact on how fact is calling itself
-}
baseCase0 :: Int
baseCase0 = 1

baseCase1 :: Int
baseCase1 = 1

recCase :: Int -> Int -> Int
recCase n r = n + r

{-
 Recursion-only part:
 The only implementation details it contains
 are about how fact it calling itself
 -}
fib :: Int -> Int
fib 0 = baseCase0
fib 1 = baseCase1
fib n = recCase r1 r2
  where
    r1 = fib (n - 1)
    r2 = fib (n - 2)

Which leads to the recursion scheme:

def scheme(baseCase0: Int, baseCase1: Int, recCase: (Int, Int) => Int)(n: Int): Int =
  n match {
    case 0 => baseCase0
    case 1 => baseCase1
    case n =>
      val r1 = fib(n-1)
      val r2 = fib(n-2)
      recCase(r1, r2)
  }
scheme :: Int -> Int -> (Int -> Int -> Int) -> Int -> Int
scheme baseCase0 baseCase1 recCase = aux
  where
    aux 0 = baseCase0
    aux 1 = baseCase1
    aux n = recCase r1 r2
      where
        r1 = aux (n - 1)
        r2 = aux (n - 2)

It is then trivial to define fib by giving appropriate definition to scheme arguments: baseCase0, baseCase1 and recCase.

def fib: Int => Int = scheme(1, 1, (r1: Int, r2: Int) => r1 + r2)
fib :: Int -> Int
fib = scheme 1 1 (+)

Once again this implementation is not optimal as each call of fib can make to 2 recursive calls which leads to an exponential time complexity. While computing fib(5) is fast, computing fib(1000) may take much longer. As you already probably guessed, writing the recursion scheme as an iterative loop, which sadly makes it more intricate, solves the problem:

def scheme(baseCase0: Int, baseCase1: Int, recCase: (Int, Int) => Int)(n: Int): Int =
  if (n == 0) baseCase0
  else {
    var b0 = baseCase0
    var b1 = baseCase1
    var i = 2
    while (i <= n) {
      val b2 = recCase(b0, b1)
      b0 = b1
      b1 = b2
      i += 1
    }
    b1
  }
scheme :: Int -> Int -> (Int -> Int -> Int) -> Int -> Int
scheme baseCase0 baseCase1 recCase 0 = baseCase0
scheme baseCase0 baseCase1 recCase n = aux baseCase0 baseCase1 2
  where
    aux b0 b1 i = if i <= n
                  then aux b1 (recCase b0 b1) (i + 1)
                  else b1

By now you should get a good grasp on what recursion schemes are. But we have only seen a tiny fraction of how useful they are. It’s about time to consider the real power of fact, sum and fib’s schemes.

Time to take off!

Previously we defined fact and sum’s schemes as

def scheme(baseCase: Int, recCase: (Int, Int) => Int)(n: Int): Int = {
  var res = baseCase
  var i: Int = 1
  while (i <= n) {
    res = recCase(i, res)
    i += 1
  }
  res
}
scheme :: Int -> (Int -> Int -> Int) -> Int -> Int
scheme baseCase recCase n = aux baseCase 1
  where
    aux res i = if i <= n
                then aux (recCase i res) (i + 1)
                else res

I have a small exercise for you: find where this code relies on baseCase to be an Int? It’s important, take the time to figure it out. The answer is simple: it does not! baseCase can actually be any of type A! We don’t even have to modify the code (only the type signature):

def scheme[A](baseCase: A, recCase: (Int, A) => A)(n: Int): A = {
  var res = baseCase
  var i: Int = 1
  while (i <= n) {
    res = recCase(i, res)
    i += 1
  }
  res
}
scheme :: a -> (Int -> a -> a) -> Int -> a
scheme baseCase recCase n = aux baseCase 1
  where
    aux res i = if i <= n
                then aux (recCase i res) (i + 1)
                else res

Not only can we still define fact (and sum) like above but it makes trivial defining the functions like list which returns the list of integers between n and 1:

def list: Int => List[Int] = scheme[List[Int]](Nil, (n: Int, r: List[Int]) => n :: r)
list :: Int -> [Int]
list = scheme [] (:)

Unsurprisingly fib’s recursion scheme can also be generalized without changing a single line of code (only type signature):

def scheme[A](baseCase0: A, baseCase1: A, recCase: (A, A) => A)(n: Int): A =
  if (n == 0) baseCase0
  else {
    var b0 = baseCase0
    var b1 = baseCase1
    var i = 2
    while (i <= n) {
      val b2 = recCase(b0, b1)
      b0 = b1
      b1 = b2
      i += 1
    }
    b1
  }
scheme :: a -> a -> (a -> a -> a) -> Int -> a
scheme baseCase0 baseCase1 recCase 0 = baseCase0
scheme baseCase0 baseCase1 recCase n = aux baseCase0 baseCase1 2
  where
    aux b0 b1 i = if i <= n
                  then aux b1 (recCase b0 b1) (i + 1)
                  else b1

While fact’s scheme is related to lists, fib’s one is related to trees:

sealed abstract class Tree[+A]
final case class Leaf[+A](value: A) extends Tree[A]
final case class Node[A](left: Tree[A], right: Tree[A]) extends Tree[A]

def tree: Int => Tree[Boolean] =
  scheme(
    Leaf(false),
    Leaf(true),
    (r1: Tree[Boolean], r2: Tree[Boolean]) => Node(r1,r2)
  )
data Tree a = Leaf a | Node (Tree a) (Tree a)

tree :: Int -> Tree Bool
tree = scheme (Leaf False) (Leaf True) Node

I have few real exercises for you this time:

  • find in your production code several spots where this scheme could be useful.
  • write schemes, as general as possible, for at least 5 recursive functions in our production code.

Obviously I won’t check you did the exercises but you should really do them. Reading is not sufficient to develop your understanding of the technique, you need to experiment! Try things, play with these notions until it clicks. Learning recursion schemes is like going on expeditions: preparation time may seem the easier part but if you did not prepare well enough, you’ll get lost.

Yeah! Buzzwords!

As we have seen, fact’s scheme takes 2 arguments:

def scheme[A](baseCase: A, recCase: (Int, A) => A): Int => A
scheme :: a -> (Int -> a -> a) -> Int -> a

While this definition is perfectly ok, we can regroup these argument in any structure that can hold both values like a pair, an interface or a trait:

trait FactorialSchemeArguments[A] {
  val baseCase: A
  def recCase(n: Int, r: A): A
}

def scheme[A](arguments: FactorialSchemeArguments[A]): Int => A
class FactorialSchemeArguments a where
  baseCase :: a
  recCase :: Int -> a -> a

scheme :: FactorialSchemeArguments a => Int -> a

Note that scheme is still the same: it still takes the same two arguments. But even if the code didn’t change, this transformation makes us see scheme from a different perspective. It shows scheme as a functions transforming an integer to an A provided that we give some structure to A: a constant baseCase and an operation recCase. Let’s give this structure and the scheme names: i decide to call the structure a AkolovioaAlgebra (don’t look for it in literature, i just coined the term) and the scheme an akolovioaMorphism:

trait AkolovioaAlgebra[A] {
  val initial: A
  def action(n: Int, r: A): A
}

def akolovioaMorphism[A: AkolovioaAlgebra]: Int => A
class AkolovioaAlgebra a where
  initial :: a
  action :: Int -> a -> a

akolovioaMorphism :: AkolovioaAlgebra a => Int -> a

This looks smart, doesn’t it? 😉 It is actually very close to a very common structure in programming! Will you find which one? Obviously the same can be done for fibonacci’s scheme. As an exercise, apply this technique to fibonacci’s scheme and give them pretty names.

Where to go from here?

As you know this is not the end of the story: the subject is closely related to pervasive notions such as (co)algebras, inductive types, categories, initial-objects, fixed-points, algebraic data types, etc. Whichever next subject you choose to dive into, the approach this article follows, i.e. experimenting on toy examples, really helps developing a solid understanding. I want you to realize each definition you read in books, articles, talks, etc is the result of people experimenting. The common trap in this field is looking at definitions as sacred pieces of unquestionable truth no mortal can see through. It is actually the exact opposite! Science is by essence experimentation. This is by investigating and trying things you end up figuring out how things work. But, like in science, for your investigation to be productive your tests need to be done in a controlled environment with as few variables as possible so that it is easy for you to see what’s going on. That’s why toy examples are so important: they contain the essence of what makes things work without all the noise real examples have.

Take care and spread recursion schemes around 😉

Demystifying GADTs

Generalized Algebraic Data Types (GADT) is certainly one of the most feared concept in programming nowadays. Very few mainstream languages support GADTs. The only ones i know which does are Haskell, Scala, OCaml and Haxe. The idea is actually very simple but often presented in complicated ways. In fact, if you’re familiar to both basic Object-Oriented-with-Generics and basic functional programming, then you most probably are already familiar with GADTs without even knowing you are. But if GADTs are so simple, why so many people feel terrified by them? Well GADTs rely on two fundamental ideas, one of them is known by every Object-Oriented-with-Generics programmer while the other is known by every functional programmer. The problem is most people make the huge mistake of opposing them even though they are complementary. So before diving into GADTs, let me remind you of these elementary notions from Object-Oriented and functional programming.

Object-Oriented Programming 101

Let’s start by some plain old Java (the examples works in probably all Object-Oriented language which supports generics). We want to define an abstract class for sequences:

public abstract class Sequence<A> {
  abstract public int length();
  abstract public A   getNthElement(int nth);
}

In Java it would be better to define Sequence<A> as an interface but i want this example to be as simple as possible. Would you be surprised if told you a String is a sequence of characters ? ;) As i said, GADTs rely on very basic programming knowledge.

public class MyString extends Sequence<Character> {
  private String str;

  public MyString(String s) {
    this.str = s;
  }

  public int length() {
    return this.str.length();
  }

  public Character getNthElement(int nth) {
    return this.str.charAt(nth);
  }
}

Likewise, bytes are sequences of 8 bits (we represent a bit by a Boolean):

public final class MyByte extends Sequence<Boolean> {
  private byte bte;

  public MyByte(byte x) {
    this.bte = x;
  }

  public int length() {
    return 8;
  }

  public Boolean getNthElement(int nth) {
    if (nth >= 0 && nth <= 7)
      return ((bte >>> nth & 1) == 1);
    else
      throw new java.lang.IndexOutOfBoundsException("");
  }
}

Have you noticed how MyByte and MyString declares themselves being respectively a sequence of booleans (Sequence<Boolean>) and a sequence of characters (Sequence<Character>) but not sequences of A (Sequence<A>) for any type A? Let’s try to make it work for any type A:

public final class MyByte<A> extends Sequence<A> {
  private byte bte;

  public MyByte(byte x) {
    this.bte = x;
  }

  public int length() {
    ???
  }

  public A getNthElement(int nth) {
    ???
  }
}

How would you write the methods length and getNthElement? Do you really imagine what would be a MyByte<Graphics2D>? It just doesn’t make any sense at all. You could argue that a string is also a sequence of byte and a byte a sequence of one byte. Indeed this relation is not unique, but it does not change the fact that it works for only a small selection of type A and not every one! We can go even deeper in Object-Oriented Programming:

public final class MyArray<A extends Number> extends Sequence<Number> {
  private A[] array;

  public MyArray(A[] a) {
    this.array = a;
  }

  public int length() {
    return this.array.length;
  }

  public Number getNthElement(int nth) {
    return this.array[nth];
  }
}

Note how the generics A, which is required to be a sub-class of Number, is present as argument of MyArray but not in extends Sequence<Number>. Now what do you think about this code? Do you think it can be wrong?

public static <A> void guess(Sequence<A> x) {
  if (x instanceof MyByte) {
    System.out.println("I guess A is actually Boolean, let's check!");
    System.out.println(((Sequence<Boolean>)x).getNthElement(0).getClass().getName());
  } else
  if (x instanceof MyString) {
    System.out.println("I guess A is actually Character");
    System.out.println(((Sequence<Character>)x).getNthElement(0).getClass().getName());
  } else
  if (x instanceof MyArray<?>) {
    System.out.println("I guess A is a sub-class of Number but i can not guess which one");
    System.out.println(((Sequence<?>) x).getNthElement(0).getClass().getName());
  }
  else
   System.out.println("I don't know what A is");
}
  • If x is an instance of MyByte, which is a sub-class of Sequence<Boolean>, then by trivial inheritance x is also an instance of Sequence<Boolean>. In this case A is forced to be Boolean.
  • If x is an instance of MyString, which is a sub-class of Sequence<Character>, then again by trivial inheritance x is also an instance of Sequence<Character> . In this case A has to be Character.
  • If x is an instance of MyArray<A> for some type A, which is a sub-class of Sequence<Number>, then once again by trivial inheritance x is an instance of Sequence<Number>. In this case we know A is a sub-class of Number but we don’t know which one.

This is the essence of Generalized Algebraic Data Types. It you understand the code above, then you understand how GADTs work. As you see this is very basic Oriented-Object with Generics. You can find lots of examples of this kind in almost every Java/C#/etc project (search for the instanceof keyword).

Functional Programming 101

Functional languages often support a feature called Algebraic Data Types (ADT) which is essentially enumerations on steroids. Like enumerations this is a disjoint union of a fixed number of cases but unlike enumerations, where each case is a constant, ADTs cases can have parameters. As an example, the type of lists whose elements are of type a, written List a in Haskell, is defined:

data List a = Nil | Cons a (List a)

It means any value of type List a belong to exactly one of the following cases:

  • either the value is the constant Nil which represents the empty list.
  • or the value is Cons hd tl which represent the list whose first element is hd (of type a) and whose tail is tl (of type List a).

The list [1,2,3,4] is encoded by Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil))). Where ADTs really shine is pattern-matching which is a very powerful and flexible switch (as i said above, ADTs are enumerations on steroids). ADTs being made of a fixed number of distinct cases, pattern-matching enable to inspect values and perform computations based on a case by case analysis of the form of the value. Here is how to implement the merge sort algorithm on this type:

split :: List a -> (List a, List a)
split l =
  case l of
    Cons x (Cons y tl) -> (case split tl of
                              (l1, l2) -> (Cons x l1, Cons y l2)
                          )
    Cons _ Nil         -> (l  , Nil)
    Nil                -> (Nil, Nil)

merge :: (a -> a -> Bool) -> List a -> List a -> List a
merge isLessThan l1 l2 =
  case (l1, l2) of
    (_           , Nil         ) -> l1
    (Nil         , _           ) -> l2
    (Cons hd1 tl1, Cons hd2 _  ) | hd1 `isLessThan` hd2 ->  Cons hd1 (merge isLessThan tl1 l2)
    (_           , Cons hd2 tl2)                        ->  Cons hd2 (merge isLessThan l1 tl2)

sort :: (a -> a -> Bool) -> List a -> List a
sort isLessThan l =
  case l of
    Nil        -> Nil
    Cons _ Nil -> l
    _          -> case split l of
                    (l1, l2) -> merge isLessThan (sort isLessThan l1) (sort isLessThan l2)

I know there are smarter ways to write it in Haskell but this article is not about it. The code above could be translated trivially in OCaml by replacing case ... of by match ... with, in Scala by ... match { ... }, etc. This style is valid is probably all languages supporting pattern-matching so it fits our goal.

The case l of expressions are pattern-matching. They are a sequence of pattern | condition -> code. The code being executed is the right-hand side of the first case for which the value l is of the form of its pattern and satisfy the condition. l is then said to match this case. For example, the case Cons x (Cons y tl) -> (case split tl of (l1, l2) -> (Cons x l1, Cons y l2)) states that if l is of the form Cons x (Cons y tl), which means that there are three values x, y and tl such that l == Cons x (Cons y tl), then the code executed is (case split tl of (l1, l2) -> (Cons x l1, Cons y l2)). One very important condition is that pattern-matching must be exhaustive! It means that the sequence of cases must cover all possible value of l.

If your understand the previous section, the type List a and how pattern-matching works in the example above, then i am very glad to inform you that you already understand GADTs! Well done :)

Summing up!

In this section i assume previous sections are ok for you. If you do not understand previous examples, don’t go further but go back to the basics of generics and pattern-matching. Likewise, if you find what follows complicated, go back to the basics generics and pattern-matching. There is no shame in doing so! Difficulties in understanding advanced notion is often the reflect of a lack of understanding in the ones they rely upon. As i said, there is no shame in it, if you think programming paradigms are “simple” then write a compiler ;)

It’s about time to sum up everything. First, note that List a is not one single type. Each type a actually gives rise to a distinct type List a. For example List Int, List String, List (List Bool), etc are all distinct types. Indeed the list Cons 1 Nil is neither a list of strings nor of booleans! For each type a, the type List a have two constructors: the constant Nil and the function Cons :: a -> List a -> List a which builds a List a from a value of type a and other List a.

There is another equivalent way to define List a in Haskell which makes the nature of the constructor more apparent:

data List a where
  Nil  :: List a
  Cons :: a -> List a -> List a

Indeed, for each type a, Nil is constant of type List a while Cons is a function of type a -> List a -> List a. Note that it is actually very close to the way to define it in Scala:

sealed abstract class List[A]
final case class Nil[A]() extends List[A]
final case class Cons[A](head: A, tail: List[A]) extends List[A]

Do you remember the example of the first section Sequence<A>? There was three sub-classes of Sequence<A: MyString which is actually a sub-class of Sequence<Character>, MyByte which is a sub-class of Sequence<Boolean> and MyArray<A extends Number> which is a sub-class of Sequence<Number>. What is the type of their constructors? Some admissible type for them is (in Scala notation):

def MyString             : String   => Sequence[Character]
def MyByte               : Byte     => Sequence[Boolean]
def MyArray[A <: Number] : Array[A] => Sequence[Number]

From this, this is trivial to write:

sealed abstract class Sequence[A]
final case class MyString(str: String)                 extends Sequence[Character]
final case class MyByte(bte: Byte)                     extends Sequence[Boolean]
final case class MyArray[A <: Number](array: Array[A]) extends Sequence[Number]

or in Haskell:

data Number where
  MkNum :: forall a. Num a => Number

data Sequence a where
  MyString :: String -> Sequence Char
  MyByte   :: Word8  -> Sequence Bool
  MyArray  :: forall a. Num a => List a -> Sequence Number

Sequence is a GADT. What makes it different from List above? For any type a, values of type List a are build using the two constructors Nil and Cons. Note that it does not depend on what a is. Values of type List Int are build using the exact same constructors than List Bool, List String, List (List Char), etc. Sequence have three constructors MyString, MyByte and MyArray. But values of type Sequence[Character] can only be built by the constructor MyString while values of type Sequence[Boolean] can only be built by the constructor MyByte and values of type Sequence[Number] can only be built by the constructor MyArray. What about values of type Sequence[Unit] or Sequence[String], etc? There is simply no constructor to build values of these types, so there is no values of these types!

We can rewrite the methods on Sequence and the guess function to use patten-matching:

def length[A](x: Sequence[A]): Int =
  x match {
    case MyByte(_)     => 8
    case MyString(str) => str.length
    case MyArray(arr)  => arr.size
  }

def getNthElement[A](x: Sequence[A], nth: Int): A =
  x match {
    case MyByte(bte) => // So A is actually Boolean
      if (nth >= 0 && nth <= 7)
        (bte >>> nth & 1) == 1
      else
        throw new java.lang.IndexOutOfBoundsException("")

    case MyString(str) => // So A is actually Character
      str.charAt(nth)

    case MyArray(array) => // So A is actually a sub-class of Number
      array(nth)
  }

def guess[A](x : Sequence[A]): Unit =
  x match {
    case MyByte(bte) =>
      println("I guess A is actually Boolean, let's check!")
      println(getNthElement(x, 0).getClass.getName)

    case MyString(str) =>
      println("I guess A is actually Character")
      println(getNthElement(x, 0).getClass.getName)
  
    case MyArray(array) =>
      println("I guess A is a sub-class of Number but i can not guess which one")
      println(getNthElement(x, 0).getClass.getName)
  }

As you can see getNthElement must returns a value of type A but the case MyByte returns a Boolean. It means Scala is aware that in this case A is actually Boolean. Likewise in the case MyString, Scala knowns that the only possible concrete type for A is Character so it accepts we return one. Scala is (most of the time) able to guess, depending on the case, what are the constraints on A. This is all the magic behind GADTs: specialized constructors like in object-oriented-with-generics programming and closed types (i.e. with a fixed number of cases) on which we can pattern-match like in usual functional programming.

How are GADTs useful? First of all, there are handy when you have a specialized constructor like in every day life object-oriented programming. It makes sense for a byte (resp. string) to be sequence of booleans (resp. characters) but not a sequence of anything. A prolific use of this is writing implicits in Scala as GADTs. This way we can pattern-match on the structure of the implicits to derive instances (see this gist for more details). They are also very useful to encode properties on types. As i said above, not all types Sequence[A] have (non-null) values! There is no (non-null) value of type Sequence[Unit] or Sequence[String] etc but there are values of type Sequence[Boolean], Sequence[Character] and Sequence[Number]. So if i give you a value of type Sequence[A], then you know A is either Boolean, Character or Number. If you don’t believe me, try to call the function guess on a type A which is neither Boolean nor Character nor Number (without using null)! Let me give you some useful examples.

The first one is restricting a generic type like in the code below. The GADT IsIntOrString forces A to be either String or Int in the function and the case class.

sealed abstract class IsIntOrString[A]
implicit final case object IsInt    extends IsIntOrString[Int]
implicit final case object IsString extends IsIntOrString[String]

def canOnlyBeCalledOnIntOrString[A](a: A)(implicit ev: IsIntOrString[A]): A =
  ev match {
    case IsInt => // A is Int
      a + 7
    case IsString => // A is String
      a.reverse
  }

final case class AStringOrAnIntButNothingElse[A](value: A)(implicit val proof : IsIntOrString[A])

Another handy use is encoding effects:

trait UserId
trait User

sealed abstract class BusinessEffect[A]
final case class GetUser(userId: UserId) extends BusinessEffect[User]
final case class SetUser(user: User)     extends BusinessEffect[UserId]
final case class DelUser(userId: UserId) extends BusinessEffect[Unit]

Have you ever heard that Set is not a functor? With the usual definition of a functor, indeed Set is not one.

trait Functor[F[_]] {
  def map[A,B](fa: F[A])(f: A => B): F[B]
}

The reason is you can only have a Set[A] for types A such that you can compare values. As an example let A be Int => Int. The two following functions are arguably equal:

val doubleByMult: Int => Int = (x: Int) => 2 * x
val doubleByPlus: Int => Int = (x: Int) => x + x

scala> Set(doubleByMult).contains(doubleByPlus)
res0: Boolean = false

This is just impossible, in the general case, to know if two functions compute the same thing. I didn’t just say we don’t know how to do it. It is actually proven that this is impossible (like no one can, and no one could for ever!). Have a look at this List of undecidable problems for more information on the subject. Using extensional equality (the one where f == g if and only f(x) == g(x) for all x), there is just no implementation of Set[Int => Int]. But if Set was a functor, it would be trivial using map to get a Set[Int => Int]:

Set[Boolean](true, false).map {
  case true  => doubleByMult
  case false => doubleByPlus
}: Set[Int => Int]

The conclusion is that Set is not a functor … in the usual (i.e. Scal) category. But it is in for some categories. The problem with Functor is map can be applied on any A and B which is impossible for Set. But if we restrict A and B such that they have interesting properties (like having an Ordering), then it works. In the code below, the GADT predicate is used to restrict on which A and B map can be applied on:

trait GenFunctor[predicate[_],F[_]] {
  def map[A,B](fa: F[A])(f: A => B)(implicit proofA: predicate[A], proofB: predicate[B]): F[B]
}

Then Set is becomes a functor with Ordering as predicate:

object SetInstance extends GenFunctor[Ordering, Set] {
  def map[A,B](fa: Set[A])(f: A => B)(implicit orderingA: Ordering[A], orderingB: Ordering[B]): Set[B] =  {
    val set = TreeSet.newBuilder(orderingB)
    for (a <- fa) set += f(a)
    set.result
  }
}

Surprisingly even String can be a functor (with A and B being both Char)!!!

sealed abstract class IsItChar[A]
implicit final case object YesItIsChar extends IsItChar[Char]

type StringK[A] = String

object StringInstance extends GenFunctor[IsItChar, StringK] {
  def map[A,B](fa: String)(f: A => B)(implicit proofA: IsItChar[A], proofB: IsItChar[B]): String =
    (proofA, proofB) match {
      case (YesItIsChar, YesItIsChar) => // A and B are both Char!
        fa.toList.map(f).mkString
    }
}

GADTs are an example of Bushnell’s law. As you can see, they are easy to learn but can be used in very tricky situations which makes them hard to master. They are clearly very helpful in many situations but it seems they are still unfortunately very little used. Haskell supports them very well! Scala’s support is actually very good but not as good as Haskell’s. Scala 3 will probably support them as well as Haskell since Dotty’s support is excellent. The only two other mainstream languages i know supporting them are OCaml and Haxe. Even if those two have a very good support, their lack of Higer-Kinded types forbids the most interesting uses.

As you probably know, it is possible to define a fold functor for every Algebraic Data Type. It is also possible to define fold functions for every GADT. As an exercise, try to define fold functions for the following GADTs:

  • This GADT encode the equality between two types A and B:

    sealed abstract class Eq[A,B]
    final case class Refl[A]() extends Eq[A,A]
  • This GADT represent an unknown type for which we have an instance of a type-class:

    sealed abstract class Ex[TypeClass[_]]
    final case class MakeEx[TypeClass[_],A](value:A, instance: TypeClass[A]) extends Ex[TypeClass]

You’ll find how to define such fold functions here. Have fun and spread the love of GADTs everywhere :)

Let's meet the charming fold family

Today we will meet an amazing family: the fold functions!

The well known foldRight

Lists is one of the first data structure every developer/computer scientist meet in her/his journey into programming:

sealed abstract class List[+A]
final case object Nil                              extends List[Nothing]
final case class  Cons[+A](head: A, tail: List[A]) extends List[A]

It means means values of type List[A] can be of (only) two forms:

  • either Nil
  • or Cons(head, tail) for some values head of type A and tail of type List[A]

For example we can define the following lists:

val empty : List[Int] = Nil
val l1 : List[Int] = Cons(61, Nil)
val l2 : List[Int] = Cons(34, Cons(61, Nil))
val l3 : List[String] = Cons("a", Cons("b", Cons("c", Nil)))

In addition, Nil and Cons can be seen as constants and functions returning List[A]:

def nil[A]: List[A] = Nil
def cons[A](head: A, tail: List[A]): Lis[A] = Cons(head, tail)

The fold function, often called foldRight, answers the question:

What would have happened if, instead of having used Nil and Cons in the construction of a list l:List[A], we would have used another constant z:T and another function f:(A, T) => T for some type T?

Let’s illustrate this using the previous examples:

val empty : Int = 0 // z = 0
val v1 : Int = max(61, 0) // z = 0, f = max
val v2 : Int = mult(34, mult(61, 1)) // z = 1, f = mult
val v3 : String = concat("a", concat("b", concat("c", "")))
  -- z = "", f = concat

The definition of foldRight illustrates well the transformation process. It deconstructs the list l:List[A] and replace Nil by z and Cons by f:

def foldList[A,T](z: T, f: (A,T) => T): List[A] => T = {
  def transform(l: List[A]): T =
    l match {
      case Nil => z
      case Cons(head, tail) =>
        val transformedTail = transform(tail)
        f(head, transformedTail)
    }
  
  transform _
}

The simple cases: Enum Types

fold functions can be defined for a wide range of data structures. As a first example, let’s take this type:

sealed abstract class SingletonType
final case object SingleValue extends SingletonType

The type SingletonType admits one and only one value: SingleValue. Folding over SingletonType means, replacing SingleValue by a constant z:T for some type T :

def foldSingletonType[T](z:T): SingletonType => T = {
  def transform(v: SingletonType): T =
    v match {
      case SingleValue => z
    }

  transform _
}

While SingletonType has only one value, the type Boolean have exactly two values True and False:

sealed abstract class Boolean
final case object True  extends Boolean
final case object False extends Boolean

So folding over Booleans mean, given a type T and two constants tt:T and ff:T, replacing True by tt and False by ff:

def foldBoolean[T](tt:T, ff:T): Boolean => T = {
  def transform(v: Boolean): T =
    v match {
      case True  => tt
      case False => ff
    }

  transform _
}

And so on for every enum type.

Beyond enums

You may start the see general process. If values of type C are build using constructors (Nil and Cons[A] for List[A], SingleValue for SingletonType, True and False for Boolean), then folding is all about transforming values of type C into another type T by replacing each constructor of C by a constant or function on T of the same shape. Let’s consider the type Either[A,B]:

sealed abstract class Either[A,B]
final case class Left[A,B](value: A)  extends Either[A,B]
final case class Right[A,B](value: B) extends Either[A,B]

To transform values of type Either[A,B] into T we need two functions on T:

  • Left being of type A => Either[A,B] we need a function f: A => T.
  • Right being of type B => Either[A,B] we need a function g: B => T.

Then we can operate the transformation:

def foldEither[A,B,T](f: A => T, g: B => T): Either[A,B] => T = {
  def transform(v: Either[A,B]): T =
    v match {
      case Left(a)  => f(a)
      case Right(b) => g(b)
    }

  transform _
}

Recursive Types

Folding over recursive types obey the previous rules. Recursion is handled by transforming sub-terms first. Let’s consider the type of binary trees:

sealed abstract class Tree[+A]
final case object Empty extends Tree[Nothing]
final case class  Node[+A](value:A, left: Tree[A], right: Tree[A]) extends Tree[A]

To transform values of type Tree[A] into T we need:

  • Empty being a constant of type Tree[A], we need a constant z:T.
  • Node being a function of type (A, Tree[A], Tree[A]) => Tree[A] we need a function f: (A, T, T) => T. Note how all occurrences of Tree[A] have been replaced by T in the type.

Then we can operate the transformation:

def foldTree[A,T](z: T, f: (A, T, T) => T): Tree[A] => T = {
  def transform(v: Tree[A]): T =
    v match {
      case Empty => z
      case Node(a,l,r) =>
        val g: T = transform(l) // Transforming sub-term l
        val d: T = transform(r) // Transforming sub-term r
        f(a,g,d)
    }

  transform _
}

Generalized Algebraic Data Types (GADT)

Instead of giving a formal definition of what Generalized Algebraic Data Types i will show you some examples.

Type Equalities

Consider the type:

sealed abstract class EmptyOrSingleton[A]
final case object SingleValueIfAisInt extends EmptyOrSingleton[Int]

This type looks very similar to SingletonType but, while SingleValue was always a value of SingletonType, SingleValueIfAisInt is only a value of EmptyOrSingleton[Int], i.e. when A is Int. So what happens to EmptyOrSingleton[A] when A is not Int? Then there is no constructor for EmptyOrSingleton[A] so no value for SingletonIfInt[A] (excluding null which we will pretend no to exist).

GADTs are very useful to encode predicates over types. Imagine you have a value v:EmptyOrSingleton[A] for some type A (remember we pretend null does not exist). What could you say about A? The only way to get a value of type EmptyOrSingleton[A] is through SingleValueIfAisInt. Thus v is SingleValueIfAisInt which is of type EmptyOrSingleton[Int] so is v. We can conclude that A is actually Int. Not convinced? Let A be String, can you build a value of type EmptyOrSingleton[String] without using null? Try it.

To find how to fold EmptyOrSingleton[A] into T, let’s apply the technique we used in the previous sections. EmptyOrSingleton[A] has only one constructor, SingleValueIfAisInt, so we need a constant z:T. But SingleValueIfAisInt is not of type EmptyOrSingleton[A] but EmptyOrSingleton[Int]. The argument A matters so let T depend on A: we want to transform values of type EmptyOrSingleton[A] into T[A].

  • SingleValueIfAisInt being of type EmptyOrSingleton[Int] we need a constant z:T[Int]

Then we can operate the transformation:

def foldEmptyOrSingleton[A, T[_]](z: T[Int]): EmptyOrSingleton[A] => T[A] = {
  def transform(v: EmptyOrSingleton[A]): T[A] =
    v match {
      case SingleValueIfAisInt => z // Because we know A = Int
    }

  transform _
}

foldEmptyOrSingleton means that, for some T[_], if you have a value z:T[Int] then you can transform any value EmptyOrSingleton[A] into T[A]. For example, let’s take

type T[X] = X =:= Int
val z:T[Int] = implicitly[Int =:= Int]

Then foldEmptyOrSingleton[A,T](z) gives us, for any value v:EmptyOrSingleton[A] a proof that A =:= Int. Another important use case is asserting type equality:

sealed abstract class Eq[A,B]
final case class Refl[X]() extends Eq[X,X]

Any non-null value v:Eq[A,B] must be a Refl[X]() : Eq[X,X] for some X, then Eq[A,B] = Eq[X,X] proving that A = X = B. To transform a value of type Eq[A,B] into T[A,B] we need:

  • Refl[X]() is essentially a constant of type Eq[X,X] for all type X (note: Scala write this type [X]Eq[X,X]). We need a constant z:T[X,X] for all type X (so the type [X]T[X,X]). Scala does not support transparent higher-ranked types, we need to emulate them with a trait:
trait ElimRefl[T[_,_]] {
  def apply[X]: T[X,X]
}

Then we could have hoped to be able to operate the transformation like previous section. But given a value v:Eq[A,B], convincing Scala that A = B is a bit tough. Instead we can write the fold as a method:

sealed abstract class Eq[A,B] {
  def fold[T[_,_]](z: ElimRefl[T]): T[A,B]
}
final case class Refl[X]() extends Eq[X,X] {
  def fold[T[_,_]](z: ElimRefl[T]): T[X,X] = z[X]
}

def foldEq[A, B, T[_,_]](z: ElimRefl[T]): Eq[A,B] => T[A,B] =
  (v:Eq[A,B]) => v.fold[T](z)

Ingenious definition of T[_,_] leads to interesting results:

trait C[X]

type T1[A,B] = C[A] =:= C[B]

val z1: ElimRefl[T1] =
  new ElimRefl[T1] {
    def apply[X]: T1[X,X] = implicitly[C[X] =:= C[X]]
  }

def transform[A,B]: Eq[A,B] => C[A] =:= C[B] =
  foldEq[A,B,T1](z1)

Existential Quantification

GADTs not only provide useful type equalities, they also offer existential quantification!

sealed abstract class Ex[F[_]] {
  type hidden
  val value: hidden
  val evidence: F[hidden]
}
final case class MakeEx[F[_],A](value: A, evidence: F[A]) extends Ex[F] {
  type hidden = A
}

Any value v:Ex[F] has to be an instance of MakeEx[F,A] for some type A. Which means we have a value, v.value, of type A and an instance of the type-class F for A (for example an instance of Monoid[A] with F[X] = Monoid[X]).

To transform values of type Ex[F] into T we need:

  • MakeEx[F[_],?] being of type [A](A, F[A]) => Ex[F] meaning: For_all_type A, (A, F[A]) => Ex[F], we need a function f of type [A](A, F[A]) => T. Scala still does not support transparent higher ranked types, we need to emulate them with another trait:
trait ElimMakeEx[F[_],T] {
  def apply[A](value: A, evidence: F[A]): T
}

Then we can operate the transformation:

def foldEx[F[_], T](f: ElimMakeEx[F, T]): Ex[F] => T = {
  def transform(v: Ex[F]): T =
    v match {
      case w@MakeEx(value, evidence) => f[w.hidden](value, evidence)
    }

  transform _
}

Duality

In this post we have deduced the fold functions from the definition of each type. It is possible to do the opposite: each constructor can be derived from the fold function of its type. For example:

trait List[+A] {
  def fold[T](z:T, f: (A,T) => T): T
}

def nil[A]: List[A] =
  new List[A] {
    def fold[T](z:T, f: (A,T) => T): T = z
  }

def cons[A](head:A, tail: List[A]): List[A] =
  new List[A] {
   def fold[T](z:T, f: (A,T) => T): T = f(head, tail.fold(z,f))
  }

def equality[A](l1: List[A], l2:List[A]): Boolean = ??? // Difficult but worthy exercice

Conclusion

I hope i convinced you folds are immensely useful. First, they let us write simply complex transform functions. But this not the most interesting property. It is sometimes easier to define a type by its fold function. Java, for example, does not have support for neither sealed classes nor pattern-matching. How could we define the List type so that Nil and Cons are the two only cases? The fold function forces any instance of List to fit into the desired shape (if some rules are obeyed like no null and no runtime-reflection). It can also happen that type-inference is not smart enough, fold function provide an alternative way which is often easier for the Scala type-checker.

How to make game in the browser thanks to ScalaJS

A few month ago, the Paris Scala User Group kindly invited me present a workshop introducing [Scala.js(https://www.scala-js.org/). Even better, i had the chance to present it at ScalaIO 2018. The workshop materials are here.

I will present how to develop a web application in Scala.js. This talk is for Scala developers having a penchant for web development but rebuffed by JavaScript. It goes from ScalaJS basics to the implementation of a naive virtual DOM written in Scala. It presents:

  • how to setup Sbt for cross compilation
  • what is the DOM and how to manipulate it
  • events and their propagation
  • the Model/View/Update architecture (a.k.a the ELM architecture)

The final result can be experimented with at slimetrail. The english material for the ScalaIO workshop are not yet available but the ones for the PSUG workshop, in french are here.

The Application

JSON to XML: the probably a tiny bit over engineered way

The complete code of the article. You need Cats and Play-Json in order to run it.

The Slides are here

It happens regularly in software development that we have to connect systems speaking different languages. JSON is nowadays ubiquitous in service communication, especially in web development but XML still has its fair amount of bastions. Imagine you need to pass information provided by a JSON API through an XML layer, you need a converter.

The easy way

This translation is actually pretty trivial, it takes essentially 6 lines of simple pattern-matching code in Scala:

import play.api.libs.json._
import scala.xml._

def json2xml(json: JsValue, rootLabel: String): Elem = {
  // XML node creation helper
  def mkElem(jsType: String, children: Node*): Elem =
    Elem(null, rootLabel,
         new UnprefixedAttribute("type", jsType, scala.xml.Null),
         TopScope, true, children: _*
        )

  // The real translation
  json match {
    case JsNull =>
      mkElem("null")

    case JsString(s) =>
      mkElem("string", PCData(s))

    case JsNumber(n) =>
      mkElem("number", Text(n.toString))

    case JsBoolean(b) =>
      mkElem("boolean", Text(b.toString))

    case JsArray(l) =>
      mkElem("array", l.map(json2xml(_, s"${rootLabel}Item")):_*)

    case JsObject(m) =>
      mkElem("object", m.toList.map { case (k,v) => json2xml(v, k) }: _*)
  }
}

The trickiest part of this example is figuring out how to build XML nodes in Scala. It translates the following JSON:

[
  { "title": "2001 : A Space Odyssey",
    "release":
      { "day": 27,
        "month": 9,
        "year": 1968
      },
    "genres" : [ "Science fiction" ],
    "actors": [
      { "lastName": "Dullea",
        "firstName": "Keir",
        "role": "Dr. David Bowman"
      }
    ],
    "directors": [
      { "lastName": "Kubrick",
        "firstName": "Stanley"
      }
    ]
  }
]

into

<films type="array">
  <filmsItem type="object">
    <title type="string"><![CDATA[2001 : A Space Odyssey]]></title>
    <release type="object">
      <day type="number">27</day>
      <month type="number">9</month>
      <year type="number">1968</year>
    </release>
    <genres type="array">
      <genresItem type="string"><![CDATA[Science fiction]]></genresItem>
    </genres>
    <actors type="array">
      <actorsItem type="object">
        <lastName type="string"><![CDATA[Dullea]]></lastName>
        <firstName type="string"><![CDATA[Keir]]></firstName>
        <role type="string"><![CDATA[Dr. David Bowman]]></role>
      </actorsItem>
    </actors>
    <directors type="array">
      <directorsItem type="object">
        <lastName type="string"><![CDATA[Kubrick]]></lastName>
        <firstName type="string"><![CDATA[Stanley]]></firstName>
      </directorsItem>
    </directors>
  </filmsItem>
</films>

Note that, unlike JSON, XML have no notion of booleans, number or null, so we add type information as attribute on each node. This has the benefit of enabling us to convert such XML back to their former JSON form. Also note that, we need CDATA sections to preserve spaces.

Problem solved? Yes! But we can go much much further on this subject…

The Rocket Science way

There much more thing to say about this example, first let’s expose some properties of JSON values.

Inviting (Co)Algebras to the Party

JSON values can be modelled with an Algebraic Data Type or ADT for short. Play-Json represents them by the type JsValue:

sealed abstract class JsValue
final case object JsNull extends JsValue
final case class JsNumber(value: BigDecimal) extends JsValue
final case class JsBoolean(value: Boolean) extends JsValue
final case class JsString(value: String) extends JsValue
final case class JsArray(value: List[JsValue]) extends JsValue
final case class JsObject(value: Map[String, JsValue]) extends JsValue

But in order to simplify the presentation, we will use slightly different, but equivalent, definition of JSON values:

sealed abstract class Atomic
final case object Null extends Atomic
final case class Bool(value: Boolean) extends Atomic
final case class Number(value: BigDecimal) extends Atomic
final case class Str(value: String) extends Atomic

sealed abstract class JsValue
final case class JsAtom(value: Atomic) extends JsValue
final case class JsArray(value: List[JsValue]) extends JsValue
final case class JsObject(value: Map[String, JsValue]) extends JsValue

Like in any Algebraic Data Type, the constructors of JsValues can be seen as operations on it. JsAtom informs us that every number, boolean, string and null give rise to a distinct JSON value. JsArray and JsObject tells us that each (qualified) list of JSON values forms a distinct JSON value itself. Considering that JSON values are defined in terms of these operations, and that we want to translate JSON into XML, it would make sense to define them on XML as well. First, let’s explicit these operations:

sealed abstract class JsLike[+R]
final case class Atom(value: Atomic) extends JsLike[Nothing]
final case class Arr[+R](value: List[R]) extends JsLike[R]
final case class Obj[+R](value: Map[String, R]) extends JsLike[R]

The interesting point here is we can translate back and forth between JsValue and JsLike[JsValue]. These translations are even the inverse of each other, meaning both types are totally equivalent!

val jsLike2JsValue: JsLike[JsValue] => JsValue = {
  case Atom(a)         => JsAtom(a)
  case Arr(a)          => JsArray(a)
  case Obj(m)          => JsObject(m)
}

val jsValue2JsLike: JsValue => JsLike[JsValue] = {
  case JsAtom(a)    => Atom(a)
  case JsArray(a)   => Arr(a.toList)
  case JsObject(m)  => Obj(m)
}

jsLike2JsValue is called a JsLike-Algebra because it has the form JsLike[X] => X. It means jsLike2JsValue is a way “compute” JsLike operation, i.e. it composes values to form new ones. On the opposite, jsValue2JsLike is called a JsLike-CoAlgebra because it has the form X => JsLike[X]. It is a way to expose how a value is built, i.e. it deconstructs values to expose their structure.

Can we find such functions for XML values? We are looking for two functions:

val jsLike2Elem: JsLike[Elem] => Elem = ???
val elem2JsLike: Elem => JsLike[Elem] = ???

It would certainly be nice, but unfortunately this is not that simple! 5, true and null are valid JSON values, So jsLike2Elem(Atom(Number(5))), jsLike2Elem(Atom(Bool(true))) and jsLike2Elem(Atom(Null))) should be valid XML value! But what should be the root tag of the resulting elements? How to translate 5 into a valid XML? We know that it would have the form:

<someRootTag type="number">5</someRootTag>

But what someRootTag should be? We could pick an arbitrary one, but it would break composability (try it, you’ll see!). There’s no escape, all XML values need tags but not every JSON value have some! The situation suggest JSON values are closer to “XML values with unknown root tags” <X type="number">5</X> where X as the unknown, i.e. the functional space String => Elem:

val _5: String => Elem =
  (someRootTag: String) => <someRootTag type="number">5</someRootTag>

Do you think we can define meaningful functions?

val jsLike2xml: JsLike[String => Elem] => (String => Elem) = ???
val xml2JsLike: (String => Elem) => JsLike[String => Elem] = ???

Yes we can … partially. We can define jsLike2xml:

val jsLike2xml: JsLike[String => Elem] => (String => Elem) = {
  def mkRoot(jsType: String, children: Node*): String => Elem =
    (someRootTag: String) =>
      Elem(null,
           someRootTag,
           new UnprefixedAttribute("type", jsType, scala.xml.Null),
           TopScope,
           true,
           children: _*
          )

  (j: JsLike[String => Elem]) =>
    j match {
      case Atom(Null)   =>
        mkRoot("null")

      case Atom(Str(s)) =>
        mkRoot("string", PCData(s))

      case Atom(Bool(b)) =>
        mkRoot("boolean", Text(b.toString))

      case Atom(Number(n)) =>
        mkRoot("number", Text(n.toString))
        
      case Arr(a) =>
        (root: String) => {
          mkRoot("array", a.map(_(s"${root}Item")): _*)(root)
        }

      case Obj(m) =>
        mkRoot("object", m.toList.map { case (k, v) => v(k) }: _*)
    }
}

but for xml2JsLike, we’re facing two not-that-small issues:

  • First, unlikejsValue2JsLike, we can not pattern-match on functions. We have no sane way to know that

    (someRootTag: String) => <someRootTag type="number">5</someRootTag>

    is built from Atom(Number(5)).

  • Even if we could pattern-match on functions, jsLike2xml is not surjective, i.e. not every XML element is the result of jsLike2xml(f) for some f. To deal with invalid input, the return type of xml2JsLike can not be JsLike[String => Elem] but F[JsLike[String => Elem]] for some functor F able to deal with errors like Option, Either, etc. For simplicity’s sake, let’s consider F to be Option.

Let’s once again take a step back. We want to decompose a function (f: String => Elem) into an Option[JsLike[String => Elem]] without pattern-matching it. The only reasonable thing we can do with functions is pass them some arguments:

def xml2JsLike(f: (String => Elem)): String => Option[JsLike[Elem => String]] =
  (someRootTag: String) => ... f(someRootTag) ...

The type String => Option[A] is actually a monad, known as a ReaderT[Option, String, A]. Which makes xml2JsLike a monadic coalgebra. Let’s give it a name:

import cats.data.ReaderT

type TagOpt[A] = ReaderT[Option, String, A]

As an exercise try to implement xml2JsLike. *To that end, it may be useful to notice that JsLike is a Traverse, i.e. that an instance of Traverse[JsLike] can be defined. Such an instance defines a function:

def `traverse[G[_]: Applicative, A, B](ja: JsLike[A])(f: A => G[B]): G[JsLike[B]]`

To summarize this part, we have these four functions:

val jsLike2JsValue: JsLike[JsValue]    => JsValue
val jsValue2JsLike: JsValue            => JsLike[JsValue]

val jsLike2xml: JsLike[String => Elem] => (String => Elem)
val xml2JsLike: (String => Elem)       => TagOpt[JsLike[String => Elem]]

Now we want to convert JsValue from/into String => Elem.

Converting back and forth

Now that we know how to compose and decompose both JSON and XML values. How do we write converters? For simplify’s sake, let’s be a bit more abstract. Let A and B be to types (like JsValue and String => Elem) and F[_] a type constructor (like JsLike) that have the nice property of being a functor (i.e. it has function map: F[A] => (A => B) => F[B]). In addition, let decomposeA: A => F[A] and recomposeB: F[B] => B (like jsValue2JsLike and jsLike2xml). We want a function convert: A => B:

trait Direct {
  import cats.Functor
  import cats.syntax.functor._

  type A
  type B
  type F[_]

  implicit val fHasMap: Functor[F]

  val decomposeA: A    => F[A]
  val recomposeB: F[B] => B

  final def convert(a: A): B = {
    val fa: F[A] = decomposeA(a)
    val fb: F[B] = fa.map(convert)
    recomposeB(fb): B
  }
}

Or in a more compact way:

def hylo[A,B, F[_]: Functor](decompose: A => F[A], recompose: F[B] => B): A => B = {
  def convert(a: A): B =
    recompose(decompose(a).map(convert))

  convert _
}

And voila, a converter in just 1 lines of code:

def json2xml(json: JsValue): String => Elem =
  hylo(jsValue2JsLike, jsLike2xml).apply(json)

The way back is only a bit more involving. This time we require F to be Traverse and the function decomposeA to be of type A => M[F[A]] for some monad M:

trait WayBack {
  type A
  type B

  type F[_]
  implicit val fHasTraverse: Traverse[F]

  type M[_]
  implicit val mIsAMonad: Monad[M]

  val decomposeA: A    => M[F[A]]
  val recomposeB: F[B] => B

  final def convert(a: A): M[B] =
    for {
      fa <- decomposeA(a)
      fb <- fa.traverse(convert)
    } yield recomposeB(fb)
}

Again, in a more compact way:

def hyloish[A,B, F[_]: Traverse, M[_]: Monad](decompose: A => M[F[A]], recompose: F[B] => B): A => M[B] = {
  def convert(a: A): M[B] =
    for {
    fa <- decompose(a)
    fb <- fa.traverse(convert)
  } yield recompose(fb)

  convert _
}

which gives the way back as the oneliner:

def xml2json(f: String => Elem): TagOpt[JsValue] =
    hyloish(xml2JsLike, jsLike2JsValue).apply(f)

Reorganizing a bit, it leads to the two conversion functions between (String, JsValue) and Elem:

val json2xmlBetter: ((String, JsValue)) => Elem =
  (jsonPlusTag: (String, JsValue)) => json2xml(jsonPlusTag._2)(jsonPlusTag._1)

val xml2jsonBetter: Elem => TagOpt[(String, JsValue)] =
  (e: Elem) => xml2json((s: String) => e.copy(label = s)).map(e.label -> _)

What’s the point?

Apart from being so much more complicated that the trivial approach, is there some benefits? Actually yes.

  • Firstly, given n formats, there are converters. Writing and testing functions is a lot of tedious and error-prone work. But if you find some common operations F[_], you only need 2n functions (one X => F[X] and one F[X] => X for each format X) to achieve the same goal. Furthermore, each of those functions will be easier to test, which is not to neglect.
  • Secondly, algebras (functions X => F[X]) and coalgebras (functions F[X] => X) operate one level at a time. They enable to treat format X as if it was an algebraic data type over operations F. Pattern-matching is such a nice feature!
  • Thirdly, you can write generic functions taking any type X for which you can provide functions X => F[X] and F[X] => X. These functions also have higher chances of being correct because there is less space for unexpected behaviour.

If want to dive deeper in this subject, you can look at Matryoshka, read Functional programming with bananas, lenses, envelopes and barbed wire or any resource on F-Algebras and recursion schemes.

Solution to exercises

JsLike instance for Traverse

implicit val jsLikeInstances: Traverse[JsLike] =
  new Traverse[JsLike] {
    import cats.Eval

    def traverse[G[_], A, B](fa: JsLike[A])(f: A => G[B])(
        implicit G: Applicative[G]): G[JsLike[B]] =
      fa match {
        case Atom(a) => G.point(Atom(a))
        case Arr(a)  => a.traverse[G, B](f).map(Arr(_))
        case Obj(m) =>
          m.toList
            .traverse[G, (String, B)] { case (s, a) => f(a).map(s -> _) }
            .map(i => Obj(i.toMap))
      }

    def foldLeft[A, B](fa: JsLike[A], b: B)(f: (B, A) => B): B =
      fa match {
        case Atom(_) => b
        case Arr(a)  => a.foldLeft(b)(f)
        case Obj(m)  => m.values.foldLeft(b)(f)
      }

    def foldRight[A, B](fa: JsLike[A], lb: Eval[B])(
        f: (A, Eval[B]) => Eval[B]): Eval[B] =
      fa match {
        case Atom(_) => lb
        case Arr(a)  => a.foldRight(lb)(f)
        case Obj(m)  => m.values.foldRight(lb)(f)
      }
  }

xml2JsLike

def xml2JsLike(f: String => Elem): TagOpt[JsLike[String => Elem]] =
  ReaderT[Option, String, JsLike[String => Elem]] { (s: String) =>
    val elem: Elem = f(s)

    elem
      .attributes
      .asAttrMap
      .get("type")
      .flatMap[JsLike[String => Elem]] {
        case "null" =>
          Some(Atom(Null))

        case "boolean" =>
          elem.text match {
            case "true"  => Some(Atom(Bool(true)))
            case "false" => Some(Atom(Bool(false)))
            case _       => None
          }

        case "number" =>
          import scala.util.Try

          Try(BigDecimal(elem.text))
            .toOption
            .map(n => Atom(Number(n)))

        case "string" =>
          Some(Atom(Str(elem.text)))

        case "array" =>
          Some(Arr(
            elem
              .child
              .toList
              .flatMap {
                case e: Elem => List((s: String) => e.copy(label = s))
                case _       => Nil
              }
          ))

        case "object" =>
          Some(Obj(
            elem
              .child
              .toList
              .flatMap {
                case e: Elem => List(e.label -> ((s: String) => e.copy(label = s)))
                case _       => Nil
              }.toMap
          ))

        case _ =>
          None
    }
  }

F-Algebra talk at ScalaIO 2017: Modéliser astucieusement vos données

I had the chance to present a talk about F-Algebras at ScalaIO 2017.

The Video

The Slides