Mastering Scala Basics: Modeling Data with Traits

Traits allow us to express that multiple classes share a common super-type. Traits are similar to Java’s interfaces and can be thought of abstraction over classes.

// Trait declaration
trait TraitName {
    declarationExpression
}

// declare class is a subtype of trait
class Name(...) extends TraitName {
    ...
}

Two ways that traits are different from classes:

  1. A trait cannot have a constructor
  2. Traits can define abstract methods that have names and type signatures but no implementation. Classes that extend traits must implement the abstract methods.
trait Animal {
    def sound: String
    def color: String
    def food: String = "Kibbles"
}

// subtypes must implement sound and color
case class Dog(sound: String, color: String) extends Animal
case class Cat(sound: String, color: String) extends Animal

If all the subtypes of a trait are known then seal the trait.

sealed trait TraitName {
    ...
}

// consider making final subtypes final if there is no case for extending them
final case class Name(...) extends TraitName {
    ...
}

The main advantages of using sealed trait/final case class are:

  1. The compiler will warn if we miss a case in pattern matching; and
  2. We can control extension points of sealed traits and thus make stronger guarantees about the behavior of the subtypes.

In Scala, we can define any data models using logical ors and ands. In object oriented programming, we will express this as is-a and has-a relationships. In functional programming, we are learning about sum and product types, which are together called algebraic data types.

Product Type Pattern

This pattern is used to model data that contains other data. For example: Animal has a color and a sound.

If A has b (with type B) and a c (with type C), write

case class A(b: B, c: C)

or

trait A {
    def b: B
    def c: C
}

Sum Type Pattern

This pattern is used to model data that is two or more distinct cases. Example: a Visitor on a webpage can be anonymous or logged-in.

If A is a B or a C:

sealed trait A
final case class B() extends A
final case class C() extends A

We work with data using structural recursion, which is essentially breaking down the data into smaller pieces. We will use two patterns to decompose data using structural recursion.

Structural Recursion using Polymorphism

I assume you’re already familiar with Polymorphism but if you’re not then I’d recommend check polymorphism in depth.

The Product Type Polymorphism Pattern

If A has b(with type B) and a c(with type C), and we want to write a method g returning G:

case class A(b: B, c: C) {
    def g: G = ???
}

The Sum Type Polymorphism Pattern

If A is a B or a C, and we want to write a method g returning G, define f as an abstract method on A and provide concrete implementations in B and C.

sealed trait A {
    def g: G
}

final case class B() extends A {
    def g: G = ???
}

final case class C() extends A {
    def g: G = ???
}

Structural Recursion using Pattern Matching

Here we simply have a case for every subtype, and each pattern matching case must extract the fields we’re interested in.

The Product Type Matching Pattern

If A has b(with type B) and a c(with type C), and we want to write a method g that accepts an A and returns G:

def g(a: A): G = 
    a match {
        case A(b, c) => ???
    }
The Sum Type Matching Pattern

If A is a B or a C, and we want to write a method g accepting an A and returning G, define a pattern matching case for B and C.

def g(a: A): G = 
    a match {
        case B() => ???
        case C() => ???
    }

Example: we can say that Animal is a Dog, Cat or a Parrot. In addition, Dog has a favorite food.

sealed trait Animal
final case class Dog(favoriteFood: String) extends Animal
final case class Cat() extends Animal
final case class Parrot() extends Animal

Now, let’s suppose we need to implement a method dinner, that will return the appropriate food for the animal. We can represent food with a String but let’s represent it with a type.

sealed trait Food
final case class DogFood(food: String) extends Food
case object CatFood extends Food
case object Chillies extends Food

So, now we have data Food, let’s implement dinner

sealed trait Animal {
  def dinner: Food
}
final case class Dog(favoriteFood: String) extends Animal {
  def dinner: Food =
    DogFood(favoriteFood)
}
final case class Cat() extends Animal {
  def dinner: Food = CatFood
}
final case class Parrot() extends Animal {
  def dinner: Food = Chillies
}

There are two ways to pattern matching:

  1. Implementing code in a single method: If a method only depends on other fields and methods in a class.
sealed trait Animal {
  def dinner: Food =
    this match {
      case Dog(favoriteFood) => DogFood(favoriteFood)
      case Cat() => CatFood
      case Parrot() => Chillies
    }
}
final case class Dog(favoriteFood: String) extends Animal
final case class Cat() extends Animal
final case class Parrot() extends Animal
  1. Implement it in a method on another object: If the method depends on other data.
object Dinner {
  def dinner(animal: Animal): Food =
    animal match {
      case Dog(favoriteFood) => DogFood(favoriteFood)
      case Cat() => CatFood
      case Parrot() => Chillies
    }
}

Just like when we write recursive functions - we need a case that is not recursive and one that is not, also called base case. Using the same principle, we create recursive data in Scala. Example - in singly-linked list we have a head and a tail. Now, if we want to find the sum of all the elements in a list we can write a function that adds head with the tail. To break the tail into small pieces, we need to make it a recursive call. And, once we reach the end, suppose a null, what can we return? a 0.

// recursive algebraic data types pattern
sealed trait IntList  // define a trait data
case object End extends IntList // define base case
final case class Pair(head: Int, tail: IntList) extends IntList // define recursive case

Recursive structural recursive pattern:

  • Whenever we encounter a recursive element in the data we make a recursive call to our method.
  • Whenever we encounter a base case in the data we return the identity of the operation we are performing. In case of adding list elements, identity operation will be x + 0 = x.
// recursive function
def sum(list: IntList): Int =
    list match {
      case End => 0
      case Pair(hd, tl) => hd + sum(tl)
    }

Scala applies an optimization called tail recursion as recursive calls consume excessive stack space. A tail call is a method where the caller immediately returns the value.

def aMethod: Int = 1

def aTailCall: Int = aMethod

// not a tail call - as it adds a number to the result
def notATailCall: Int = aMethod + 10

To make a non-tail recursive functions to a tail recursive version, add an accumulator. Example:

@tailrec
  def sum(list: IntList, acc: Int = 0): Int =
    list match {
      case End => acc
      case Pair(hd, tl) => sum(tl, hd + acc)
    }

References used are scala doc and essential scala