Functional Programming: So what is for-comprehension and Monad in Scala? Part-3

Mageswaran D
8 min readAug 30, 2018

Part 2 @

Git @ https://github.com/dhiraa/medium/tree/master/for-comprehension

File @

In last post we ended with a wrapper for cases where a state needs to be maintained and used in for expression.

case class State(value: Int) {

def flatMap(f: Int => State): State = {
println(">>>>>>> "+ this.toString + " flatmap " + "\n")
val newState = f(value)
println("<<<<<<< " + this.toString + " flatmap \n")
State(newState.value)
}

def map(f: Int => Int) = {
println(">>>>>>> "+ this.toString + " map " + "\n")
val res = State(f(value))
println("<<<<<<< " + this.toString + " map \n")
res
}
}

Which can be used as

val res3 = for {
a <- State(20)
b <- State(a + 15) //manually carry over `a`
c <- State(b + 0) //manually carry over `b`
} yield c
println(s"res: $res3")

Works well, but FP programmers wanted to be lazy here by not carrying over the previous state manually to next state. Don’t forget FP programmers don’t like mutating the values also. Like following…

val stateWithNewDistance: StateM[GolfState, Int] = for {
_ <- swing(distance=10)
_ <- swing(distance=20)
totalDistance <- swing(distance=30)
} yield totalDistance

Actually I felt the above code is much easier and manageable after trying to understand the State Monad :)

Original code can be found @ https://github.com/alvinj/StateMonadExample/tree/master/src/main/scala/state_monad

I made some changes to the code and added a lot of debug statements for easy understanding of the code flow, which are basically self explanatory.

Original Code

case class State[S, A](run: S => (S, A)) {

def flatMap[B](g: A => State[S, B]): State[S, B] = State { (s0: S) =>
val (s1, a) = run(s0)
g(a).run(s1)
}

def map[B](f: A => B): State[S, B] = flatMap(a => State.point(f(a)))
}

object State {
def point[S, A](v: A): State[S, A] = State(run = s => (s, v))
}

My Version of code:

/**
* A class that maintains th state of type S and aggregates them over the time as a type B from A
*
@param run
* @tparam S StateM Type
*
@tparam A Aggregated Type
*/
case class StateM[S, A](run: S => (S, A)) {

//As a naive way of remembering, flat map always has to return a new type in the wrapper.
//Here the state remain the same, A -> B
def flatMap[B](g: A => StateM[S, B]): StateM[S, B] = StateM { (currentState: S) =>
println("\n>>> flatmap: currentState " + currentState + " calling run/swing on current state...")
val (nextState, a) = run(currentState)
// val res = g(a).run(nextState)
// println("<<< aggregated state info " + res)
// res
println
("--- flatmap: nextState " + nextState + " new distance " + a + ". calling g...")
val stateChangeToB = g(a)
println("--- flatmap: actually triggering run/swing on next state after initializing g()...\n")
val res = stateChangeToB.run(nextState)
println("<<< flatmap: aggregated state info " + res)
res
}

/**
* Map function that does apply some logic to convert the type A -> B. Eg: 10 -> 10+5 or 10 -> 10,
* while keeping the state the same. I think this must have come after try
*
@param f Function that applies the logic on flatmapped value/distance here
*
@tparam B
* @return
*/
def map[B](f: A => B): StateM[S, B] = flatMap(a => StateM.point(f(a))) //Creates a new StateM with run s => s, f(a)
}

Say we wanted to maintain the GolfState distance ,after each swing.

First we need to define a state class to maintain the state

case class GolfState(distance: Int)

Second we need to write the swing() function. Considering previous examples this swing() function has to do little extra work.

The extra work can be summarized as a function that takes the new swing distance and returns an anonymous function that takes the previous state as input and uses the current swing distance to update the state, wrapped in the State Monad class through constructor initialization! Don’t be afraid of the syntax, we are simply assigning the lambda function to “run” in State Monad.

def swing(distance: Int): State[GolfState, Int] = State { (s: GolfState) =>
val newDistance = s.distance + distance
(GolfState(newDistance), newDistance)
}

My version:

/**
* swing() is a function that takes the state info as a int and creates a State Monad,
* whicn then can be run to compute new state
*
@param distance
* @return
*/
def swing(distance: Int): StateM[GolfState, Int] = StateM { (s: GolfState) => //run: S => (S, A)
println
(">>> run/swing triggered: distance: " + distance + " state: " + s)
val newDistance = s.distance + distance
println("<<< run/swing: " + (GolfState(newDistance), newDistance) + "")
(GolfState(newDistance), newDistance)
}

