Monthly Hacker's Blog

毎月のテーマに沿ったプログラミング記事を中心に書きます。

chainerのtrainer機能を使ってDiscoGANを実装した

できること

この記事では、次のことができるようになります。

  • CelebAデータセットをダウンロードする
  • updaterを書き換える
  • make_extensionでextensionを追加する
  • 画像の前処理を追加する
  • chainerのtrainerを使ってCelebAでDiscoGAN

はじめに

Twitterを始めました。
twitter.com

DiscoGANとは

DiscoGANは画像変換の一種です。最近は似た研究もあります。また、画像変換というとStyleNetやpix2pixなど様々な手法があります。DiscoGANがこれらの手法と大きく異なる点は教師なしということです。例えばStyleNetではスタイルにあたる画像を与える必要がありますし、pix2pixではペアの画像を与える必要があります。対してDiscGANでは、変換元の特徴を持つ大量の画像と、変換後の特徴を持つ大量の画像を与えるだけで学習ができます。具体例でいうと、男性の画像と女性の画像を大量に集めるだけで、性別変換ができるということです。

実はここまでの特徴は、先に紹介した似た研究でも言えることです。DiscoGANが面白いのは、変換前と変換後が全く異なる画像でもうまくいくと主張していることです。例えば、バッグから靴といった突拍子もない変換がうまくいった例が論文には載っています。

コードはGithubにアップロードしてあります。

github.com

CelebAデータセットのダウンロード

GAN系の論文でデファクトスタンダードになりつつあるCelebAデータセットをダウンロードします。こちらのサイトからダウンロードしてください。

mmlab.ie.cuhk.edu.hk

DropBox、Google Drive, Baidu Driveからダウンロードできると書いてありますが、DropBoxとBaidu Driveはダウンロードできない(おそらくアクセスが多すぎることが原因)ため、Google Driveからダウンロードするしかありません。余談ですが、Baidu Driveというのを初めて見たので、アクセスしてみたところ、案の定中国語だったのでChromeで自動翻訳してみると

笑、あなたはページが存在しません参照してください。

と出てきて少しイラっとしました。

Chromeで大容量のデータを大量にダウンロードしようとするとChromeが落ちてしまい、なかなかダウンロードできなかったので、データセットをダウンロードするスクリプトを書きました。こちらのやりとりが非常に参考になりました。

stackoverflow.com

以下のコードで一気にダウンロードできます。

python download.py /media/hdd/

このように引数にダウンロードしたいディレクトリを指定してください。引数を入れなかった場合はカレントディレクトリに保存します。python3系でないと動きませんが、requests周りを少し変えれば2系でも動くかもしれません。ファイルは7zipで圧縮されているので、解凍してお使いください。

import requests
import os
import sys
import subprocess


def download_file_from_google_drive(fileid, path):
    URL = "https://docs.google.com/uc?export=download"

    session = requests.Session()

    response = session.get(URL, params={'id': fileid}, stream=True)
    token = get_confirm_token(response)

    if token:
        params = {'id': fileid, 'confirm': token}
        response = session.get(URL, params=params, stream=True)

    save_response_content(response, path)


def get_confirm_token(response):
    for key, value in response.cookies.items():
        if key.startswith('download_warning'):
            return value

    return None


def save_response_content(response, path):
    CHUNK_SIZE = 32768

    with open(path, "wb") as f:
        for chunk in response.iter_content(CHUNK_SIZE):
            if chunk:
                f.write(chunk)

