Scala Stream(Scala関数型デザインより)

Scala Lazy!(Scala関数型デザインより) - shutdown -r nowからの続きで、Streamを扱う。

package example.laziness.stream

import Stream._

trait Stream[+A] {
  def toList: List[A] = {
    @annotation.tailrec
    def go(s: Stream[A], acc: List[A]): List[A] = s match {
      case Cons(h,t) => go(t(), h() :: acc)
      case _ => acc
    }
    go(this, List()).reverse
  }

  def take(n: Int): Stream[A] = this match {
    case Cons(h, t) if n > 1 => cons(h(), t().take(n - 1))
    case Cons(h, _) if n == 1 => cons(h(), empty)
    case _ => empty
  }

  def takeWhile(f: A => Boolean): Stream[A] =
    foldRight(empty[A])((h, t) =>
      if (f(h)) cons(h,t)
      else empty)

  @annotation.tailrec
  final def drop(n: Int): Stream[A] = this match {
    case Cons(_, t) if n > 0 => t().drop(n - 1)
    case _ => this
  }

  // (A, => B) は、(A,B)と同じようではあるが、遅延実行させるため
  // 第2引数を名前渡しで受け取る
  // def foldRight[B](z: => B)(f: (A, B) => B): B = this match {
  def foldRight[B](z: => B)(f: (A, => B) => B): B = this match {
    case Cons(h,t) => f(h(), t().foldRight(z)(f))
    case _ => z
  }

  def forAll(f: A => Boolean): Boolean =
    foldRight(true)((a, b) => f(a) && b)

  def map[B](f: A => B): Stream[B] =
    foldRight(empty[B])((h,t) => cons(f(h), t))

  def filter(f: A => Boolean): Stream[A] =
    foldRight(empty[A])((h,t) =>
      if (f(h)) cons(h, t)
      else t)

  def append[B>:A](s: => Stream[B]): Stream[B] =
    foldRight(s)((h,t) => cons(h,t))

  def flatMap[B](f: A => Stream[B]): Stream[B] =
    foldRight(empty[B])((h,t) => f(h) append t)

  def mapViaUnfold[B](f: A => B): Stream[B] =
    unfold(this) {
      case Cons(h,t) => Some((f(h()), t()))
      case _ => None
    }

  def takeViaUnfold(n: Int): Stream[A] =
    unfold((this,n)) {
      case (Cons(h,_), 1) => Some((h(), (empty, 0)))
      case (Cons(h,t), n) if n > 1 => Some((h(), (t(), n-1)))
      case _ => None
    }

  def takeWhileViaUnfold(f: A => Boolean): Stream[A] =
    unfold(this) {
      case Cons(h,t) if f(h()) => Some((h(), t()))
      case _ => None
    }

  def zipWith[B,C](s2: Stream[B])(f: (A,B) => C): Stream[C] =
    unfold((this, s2)) {
      case (Cons(h1,t1), Cons(h2,t2)) =>
        Some((f(h1(), h2()), (t1(), t2())))
      case _ => None
    }

  def zip[B](s2: Stream[B]): Stream[(A,B)] =
    zipWith(s2)((_,_))

  def zipAll[B](s2: Stream[B]): Stream[(Option[A],Option[B])] =
    zipWithAll(s2)((_,_))

  def zipWithAll[B, C](s2: Stream[B])(f: (Option[A], Option[B]) => C): Stream[C] =
    unfold((this, s2)) {
      case (Empty, Empty) => None
      case (Cons(h, t), Empty) => Some(f(Some(h()), Option.empty[B]) -> (t(), empty[B]))
      case (Empty, Cons(h, t)) => Some(f(Option.empty[A], Some(h())) -> (empty[A] -> t()))
      case (Cons(h1, t1), Cons(h2, t2)) => Some(f(Some(h1()), Some(h2())) -> (t1() -> t2()))
    }

  def startsWith[A](s: Stream[A]): Boolean =
    zipAll(s).takeWhile(!_._2.isEmpty) forAll {
      case (h,h2) => h == h2
    }

  def tails: Stream[Stream[A]] =
    unfold(this) {
      case Empty => None
      case s => Some((s, s drop 1))
    } append Stream(empty)