Lets get into action…

We are doing our first swing which moved the ball by 10 meters. Swing function takes the the distance and return the StateM Monad basically a class instance, only the catch is the swing() method had wrapped the update logic inside the anonymous function which is passed to the StateM/State class constructor param ‘run’.

This way of wrapping makes the state based operation as lazy and can be triggered by calling run().

val beginningStateM = GolfState(0)
val
firstHit = swing(distance=10) //Lets say we moved the ball by 10 meters
// println("firstHit: " + firstHit)
// State doesn't change unless we ask it to do so
val afterFirstHit = firstHit.run(beginningStateM)
println("afterFirstHit " + afterFirstHit)

println("\n-------------------------------------------------\n")

Output:

>>> run/swing triggered: distance: 10 state: GolfState(0)
<<< run/swing: (GolfState(10),10)
afterFirstHit (GolfState(10),10)

-------------------------------------------------

Next lets see what happens in the flatMap() call…

println("about to make second hit... ")
//Lets move the ball by another 10 meters
val secondHit = firstHit.flatMap{_ =>
println("secondhit by calling swing inside the flatMap")
swing(20)} //flatMap[B](g: A => StateM[S, B]): StateM[S, B]
println
("Lets move the ball by another 10 meters " + secondHit.run(beginningStateM))

println("\n-------------------------------------------------\n")

A recap on the flatMap()

//As a naive way of remembering, flat map always has to return a new type in the wrapper.
//Here the state remain the same, A -> B
def flatMap[B](g: A => StateM[S, B]): StateM[S, B] = StateM { (currentState: S) =>
println("\n>>> flatmap: currentState " + currentState + " calling run/swing on current state...")
val (nextState, a) = run(currentState)
// val res = g(a).run(nextState)
// println("<<< aggregated state info " + res)
// res
println
("--- flatmap: nextState " + nextState + " new distance " + a + ". calling g...")
val stateChangeToB = g(a)
println("--- flatmap: actually triggering run/swing on next state after initializing g()...\n")
val res = stateChangeToB.run(nextState)
println("<<< flatmap: aggregated state info " + res)
res
}

flatMap expects a anonymous function that can take a variable of type A and return SateM/State instance.

For simplicity of our Golf game, we are ignoring the previous hit distance, and making a new swing inside the flapMap. This becomes the input to the g() which has the type g: A => StateM[S, B]

At the end of the flatMap call we are given with an instance of State/StateM class.

firstHit → flatMap → secondHit → run() → new state

about to make second hit...

>>> flatmap: currentState GolfState(0) calling run/swing on current state...
>>> run/swing triggered: distance: 10 state: GolfState(0)
<<< run/swing: (GolfState(10),10)
--- flatmap: nextState GolfState(10) new distance 10. calling g...
secondhit by calling swing inside the flatMap
--- flatmap: actually triggering run/swing on next state after initializing g()...

>>> run/swing triggered: distance: 20 state: GolfState(10)
<<< run/swing: (GolfState(30),30)
<<< flatmap: aggregated state info (GolfState(30),30)
Lets move the ball by another 20 meters (GolfState(30),30)

-------------------------------------------------

flatMap() call is getting triggered on the firstHit instance, thus we see GolfState(0) which was passed to run().

GolfState(0) -> run() -> swing anonymous function (distance=10) -> returns new state and the total distance as tuple -> now call the g() which is our second swing(distance=20) operation -> call run() explicitly on the new state -> return the new state i.e swing return value

Next coming to the map()…

object StateM {
//type is named as B to be inline with map() API
def point[S, B](v: B): StateM[S, B] = StateM(run = s => (s, v))
}

State monad will have a point/lift function that takes the state value and wraps it in the State/StateM monad, while preserving the current state

/**
* Map function that does apply some logic to convert the type A -> B. Eg: 10 -> 10+5 or 10 -> 10,
* while keeping the state the same. I think this must have come after try
*
@param f Function that applies the logic on flatmapped value/distance here
*
@tparam B
* @return
*/
def map[B](f: A => B): StateM[S, B] = {
println(">>> map")
flatMap(a => StateM.point(f(a))) //Creates a new StateM with run s => s, f(a)
}
println("\n-------------------------------------------------\n")

