done is better than perfect

自分が学んだことや、作成したプログラムの記事を書きます。すべての記載は他に定める場合を除き個人的なものです。

Sparkのaggregateが謎すぎたのでメモ

謎っていうのは自分にとってというだけなのであしからず。

初めてのSpark読んでて、aggregate関数なるものが説明読んでもよくわからなかったので、備忘録として今の自分の理解を書いておきます。

あまり良く調べていないので、間違っている可能性大です。後で調べてて間違っていたら追記します。

nums = sc.parallelize([1,2,3,4,5,6,7,8,9,10])
sumCount = nums.aggregate((0, 0),
                            (lambda acc, value: (acc[0] + value, acc[1] + 1)),
                            (lambda acc1, acc2: (acc1[0] + acc2[1], acc1[1] + acc2[1])))
avg = sumCount[0] / float(sumCount[1])

これで、avgには1 - 10までの平均が入っているという寸法です。

今の自分の理解としては、

  • 最初のlambda([0+1, 0+1], [0+2, 0+1], ... , [0+10, 0+1])みたいな変換が行われる
  • 続くlambdaで先に変換したRDDに対してreduceする

といったイメージです。間違っていたら誰かおしえて・・・