  def scanRight[B](z: B)(f: (A, => B) => B): Stream[B] =
    foldRight((z, Stream(z)))((a, p0) => {
      lazy val p1 = p0
      val b2 = f(a, p1._1)
      (b2, cons(b2, p1._2))
    })._2
}
case object Empty extends Stream[Nothing]
case class Cons[+A](h: () => A, t: () => Stream[A]) extends Stream[A]

object Stream {

  def cons[A](hd: => A, tl: => Stream[A]): Stream[A] = {
    lazy val head = hd
    lazy val tail = tl
    Cons(() => head, () => tail)
  }

  def empty[A]: Stream[A] = Empty

  def apply[A](as: A*): Stream[A] =
    if (as.isEmpty) empty
    else cons(as.head, apply(as.tail: _*))

  val ones: Stream[Int] = cons(1, ones)

  def constant[A](a: A): Stream[A] = {
    lazy val tail: Stream[A] = Cons(() => a, () => tail)
    tail
  }

  def unfold[A, S](z: S)(f: S => Option[(A, S)]): Stream[A] =
    f(z) match {
      case Some((h,s)) => cons(h, unfold(s)(f))
      case None => empty
    }

  def map2[A,B,C](a: Option[A], b: Option[B])(f: (A, B) => C): Option[C] =
    a flatMap(aa => b map (bb => f(aa, bb)))

  def main(args: Array[String]): Unit = {
    val sl = Stream(1,2,3,4)
    val sl2 = Stream("a","b","c","d")

    val ass2 = constant("a").takeViaUnfold(5)
    println(s"ass2:${ass2}")
    println(s"ass2.toList:${ass2.toList}")

    val slstr = sl.mapViaUnfold((a) => a.toString() + "!")
    println(s"slstr:${slstr.toList}")

    val zips = sl.zip(sl2)
    println(s"zips:${zips.toList}")

    val zipws = sl.zipWith(sl2)((a,b) => a + "!!" + b)
    println(s"zipws:${zipws.toList}")

    val zipalls = sl.zipAll(sl2)
    println(s"zipalls:${zipalls.toList}")

    val zipwalls = sl.zipWithAll(sl2)((a,b) => map2(a,b)((aa,bb) => aa + "!!" + bb))
    println(s"zipwalls:${zipwalls.toList}")

    val b = sl.startsWith(sl2)
    println(s"b:${b}")

    println("tails")
    sl.tails.toList.foreach(x => println(x.toList))

    println("scanRight")
    println(sl.scanRight(0)(_ + _).toList)
  }

}

constant

  def constant[A](a: A): Stream[A] = {
    lazy val tail: Stream[A] = Cons(() => a, () => tail)
    tail
  }

無限ストリーム。Consで、tail側を余再帰呼び出しする。

    val ass2 = constant("a").takeViaUnfold(5)
    println(s"ass2:${ass2}")
    println(s"ass2.toList:${ass2.toList}")

> ass2.toList:List(a, a, a, a, a)

constantで、aの無限ストリームを作り、そこから、takeViaUnfold(5)で、5個取り出す。

map

trait Stream[+A] {
  def map[B](f: A => B): Stream[B] =
    foldRight(empty[B])((h,t) => cons(f(h), t))

  def mapViaUnfold[B](f: A => B): Stream[B] =
    unfold(this) {
      case Cons(h,t) => Some((f(h()), t()))
      case _ => None
    }
}

object Stream {
  def unfold[A, S](z: S)(f: S => Option[(A, S)]): Stream[A] =
    f(z) match {
      case Some((h,s)) => cons(h, unfold(s)(f))
      case None => empty
    }
}

ストリームに対するmap。mapは、トレイトに実装し、その中で使われるunfoldは、Streamオブジェクトに実装。
unfoldは、引数で渡した関数がSomeを返す限り、無限ストリームになる。
mapは、自分のストリームに対して、関数fを適用するための関数。
自分をもと(this)に値があれば、hに関数を適用して、unfoldに渡すために、Someでつつむ。

    val slstr = sl.mapViaUnfold((a) => a.toString() + "!")
    println(s"slstr:${slstr.toList}")

ストリームの各値のおしりに、"!"をつける。

