Functional Programming: So what is for-comprehension and Monad in Scala? Part 2:

Part 1 @

https://medium.com/@mageswaran1989/functional-programming-so-what-is-for-comprehension-and-monad-in-scala-eed22d52e3f7

In last post we saw what is a monad or what makes a monad in Scala.

Code @ https://github.com/dhiraa/medium/tree/master/for-comprehension

In short a class/container which had map() and flatmap() implemented can informally be called as a Monad, which then can be used easily inside for expresssion.

In this post we will see what really happens inside the for expression and how the function calls are chained.

When using multiple generators in `for expression`, it is desugared as chain of flatmap calls followed by map call with the yield expression as the map’s lambda function body. For simplicity we have used Option as our monad (though there is no definition of monad as such in Scala)

def makeInt(s: String): Option[Int] = {
try {
Some(s.trim.toInt)
} catch {
case e: Exception => None
}
}

val res = for {
i <- makeInt("1")
j <- makeInt("2")
k <- makeInt("3")
} yield i + j + k

//compiler generated, cleaned for easy reference
val res1 = makeInt("1").flatMap(i =>
makeInt("2").flatMap(j =>
makeInt("3").map(k =>
i.$plus(j).$plus(k)
)
)
)

println("res = " + res) //res = Some(6)
println
("res1 = " + res1) //res1 = Some(6)

Following is the complete example to chain multiple functions with Higher Order Functions (H.O.F) and with for expression. While latter output will be discussed in detail.

package io.dhiraa

case class DebuggableWrapper[A](value: A, msg: String) {

def map[B](f: A => B): DebuggableWrapper[B] = {
println(">>>>>>> "+ this.toString + " map "+ value + " => ")
val newValue = f(value)
println("<<<<<<< " + this.toString + " map \n")
DebuggableWrapper(newValue, msg)
}

def flatMap[B](f: A => DebuggableWrapper[B]): DebuggableWrapper[B] = {
println(">>>>>>> "+ this.toString + " flatmap " + this.toString + " map "+ value + " => ")
val newValue = f(value)
val res = DebuggableWrapper(newValue.value, msg + " " + newValue.msg)
println("<<<<<<< "+ this.toString + " flatmap " + res + "\n")
res
}
}

case class State(value: Int) {

def flatMap(f: Int => State): State = {
println(">>>>>>> "+ this.toString + " flatmap " + "\n")
val newState = f(value)
println("<<<<<<< " + this.toString + " flatmap \n")
State(newState.value)
}

def map(f: Int => Int) = {
println(">>>>>>> "+ this.toString + " map " + "\n")
val res = State(f(value))
println("<<<<<<< " + this.toString + " map \n")
res
}
}

object Debuggable {

//Lets see how we can bind/compbine pure functions in FP style using Higher order Functions

//Below function calls possible use case is when we wanted to gather the debug info as we call series of functions
def f0(a: Int): (Int, String) = {
val result = a + 1
(result, " f0")
}

def g0(a: Int): (Int, String) = {
val result = a + 2
(result, " g0")
}
def h0(a: Int): (Int, String) = {
val result = a + 3
(result, " h0")
} // bind, a HOF

def bind(fun: (Int) => (Int, String),
tup: Tuple2[Int, String]): (Int, String) =
{
val (intResult, stringResult) = fun(tup._1)
(intResult, tup._2 + stringResult)
}

//Now lets see how to use for expression and make this in FP style

def f(v: Int): DebuggableWrapper[Int] = {
val res = DebuggableWrapper(v + 1, "F")
println("f(" + v +") -> " + res + "\n")
res
}
def g(v: Int): DebuggableWrapper[Int] = {
val res = DebuggableWrapper(v + 1, "G")
println("g(" + v +") -> " + res + "\n")
res
}
def h(v: Int): DebuggableWrapper[Int] = {
val res = DebuggableWrapper(v + 1, "H")
println("h(" + v +") -> " + res + "\n")
res
}

def main(args: Array[String]): Unit = {

val fResult = f0(0)
val gResult = bind(g0, fResult)
val hResult = bind(h0, gResult)

println(s"result: ${hResult._1} debug: ${hResult._2}")

println("-------------------------------------------------------")

val res = for {
i <- f(0)
j <- g(i)
k <- h(j)
} yield k

println(res)

println("-----------------------------------------------------")

//compiler generated : "scalac -Xprint:parse src/main/scala/io/dhiraa/Debuggable.scala "
//val res = f(0).flatMap(((i) => g(i).flatMap(((j) => h(j).map(((k) => k))))));

val res1 = f(0).flatMap { i =>
println("This is called as part of lambda function with value of i as " + i)
g(i).flatMap { j =>
println("This is called as part of lambda function with value of j as " + j)
h(j).map{k =>
println("This is called as part of lambda function with value of k as " + k)
k
}
}
}

println(res1)

println("-----------------------------------------------------")

val res3 = for {
a <- State(20)
b <- State(a + 15) //manually carry over `a`
c <- State(b + 0) //manually carry over `b`
} yield c
println(s"res: $res3") //prints "State(35)"

//val res3 = State(20).flatMap(((a) => State(a.$plus(15)).
// flatMap(((b) => State(b.$plus(0)).map(((c) => c))))));

}
}