if __name__ == "__main__":
    # id and path
    readme_ids = [
        '0B7EVK8r0v71pOXBhSUdJWU1MYUk']
    readme_paths = [
        'README.txt']

    annotation_ids = [
        '0B7EVK8r0v71pbThiMVRxWXZ4dU0',
        '0B7EVK8r0v71pblRyaVFSWGxPY0U',
        '0B7EVK8r0v71pd0FJY3Blby1HUTQ',
        '0B7EVK8r0v71pTzJIdlJWdHczRlU']
    annotation_paths = [
        'Anno/list_bbox_celeba.txt',
        'Anno/list_attr_celeba.txt',
        'Anno/list_landmarks_align_celeba.txt',
        'Anno/list_landmarks_celeba.txt']

    eval_ids = [
        '0B7EVK8r0v71pY0NSMzRuSXJEVkk']
    eval_paths = [
        'Eval/list_eval_partition.txt']

    img_celeba_ids = [
        '0B7EVK8r0v71pQy1YUGtHeUM2dUE',
        '0B7EVK8r0v71peFphOHpxODd5SjQ',
        '0B7EVK8r0v71pMk5FeXRlOXcxVVU',
        '0B7EVK8r0v71peXc4WldxZGFUbk0',
        '0B7EVK8r0v71pMktaV1hjZUJhLWM',
        '0B7EVK8r0v71pbWFfbGRDOVZxOUU',
        '0B7EVK8r0v71pQlZrOENSOUhkQ3c',
        '0B7EVK8r0v71pLVltX2F6dzVwT0E',
        '0B7EVK8r0v71pVlg5SmtLa1ZiU0k',
        '0B7EVK8r0v71pa09rcFF4THRmSFU',
        '0B7EVK8r0v71pNU9BZVBEMF9KN28',
        '0B7EVK8r0v71pTVd3R2NpQ0FHaGM',
        '0B7EVK8r0v71paXBad2lfSzlzSlk',
        '0B7EVK8r0v71pcTFwT1VFZzkzZk0']
    img_celeba_paths = [
        'Img/img_celeba/img_celeba.7z.001',
        'Img/img_celeba/img_celeba.7z.002',
        'Img/img_celeba/img_celeba.7z.003',
        'Img/img_celeba/img_celeba.7z.004',
        'Img/img_celeba/img_celeba.7z.005',
        'Img/img_celeba/img_celeba.7z.006',
        'Img/img_celeba/img_celeba.7z.007',
        'Img/img_celeba/img_celeba.7z.008',
        'Img/img_celeba/img_celeba.7z.009',
        'Img/img_celeba/img_celeba.7z.010',
        'Img/img_celeba/img_celeba.7z.011',
        'Img/img_celeba/img_celeba.7z.012',
        'Img/img_celeba/img_celeba.7z.013',
        'Img/img_celeba/img_celeba.7z.014']

    img_align_celeba_png_ids = [
        '0B7EVK8r0v71pSVd0ZjQ3Sks2dzg',
        '0B7EVK8r0v71pR2NwRnU2cVZ2RTg',
        '0B7EVK8r0v71peUlHSDVhd0JTamM',
        '0B7EVK8r0v71pVmYwbmRtV2hZcDA',
        '0B7EVK8r0v71pVjRlNVB3cDVjaDQ',
        '0B7EVK8r0v71pa3NIcEgtTXZrM0U',
        '0B7EVK8r0v71pNE5aQmY5c2ZLOXc',
        '0B7EVK8r0v71pejhuem9QV2h0MDQ',
        '0B7EVK8r0v71pZk5QcUlObVltaEE',
        '0B7EVK8r0v71pLThPNzFETUNMUVE',
        '0B7EVK8r0v71pZWZ4UGdBbk9UVWs',
        '0B7EVK8r0v71pSk1zVWN2aHhMZ3c',
        '0B7EVK8r0v71pNjFfTGYzTWJDdUU',
        '0B7EVK8r0v71pbFlZaURkY3dhWWM',
        '0B7EVK8r0v71pczZ0NFNFdFRXSUU',
        '0B7EVK8r0v71pckZsdFFIYlJoN1k']
    img_align_celeba_png_paths = [
        'Img/img_align_celeba_png/img_align_celeba_png.7z.001',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.002',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.003',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.004',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.005',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.006',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.007',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.008',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.009',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.010',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.011',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.012',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.013',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.014',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.015',
        'Img/img_align_celeba_png/img_align_celeba_png.7z.016']

    ids = readme_ids + annotation_ids + eval_ids +\
        img_celeba_ids + img_align_celeba_png_ids

    paths = readme_paths + annotation_paths + eval_paths +\
        img_celeba_paths + img_align_celeba_png_paths

    # directory
    try:
        root = os.path.join(sys.argv[1], 'CelebA/')
    except:
        root = './CelebA/'
    Img_img_celeba = os.path.join(root, 'Img/img_celeba')
    Img_img_align_celeba_png = os.path.join(root, 'Img/img_align_celeba_png')
    Anno = os.path.join(root, 'Anno')
    Eval = os.path.join(root, 'Eval')

    if not os.path.exists(Img_img_celeba):
        os.makedirs(Img_img_celeba)

    if not os.path.exists(Img_img_align_celeba_png):
        os.makedirs(Img_img_align_celeba_png)

    if not os.path.exists(Anno):
        os.makedirs(Anno)

    if not os.path.exists(Eval):
        os.makedirs(Eval)

    # download
    for i, (fileid, path) in enumerate(zip(ids, paths)):
        print('{}/{} downloading {}'.format(i + 1, len(ids), path))
        path = os.path.join(root, path)
        if not os.path.exists(path):
            download_file_from_google_drive(fileid, path)

