Go Generics で実装する Y コンビネータ、再帰関数を汎用的にメモ化する

Published on

先日面白い記事を読んだ。The Y combinator in Go with generics である。

一番最初目に入ってきたのは難解なジェネリクスのコードである。

type Func[T, U any] func(T) U
type TagFunc[T, U any] func(Func[T, U]) Func[T, U]
type CombinatorFunc[T, U any] func(CombinatorFunc[T, U]) Func[T, U]

func Y[T, U any](f TagFunc[T, U]) Func[T, U] {
  return func(self CombinatorFunc[T, U]) Func[T, U] {
    return f(func(n T) U {
      return self(self)(n)
    })
  }(func(self CombinatorFunc[T, U]) Func[T, U] {
    return f(func(n T) U {
      return self(self)(n)
    })
  })
}

このコードを読み解くために、Y コンビネータ1というものを理解する必要がある。そのためには前提となる無名関数を利用した処理を行うコードを理解しておくとよりわかりやすくなる。なぜならこの Y コンビネータは再帰を用いた無名関数呼び出しを行うことでループ処理を実現する機構として利用できるからである。

引数を一つにする

Go は無名関数を作成できる。足し算を行うようなコードを記述するとこんな感じ。

sum := func(x, y int) int { return x + y }
sum(1, 2) // 3

引数を 2 つ受け付けるようにしているが、1 つの引数で同様の挙動を実現させようとするとどうすべきだろうか。「カリー化」と呼ばれる手法を用いることでこれを実現する。 この手法は端的に言えば「複数の引数を1つに減らす」ための手法である。

まずは定義した足し算を行う無名関数の y を 2 で固定した add3 関数を考える。

add2 := func(x int) int { return sum(x, 2) }
add2(1) // 3

しかし、2 の箇所も引数として指定することで自由に計算できるようにしたい。そこで後で引数を指定してもらえるように、y を引数に取るような無名関数を返す形で定義する。

add := func(x int) func(y int) int {
  return func(y int) int { // y を引数に取る関数を返す
    return sum(x, y)
  }
}
add1 := add(1)
add1(2) // 3

こうして複数の引数を持つ関数を 1 引数の関数に変換できる。

無名関数での再帰

階乗の計算

階乗の計算を行う再帰関数を考える。0 の階乗を計算すると 1 となる。それ以外の計算方法は n! = n * (n - 1)! となる法則がある。まとめるとこんな感じ。

0! = 1
n! = n * (n - 1)!

これを素直にコードへ変換することができる。

func factorial(n int) int {
    if n == 0 {
        return 1
    }
    return n * factorial(n - 1)
}

factorial(10) // 3628800

さて、これは無名関数で表すことができるのだろうか。一番最初に思いつく素直な方法は次のようなコードだろう。しかしこれだと undefined: factorial のメッセージと共にコンパイルエラーになる。

factorial := func(n int) int {
  if n == 0 {
    return 1
  }
  return n * factorial(n-1) // undefined: factorial
}

ではどうすれば良いのか。少し複雑になるがその方法をここに記述する。

type FactorialFunc func(int) int
type FactorialMakerFunc func(FactorialMakerFunc) FactorialFunc

factorialMaker := func(self FactorialMakerFunc) FactorialFunc {
  return func(n int) int {
    if n == 0 {
      return 1
    }
    return n * self(self)(n-1)
  }
}
factorial := factorialMaker(factorialMaker)
factorial(10) // 3628800
  1. factorialMaker は階乗の計算を行わない。代わりに無名関数を返す。この無名関数は factorialMaker が返したのと同じ関数を受け取ることを期待している。
  2. factorial := factorialMaker(factorialMaker) では factorialMaker が返す無名関数に factorialMaker 自身を渡すことで階乗の計算を行う無名関数を完成させる。
  3. factorial(10) で実際に階乗の計算を行う。

同様にフィボナッチ数を求める計算についても考えてみる。

フィボナッチ数の計算

フィボナッチ数列は以下のような数列である。

0, 1, 1, 2, 3, 5, 8, 13, 21, ...

この数列は次のような法則で成り立っていることが分かる。

n = 0 => 0
n = 1 => 1
n = fib(n-1) + fib(n-2)

これを基に n 番目のフィボナッチ数を求める名前付き関数を定義する。

func fib(n int) int {
  if n <= 1 {
    return n
  }
  return fib(n-1) + fib(n-2)
}

fib(10) // 55

これも階乗の計算と同じように無名関数を利用して表現してみる。

type FibFunc func(int) int
type FibMakerFunc func(FibMakerFunc) FibFunc

fibMaker := func(self FibMakerFunc) FibFunc {
  return func(n int) int {
    if n <= 1 {
      return n
    }
    fib := self(self)
    return fib(n-1) + fib(n-2)
  }
}
fib := fibMaker(fibMaker)
fib(10) // 55

このコードも階乗の計算を行う無名関数と似たアプローチで作成できた。これまでの実装を見てみると、この無名関数の再帰呼び出しを行う仕組みは汎用化することができそうである。この汎用化の部分がY コンビネータになる。

Y コンビネータ

