done is better than perfect

自分が学んだことや、作成したプログラムの記事を書きます。

Go言語でNgram (with channel)

Go言語で簡単なN-Gramのカウント部分(tokenに分割し、カウントするだけ)を実装してみたいと思います。 ただ実装するだけではつまらないので、channelを使って少しでも早くしてみます。

以下で出てくるGo言語のソースコードで、最初の行に書かれているのはファイル名です。また、データはソースコードと同じディレクトリに配置されていることを想定しています。

また、今回の記事はこちらのスライドにインスパイアされています。非常にわかりやすかったです。

  • 使用するデータ: Data Compression Program
    • word2vecが使っているデータと同様です。別にスペースで区切られている文章であればなんでも良かったです。

以下のコマンドでダウンロードと解凍ができます。

$ wget http://mattmahoney.net/dc/text8.zip -O text8.gz && gzip -d text8.gz -f

1. ワードカウント

まずは単純に単語をカウントしてみたいと思います。

  • 以下のページを参考にしました。
// test.go
package main

import (
    "bufio"
    "bytes"
    "fmt"
    "io/ioutil"
    "log"
    "os"
)

func main() {
    data, err := ioutil.ReadFile("./text8")
    if err != nil {
        log.Fatal(err)
    }
    scanner := bufio.NewScanner(bytes.NewReader(data))
    scanner.Split(bufio.ScanWords)
    count := 0
    for scanner.Scan() {
        count++
    }
    if err := scanner.Err(); err != nil {
        fmt.Fprintln(os.Stderr, "reading input:", err)
    }
    fmt.Printf("%d\n", count)
}

実行すると、以下のようになります。

$ go run test.go
17005207

念のため、wcコマンドでもカウントしてみます。

$ wc -w text8
17005207 text8

結果が同じであることが確認できました。

2. N-Gram (No concurrent)

次に、何も工夫せずにN-Gramを実装してみたいと思います。N-Gramについては以下のURLなどを参照してください。

以下、コードです。

// ngram.go
package main


import (
    "bufio"
    "bytes"
    "fmt"
    "io/ioutil"
    "log"
    "strings"
)

func Ngram(words []string, n int) map[string]int {
    ngrams := make(map[string]int)
    for i := 0; i < len(words)-n+1; i++ {
        ngrams[strings.Join(words[i:i+n], " ")]++
    }
    return ngrams
}

