できること
この記事では、次のことができるようになります。
- CelebAデータセットをダウンロードする
- updaterを書き換える
- make_extensionでextensionを追加する
- 画像の前処理を追加する
- chainerのtrainerを使ってCelebAでDiscoGAN
はじめに
Twitterを始めました。
twitter.com
DiscoGANとは
DiscoGANは画像変換の一種です。最近は似た研究もあります。また、画像変換というとStyleNetやpix2pixなど様々な手法があります。DiscoGANがこれらの手法と大きく異なる点は教師なしということです。例えばStyleNetではスタイルにあたる画像を与える必要がありますし、pix2pixではペアの画像を与える必要があります。対してDiscGANでは、変換元の特徴を持つ大量の画像と、変換後の特徴を持つ大量の画像を与えるだけで学習ができます。具体例でいうと、男性の画像と女性の画像を大量に集めるだけで、性別変換ができるということです。
実はここまでの特徴は、先に紹介した似た研究でも言えることです。DiscoGANが面白いのは、変換前と変換後が全く異なる画像でもうまくいくと主張していることです。例えば、バッグから靴といった突拍子もない変換がうまくいった例が論文には載っています。
コードはGithubにアップロードしてあります。
CelebAデータセットのダウンロード
GAN系の論文でデファクトスタンダードになりつつあるCelebAデータセットをダウンロードします。こちらのサイトからダウンロードしてください。
DropBox、Google Drive, Baidu Driveからダウンロードできると書いてありますが、DropBoxとBaidu Driveはダウンロードできない(おそらくアクセスが多すぎることが原因)ため、Google Driveからダウンロードするしかありません。余談ですが、Baidu Driveというのを初めて見たので、アクセスしてみたところ、案の定中国語だったのでChromeで自動翻訳してみると
笑、あなたはページが存在しません参照してください。
と出てきて少しイラっとしました。
Chromeで大容量のデータを大量にダウンロードしようとするとChromeが落ちてしまい、なかなかダウンロードできなかったので、データセットをダウンロードするスクリプトを書きました。こちらのやりとりが非常に参考になりました。
以下のコードで一気にダウンロードできます。
python download.py /media/hdd/
このように引数にダウンロードしたいディレクトリを指定してください。引数を入れなかった場合はカレントディレクトリに保存します。python3系でないと動きませんが、requests周りを少し変えれば2系でも動くかもしれません。ファイルは7zipで圧縮されているので、解凍してお使いください。
import requests import os import sys import subprocess def download_file_from_google_drive(fileid, path): URL = "https://docs.google.com/uc?export=download" session = requests.Session() response = session.get(URL, params={'id': fileid}, stream=True) token = get_confirm_token(response) if token: params = {'id': fileid, 'confirm': token} response = session.get(URL, params=params, stream=True) save_response_content(response, path) def get_confirm_token(response): for key, value in response.cookies.items(): if key.startswith('download_warning'): return value return None def save_response_content(response, path): CHUNK_SIZE = 32768 with open(path, "wb") as f: for chunk in response.iter_content(CHUNK_SIZE): if chunk: f.write(chunk) if __name__ == "__main__": # id and path readme_ids = [ '0B7EVK8r0v71pOXBhSUdJWU1MYUk'] readme_paths = [ 'README.txt'] annotation_ids = [ '0B7EVK8r0v71pbThiMVRxWXZ4dU0', '0B7EVK8r0v71pblRyaVFSWGxPY0U', '0B7EVK8r0v71pd0FJY3Blby1HUTQ', '0B7EVK8r0v71pTzJIdlJWdHczRlU'] annotation_paths = [ 'Anno/list_bbox_celeba.txt', 'Anno/list_attr_celeba.txt', 'Anno/list_landmarks_align_celeba.txt', 'Anno/list_landmarks_celeba.txt'] eval_ids = [ '0B7EVK8r0v71pY0NSMzRuSXJEVkk'] eval_paths = [ 'Eval/list_eval_partition.txt'] img_celeba_ids = [ '0B7EVK8r0v71pQy1YUGtHeUM2dUE', '0B7EVK8r0v71peFphOHpxODd5SjQ', '0B7EVK8r0v71pMk5FeXRlOXcxVVU', '0B7EVK8r0v71peXc4WldxZGFUbk0', '0B7EVK8r0v71pMktaV1hjZUJhLWM', '0B7EVK8r0v71pbWFfbGRDOVZxOUU', '0B7EVK8r0v71pQlZrOENSOUhkQ3c', '0B7EVK8r0v71pLVltX2F6dzVwT0E', '0B7EVK8r0v71pVlg5SmtLa1ZiU0k', '0B7EVK8r0v71pa09rcFF4THRmSFU', '0B7EVK8r0v71pNU9BZVBEMF9KN28', '0B7EVK8r0v71pTVd3R2NpQ0FHaGM', '0B7EVK8r0v71paXBad2lfSzlzSlk', '0B7EVK8r0v71pcTFwT1VFZzkzZk0'] img_celeba_paths = [ 'Img/img_celeba/img_celeba.7z.001', 'Img/img_celeba/img_celeba.7z.002', 'Img/img_celeba/img_celeba.7z.003', 'Img/img_celeba/img_celeba.7z.004', 'Img/img_celeba/img_celeba.7z.005', 'Img/img_celeba/img_celeba.7z.006', 'Img/img_celeba/img_celeba.7z.007', 'Img/img_celeba/img_celeba.7z.008', 'Img/img_celeba/img_celeba.7z.009', 'Img/img_celeba/img_celeba.7z.010', 'Img/img_celeba/img_celeba.7z.011', 'Img/img_celeba/img_celeba.7z.012', 'Img/img_celeba/img_celeba.7z.013', 'Img/img_celeba/img_celeba.7z.014'] img_align_celeba_png_ids = [ '0B7EVK8r0v71pSVd0ZjQ3Sks2dzg', '0B7EVK8r0v71pR2NwRnU2cVZ2RTg', '0B7EVK8r0v71peUlHSDVhd0JTamM', '0B7EVK8r0v71pVmYwbmRtV2hZcDA', '0B7EVK8r0v71pVjRlNVB3cDVjaDQ', '0B7EVK8r0v71pa3NIcEgtTXZrM0U', '0B7EVK8r0v71pNE5aQmY5c2ZLOXc', '0B7EVK8r0v71pejhuem9QV2h0MDQ', '0B7EVK8r0v71pZk5QcUlObVltaEE', '0B7EVK8r0v71pLThPNzFETUNMUVE', '0B7EVK8r0v71pZWZ4UGdBbk9UVWs', '0B7EVK8r0v71pSk1zVWN2aHhMZ3c', '0B7EVK8r0v71pNjFfTGYzTWJDdUU', '0B7EVK8r0v71pbFlZaURkY3dhWWM', '0B7EVK8r0v71pczZ0NFNFdFRXSUU', '0B7EVK8r0v71pckZsdFFIYlJoN1k'] img_align_celeba_png_paths = [ 'Img/img_align_celeba_png/img_align_celeba_png.7z.001', 'Img/img_align_celeba_png/img_align_celeba_png.7z.002', 'Img/img_align_celeba_png/img_align_celeba_png.7z.003', 'Img/img_align_celeba_png/img_align_celeba_png.7z.004', 'Img/img_align_celeba_png/img_align_celeba_png.7z.005', 'Img/img_align_celeba_png/img_align_celeba_png.7z.006', 'Img/img_align_celeba_png/img_align_celeba_png.7z.007', 'Img/img_align_celeba_png/img_align_celeba_png.7z.008', 'Img/img_align_celeba_png/img_align_celeba_png.7z.009', 'Img/img_align_celeba_png/img_align_celeba_png.7z.010', 'Img/img_align_celeba_png/img_align_celeba_png.7z.011', 'Img/img_align_celeba_png/img_align_celeba_png.7z.012', 'Img/img_align_celeba_png/img_align_celeba_png.7z.013', 'Img/img_align_celeba_png/img_align_celeba_png.7z.014', 'Img/img_align_celeba_png/img_align_celeba_png.7z.015', 'Img/img_align_celeba_png/img_align_celeba_png.7z.016'] ids = readme_ids + annotation_ids + eval_ids +\ img_celeba_ids + img_align_celeba_png_ids paths = readme_paths + annotation_paths + eval_paths +\ img_celeba_paths + img_align_celeba_png_paths # directory try: root = os.path.join(sys.argv[1], 'CelebA/') except: root = './CelebA/' Img_img_celeba = os.path.join(root, 'Img/img_celeba') Img_img_align_celeba_png = os.path.join(root, 'Img/img_align_celeba_png') Anno = os.path.join(root, 'Anno') Eval = os.path.join(root, 'Eval') if not os.path.exists(Img_img_celeba): os.makedirs(Img_img_celeba) if not os.path.exists(Img_img_align_celeba_png): os.makedirs(Img_img_align_celeba_png) if not os.path.exists(Anno): os.makedirs(Anno) if not os.path.exists(Eval): os.makedirs(Eval) # download for i, (fileid, path) in enumerate(zip(ids, paths)): print('{}/{} downloading {}'.format(i + 1, len(ids), path)) path = os.path.join(root, path) if not os.path.exists(path): download_file_from_google_drive(fileid, path)
updaterの書き換え
class DiscoGANUpdater(training.StandardUpdater): def __init__(self, iterator_a, iterator_b, opt_g_ab, opt_g_ba, opt_d_a, opt_d_b, device): self._iterators = {'main': iterator_a, 'second': iterator_b} self.generator_ab = opt_g_ab.target self.generator_ba = opt_g_ba.target self.discriminator_a = opt_d_a.target self.discriminator_b = opt_d_b.target self._optimizers = {'generator_ab': opt_g_ab, 'generator_ba': opt_g_ba, 'discriminator_a': opt_d_a, 'discriminator_b': opt_d_b} self.device = device self.converter = convert.concat_examples self.iteration = 0 self.xp = self.generator_ab.xp def compute_loss_gan(self, y_real, y_fake): batchsize = y_real.shape[0] loss_dis = 0.5 * F.sum(F.softplus(-y_real) + F.softplus(y_fake)) loss_gen = F.sum(F.softplus(-y_fake)) return loss_dis / batchsize, loss_gen / batchsize def compute_loss_feat(self, feats_real, feats_fake): losses = 0 for feat_real, feat_fake in zip(feats_real, feats_fake): feat_real_mean = F.sum(feat_real, 0) / feat_real.shape[0] feat_fake_mean = F.sum(feat_fake, 0) / feat_fake.shape[0] l2 = (feat_real_mean - feat_fake_mean) ** 2 loss = F.sum(l2) / l2.size # loss = F.mean_absolute_error(feat_real_mean, feat_fake_mean) losses += loss return losses def update_core(self): # read data batch_a = self._iterators['main'].next() x_a = self.converter(batch_a, self.device) batch_b = self._iterators['second'].next() x_b = self.converter(batch_b, self.device) batchsize = x_a.shape[0] # conversion x_ab = self.generator_ab(x_a) x_ba = self.generator_ba(x_b) # reconversion x_aba = self.generator_ba(x_ab) x_bab = self.generator_ab(x_ba) # reconstruction loss recon_loss_a = F.mean_squared_error(x_a, x_aba) recon_loss_b = F.mean_squared_error(x_b, x_bab) # discriminate y_a_real, feats_a_real = self.discriminator_a(x_a) y_a_fake, feats_a_fake = self.discriminator_a(x_ba) y_b_real, feats_b_real = self.discriminator_b(x_b) y_b_fake, feats_b_fake = self.discriminator_b(x_ab) # GAN loss gan_loss_dis_a, gan_loss_gen_a =\ self.compute_loss_gan(y_a_real, y_a_fake) feat_loss_a = self.compute_loss_feat(feats_a_real, feats_a_fake) gan_loss_dis_b, gan_loss_gen_b =\ self.compute_loss_gan(y_b_real, y_b_fake) feat_loss_b = self.compute_loss_feat(feats_b_real, feats_b_fake) # compute loss if self.iteration < 10000: rate = 0.01 else: rate = 0.5 total_loss_gen_a = (1.-rate)*(0.1*gan_loss_gen_b + 0.9*feat_loss_b) + \ rate * recon_loss_a total_loss_gen_b = (1.-rate)*(0.1*gan_loss_gen_a + 0.9*feat_loss_a) + \ rate * recon_loss_b gen_loss = total_loss_gen_a + total_loss_gen_b dis_loss = gan_loss_dis_a + gan_loss_dis_b if self.iteration % 3 == 0: self.discriminator_a.cleargrads() self.discriminator_b.cleargrads() dis_loss.backward() self._optimizers['discriminator_a'].update() self._optimizers['discriminator_b'].update() else: self.generator_ab.cleargrads() self.generator_ba.cleargrads() gen_loss.backward() self._optimizers['generator_ab'].update() self._optimizers['generator_ba'].update() # report chainer.reporter.report({ 'loss/generator': gen_loss, 'loss/feature maching loss': feat_loss_a + feat_loss_b, 'loss/recon': recon_loss_a + recon_loss_b, 'loss/discriminator': dis_loss})
GANを書くときには、以下の関係を利用することが多いです。
この関係を利用する利点は2つあります。1つ目はonesやzerosが必要ないことです。chainerにもsigmoid_cross_entropyという関数はありますが、引数にラベルを表す1あるいは0を入れた配列を用意する必要があります。softplusを用いればその手間を省くことができます。
2つ目は、nanを避けやすいことです。logの引数が0になるとnanが出力され学習が進まなくなります。どうやら、chainerだとnanを誤差逆伝播させて値を更新してしまうようです。最初はコードの簡潔さよりも論文との整合性を優先して-log(1-y)のように書いていましたが、このエラーが出たことでsoftplusを使うことにしました。
ひと通りsoftplusの利便性を説明したところで、SCEとsoftplusの関係を、実際に式変形することで求めてみます。ここまでは話を簡単にするために明言を避けていましたが、sigmoidやsoftplusはある特徴を満たす関数の総称です。式変形ではもっとも一般的な次の関数を使用します。
式変形は次の通りです。
extensionの追加
def out_generated_image(iterator_a, iterator_b, generator_ab, generator_ba, device, dst): @chainer.training.make_extension() def make_image(trainer): # read data batch_a = iterator_a.next() x_a = convert.concat_examples(batch_a, device) x_a = chainer.Variable(x_a, volatile='on') batch_b = iterator_b.next() x_b = convert.concat_examples(batch_b, device) x_b = chainer.Variable(x_b, volatile='on') # conversion x_ab = generator_ab(x_a, test=True) x_ba = generator_ba(x_b, test=True) # to cpu x_a = chainer.cuda.to_cpu(x_a.data) x_b = chainer.cuda.to_cpu(x_b.data) x_ab = chainer.cuda.to_cpu(x_ab.data) x_ba = chainer.cuda.to_cpu(x_ba.data) # reshape x = np.concatenate((x_a, x_ab, x_b, x_ba), 0) x = x.reshape(4, 10, 3, 64, 64) x = x.transpose(0, 3, 1, 4, 2) x = x.reshape((4 * 64, 10 * 64, 3)) # to [0, 255] x += 1 x *= (255 / 2) x = np.asarray(np.clip(x, 0, 255), dtype=np.uint8) preview_dir = '{}/preview'.format(dst) preview_path = preview_dir +\ '/image{:0>5}.png'.format(trainer.updater.epoch) if not os.path.exists(preview_dir): os.makedirs(preview_dir) Image.fromarray(x).save(preview_path) return make_image
DiscoGANは一般的なGANと違い、generatorの入力が乱数ではなく画像です。そのため、generatorの生成結果を確認するためにも画像を与える必要があります。今回は学習用に作ったデータセットを生成時にも使っています。データセットには生成結果を確認したい枚数(今回は各データセット10枚)だけ画像を入れています。そしてバッチサイズとデータセット中の画像の枚数を合わせることで、コードを簡単にしています。これはほとんどDCGANのexampleのコピーです。
前処理の追加
class PreprocessedDataset(chainer.dataset.DatasetMixin): def __init__(self, paths, root, size=64, random=True): self.paths = paths self.root = root self.size = size self.random = random def __len__(self): return len(self.paths) def read_image_as_array(self, path): f = Image.open(path) f = f.resize((109, 89), Image.ANTIALIAS) try: image = np.asarray(f, dtype=np.float32) finally: if hasattr(f, 'close'): f.close() return image.transpose((2, 0, 1)) def get_example(self, i): # It reads the i-th image/label pair and return a preprocessed image. # It applies following preprocesses: # - Cropping (random or center rectangular) # - Random flip # - Scaling to [-1, 1] value path = os.path.join(self.root, self.paths[i]) image = self.read_image_as_array(path) _, h, w = image.shape if self.random: # Randomly crop a region and flip the image top = random.randint(0, h - self.size) left = random.randint(0, w - self.size) if random.randint(0, 1): image = image[:, :, ::-1] else: # Crop the center top = (h - self.size) // 2 left = (w - self.size) // 2 bottom = top + self.size right = left + self.size image = image[:, top:bottom, left:right] image *= (2 / 255) image -= 1 return image
こちらはほとんどImageNetのexampleのコピーです。変更点は2点です。
- 画像を半分のサイズ(109x89)にリサイズしてから64x64にクロップ
- [-1, 1]に正規化
個人的にアスペクト比を変えると、生成する画像の質に影響しそうだと思ったため、このようにしました。正規化に関しては[0, 1]でも良かったのですが、好みで[-1, 1]にしました。
結果
35epoch(15000iter)でこのくらいです。TITAN Xだと寝る前に動かして朝確認するとこのくらい学習が進んでいます。100epochくらい回せばそれらしくなると思います。