zip

  def zipWith[B,C](s2: Stream[B])(f: (A,B) => C): Stream[C] =
    unfold((this, s2)) {
      case (Cons(h1,t1), Cons(h2,t2)) =>
        Some((f(h1(), h2()), (t1(), t2())))
      case _ => None
    }

  def zip[B](s2: Stream[B]): Stream[(A,B)] =
    zipWith(s2)((_,_))

実行結果

> zips:List((1,a), (2,b), (3,c), (4,d))
> zipws:List(1!!a, 2!!b, 3!!c, 4!!d)

以下は、リストの場合のzipWith

  def zipWith[A,B,C](a: List[A], b: List[B])(f: (A,B) => C): List[C] = (a,b) match {
    case (Nil, _) => Nil
    case (_, Nil) => Nil
    case (Cons(h1,t1), Cons(h2,t2)) => Cons(f(h1,h2), zipWith(t1,t2)(f))
  }

ストリームの場合は、タプルで、(自分と結合するストリーム)を見て、
Consでつなげ、unfoldにわたすため、Someでつつむ

  def zipWith[B,C](s2: Stream[B])(f: (A,B) => C): Stream[C] =
    unfold((this, s2)) {
      case (Cons(h1,t1), Cons(h2,t2)) =>
        Some((f(h1(), h2()), (t1(), t2())))
      case _ => None
    }

zipAll

  def zipAll[B](s2: Stream[B]): Stream[(Option[A],Option[B])] =
    zipWithAll(s2)((_,_))

  def zipWithAll[B, C](s2: Stream[B])(f: (Option[A], Option[B]) => C): Stream[C] =
    unfold((this, s2)) {
      case (Empty, Empty) => None
      case (Cons(h, t), Empty) => Some(f(Some(h()), Option.empty[B]) -> (t(), empty[B]))
      case (Empty, Cons(h, t)) => Some(f(Option.empty[A], Some(h())) -> (empty[A] -> t()))
      case (Cons(h1, t1), Cons(h2, t2)) => Some(f(Some(h1()), Some(h2())) -> (t1() -> t2()))
    }

実行結果

> zipalls:List((Some(1),Some(a)), (Some(2),Some(b)), (Some(3),Some(c)), (Some(4),Some(d)))
> zipwalls:List(Some(1!!a), Some(2!!b), Some(3!!c), Some(4!!d))

zipWithAllは、パターンマッチングで、4パターンに分かれる。タプルのうち、両方がEmptyの場合のみ、Noneをunfoldにわたす。

Some(f(Some(h()), Option.empty[B]) -> (t(), empty[B]))

で、 「->」があるが、これは、タプルを表す。

scala> 1 -> 2
res0: (Int, Int) = (1,2)

scala> List().->(2)
res1: (List[Nothing], Int) = (List(),2)

scala> (1 -> 2) == ((1, 2))
res2: Boolean = true

カッコが多いから、「->」を使っているようだが、逆にわかりづらい気もする。。
unfoldにわたすために、タプルのタプルをSomeでくるんでいる。

startsWith

  def startsWith[A](s: Stream[A]): Boolean =
    zipAll(s).takeWhile(!_._2.isEmpty) forAll {
      case (h,h2) => h == h2
    }

先頭比較のために、zipAllして、タプルのセットを作る。
takeWhileでタプルの2番目が空でない所まで取り出す。
forAllで、それぞれ取り出した値に対して、比較を行う。

map(_._ 2)はmap(x =>(x._2))の省略形
x._2はタプルの2番目の要素

tails

  def tails: Stream[Stream[A]] =
    unfold(this) {
      case Empty => None
      case s => Some((s, s drop 1))
    } append Stream(empty)

tailsは、実行例を見た方がイメージがつきやすい。

sl.tails.toList.foreach(x => println(x.toList))

List(1, 2, 3, 4)
List(2, 3, 4)
List(3, 4)
List(4)
List()

畳込みしないfoldRightのイメージで、途中式をStream[Stream[A]]にして返す。

scanRight

  def scanRight[B](z: B)(f: (A, => B) => B): Stream[B] =
    foldRight((z, Stream(z)))((a, p0) => {
      lazy val p1 = p0
      val b2 = f(a, p1._1)
      (b2, cons(b2, p1._2))
    })._2

実行例

    println(sl.scanRight(0)(_ + _).toList)

List(10, 9, 7, 4, 0)

各tailsの結果を畳み込んだイメージ