Monthly Hacker's Blog

プログラミングや機械学習の記事を中心に書きます。

chainerでEBGAN(Energy-Based Generative Adversarial Network)を実装した

2017/03/09
@akira_you さんにご指摘いただぎ、コードの一部を修正

できること

この記事では、次のことができるようになります。

  • updaterを書き換える
  • make_extensionでextensionを追加する
  • chainerのtrainerを使ってMNISTでGANおよびEBGAN(pull-away termあり)

経緯

今月のテーマはニコニコデータセットを使って何かをするということで、GANをやろうと思っていました。というのも、chainerでのGANの実装はいくつもあるが、trainerを使ってモダンな書き方をしているものは見当たらなかったからです。さらにちょうどよいタイミングでEBGANが発表されたので、こちらも合わせて実装して、trainerの書き方を学ぶことにしました。

ちなみに、現在MNISTでコードが動くことを確認した段階で、今月のテーマであるニコニコデータセットではまだ試していません。そのうち追記する予定です。

コードはgithubにアップロードしてあります。
github.com


また、以下のコード解説は以前書いた記事を眺めながら読むと理解が深まると思います。
www.monthly-hack.com

updaterを書き換える

class gan_updater(training.StandardUpdater):

    def __init__(self, iterator, discriminator, generator,
                 optimizer_d, optimizer_g, device,
                 batchsize, xp, converter=convert.concat_examples):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator
        self.discriminator = discriminator
        self.generator = generator
        self._optimizers = {'discriminator': optimizer_d,
                            'generator': optimizer_g}
        self.device = device
        self.converter = converter
        self.iteration = 0
        self._xp = xp
        self.ones = chainer.Variable(self._xp.ones((batchsize, 1),
                                                   dtype=np.int32))
        self.zeros = chainer.Variable(self._xp.zeros((batchsize, 1),
                                                     dtype=np.int32))

    def update_core(self):
        batch = self._iterators['main'].next()
        in_arrays = self.converter(batch, self.device)

        in_var = chainer.Variable(in_arrays)
        generated = self.generator(False)

        label_data = self.discriminator(in_var)
        loss_dis = F.sigmoid_cross_entropy(label_data, self.zeros)

        label_generated = self.discriminator(generated)
        loss_dis += F.sigmoid_cross_entropy(label_generated, self.ones)
        loss_gen = F.sigmoid_cross_entropy(label_generated, self.zeros)

        reporter.report({'dis/loss': loss_dis})
        reporter.report({'gen/loss': loss_gen})

        self._optimizers['discriminator'].target.cleargrads()
        loss_dis.backward()
        self._optimizers['discriminator'].update()

        self._optimizers['generator'].target.cleargrads()
        loss_gen.backward()
        self._optimizers['generator'].update()

GANのように複数のlossやoptimizerを扱う場合は、自分でupdaterを書き換えることが多いです。そして、今回のようにiterationの管理などでは特別なことをしないけど、updateの規則だけ変えたいときは、StandardUpdaterを使うと便利です。

続いて、EBGANも見てみます。誤差関数が変わり、pull-away termが増えただけです。

class gan_updater(training.StandardUpdater):

    def __init__(self, iterator, discriminator, generator,
                 optimizer_d, optimizer_g, margin,
                 device, batchsize,
                 converter=convert.concat_examples, pt=True):
        if isinstance(iterator, iterator_module.Iterator):
            iterator = {'main': iterator}
        self._iterators = iterator
        self.discriminator = discriminator
        self.generator = generator
        self._optimizers = {'discriminator': optimizer_d,
                            'generator': optimizer_g}
        self.margin = margin
        self.pt = pt
        self.device = device
        self.converter = converter
        self.iteration = 0

    def update_core(self):
        batch = self._iterators['main'].next()
        in_arrays = self.converter(batch, self.device)

        in_var = chainer.Variable(in_arrays)
        generated = self.generator()

        y_data = self.discriminator(in_var)
        loss_dis = F.mean_squared_error(y_data, in_var)

        y_generated = self.discriminator(generated)
        loss_gen = F.mean_squared_error(y_generated, generated)
        loss_dis += F.relu(self.margin - loss_gen)

        if self.pt:
            s = self.discriminator.encode(generated)
            normalized_s = F.normalize(s)
            cosine_similarity = F.matmul(normalized_s, normalized_s,
                                         transb=True)
            ptterm = F.sum(cosine_similarity)
            ptterm /= s.shape[0] * s.shape[0]
            loss_gen += 0.1 * ptterm

        reporter.report({'dis/loss': loss_dis})
        reporter.report({'gen/loss': loss_gen})

        self._optimizers['discriminator'].target.cleargrads()
        loss_dis.backward()
        self._optimizers['discriminator'].update()

        self._optimizers['generator'].target.cleargrads()
        loss_gen.backward()
        self._optimizers['generator'].update()

