Monthly Hacker's Blog

毎月のテーマに沿ったプログラミング記事を中心に書きます。

chainerのtrainer機能を使ってWGAN(Wasserstein GAN)を実装した

できること

この記事では、次のことができるようになります。

  • updaterを書き換える
  • make_extensionでextensionを追加する
  • chainerのtrainerを使ってMNISTでWGAN

経緯

修士論文の発表と同時期にWGANが発表されて、出遅れてしまいましたが、chainerで実装しました。論文の著者であるMartin Arjovskyさんが、pytorchというchainerとよく似た最新のフレームワークでの実装を公開しています。そのためchainerを使う人ならかなり参考になると思います。

github.com

また、すでにchainerによる実装もあります。

github.com

ではなぜわざわざ自分で実装したかというと、trainerを使いたかったからです。PlotReportを始め、アップデートする度にtrainerにはどんどん便利な機能が追加されています。trainerを使っていればこうした恩恵を受けることができます。

一方で、trainerはちょっと癖がある(感じ方には個人差があると思います)ので、手を出しにくいと思っている人が多いと思います。そんな人の参考になればと思ってtrainerでの実装にこだわっています。とはいえ、やっていることはEBGANの実装とほぼ同じですので、分かりにくいところがあればこちらの記事も参考にしてください。

コードはGitHubにアップロードしてあります。

github.com

コードを書き換えるときの注意点

自前のデータで実行するときは、ImageDatasetを使うと楽です。例えばこんな感じです。

# train, _ = chainer.datasets.get_mnist(withlabel=False, ndim=3)
train = chainer.datasets.ImageDataset(paths)

pathsには学習データとする画像のパスのリストをいれます。numpy配列などでも大丈夫です。

また、ネットワークを変える際は、generatorの入力(乱数)の次元に合わせて、MyUpdaterとout_generated_imageのzを生成する部分を書き換えます。

WGANについて

Martin Arjovskyさんの実装を見ると、(見落としていなければ)論文には書いていなかった学習を安定させる工夫がありました。critic(通常のGANでいうdiscriminator)とgeneratorの更新回数は5:1の比率としか書いていません。しかし、generatorを25回更新するまでは100:1にしています。これは、学習初期に真の分布とのwasserstein距離の近似の精度を高めておくためだと考えられます。近似が不十分な状態でgeneratorを学習してはWGANではありません。もちろん、更新回数が5:1でも十分近似できる可能性はありますので、絶対に必要な工夫というわけではありません。musyokuさんの実装ではこの工夫を取り入れていませんが、うまく生成できています。ちなみに、Martin Arjovskyさんの実装ではgeneratorを500回更新するごとに1回、criticを5回ではなく100回更新する工夫もしています。手を抜いて、その実装していません。

最後にWGANがうまくいく理由について考えてみます。研究室の先生がすごく分かりやすく説明してくれたので、それを自分なりに解釈して書いていきます。

通常のGANではgeneratorの学習が安定しないことがあります。というか、安定しないことばかりです。generatorはdiscriminatorから伝播されてきた勾配を元に学習をするため、適切な勾配が与えられればうまく学習するはずです。ではなぜうまく学習せず安定しないかというと、その勾配が適切ではないからです。

多くの場合discriminatorから伝播されてきた勾配は0に近く、学習が進みません。なぜなら真のデータを生成する確率分布とgeneratorの確率分布がdisjointだと、discriminatorが真のデータと生成されたデータを完全に分類できてしまうからです。そのため誤差は0に近づき、当然勾配も0に近づきます。

この問題を解消するために、discriminatorは真のデータを生成する確率分布とgeneratorの確率分布の違い(距離)を学習するようにしました。確率分布同士の距離を定義する方法はいくつかありますが、wasserstein距離だと学習がうまくいったということです。論文だと、[0, 1]の一様分布と[2, 3]の一様分布の距離およびその勾配が消失しないことを例に、wasserstein距離が優れていると主張しています。