Exploring Tagless Final pattern for extensive and readable Scala code

In this post I will try to share with you all a functional pattern I stumbled upon recently - Tagless Final. This pattern tries to address a vital problem for every software engineer: how to make sure the programs we write are correct? I will try to explain how Tagless Final works and how it can be applied in practice, while keeping things down to earth and as practical as possible. Of course, I didn’t invent it from scratch, but I would like to share what I’ve learned and maybe popularize this solution.

Kudos to Oleg Kiselyov for describing the pattern in depth and John De Goes for inspiring me to write this post.

Let’s get started.

Introducing Tagless Final

Tagless Final allows you to build a subset of the host language which is sound, typesafe and predictable. When designed properly this subset makes it easy to write correct programs and hard to write incorrect ones. In fact, invalid states can’t be expressed at all! Later, the solution written in the hosted language is safely run to return the value to the hosting language.

For us Scala will be our hosting language, but the pattern also applies to other ecosystems. Tagless Final can be seen as composite pattern, it builds on top of or is inspired by many other patterns like type classes or Free monads. So you might spot some similarities.

On a conceptual level our Scala implementation has 3 distinct parts:

  • Language, that defines a subset of operations that the hosted language allows
  • Bridges, helpers that express Scala values and business logic in the Language
  • Interpreters, dual to Bridges they run logic expressed as the Language and get the final value (This is not the official lingo, but I will be using this naming here to make explanations simpler.)

Let’s explore this with a simple example - basic math. You can find the whole code in the repo.

Language

Language is the heart of our DSL. It defines what can we do with the hosted language.

trait Language[Wrapper[_]] {

def number(v: Int): Wrapper[Int]

def increment(a: Wrapper[Int]): Wrapper[Int]

def add(a: Wrapper[Int], b: Wrapper[Int]): Wrapper[Int]

...

}

view rawpost.scala hosted with ❤ by GitHub

Above we defined few basic operations. The trait type definitely caught your eye. The Language is parameterized by a Wrapper type which itself has an internal type. Both type parameters will be useful for us later. Wrapper will allow us to change the “package” on which we operate, while the type of the Wrapper makes sure our API is typesafe.

Consider:

trait Language[Wrapper[_]] {

...

def text(v: String): Wrapper[String]

def toUpper(a: Wrapper[String]): Wrapper[String]

def concat(a: Wrapper[String], b: Wrapper[String]): Wrapper[String]

...

}

view rawpost.scala hosted with ❤ by GitHub

These methods are called with or return Wrapper[String] in contrast to those above. The consequence is that you cannot call toUpper on an Int, which makes the API much harder to use incorrectly. In fact, to combine those two method sets we would need a method to covert between the types, as the one below:

trait Language[Wrapper[_]] {

...

def toString(v: Wrapper[Int]): Wrapper[String]

}

view rawpost.scala hosted with ❤ by GitHub

Fairly standard stuff: one method to create a Wrapper[X] values and a set of operations on those values.

Bridges

Having a Language in place, we now need to “bridge” the gap between the hosting language (Scala) and our hosted language. Here we will build an interface for a generic bridge to do just that:

trait ScalaToLanguageBridge[ScalaValue] {

def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[ScalaValue]

}

view rawpost.scala hosted with ❤ by GitHub

As you can see it takes an implicit Language to build expressions and after calling apply returns a result wrapped in the desired Wrapper class. Although he trait might seem confusing, taking a look at some examples hopefully will clear things up:

def buildNumber(number: Int) = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = L.number(number)

}

def buildIncrementNumber(number: Int) = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = L.increment(L.number(number))

}

view rawpost.scala hosted with ❤ by GitHub

Our bridges are really simple. They take a single Scala value number: Int and convert it into an expression in our language. Since we are operating only on given L we cannot represent incorrect logic, like incrementing a String. This approach for bridges is “fine grained” - we build only simple expressions. But can easily combine those simple expression into bigger ones:

def buildIncrementExpression(expression: ScalaToLanguageBridge[Int]) = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = L.increment(expression.apply)

}

view rawpost.scala hosted with ❤ by GitHub

or just follow a “coarse grained” approach where we express bigger algorithms straight away

// builds an expression like: println(s"$text ${a + (b + 1)}")

def buildComplexExpression(text: String, a: Int, b: Int) = new ScalaToLanguageBridge[String] {

override def apply[Wrapper[_]](implicit F: Language[Wrapper]): Wrapper[String] = {

val addition = F.add(F.number(a), F.increment(F.number(b)))

F.concat(F.text(text), F.toString(addition))

}

}