ここで少し、論文と異なる実装をしているので、解説しておきます。pull-away termについてです。

chainerでのコサイン類似度の計算

chainer のFunctionsにはコサイン類似度がありません。そこで他の関数の組み合わせで実装する必要があります。

まず、論文中の定義式を確認します。(bsは論文と同様、バッチサイズの略です)

{
\displaystyle
\begin{equation}

f_{PT} (S) = \frac{1}{bs(bs - 1)} \sum_{i} \sum_{j \neq i} \left( \frac{S_{i}^{T}S_{j}}{\| S_{i} \| \| S_{j} \|} \right)^{2}

\end{equation}
}

この式だけ見ても少し分かりにくいので、もう少し噛み砕いてみましょう。

{
\displaystyle
\begin{eqnarray*}

f_{PT} \left( Encode \left( z \right) \right) &=& \frac{1}{bs(bs - 1)} \sum_{i=1}^{bs} \sum_{\substack{j=1 \\ j \neq i}}^{bs} \left( \frac{Encode \left( G \left( z_{i} \right) \right)^{T}Encode \left( G \left( z_{j} \right) \right)}{\| Encode \left( G \left( z_{i} \right) \right) \| \| Encode \left( G \left( z_{j} \right) \right) \|} \right)^{2} \\
&=& \frac{1}{bs(bs - 1)} \sum_{i=1}^{bs} \sum_{\substack{j=1 \\ j \neq i}}^{bs}
\left(
    \frac{
        Encode \left(
            G \left(
                z_{i}
            \right)
        \right)
    }
    {\|
        Encode \left(
            G \left(
                z_{i}
            \right)
        \right)
    \|}
\cdot
    \frac{
        Encode \left(
            G \left(
                z_{j}
            \right)
        \right)
    }
    {\|
        Encode \left(
            G \left(
                z_{j}
            \right)
        \right)
    \|}
\right)^{2}



\end{eqnarray*}
}

なるほど、i=jのとき1(定数)になってしまうから、右のシグマではj \neq iとしているわけです。ところが、場合分けはGPUの苦手分野です。なるべく避けたいところです。

ここで改めて考えることは、lossの特徴です。lossの値は学習の進み具合を確認するために使うもので、学習を進めるためには使いません。学習を進めるためには、勾配のみを使います。今i=jのとき1(定数)になってしまうことが問題ですが、微分したら0になるため、学習を進めることに全く影響しません。そこで、プログラム中では次のようにpull-away termを計算しています。

{
\displaystyle
\begin{eqnarray*}

f_{PT} &=& \frac{1}{bs^{2}} \sum_{i=1}^{bs} \sum_{j=1}^{bs}
\left(
    \frac{
        Encode \left(
            G \left(
                z_{i}
            \right)
        \right)
    }
    {\|
        Encode \left(
            G \left(
                z_{i}
            \right)
        \right)
    \|}
\cdot
    \frac{
        Encode \left(
            G \left(
                z_{j}
            \right)
        \right)
    }
    {\|
        Encode \left(
            G \left(
                z_{j}
            \right)
        \right)
    \|}
\right)^{2}



\end{eqnarray*}
}

定数倍は勾配の絶対値は変えるものの、方向は変えません。そのため、式やプログラムの見た目がきれいになるように改変しています。(本来は\frac{1}{bs^{2}}ではなく\frac{1}{bs \left( bs - 1 \right)}です)
また、lossの値も正確にするためには、定数倍する前にbsをひきます。

make_extensionでextensionを追加する

今回、エポックごとにgeneratorの出力を見てみたいと思い、extensionを追加することにしました。extensionの追加にはデコレーター(decorator)を使います。まだデコレーターをよく理解していないのですが、ひとまず、次のように書けばextensionを追加できます。

