Back to Fpinscala

02.Answer

answerkey/localeffects/02.answer.md

latest733 B
Original Source
scala
def partition[S](
  a: STArray[S, Int], l: Int, r: Int, pivot: Int
): ST[S, Int] =
  for
    vp <- a.read(pivot)
    _ <- a.swap(pivot, r)
    j <- STRef(l)
    _ <- (l until r).foldLeft(ST[S, Unit](()))((s, i) =>
      for
        _ <- s
        vi <- a.read(i)
        _  <- if vi < vp then
          for
            vj <- j.read
            _  <- a.swap(i, vj)
            _  <- j.write(vj + 1)
          yield ()
        else ST[S, Unit](())
      yield ())
    x <- j.read
    _ <- a.swap(x, r)
  yield x

def qs[S](a: STArray[S,Int], l: Int, r: Int): ST[S, Unit] =
  if l < r then for
    pi <- partition(a, l, r, l + (r - l) / 2)
    _ <- qs(a, l, pi - 1)
    _ <- qs(a, pi + 1, r)
  yield () else ST[S, Unit](())