val fullExpression = buildComplexExpression("Result is ", 10, 1)

view rawpost.scala hosted with ❤ by GitHub

Interpreters

We have defined what our meta-language can do. We expressed our problems in the language. Now it’s time to make it run. For instance with an interpreter like this:

type NoWrap[ScalaValue] = ScalaValue

val interpret = new Language[NoWrap] {

override def number(v: Int): NoWrap[Int] = v

override def increment(a: NoWrap[Int]): NoWrap[Int] = a + 1

override def add(a: NoWrap[Int], b: NoWrap[Int]): NoWrap[Int] = a + b

override def text(v: String): NoWrap[String] = v

override def toUpper(a: NoWrap[String]): NoWrap[String] = a.toUpperCase

override def concat(a: NoWrap[String], b: NoWrap[String]): NoWrap[String] = a + " " + b

override def toString(v: NoWrap[Int]): NoWrap[String] = v.toString

}

view rawpost.scala hosted with ❤ by GitHub

First, we need to define a Wrapper for us. Here we don’t need anything fancy so we will just operate on plain values. Our NoWrap will literally default to the type it was parameterized with. After this is done, we simply need to implement our interface. Nothing fancy here.

Note: As David Barri pointed out in his comment below. The NoWrap/Id type aliases might be tricky to use in production code, due to their eager evaluation and other concerns. Here we stick to it only for simplicity.

Having all 3 items in place we can run them together:

val fullExpression = buildComplexExpression("Result is ", 10, 1)

println(s"interpreted full: ${fullExpression.apply(interpret)}")

// interpreted full: Result is 12

view rawpost.scala hosted with ❤ by GitHub

Simple enough, but why should we bother with the Wrappers in the first place? If you ever played with Free monads you probably now why - to build multiple, specialized interpreters. To give you an example, let’s build another one. This time we will build an utility interpreter that will be helpful in later stages. The new interpreter will pretty-print our math expressions in a Lisp-like syntax, so we can easily spot mistakes in our code.

type PrettyPrint[ScalaValue] = String

val interpretAsPrettyPrint = new Language[PrettyPrint] {

override def number(v: Int): PrettyPrint[Int] = s"($v)"

override def increment(a: PrettyPrint[Int]): PrettyPrint[Int] = s"(inc $a)"

override def add(a: PrettyPrint[Int], b: PrettyPrint[Int]): PrettyPrint[Int] = s"(+ $a $b)"

override def text(v: String): PrettyPrint[String] = s"[$v]"

override def toUpper(a: PrettyPrint[String]): PrettyPrint[String] = s"(toUpper $a)"

override def concat(a: PrettyPrint[String], b: PrettyPrint[String]): PrettyPrint[String] = s"(concat $a $b)"

override def toString(v: PrettyPrint[Int]): PrettyPrint[String] = s"(toString $v)"

}

view rawpost.scala hosted with ❤ by GitHub

val fullExpression = buildComplexExpression("Result is ", 10, 1)

println(s"interpreted full (as pretty print): ${fullExpression.apply(interpretAsPrettyPrint)}")

// interpreted full (as pretty print): (concat [Result is ] (toString (+ (10) (inc (1)))))

view rawpost.scala hosted with ❤ by GitHub

So far, so good. To reiterate: we defined the possible operations in a form of a Language trait, we defined helpers to convert Scala values into expressions using Language, we implemented our meta-language and run the code. Every element above is just plain Scala, but thanks to the way it is composed we achieve few nice properties.

Benefits

 

Extensibility

You are probably familiar with the expression problem, one of the classical concepts in computer science. Long story short: having an interface and a set of implementation of that interface, we want to be able to easily add operations to the interface and new interface implementations. Ideally, because usually OOP makes it hard add interface methods but easy to add implementations. In FP, on the other hand, it’s easy to add new methods, but harder to add implementations.

This is a big deal for every programmer, even though we might not think about it every day. No useful software is static - written once and left unchanged for years. Our software needs to evolve, hence extensibility is an important factor when picking patterns.

Tagless Final finds itself closer to the FP part of the extensibility spectrum. Which is ok in this case. We have tools to easily and safely modify our Language, which (I presume) is the more common operation. Creating interpreters might be harder, I agree, but the cool thing is that you don’t always need to update them. Your change can be made in a non-breaking way (or at least keeping the damage to minimum).