@training.make_extension()
def func(trainer):
    ...

trainer.extend(func)

プログラム内でいうと、次の部分です。

@training.make_extension(trigger=(1, 'epoch'))
def save_images(trainer):
    gen.save_images(args.image, result)

trainer.extend(save_images)

結果

「どのくらいきれいな結果が出せるか」は論文を見ていいただくとして、更新回数とGeneratorの出力結果の関係を知るための参考として、epoch1とepoch2で生成した画像を載せておきます。各エポック、DiscriminatorにMNISTのテストデータを入力したもの、Generatorに乱数を入力したものを載せています。

epoch1
Discriminator
f:id:d-higurashi:20161024142608p:plain

Generator
f:id:d-higurashi:20161024142621p:plain

epoch2
Discriminator
f:id:d-higurashi:20161024143119p:plain

Generator
f:id:d-higurashi:20161024143128p:plain

パラメータや構造について

出力層の活性化関数

出力層の活性化関数は、入力画像の値域に合わせてtanhやsigmoidを使うことが多くあると感じています。しかし、これらの関数だと絶対値が大きくなるとサチる傾向があり、学習が安定しないことが多くありました。EBGANの場合は(個人的にはGANの場合も)出力層の活性化関数は恒等関数が適任だと思います。

バッチ正規化( BatchNormalization )

GANでのバッチ正規化についてはこちらの記事でも言及されています。
Chainerを使ってコンピュータにイラストを描かせる - Qiita

よく読んでみると、気になる記述がありました。

Discriminatorのパラメータ更新には,実画像のバッチと偽画像のバッチ2つを1つのバッチにまとめて更新する方法と,明示的に損失関数を2つに分けて更新する2通りの方法がある.Batch Normalization layerが無い場合だと最終的に得られる勾配はどっちも変わらないのだが,含めた場合は明確な違いが出てしまう(複雑ですね).最初前者の方法で更新していたらG, Dどちらも勝率が100%になってしまう奇妙な結果が得られてしまった.最初chainer側のバグかと疑ってしまったのだが,最終的にBNの性質に着目し後者の実装にしたら綺麗に収束した(バグではなかった).

この記述(指摘)は「実画像バッチと偽画像バッチの2つのバッチ間で、バッチ正規化を適用する中間素子の平均分散に大きく差がある場合に、バッチ正規化が本来の機能を果たさないのではないか」という考えが前提にあります。そのため引用元の記事では、実画像は実画像でバッチ正規化をして、偽画像では偽画像でバッチ正規化をしています。

個人的にはこれでも中途半端で、次のようにすることが理想だと思っています。

実画像と偽画像でバッチを分ける。その際、実画像バッチはバッチ正規化をし、偽画像バッチは、実画像バッチのバッチ正規化に用いたパラメータ(係数と定数項)でバッチ正規化をする。

少し説明をします。引用した記事で個人的に「あれ?」と感じることは実画像バッチと偽画像バッチでバッチ正規化に用いたパラメータが異なることです。これでは、実画像と偽画像を別のDiscriminatorに通しているようなものだと思います。この現象の解決方法としては、実画像バッチと偽画像バッチを1つのバッチにまとめる*1か、上記で提案した手法を使うかです。

手を抜いて、この部分の実装はしていませんが、それでも数字を生成できたので、バッチ正規化についてあんまり頭を固くして考えなくても良い結果は出るのではないかと思っています。もし実装する場合、chainerのバッチ正規化にはtestという引数があるので、偽画像バッチのときにはこの引数をTrueにするとうまく動くはずです。

追記(2016/11/16)
MNISTで対照実験を何回か行ったところ、上記の手法を実装すると(あくまで主観ですが)生成される画像のクオリティが大きく下がりました。バッチ正規化をする場合は、実画像バッチと偽画像バッチを1つのバッチにまとめる手法が簡単かつ安定した結果を出す印象です。

今後

本来のテーマであるニコニコデータセットでの実行ができていないので、気が向いたらやろうと思っています。また、バッチ正規化のコードを追記する予定です。

*1:研究室の同期が実装しています。 https://github.com/1zk/Chainer-DCGAN-MNIST