updaterの書き換え

class DiscoGANUpdater(training.StandardUpdater):
    def __init__(self, iterator_a, iterator_b, opt_g_ab, opt_g_ba,
                 opt_d_a, opt_d_b, device):
        self._iterators = {'main': iterator_a, 'second': iterator_b}
        self.generator_ab = opt_g_ab.target
        self.generator_ba = opt_g_ba.target
        self.discriminator_a = opt_d_a.target
        self.discriminator_b = opt_d_b.target
        self._optimizers = {'generator_ab': opt_g_ab,
                            'generator_ba': opt_g_ba,
                            'discriminator_a': opt_d_a,
                            'discriminator_b': opt_d_b}
        self.device = device
        self.converter = convert.concat_examples
        self.iteration = 0
        self.xp = self.generator_ab.xp

    def compute_loss_gan(self, y_real, y_fake):
        batchsize = y_real.shape[0]
        loss_dis = 0.5 * F.sum(F.softplus(-y_real) + F.softplus(y_fake))
        loss_gen = F.sum(F.softplus(-y_fake))
        return loss_dis / batchsize, loss_gen / batchsize

    def compute_loss_feat(self, feats_real, feats_fake):
        losses = 0
        for feat_real, feat_fake in zip(feats_real, feats_fake):
            feat_real_mean = F.sum(feat_real, 0) / feat_real.shape[0]
            feat_fake_mean = F.sum(feat_fake, 0) / feat_fake.shape[0]
            l2 = (feat_real_mean - feat_fake_mean) ** 2
            loss = F.sum(l2) / l2.size
            # loss = F.mean_absolute_error(feat_real_mean, feat_fake_mean)
            losses += loss
        return losses

    def update_core(self):

        # read data
        batch_a = self._iterators['main'].next()
        x_a = self.converter(batch_a, self.device)

        batch_b = self._iterators['second'].next()
        x_b = self.converter(batch_b, self.device)

        batchsize = x_a.shape[0]

        # conversion
        x_ab = self.generator_ab(x_a)
        x_ba = self.generator_ba(x_b)

        # reconversion
        x_aba = self.generator_ba(x_ab)
        x_bab = self.generator_ab(x_ba)

        # reconstruction loss
        recon_loss_a = F.mean_squared_error(x_a, x_aba)
        recon_loss_b = F.mean_squared_error(x_b, x_bab)

        # discriminate
        y_a_real, feats_a_real = self.discriminator_a(x_a)
        y_a_fake, feats_a_fake = self.discriminator_a(x_ba)

        y_b_real, feats_b_real = self.discriminator_b(x_b)
        y_b_fake, feats_b_fake = self.discriminator_b(x_ab)

        # GAN loss
        gan_loss_dis_a, gan_loss_gen_a =\
            self.compute_loss_gan(y_a_real, y_a_fake)
        feat_loss_a = self.compute_loss_feat(feats_a_real, feats_a_fake)

        gan_loss_dis_b, gan_loss_gen_b =\
            self.compute_loss_gan(y_b_real, y_b_fake)
        feat_loss_b = self.compute_loss_feat(feats_b_real, feats_b_fake)

        # compute loss
        if self.iteration < 10000:
            rate = 0.01
        else:
            rate = 0.5

        total_loss_gen_a = (1.-rate)*(0.1*gan_loss_gen_b + 0.9*feat_loss_b) + \
            rate * recon_loss_a
        total_loss_gen_b = (1.-rate)*(0.1*gan_loss_gen_a + 0.9*feat_loss_a) + \
            rate * recon_loss_b

        gen_loss = total_loss_gen_a + total_loss_gen_b
        dis_loss = gan_loss_dis_a + gan_loss_dis_b

        if self.iteration % 3 == 0:
            self.discriminator_a.cleargrads()
            self.discriminator_b.cleargrads()
            dis_loss.backward()
            self._optimizers['discriminator_a'].update()
            self._optimizers['discriminator_b'].update()
        else:
            self.generator_ab.cleargrads()
            self.generator_ba.cleargrads()
            gen_loss.backward()
            self._optimizers['generator_ab'].update()
            self._optimizers['generator_ba'].update()

        # report
        chainer.reporter.report({
            'loss/generator': gen_loss,
            'loss/feature maching loss': feat_loss_a + feat_loss_b,
            'loss/recon': recon_loss_a + recon_loss_b,
            'loss/discriminator': dis_loss})