To explain how this works, let’s go back to the example from before. We had one math operation - add, but let’s say we wanted another one called multiply. We could just go and add the new method to Language. That would be really simple, but unfortunately also caused a ripple effect throughout the system as you would have to update every interpreter that uses the language. Not good. So let’s do something else instead.

trait LanguageWithMul[Wrapper[_]] extends Language[Wrapper] {

def multiply(a: Wrapper[Int], b: Wrapper[Int]): Wrapper[Int]

}

view rawpost.scala hosted with ❤ by GitHub

Conceptually what we did is we created a child language, which has the whole API of it’s parent, but also few new things. From this point on we can just do the same things we did for the parent.

Define a bridge:

trait ScalaToLanguageWithMulBridge[ScalaValue] {

def apply[Wrapper[_]](implicit L: LanguageWithMul[Wrapper]): Wrapper[ScalaValue]

}

def multiply(a: Int, b: Int) = new ScalaToLanguageWithMulBridge[Int] {

override def apply[Wrapper[_]](implicit L: LanguageWithMul[Wrapper]): Wrapper[Int] = {

L.multiply(L.number(a), L.number(b))

}

}

view rawpost.scala hosted with ❤ by GitHub

 

Build an interface:

val interpretWithMul = new LanguageWithMul[NoWrap] {

override def multiply(a: NoWrap[Int], b: NoWrap[Int]): NoWrap[Int] = a * b

override def number(v: Int): NoWrap[Int] = v

override def increment(a: NoWrap[Int]): NoWrap[Int] = a + 1

override def add(a: NoWrap[Int], b: NoWrap[Int]): NoWrap[Int] = a + b

override def text(v: String): NoWrap[String] = v

override def toUpper(a: NoWrap[String]): NoWrap[String] = a.toUpperCase

override def concat(a: NoWrap[String], b: NoWrap[String]): NoWrap[String] = a + " " + b

override def toString(v: NoWrap[Int]): NoWrap[String] = v.toString

}

view rawpost.scala hosted with ❤ by GitHub

 

As you can see it required a bit of writing (could be less if we delegated/inherited in interpretWithMul definition). But the good part is that at no point we had to touch the older part - it runs at it did before.

Composability

Another interesting property that emerges from this pattern in composability. The ability to combine small parts into bigger, more powerful entities with new properties.

Let’s discuss it on a simple example. Using our original language we were able to express logic like this:

// builds a 10 + (((0 + 1)+1)+1) expression

def buildIncrementExpression() = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = {

L.add(L.number(10), L.increment(L.increment(L.increment(L.number(0)))))

}

}

view rawpost.scala hosted with ❤ by GitHub

The code compiles just fine, but there’s something we could do better - performance. 10 + (((0 + 1)+1)+1) is not the optimal form for this operation.10 + (0 + 3) or simply 10 + 3 would be much simpler and more performant. Of course here the difference is not huge, but if our language would be doing API calls or DB queries the differences would be much bigger. Conceptually we would like to simplify our expression into a more convenient form. Here we will do that naively for one case - flattening many nested inc calls into a single add.

But how to do that? Tagless final approach brings us 3 “moving parts”: Language, Bridge, Interpreter. We cannot add this optimization to the Language itself, as this could mean that our DSL would be in constant flux. And, as we mentioned above, we would prefer not to touch the once defined language. Adding the optimization into bridges is plausible, but unfortunately at call time they don’t know much about the shape we are building. Even worse, users of our language can build their own bridges which makes distributing the improved version harder. We are left with just one place - interpreter. We will have to figure something out, so our end user will be able to opt-in to the optimization easily when they find it helpful.

One interesting solution to this problem is interpreter composability. To recap: the idea of interpreters is that they take expressions in our custom language and turn that into plain Scala values. So, if we would write a specific interpreter that takes an expression, turns it into plain Scala, but that result would happen to be a bridge, then we could interpret his output later on.

In short: instead interpreting into (say) a Scala Int, we will interpret into a rewritten interpretation.

type Nested[ScalaValue] = ScalaToLanguageBridge[ScalaValue]

val simplify = new Language[Nested]{

var nesting = 0

override def number(v: Int): Nested[Int] = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = {

if(nesting > 0) {

val temp = nesting

nesting = 0

L.add(L.number(temp), L.number(v))

} else {

L.number(v)

}

}

}

