Demystifying GADTs
Generalized Algebraic Data Types (GADT) is certainly one of the most feared concept in programming nowadays. Very few mainstream languages support GADTs. The only ones i know which does are Haskell, Scala, OCaml and Haxe. The idea is actually very simple but often presented in complicated ways. In fact, if you’re familiar to both basic Object-Oriented-with-Generics and basic functional programming, then you most probably are already familiar with GADTs without even knowing you are. But if GADTs are so simple, why so many people feel terrified by them? Well GADTs rely on two fundamental ideas, one of them is known by every Object-Oriented-with-Generics programmer while the other is known by every functional programmer. The problem is most people make the huge mistake of opposing them even though they are complementary. So before diving into GADTs, let me remind you of these elementary notions from Object-Oriented and functional programming.
Object-Oriented Programming 101
Let’s start by some plain old Java (the examples works in probably all Object-Oriented language which supports generics). We want to define an abstract class for sequences:
public abstract class Sequence<A> {
abstract public int length();
abstract public A getNthElement(int nth);
}
In Java it would be better to define Sequence<A>
as an interface but i want this example to be as simple as possible. Would you be surprised if told you a String
is a sequence of characters ? ;) As i said, GADTs rely on very basic programming knowledge.
public class MyString extends Sequence<Character> {
private String str;
public MyString(String s) {
this.str = s;
}
public int length() {
return this.str.length();
}
public Character getNthElement(int nth) {
return this.str.charAt(nth);
}
}
Likewise, byte
s are sequences of 8 bits (we represent a bit by a Boolean
):
public final class MyByte extends Sequence<Boolean> {
private byte bte;
public MyByte(byte x) {
this.bte = x;
}
public int length() {
return 8;
}
public Boolean getNthElement(int nth) {
if (nth >= 0 && nth <= 7)
return ((bte >>> nth & 1) == 1);
else
throw new java.lang.IndexOutOfBoundsException("");
}
}
Have you noticed how MyByte
and MyString
declares themselves being respectively a sequence of booleans (Sequence<Boolean>
) and a sequence of characters (Sequence<Character>
) but not sequences of A
(Sequence<A>
) for any type A
? Let’s try to make it work for any type A
:
public final class MyByte<A> extends Sequence<A> {
private byte bte;
public MyByte(byte x) {
this.bte = x;
}
public int length() {
???
}
public A getNthElement(int nth) {
???
}
}
How would you write the methods length
and getNthElement
? Do you really imagine what would be a MyByte<Graphics2D>
? It just doesn’t make any sense at all. You could argue that a string is also a sequence of byte
and a byte a sequence of one byte. Indeed this relation is not unique, but it does not change the fact that it works for only a small selection of type A
and not every one! We can go even deeper in Object-Oriented Programming:
public final class MyArray<A extends Number> extends Sequence<Number> {
private A[] array;
public MyArray(A[] a) {
this.array = a;
}
public int length() {
return this.array.length;
}
public Number getNthElement(int nth) {
return this.array[nth];
}
}
Note how the generics A
, which is required to be a sub-class of Number
, is present as argument of MyArray
but not in extends Sequence<Number>
. Now what do you think about this code? Do you think it can be wrong?
public static <A> void guess(Sequence<A> x) {
if (x instanceof MyByte) {
System.out.println("I guess A is actually Boolean, let's check!");
System.out.println(((Sequence<Boolean>)x).getNthElement(0).getClass().getName());
} else
if (x instanceof MyString) {
System.out.println("I guess A is actually Character");
System.out.println(((Sequence<Character>)x).getNthElement(0).getClass().getName());
} else
if (x instanceof MyArray<?>) {
System.out.println("I guess A is a sub-class of Number but i can not guess which one");
System.out.println(((Sequence<?>) x).getNthElement(0).getClass().getName());
}
else
System.out.println("I don't know what A is");
}
- If
x
is an instance ofMyByte
, which is a sub-class ofSequence<Boolean>
, then by trivial inheritancex
is also an instance ofSequence<Boolean>
. In this caseA
is forced to beBoolean
. - If
x
is an instance ofMyString
, which is a sub-class ofSequence<Character>
, then again by trivial inheritancex
is also an instance ofSequence<Character>
. In this caseA
has to beCharacter
. - If
x
is an instance ofMyArray<A>
for some typeA
, which is a sub-class ofSequence<Number>
, then once again by trivial inheritancex
is an instance ofSequence<Number>
. In this case we knowA
is a sub-class ofNumber
but we don’t know which one.
This is the essence of Generalized Algebraic Data Types. It you understand the code above, then you understand how GADTs work. As you see this is very basic Oriented-Object with Generics. You can find lots of examples of this kind in almost every Java/C#/etc project (search for the instanceof
keyword).
Functional Programming 101
Functional languages often support a feature called Algebraic Data Types (ADT) which is essentially enumerations on steroids. Like enumerations this is a disjoint union of a fixed number of cases but unlike enumerations, where each case is a constant, ADTs cases can have parameters. As an example, the type of lists whose elements are of type a
, written List a
in Haskell, is defined:
data List a = Nil | Cons a (List a)
It means any value of type List a
belong to exactly one of the following cases:
- either the value is the constant
Nil
which represents the empty list. - or the value is
Cons hd tl
which represent the list whose first element ishd
(of typea
) and whose tail istl
(of typeList a
).
The list [1,2,3,4]
is encoded by Cons 1 (Cons 2 (Cons 3 (Cons 4 Nil)))
. Where ADTs really shine is pattern-matching which is a very powerful and flexible switch
(as i said above, ADTs are enumerations on steroids). ADTs being made of a fixed number of distinct cases, pattern-matching enable to inspect values and perform computations based on a case by case analysis of the form of the value. Here is how to implement the merge sort algorithm on this type:
split :: List a -> (List a, List a)
split l =
case l of
Cons x (Cons y tl) -> (case split tl of
(l1, l2) -> (Cons x l1, Cons y l2)
)
Cons _ Nil -> (l , Nil)
Nil -> (Nil, Nil)
merge :: (a -> a -> Bool) -> List a -> List a -> List a
merge isLessThan l1 l2 =
case (l1, l2) of
(_ , Nil ) -> l1
(Nil , _ ) -> l2
(Cons hd1 tl1, Cons hd2 _ ) | hd1 `isLessThan` hd2 -> Cons hd1 (merge isLessThan tl1 l2)
(_ , Cons hd2 tl2) -> Cons hd2 (merge isLessThan l1 tl2)
sort :: (a -> a -> Bool) -> List a -> List a
sort isLessThan l =
case l of
Nil -> Nil
Cons _ Nil -> l
_ -> case split l of
(l1, l2) -> merge isLessThan (sort isLessThan l1) (sort isLessThan l2)
I know there are smarter ways to write it in Haskell but this article is not about it. The code above could be translated trivially in OCaml
by replacing case ... of
by match ... with
, in Scala by ... match { ... }
, etc. This style is valid is probably all languages supporting pattern-matching so it fits our goal.
The case l of
expressions are pattern-matching. They are a sequence of pattern | condition -> code
. The code being executed is the right-hand side of the first case for which the value l
is of the form of its pattern and satisfy the condition. l
is then said to match this case. For example, the case Cons x (Cons y tl) -> (case split tl of (l1, l2) -> (Cons x l1, Cons y l2))
states that if l
is of the form Cons x (Cons y tl)
, which means that there are three values x
, y
and tl
such that l == Cons x (Cons y tl)
, then the code executed is (case split tl of (l1, l2) -> (Cons x l1, Cons y l2))
. One very important condition is that pattern-matching must be exhaustive! It means that the sequence of cases must cover all possible value of l
.
If your understand the previous section, the type List a
and how pattern-matching works in the example above, then i am very glad to inform you that you already understand GADTs! Well done :)
Summing up!
In this section i assume previous sections are ok for you. If you do not understand previous examples, don’t go further but go back to the basics of generics and pattern-matching. Likewise, if you find what follows complicated, go back to the basics generics and pattern-matching. There is no shame in doing so! Difficulties in understanding advanced notion is often the reflect of a lack of understanding in the ones they rely upon. As i said, there is no shame in it, if you think programming paradigms are “simple” then write a compiler ;)
It’s about time to sum up everything. First, note that List a
is not one single type. Each type a
actually gives rise to a distinct type List a
. For example List Int
, List String
, List (List Bool)
, etc are all distinct types. Indeed the list Cons 1 Nil
is neither a list of strings nor of booleans! For each type a
, the type List a
have two constructors: the constant Nil
and the function Cons :: a -> List a -> List a
which builds a List a
from a value of type a
and other List a
.
There is another equivalent way to define List a
in Haskell which makes the nature of the constructor more apparent:
data List a where
Nil :: List a
Cons :: a -> List a -> List a
Indeed, for each type a
, Nil
is constant of type List a
while Cons
is a function of type a -> List a -> List a
. Note that it is actually very close to the way to define it in Scala:
sealed abstract class List[A]
final case class Nil[A]() extends List[A]
final case class Cons[A](head: A, tail: List[A]) extends List[A]
Do you remember the example of the first section Sequence<A>
? There was three sub-classes of Sequence<A
: MyString
which is actually a sub-class of Sequence<Character>
, MyByte
which is a sub-class of Sequence<Boolean>
and MyArray<A extends Number>
which is a sub-class of Sequence<Number>
. What is the type of their constructors? Some admissible type for them is (in Scala notation):
def MyString : String => Sequence[Character]
def MyByte : Byte => Sequence[Boolean]
def MyArray[A <: Number] : Array[A] => Sequence[Number]
From this, this is trivial to write:
sealed abstract class Sequence[A]
final case class MyString(str: String) extends Sequence[Character]
final case class MyByte(bte: Byte) extends Sequence[Boolean]
final case class MyArray[A <: Number](array: Array[A]) extends Sequence[Number]
or in Haskell:
data Number where
MkNum :: forall a. Num a => Number
data Sequence a where
MyString :: String -> Sequence Char
MyByte :: Word8 -> Sequence Bool
MyArray :: forall a. Num a => List a -> Sequence Number
Sequence
is a GADT. What makes it different from List
above? For any type a
, values of type List a
are build using the two constructors Nil
and Cons
. Note that it does not depend on what a
is. Values of type List Int
are build using the exact same constructors than List Bool
, List String
, List (List Char)
, etc. Sequence
have three constructors MyString
, MyByte
and MyArray
. But values of type Sequence[Character]
can only be built by the constructor MyString
while values of type Sequence[Boolean]
can only be built by the constructor MyByte
and values of type Sequence[Number]
can only be built by the constructor MyArray
. What about values of type Sequence[Unit]
or Sequence[String]
, etc? There is simply no constructor to build values of these types, so there is no values of these types!
We can rewrite the methods on Sequence
and the guess
function to use patten-matching:
def length[A](x: Sequence[A]): Int =
x match {
case MyByte(_) => 8
case MyString(str) => str.length
case MyArray(arr) => arr.size
}
def getNthElement[A](x: Sequence[A], nth: Int): A =
x match {
case MyByte(bte) => // So A is actually Boolean
if (nth >= 0 && nth <= 7)
(bte >>> nth & 1) == 1
else
throw new java.lang.IndexOutOfBoundsException("")
case MyString(str) => // So A is actually Character
str.charAt(nth)
case MyArray(array) => // So A is actually a sub-class of Number
array(nth)
}
def guess[A](x : Sequence[A]): Unit =
x match {
case MyByte(bte) =>
println("I guess A is actually Boolean, let's check!")
println(getNthElement(x, 0).getClass.getName)
case MyString(str) =>
println("I guess A is actually Character")
println(getNthElement(x, 0).getClass.getName)
case MyArray(array) =>
println("I guess A is a sub-class of Number but i can not guess which one")
println(getNthElement(x, 0).getClass.getName)
}
As you can see getNthElement
must returns a value of type A
but the case MyByte
returns a Boolean
. It means Scala is aware that in this case A
is actually Boolean. Likewise in the case MyString
, Scala knowns that the only possible concrete type for A
is Character
so it accepts we return one. Scala is (most of the time) able to guess, depending on the case, what are the constraints on A
. This is all the magic behind GADTs: specialized constructors like in object-oriented-with-generics programming and closed types (i.e. with a fixed number of cases) on which we can pattern-match like in usual functional programming.
How are GADTs useful? First of all, there are handy when you have a specialized constructor like in every day life object-oriented programming. It makes sense for a byte (resp. string) to be sequence of booleans (resp. characters) but not a sequence of anything. A prolific use of this is writing implicits in Scala as GADTs. This way we can pattern-match on the structure of the implicits to derive instances (see this gist for more details). They are also very useful to encode properties on types. As i said above, not all types Sequence[A]
have (non-null) values! There is no (non-null) value of type Sequence[Unit]
or Sequence[String]
etc but there are values of type Sequence[Boolean]
, Sequence[Character]
and Sequence[Number]
. So if i give you a value of type Sequence[A]
, then you know A
is either Boolean
, Character
or Number
. If you don’t believe me, try to call the function guess
on a type A
which is neither Boolean
nor Character
nor Number
(without using null
)! Let me give you some useful examples.
The first one is restricting a generic type like in the code below. The GADT IsIntOrString
forces A
to be either String
or Int
in the function and the case class.
sealed abstract class IsIntOrString[A]
implicit final case object IsInt extends IsIntOrString[Int]
implicit final case object IsString extends IsIntOrString[String]
def canOnlyBeCalledOnIntOrString[A](a: A)(implicit ev: IsIntOrString[A]): A =
ev match {
case IsInt => // A is Int
a + 7
case IsString => // A is String
a.reverse
}
final case class AStringOrAnIntButNothingElse[A](value: A)(implicit val proof : IsIntOrString[A])
Another handy use is encoding effects:
trait UserId
trait User
sealed abstract class BusinessEffect[A]
final case class GetUser(userId: UserId) extends BusinessEffect[User]
final case class SetUser(user: User) extends BusinessEffect[UserId]
final case class DelUser(userId: UserId) extends BusinessEffect[Unit]
Have you ever heard that Set
is not a functor? With the usual definition of a functor, indeed Set
is not one.
trait Functor[F[_]] {
def map[A,B](fa: F[A])(f: A => B): F[B]
}
The reason is you can only have a Set[A]
for types A
such that you can compare values. As an example let A
be Int => Int
. The two following functions are arguably equal:
val doubleByMult: Int => Int = (x: Int) => 2 * x
val doubleByPlus: Int => Int = (x: Int) => x + x
scala> Set(doubleByMult).contains(doubleByPlus)
res0: Boolean = false
This is just impossible, in the general case, to know if two functions compute the same thing. I didn’t just say we don’t know how to do it. It is actually proven that this is impossible (like no one can, and no one could for ever!). Have a look at this List of undecidable problems for more information on the subject. Using extensional equality (the one where f == g
if and only f(x) == g(x)
for all x
), there is just no implementation of Set[Int => Int]
. But if Set
was a functor, it would be trivial using map
to get a Set[Int => Int]
:
Set[Boolean](true, false).map {
case true => doubleByMult
case false => doubleByPlus
}: Set[Int => Int]
The conclusion is that Set
is not a functor … in the usual (i.e. Scal) category. But it is in for some categories. The problem with Functor
is map
can be applied on any A
and B
which is impossible for Set
. But if we restrict A
and B
such that they have interesting properties (like having an Ordering
), then it works. In the code below, the GADT predicate
is used to restrict on which A
and B
map
can be applied on:
trait GenFunctor[predicate[_],F[_]] {
def map[A,B](fa: F[A])(f: A => B)(implicit proofA: predicate[A], proofB: predicate[B]): F[B]
}
Then Set
is becomes a functor with Ordering
as predicate
:
object SetInstance extends GenFunctor[Ordering, Set] {
def map[A,B](fa: Set[A])(f: A => B)(implicit orderingA: Ordering[A], orderingB: Ordering[B]): Set[B] = {
val set = TreeSet.newBuilder(orderingB)
for (a <- fa) set += f(a)
set.result
}
}
Surprisingly even String
can be a functor (with A
and B
being both Char
)!!!
sealed abstract class IsItChar[A]
implicit final case object YesItIsChar extends IsItChar[Char]
type StringK[A] = String
object StringInstance extends GenFunctor[IsItChar, StringK] {
def map[A,B](fa: String)(f: A => B)(implicit proofA: IsItChar[A], proofB: IsItChar[B]): String =
(proofA, proofB) match {
case (YesItIsChar, YesItIsChar) => // A and B are both Char!
fa.toList.map(f).mkString
}
}
GADTs are an example of Bushnell’s law. As you can see, they are easy to learn but can be used in very tricky situations which makes them hard to master. They are clearly very helpful in many situations but it seems they are still unfortunately very little used. Haskell supports them very well! Scala’s support is actually very good but not as good as Haskell’s. Scala 3 will probably support them as well as Haskell since Dotty’s support is excellent. The only two other mainstream languages i know supporting them are OCaml and Haxe. Even if those two have a very good support, their lack of Higer-Kinded types forbids the most interesting uses.
As you probably know, it is possible to define a fold
functor for every Algebraic Data Type. It is also possible to define fold
functions for every GADT. As an exercise, try to define fold functions for the following GADTs:
-
This GADT encode the equality between two types
A
andB
:sealed abstract class Eq[A,B] final case class Refl[A]() extends Eq[A,A]
-
This GADT represent an unknown type for which we have an instance of a type-class:
sealed abstract class Ex[TypeClass[_]] final case class MakeEx[TypeClass[_],A](value:A, instance: TypeClass[A]) extends Ex[TypeClass]
You’ll find how to define such fold
functions here. Have fun and spread the love of GADTs everywhere :)