func main() {
    data, err := ioutil.ReadFile("./text8")
    if err != nil {
        log.Fatal(err)
    }
    scanner := bufio.NewScanner(bytes.NewReader(data))
    scanner.Split(bufio.ScanWords)
    words := make([]string, 0)
    for scanner.Scan() {
        words = append(words, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        log.Fatal(err)
    }
    n := 2
    ngrams := Ngram(words, n)
    fmt.Println(len(ngrams))
}

結果です。

$ go run ngram.go
4146848

とりあえずエラー無しに動いてはいるみたいです。ここでtimeコマンドなどで実行時間を計測してもいいのですが、Go言語にはBenchmarkが簡単にできるツールがデフォルトで入ります。せっかくなので、そちらを使ってみます。

以下がベンチマーク用のソースコードです。

// ngram_test.go
package main

import (
    "bufio"
    "bytes"
    "io/ioutil"
    "log"
    "testing"
)

func BenchmarkNgram(b *testing.B) {
    data, err := ioutil.ReadFile("./text8")
    if err != nil {
        log.Fatal(err)
    }
    scanner := bufio.NewScanner(bytes.NewReader(data))
    scanner.Split(bufio.ScanWords)
    words := make([]string, 0)
    for scanner.Scan() {
        words = append(words, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        log.Fatal(err)
    }
    n := 2
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        _ = Ngram(words, n)
    }
}

結果です。

$ go test -bench .
testing: warning: no tests to run
PASS
BenchmarkNgram         1    6508285253 ns/op
ok      _/home/masatana/mywork/testgo   9.842s

約6.5秒といったところですね。

3. N-Gram (Concurrent)

3.1 とりあえずChannel

上で書いたN-Gramのコードをベースに、Channelを使って実行速度を早くしてみます。

// ngram.go
func ConcurrentNgram(words []string, n int) map[string]int {
    ngrams := make(map[string]int)
    wordCh := make(chan string)
    go func() {
        for i := 0; i < len(words)-n+1; i++ {
            wordCh <- strings.Join(words[i:i+n], " ")
        }
        close(wordCh)
    }()
    for word := range wordCh {
        ngrams[word]++
    }
    return ngrams
}

ベンチマークソースコードです。ここは基本的に変わらないです。

  • Go benchmark
// ngram_test.go
func BenchmarkConcurrentNgram(b *testing.B) {
    data, err := ioutil.ReadFile("./text8")
    if err != nil {
        log.Fatal(err)
    }
    scanner := bufio.NewScanner(bytes.NewReader(data))
    scanner.Split(bufio.ScanWords)
    words := make([]string, 0)
    for scanner.Scan() {
        words = append(words, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        log.Fatal(err)
    }
    n := 2
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        _ = ConcurrentNgram(words, n)
    }
}

結果です。

$ go test -bench .
testing: warning: no tests to run
PASS
BenchmarkNgram         1    6556920394 ns/op
BenchmarkConcurrentNgram           1    9069591440 ns/op
ok      _/home/masatana/mywork/testgo   22.565s

あれ、遅くなっている……?

3.2 GOMAXPROCSについて

もはや定番の間違いとなりつつありますが、Go言語でCPUの歓声を聞きたければGOMAXPROCSの設定をする必要があります。

// ngram_test.go
func BenchmarkConcurrentNgram(b *testing.B) {
    runtime.GOMAXPROCS(runtime.NumCPU())
    data, err := ioutil.ReadFile("./text8")
    if err != nil {
        log.Fatal(err)
    }
    scanner := bufio.NewScanner(bytes.NewReader(data))
    scanner.Split(bufio.ScanWords)
    words := make([]string, 0)
    for scanner.Scan() {
        words = append(words, scanner.Text())
    }
    if err := scanner.Err(); err != nil {
        log.Fatal(err)
    }
    n := 2
    b.ResetTimer()
    for i := 0; i < b.N; i++ {
        _ = ConcurrentNgram(words, n)
    }
}

これで早くなったでしょう!結果です。

$ go test -bench .
testing: warning: no tests to run
PASS
BenchmarkNgram         1    6658105981 ns/op
BenchmarkConcurrentNgram           1    35865530767 ns/op
testing: BenchmarkConcurrentNgram left GOMAXPROCS set to 8
ok      _/home/masatana/mywork/testgo   49.801s

WTF?????????????

3.3 Buffered Channels

channelはbufferサイズを指定できます。今回は贅沢にbufferを使用してみます。

func ConcurrentNgram(words []string, n int) map[string]int {
    ngrams := make(map[string]int)
    wordCh := make(chan string, 1000000)
    go func() {
        for i := 0; i < len(words)-n+1; i++ {
            wordCh <- strings.Join(words[i:i+n], " ")
        }
        close(wordCh)
    }()
    for word := range wordCh {
        ngrams[word]++
    }
    return ngrams
}
  • Result
$ go test -bench .
testing: warning: no tests to run
PASS
BenchmarkNgram         1    6642945380 ns/op
BenchmarkConcurrentNgram           1    5502990729 ns/op
testing: BenchmarkConcurrentNgram left GOMAXPROCS set to 8
ok      _/home/masatana/mywork/testgo   19.360s

わずかですが、早くなりました!

終わりに

あまり劇的に高速化というわけにはいきませんでしたが、Go言語では簡単にchannelを使用した並列処理が書けるので負担が少なく高速化ができたと思います。

ホントはsync.Mutexを使って

func ConcurrentNgram(words []string, n int) map[string]int {
    m := new(sync.Mutex)
    ngrams := make(map[string]int)
    wordCh := make(chan string, 1000000)
    go func() {
        for i := 0; i < len(words)-n+1; i++ {
            wordCh <- strings.Join(words[i:i+n], " ")
        }
        close(wordCh)
    }()
    for word := range wordCh {
        m.Lock()    
        ngrams[word]++
        m.Unlock()    
    }
    return ngrams
}

のようにすべきかと思われますが、自分の中でatomicな処理というのが未だに消化できていないので、理解できたらまた記事を書いてみたいと思います。