override def increment(a: ScalaToLanguageBridge[Int]): Nested[Int] = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = {

nesting = nesting + 1

a.apply(L)

}

}

override def add(a: ScalaToLanguageBridge[Int], b: ScalaToLanguageBridge[Int]): Nested[Int] = new ScalaToLanguageBridge[Int] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Int] = {

if(nesting > 0){

val temp = nesting

nesting = 0

L.add(L.number(temp), L.add(a.apply(L), b.apply(L)))

} else {

L.add(a.apply(L), b.apply(L))

}

}

}

...

}

view rawpost.scala hosted with ❤ by GitHub

Of course the code is a very simplistic, but you get the idea. During the interpretation process we can collect more detailed information about the expression and make decisions based on that.

Here’s the output:

val simpleVersion = buildIncrementExpression()

println(s"Unoptimized ${simpleVersion.apply(interpretAsPrettyPrint)} = ${simpleVersion.apply(interpret)}")

// Unoptimized (+ (10) (inc (inc (inc (0))))) = 13

val example1 = simpleVersion.apply(simplify)

println(s"Optimized ${example1.apply(interpretAsPrettyPrint)} = ${example1.apply(interpret)}")

// Optimized (+ (10) (+ (3) (0))) = 13

view rawpost.scala hosted with ❤ by GitHub

As you can see the expression has been rewritten into a different form. Later, this new form can be interpreted as usual.

Real life example

Having explained the pattern, let’s build something that solves a more typical problem. Our toy example will use Slick to make simple queries to the database.

We will start from the language, where we will define some basic operations on our Person case class, together with some helper classes to pass data around and make the API more typesafe.

trait Language[Wrapper[_]] {

type QueryObj

case class Raw(q: QueryObj)

case class WithFilter(q: QueryObj)

case class WithPagination(q: QueryObj)

def people(): Wrapper[Raw]

// we could change plain Ints into our custom class, but will skip it for brevity

def filterByIds(query: Wrapper[Raw], ids: Seq[Int]): Wrapper[WithFilter]

def paginate(query: Wrapper[WithFilter], skip: Int, limit: Int): Wrapper[WithPagination]

def run(query: Wrapper[WithPagination]): Wrapper[Seq[Person]]

}

view rawpost.scala hosted with ❤ by GitHub

Then we follow the routine that gets us into the Slick interpreter

val slickInterpreter = new Language[Future] {

/// I've skipped the Slick related part, you probably seen it 100 times already

private val slickPersonQuery = TableQuery[SlickPersonTable]

override type QueryObj = Query[SlickPersonTable, Person, Seq]

override def people(): Future[Raw] = {

Future.successful(Raw(slickPersonQuery))

}

override def filterByIds(query: Future[Raw], ids: Seq[Int]): Future[WithFilter] = {

query.map(_.q.filter(_.id inSet ids)).map(WithFilter)

}

override def paginate(query: Future[WithFilter], skip: Int, limit: Int): Future[WithPagination] = {

query.map(_.q.drop(skip).take(limit)).map(WithPagination)

}

override def run(query: Future[WithPagination]): Future[Seq[Person]] = {

query.flatMap { case finalQuery =>

db.run(finalQuery.q.result)

}

}

}

view rawpost.scala hosted with ❤ by GitHub

Our interpreter doesn’t do anything fancy - it keeps the table definitions internally, then between calls state is being accumulated in a form of QueryObj being passed around. When we reach the final point - run, query is being executed to retrieve the data.

val findMiddleUser = new ScalaToLanguageBridge[Seq[Person]] {

override def apply[Wrapper[_]](implicit L: Language[Wrapper]): Wrapper[Seq[Person]] = {

val base = L.people()

val full = L.paginate(L.filterByIds(base, Seq(1,2,3)), skip = 1, limit = 1)

L.run(full)

}

}

val result = Await.result(findMiddleUser.apply(slickInterpreter), 10.seconds)

println(s"Query result is $result")

// Query result is Vector(Person(2,person 2,1990))

view rawpost.scala hosted with ❤ by GitHub

And tada, we used Tagless Final with Slick.

Summary

And that’s it. We applied the pattern in Scala. Wrote a code that can be extended and composed with ease.

I hope this post inspired you to take a look at Tagless Final on your own, the same way as I was inspired. Thanks for the time and keep on hacking!

Source: https://blog.scalac.io/exploring-tagless-f...