Going from H.O.F to `for expression`, you will see a class called DebuggleWrapper, which implements two needed functions map() and flatmap() along with some debug prints. The only difference we see between map and flatmap is with the in argument function signature.

Note: Inside the map() and flatmap() we use class variable “value”

def map[B](f: A => B): DebuggableWrapper[B]
def flatMap[B](f: A => DebuggableWrapper[B]): DebuggableWrapper[B]

When compiled following code

val res = for {
i <- f(0)
j <- g(i)
k <- h(j)
} yield k

it generates

val res = f(0).flatMap(((i) => g(i).flatMap(((j) => h(j).map(((k) => k))))));

with some clean up we have

val res1 = f(0).flatMap { i =>
println("This is called as part of lambda function with value of i as " + i)
g(i).flatMap { j =>
println("This is called as part of lambda function with value of j as " + j)
h(j).map{k =>
println("This is called as part of lambda function with value of k as " + k)
k
}
}
}

When we run this we get following output:

-----------------------------------------------------
f(0) -> DebuggableWrapper(1,F)

>>>>>>> DebuggableWrapper(1,F) flatmap DebuggableWrapper(1,F) map 1 =>
This is called as part of lambda function with value of i as 1
g(1) -> DebuggableWrapper(2,G)

>>>>>>> DebuggableWrapper(2,G) flatmap DebuggableWrapper(2,G) map 2 =>
This is called as part of lambda function with value of j as 2
h(2) -> DebuggableWrapper(3,H)

>>>>>>> DebuggableWrapper(3,H) map 3 =>
This is called as part of lambda function with value of k as 3
<<<<<<< DebuggableWrapper(3,H) map

<<<<<<< DebuggableWrapper(2,G) flatmap DebuggableWrapper(3,G H)

<<<<<<< DebuggableWrapper(1,F) flatmap DebuggableWrapper(3,F G H)

DebuggableWrapper(3,F G H)
-----------------------------------------------------

For better understanding, requesting you to read from the chapter “The “Bind” Concept” in Functional-Programming-Simplified-Scala.pdf

This is to my understanding, as I don’t want to repeat Alvin work :) which is the best so far I have come across!

Step 1: f(0) is called which returns DebuggableWrapper(1,F)

Step 2: Call flatmap on DebuggableWrapper(1,F), which in turn calls g with 1 as value i.e g(1) which then returns DebuggableWrapper(2,G)

Step 3: Call flatmap on DebuggableWrapper(2,G), which in turn calls h with 2 as value i.e h(2) which then returns DebuggableWrapper(3,H)

Step 4: We have reached the end of the for expression generators, so now it uses map with yield expression as the function body i.e k => k

Step 5: With above step getting executed we will get DebuggableWrapper(3,H)

Step 6: Now the call stack starts unwinding the recursive calls of flatmap, starting with map returning DebuggableWrapper(3,H)

Step 7: As the call stack unwinds, the second half of the flatmap starts executing, i.e creating new instances after applying the map function. Now the call stack unwinds to Step 2 as

DebuggableWrapper(2,G) flatmap DebuggableWrapper(3,G H)

Step 8: As follow up the call stack unwinds to Step 1 with

DebuggableWrapper(1,F) flatmap DebuggableWrapper(3,F G H)

Step 9: Which then returns the first flatmap call with aggregated debug info with actual calculation

I hope this gives you a different perspective on flatmap calls than usual List of List examples :)

In next post we see how the states can be handled internally without passing the previous state information, as you see in this example.

And yeah the State class is up to you to explore, happy learning!

Mageswaran, Principal Engineer @