もう一度冒頭の難解なジェネリクスのコードを見てみよう。これは次のように少しシンプルに修正できる。これで少しは読みやすくなっただろうか。 無名関数を束縛した g は自身を受け付けることで再帰処理部分を汎用的に切り出すことができている。

type Func[T, U any] func(T) U
type TagFunc[T, U any] func(Func[T, U]) Func[T, U]
type CombinatorFunc[T, U any] func(CombinatorFunc[T, U]) Func[T, U]

func Y[T, U any](f TagFunc[T, U]) Func[T, U] {
  g := func(self CombinatorFunc[T, U]) Func[T, U] {
    return f(func(t T) U {
      return self(self)(t)
    })
  }
  return g(g)
}

このコードを理解するために筆者が理解したことを記述する。

  1. 引数 1 つで何か一つ返す関数型 Func を定義する。
  2. TagFunc が実際に処理を行うための無名関数の型になる。
  3. CombinatorFunc は関数 Y の中で利用される。この型を持つ selfself(self) とすることで引数である TagFunc と同じ関数が生成される。

この関数 Y は以下のような使い方ができる。

factorial := Y(func(self Func[int, int]) Func[int, int] {
  return func(n int) int {
    if n == 0 {
      return 1
    }
    return n * self(n-1)
  }
})

factorial(10) // 3628800

以上のように Y コンビネータを利用することで無名関数で再帰処理を記述することができた。そして冒頭で挙げたコードを少し理解できたのではないだろうか。

おとうさんにもわかるYコンビネータ!(絵解き解説編)」の記事で詳しく記述されていたので非常に参考になった。

関数のプラグイン機構を用意する

名前付き関数でも再帰処理を記述できるため、この Y コンビネータの利用シーンは少ないと思われるかもしれない。 しかし、汎用化できたということは当然再帰処理の結果をメモすることも汎用的に実装できるはずである。そこで Y コンビネータへ Pluggable に適用できるような関数 Adapt を定義する。

これらのコードは「さあ、Yコンビネータ(不動点演算子)を使おう!」の内容を基に作成してみた。

func Adapt[T, U any](f TagFunc[T, U], adapters ...TagFunc[T, U]) TagFunc[T, U] {
  return func(self Func[T, U]) Func[T, U] {
    for i := len(adapters) - 1; i >= 0; i-- {
      self = adapters[i](self)
    }
    return f(self)
  }
}

メモ化プラグイン

次にメモ化を行うプラグイン関数を定義する。

func Memo[T comparable, U any]() TagFunc[T, U] {
  memo := map[T]U{}
  return func(f Func[T, U]) Func[T, U] {
    return func(t T) U {
      result, ok := memo[t]
      if ok {
        return result
      }
      tmp := f(t)
      memo[t] = tmp
      return tmp
    }
  }
}

map を用いたシンプルなコードである。

  1. map memo を作成する。キーは comparable である必要がある。
  2. 入力値 t のメモが存在すればその値を返し、存在しなければ実行関数 f を実行し、結果をメモしてそれを返す。

定義した 2 つの関数を適用してみる。

factorialTag := func(self Func[int, int]) Func[int, int] {
  return func(n int) int {
    if n == 0 {
      return 1
    }
    return n * self(n-1)
  }
}

factorial := Y(Adapt(factorialTag, Memo[int, int]()))
factorial(10) // 3628800

実際にメモ化ができているのかベンチマークを取ってみる。ベンチマークに利用したコードを記述する。

func BenchmarkFac(b *testing.B) {
  factorial := Y(factorialTag)
  for i := 0; i < b.N; i++ {
    _ = factorial(i)
  }
}

func BenchmarkFacMemo(b *testing.B) {
  factorial := Y(Adapt(factorialTag, Memo[int, int]()))
  for i := 0; i < b.N; i++ {
    _ = factorial(i)
  }
}

ベンチマークの結果としてメモ化を行ったコードが断然速くなっていることが分かる。よって無事に適用できていることが分かる。

$ go test -benchmem -bench ./... github.com/Code-Hex/yc

goos: darwin
goarch: arm64
pkg: github.com/Code-Hex/yc
BenchmarkFac-8       	   10000	    247627 ns/op	  199980 B/op	    9999 allocs/op
BenchmarkFacMemo-8   	 5783738	       247.6 ns/op	     134 B/op	       3 allocs/op
PASS
coverage: 73.9% of statements
ok  	github.com/Code-Hex/yc	4.356s

まとめ

  • 無名関数の引数を一つにしても、返す関数に引数を持たせられれば、実質的に複数の引数を用意できる手段を紹介した。
  • この無名関数の形を利用した Y コンビネータを実装することで無名関数でも再帰処理を行うことができた。
  • プラグイン機構を用意することでメモ化処理も汎用的に行えるようにした。

今回利用したコードは https://github.com/Code-Hex/yc に置いた。

意外とまだ Go のジェネリクスのネタが少ないように思える。この記事がネタの一つになると嬉しい。

Footnotes

  1. Y コンビネータは不動点演算子とも呼ばれる。Google で素直に「Y コンビネータ」で検索すると Hacker News の会社情報ばかり出てくるため、「Y コンビネータ 関数」だったり「不動点演算子」で検索すると良い。