GANを書くときには、以下の関係を利用することが多いです。

{
\displaystyle
\begin{eqnarray*}
Sigmoid Cross Entropy(x, 1) &=& softplus (-x) \\
Sigmoid Cross Entropy(x, 0) &=& softplus (x) \\
\end{eqnarray*}}

この関係を利用する利点は2つあります。1つ目はonesやzerosが必要ないことです。chainerにもsigmoid_cross_entropyという関数はありますが、引数にラベルを表す1あるいは0を入れた配列を用意する必要があります。softplusを用いればその手間を省くことができます。

2つ目は、nanを避けやすいことです。logの引数が0になるとnanが出力され学習が進まなくなります。どうやら、chainerだとnanを誤差逆伝播させて値を更新してしまうようです。最初はコードの簡潔さよりも論文との整合性を優先して-log(1-y)のように書いていましたが、このエラーが出たことでsoftplusを使うことにしました。

ひと通りsoftplusの利便性を説明したところで、SCEとsoftplusの関係を、実際に式変形することで求めてみます。ここまでは話を簡単にするために明言を避けていましたが、sigmoidやsoftplusはある特徴を満たす関数の総称です。式変形ではもっとも一般的な次の関数を使用します。

{
\displaystyle
\begin{eqnarray*}
sigmoid(x) &=& \frac{1}{1 + \exp(-x)} \\
softplus(x) &=& \log \left( 1 + \exp(x) \right) \\
\end{eqnarray*}}

式変形は次の通りです。

{
\displaystyle
\begin{eqnarray*}
Sigmoid Cross Entropy(x, 1) &=& -1 \cdot \log \left( \frac{1}{1 + \exp(-x)} \right) \\
&=& \log \left( \frac{1}{1 + \exp(-x)} \right)^{-1} \\
&=& \log ( 1 + \exp(-x) ) \\
&=& softplus ( -x ) \\
Sigmoid Cross Entropy(x, 0) &=& -1 \cdot \log \left( 1 - \frac{1}{1 + \exp(-x)} \right) \\
&=& - \log \left( \frac{1 + \exp(-x) - 1}{1 + \exp(-x)} \right) \\
&=& - \log \left( \frac{\exp(-x)}{1 + \exp(-x)} \right) \\
&=& - \log \left( \frac{1}{\exp(x) + 1} \right) \\
&=& \log \left( \frac{1}{\exp(x) + 1} \right)^{-1} \\
&=& \log ( \exp(x) + 1 ) \\
&=& softplus(x)
\end{eqnarray*}}

extensionの追加

def out_generated_image(iterator_a, iterator_b,
                    generator_ab, generator_ba, device, dst):
