Monthly Hacker's Blog

毎月のテーマに沿ったプログラミング記事を中心に書きます。

chainer 1.11.0のMNISTサンプルを例にtrainerを読み解く

どんな記事?

Chainer 1.11.0で、trainerなどの新機能が追加されました。詳しくはこちら。この記事では、MNISTのサンプルを読みながら、次の機能を理解していきます。

  • Dataset
    • Iterator
  • Training
    • Trainer
    • Updater
    • Extension
      • Reporter

ソースコード

今回解説するのは、この部分です。

# Load the MNIST dataset
train, test = chainer.datasets.get_mnist()

train_iter = chainer.iterators.SerialIterator(train, args.batchsize)
test_iter = chainer.iterators.SerialIterator(test, args.batchsize,
                                             repeat=False, shuffle=False)

# Set up a trainer
updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu)
trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out)

# Evaluate the model with the test dataset for each epoch
trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu))

# Dump a computational graph from 'loss' variable at the first iteration
# The "main" refers to the target link of the "main" optimizer.
trainer.extend(extensions.dump_graph('main/loss'))

# Take a snapshot at each epoch
trainer.extend(extensions.snapshot())

# Write a log of evaluation statistics for each epoch
trainer.extend(extensions.LogReport())

# Print selected entries of the log to stdout
# Here "main" refers to the target link of the "main" optimizer again, and
# "validation" refers to the default name of the Evaluator extension.
# Entries other than 'epoch' are reported by the Classifier link, called by
# either the updater or the evaluator.
trainer.extend(extensions.PrintReport(
    ['epoch', 'main/loss', 'validation/main/loss',
     'main/accuracy', 'validation/main/accuracy']))

# Print a progress bar to stdout
trainer.extend(extensions.ProgressBar())

if args.resume:
    # Resume from a snapshot
    chainer.serializers.load_npz(args.resume, trainer)

# Run the training
trainer.run()

Dataset

一見trainとtestがdatasetクラスに見えます。しかし、そこが引っ掛けです。実はtrainとtestはこれまで同様numpy配列です。その後のchainer.iterators.SerialIteratorが内部でchainer.dataset.iterator(以下イテレーター)を呼んでいます。MNISTだけでなく、多くの場合でchainer.iterators.SerialIteratorが使えるでしょう。引数は次の通りです。

  • dataset
    • 入力したいnumpy配列(numpy配列以外でもイテレーションできれば可)
  • batch_size
    • バッチサイズ(int)
  • repeat
    • 繰り返しなし(エポック数1、主にテスト用)か繰り返しあり(エポック数1以上、主に訓練用)か(bool)
  • shuffle
    • 順番を入れ替えない(主に時系列データ用)か入れ替えるか(bool)

これだけで、エポックごとに自動でミニバッチを作ってくれます。np.random.permutationなんてやっていた時代が懐かしいです。

Training

Trainingクラスには多くの機能があります。サンプルコードに沿って順に見てみましょう。

Updater

最適化手法や利用するGPUなどを指定します。サンプルコードでいうtraining.StandardUpdaterが該当する部分です。引数を確認してみましょう。

  • iterator
    • 学習対象のイテレーターです。複数のイテレーターを辞書型で指定することができます。(イテレーター or dictionary)
  • optimizer
    • 最適化手法を入力します。こちらも複数の最適化手法を辞書型で指定することができます。(chainer.optimizers)
  • converter(default: concat_examples())
    • iteratorに辞書型で複数のイテレーターを入力したときに、結合に使う関数を指定します。個人的には、ここではデフォルトで済むようにデータ整形しておいたほうが分かりやすいコードになると思います。(function)
  • device(default: None)
    • GPUを使う場合は指定してください。指定しない場合はCPUで計算します。(int)
  • loss_func(default: None)
    • 対象の誤差関数を指定します。通常はoptimizerがセットアップしたリンクを指定するので、サンプルコードのように下準備を済ませておけばNoneで問題ありません。(chainer.Link)

Trainer

続いてTrainerです。ここではエポック数などを指定します。また、学習をスタートするときは、このクラスでrun()します。サンプルコードのtraining.Trainerが該当します。
引数はこちら。

  • updater
    • 使用するupdaterを指定します。(updater)
  • stop_trigger(default: None)
    • エポック数やイテレーション数などを(1000, 'iteration')や(1, 'epoch')のように指定します。(tuple)
  • out(default: 'result')
    • 出力先を指定します。(str)

Extension

for文を書かないとなると、途中経過の保存などなど、自分で処理が書けないと思うかもしれません。しかし、Extensionを使えばそうした心配が解消されます。<条件>のときに<動作>するを指定できるのがExtensionです。それを、Trainer.extend(extension)のようにextend関数を使って追加していきます。ちなみに、chainer.training.extensionsにはよく使うExtensionが用意されています。サンプルコードではこのchainer.training.extensionsのExtensionのみで書かれています。ここではサンプルコードで使われているextensionsを紹介していきます。*1

Evaluator
モデルの評価をするExtensionです。Updaterとほぼ同じような引数を持ちます。詳しくはこちら

dump_graph()
グラフを保存するExtensionです。出力はDOT言語です。詳しくはこちら

snapshot
trainerを保存するExtensionです。デフォルトでは毎エポック呼び出されます。詳しくはこちら

LogReport
ログを保存するExtensionです。json形式で出力されます。詳しくはこちら

PrintReport
ログを出力します。出力するログはデフォルトの値がないのでサンプルコードのように何かしら指定して使いましょう。詳しくはこちら

ProgressBar
学習の進行状況を出力します。こういう「あると便利だけどサボりがち」なコードが1行で書けると便利ですね。詳しくはこちら

最後に

大幅なアップデートで、最初は面倒だなと思いましたが、便利な機能が多く、使いこなしたら楽できそうです。しばらくDocsとにらめっこですが、頑張って使いこなしていこうと思います。

*1:自分でExtensionを作るときはchainer.training.make_extension()を使います