//Lets move the ball by another 10 meters
val thirdHit = secondHit.map(distance => distance + 30)
println("Lets move the ball by another 30 meters " + thirdHit.run(beginningStateM))

println("\n-------------------------------------------------\n")

Output:

-------------------------------------------------

>>> map

>>> flatmap: currentState GolfState(0) calling run/swing on current state...

>>> flatmap: currentState GolfState(0) calling run/swing on current state...
>>> run/swing triggered: distance: 10 state: GolfState(0)
<<< run/swing: (GolfState(10),10)
--- flatmap: nextState GolfState(10) new distance 10. calling g...
secondhit by calling swing inside the flatMap
--- flatmap: actually triggering run/swing on next state after initializing g()...

>>> run/swing triggered: distance: 20 state: GolfState(10)
<<< run/swing: (GolfState(30),30)
<<< flatmap: aggregated state info (GolfState(30),30)
--- flatmap: nextState GolfState(30) new distance 30. calling g...
--- flatmap: actually triggering run/swing on next state after initializing g()...

<<< flatmap: aggregated state info (GolfState(30),60)
Lets move the ball by another 30 meters (GolfState(30),60)

-------------------------------------------------

run → thirdHit → map → wraps in flatMap with [ g = s => (s, v) where v is f(distance) i.e distance + 30]→ secondHit flatMap → firstHit swing anonymous function (distance=10) → returns new state and the total distance as tuple (GolfState(10),10) → now call the g() which is our second swing(distance=20) operation → returns new state and the total distance as tuple (GolfState(30),30) → secondHit flattMap call rewinds → now map i.e flatMap executes the g() which is basically adding 30 to the total distance and return the tuple.

Note: You need to sit with the IDE to understand better

Having a little understanding on the internals of State Monad, lets move on to for-expression:

val stateWithNewDistance: StateM[GolfState, Int] = for {
_ <- swing(distance=10)
_ <- swing(distance=20)
totalDistance <- swing(distance=30)
} yield totalDistance

I took the compiler generated version of above code and added few debug prints as follows:

//compiler generated : scalac -Xprint:parse src/main/scala/io/dhiraa/StateMonad.scala
//flatMap[B](g: A => StateM[S, B]): StateM[S, B]
val stateWithNewDistance1: StateM[GolfState, Int] = swing(distance=10).flatMap { distance : Int =>
println("g on flatmap annonymous function with distance " + distance)
swing(distance = 20).flatMap {distance =>
println("g on flatmap with distance " + distance)
swing(distance = 30).map{totalDistance =>
println("g on flatmap with totalDistance " + totalDistance)
totalDistance}
}
}

val result1 = stateWithNewDistance1.run(beginningStateM)

println(s"GolfState: ${result1._1}") //GolfState(60)
println
(s"Total Distance: ${result1._2}") //60

Output:

>>> flatmap: currentState GolfState(0) calling run/swing on current state...
>>> run/swing triggered: distance: 10 state: GolfState(0)
<<< run/swing: (GolfState(10),10)
--- flatmap: nextState GolfState(10) new distance 10. calling g...
g on flatmap annonymous function with distance 10
--- flatmap: actually triggering run/swing on next state after initializing g()...


>>> flatmap: currentState GolfState(10) calling run/swing on current state...
>>> run/swing triggered: distance: 20 state: GolfState(10)
<<< run/swing: (GolfState(30),30)
--- flatmap: nextState GolfState(30) new distance 30. calling g...
g on flatmap with distance 30
>>> map
--- flatmap: actually triggering run/swing on next state after initializing g()...


>>> flatmap: currentState GolfState(30) calling run/swing on current state...
>>> run/swing triggered: distance: 30 state: GolfState(30)
<<< run/swing: (GolfState(60),60)
--- flatmap: nextState GolfState(60) new distance 60. calling g...
g on flatmap with totalDistance 60
--- flatmap: actually triggering run/swing on next state after initializing g()...

<<< flatmap: aggregated state info (GolfState(60),60)
<<< flatmap: aggregated state info (GolfState(60),60)
<<< flatmap: aggregated state info (GolfState(60),60)
GolfState: GolfState(60)
Total Distance: 60

Like other examples, the for expression makes recursive flatMap calls followed by map.

While leaving the last output to be analysed by you :) Happy debugging!

You may check this Game based on Scala FP @ https://github.com/jdegoes/lambdaconf-2014-introgame

Mageswaran, Principal Engineer @

--

--