@chainer.training.make_extension()
def make_image(trainer):
    # read data
    batch_a = iterator_a.next()
    x_a = convert.concat_examples(batch_a, device)
    x_a = chainer.Variable(x_a, volatile='on')

    batch_b = iterator_b.next()
    x_b = convert.concat_examples(batch_b, device)
    x_b = chainer.Variable(x_b, volatile='on')

    # conversion
    x_ab = generator_ab(x_a, test=True)
    x_ba = generator_ba(x_b, test=True)

    # to cpu
    x_a = chainer.cuda.to_cpu(x_a.data)
    x_b = chainer.cuda.to_cpu(x_b.data)
    x_ab = chainer.cuda.to_cpu(x_ab.data)
    x_ba = chainer.cuda.to_cpu(x_ba.data)

    # reshape
    x = np.concatenate((x_a, x_ab, x_b, x_ba), 0)
    x = x.reshape(4, 10, 3, 64, 64)
    x = x.transpose(0, 3, 1, 4, 2)
    x = x.reshape((4 * 64, 10 * 64, 3))

    # to [0, 255]
    x += 1
    x *= (255 / 2)
    x = np.asarray(np.clip(x, 0, 255), dtype=np.uint8)

    preview_dir = '{}/preview'.format(dst)
    preview_path = preview_dir +\
        '/image{:0>5}.png'.format(trainer.updater.epoch)
    if not os.path.exists(preview_dir):
        os.makedirs(preview_dir)
    Image.fromarray(x).save(preview_path)
return make_image

DiscoGANは一般的なGANと違い、generatorの入力が乱数ではなく画像です。そのため、generatorの生成結果を確認するためにも画像を与える必要があります。今回は学習用に作ったデータセットを生成時にも使っています。データセットには生成結果を確認したい枚数(今回は各データセット10枚)だけ画像を入れています。そしてバッチサイズとデータセット中の画像の枚数を合わせることで、コードを簡単にしています。これはほとんどDCGANのexampleのコピーです。

前処理の追加

class PreprocessedDataset(chainer.dataset.DatasetMixin):
    def __init__(self, paths, root, size=64, random=True):
        self.paths = paths
        self.root = root
        self.size = size
        self.random = random

    def __len__(self):
        return len(self.paths)

    def read_image_as_array(self, path):
        f = Image.open(path)
        f = f.resize((109, 89), Image.ANTIALIAS)
        try:
            image = np.asarray(f, dtype=np.float32)
        finally:
            if hasattr(f, 'close'):
                f.close()
        return image.transpose((2, 0, 1))

    def get_example(self, i):
        # It reads the i-th image/label pair and return a preprocessed image.
        # It applies following preprocesses:
        #     - Cropping (random or center rectangular)
        #     - Random flip
        #     - Scaling to [-1, 1] value

        path = os.path.join(self.root, self.paths[i])
        image = self.read_image_as_array(path)
        _, h, w = image.shape

        if self.random:
            # Randomly crop a region and flip the image
            top = random.randint(0, h - self.size)
            left = random.randint(0, w - self.size)
            if random.randint(0, 1):
                image = image[:, :, ::-1]
        else:
            # Crop the center
            top = (h - self.size) // 2
            left = (w - self.size) // 2
        bottom = top + self.size
        right = left + self.size

        image = image[:, top:bottom, left:right]
        image *= (2 / 255)
        image -= 1
        return image

こちらはほとんどImageNetのexampleのコピーです。変更点は2点です。

  • 画像を半分のサイズ(109x89)にリサイズしてから64x64にクロップ
  • [-1, 1]に正規化

個人的にアスペクト比を変えると、生成する画像の質に影響しそうだと思ったため、このようにしました。正規化に関しては[0, 1]でも良かったのですが、好みで[-1, 1]にしました。

結果

f:id:d-higurashi:20170320111155p:plain
35epoch(15000iter)でこのくらいです。TITAN Xだと寝る前に動かして朝確認するとこのくらい学習が進んでいます。100epochくらい回せばそれらしくなると思います。