どんな記事?
Chainer 1.11.0で、trainerなどの新機能が追加されました。詳しくはこちら。この記事では、MNISTのサンプルを読みながら、次の機能を理解していきます。
- Dataset
- Iterator
- Training
- Trainer
- Updater
- Extension
- Reporter
ソースコード
今回解説するのは、この部分です。
# Load the MNIST dataset train, test = chainer.datasets.get_mnist() train_iter = chainer.iterators.SerialIterator(train, args.batchsize) test_iter = chainer.iterators.SerialIterator(test, args.batchsize, repeat=False, shuffle=False) # Set up a trainer updater = training.StandardUpdater(train_iter, optimizer, device=args.gpu) trainer = training.Trainer(updater, (args.epoch, 'epoch'), out=args.out) # Evaluate the model with the test dataset for each epoch trainer.extend(extensions.Evaluator(test_iter, model, device=args.gpu)) # Dump a computational graph from 'loss' variable at the first iteration # The "main" refers to the target link of the "main" optimizer. trainer.extend(extensions.dump_graph('main/loss')) # Take a snapshot at each epoch trainer.extend(extensions.snapshot()) # Write a log of evaluation statistics for each epoch trainer.extend(extensions.LogReport()) # Print selected entries of the log to stdout # Here "main" refers to the target link of the "main" optimizer again, and # "validation" refers to the default name of the Evaluator extension. # Entries other than 'epoch' are reported by the Classifier link, called by # either the updater or the evaluator. trainer.extend(extensions.PrintReport( ['epoch', 'main/loss', 'validation/main/loss', 'main/accuracy', 'validation/main/accuracy'])) # Print a progress bar to stdout trainer.extend(extensions.ProgressBar()) if args.resume: # Resume from a snapshot chainer.serializers.load_npz(args.resume, trainer) # Run the training trainer.run()
Dataset
一見trainとtestがdatasetクラスに見えます。しかし、そこが引っ掛けです。実はtrainとtestはこれまで同様numpy配列です。その後のchainer.iterators.SerialIteratorが内部でchainer.dataset.iterator(以下イテレーター)を呼んでいます。MNISTだけでなく、多くの場合でchainer.iterators.SerialIteratorが使えるでしょう。引数は次の通りです。
- dataset
- 入力したいnumpy配列(numpy配列以外でもイテレーションできれば可)
- batch_size
- バッチサイズ(int)
- repeat
- 繰り返しなし(エポック数1、主にテスト用)か繰り返しあり(エポック数1以上、主に訓練用)か(bool)
- shuffle
- 順番を入れ替えない(主に時系列データ用)か入れ替えるか(bool)
これだけで、エポックごとに自動でミニバッチを作ってくれます。np.random.permutationなんてやっていた時代が懐かしいです。
Training
Trainingクラスには多くの機能があります。サンプルコードに沿って順に見てみましょう。
Updater
最適化手法や利用するGPUなどを指定します。サンプルコードでいうtraining.StandardUpdaterが該当する部分です。引数を確認してみましょう。
- iterator
- 学習対象のイテレーターです。複数のイテレーターを辞書型で指定することができます。(イテレーター or dictionary)
- optimizer
- 最適化手法を入力します。こちらも複数の最適化手法を辞書型で指定することができます。(chainer.optimizers)
- converter(default: concat_examples())
- iteratorに辞書型で複数のイテレーターを入力したときに、結合に使う関数を指定します。個人的には、ここではデフォルトで済むようにデータ整形しておいたほうが分かりやすいコードになると思います。(function)
- device(default: None)
- GPUを使う場合は指定してください。指定しない場合はCPUで計算します。(int)
- loss_func(default: None)
- 対象の誤差関数を指定します。通常はoptimizerがセットアップしたリンクを指定するので、サンプルコードのように下準備を済ませておけばNoneで問題ありません。(chainer.Link)
Trainer
続いてTrainerです。ここではエポック数などを指定します。また、学習をスタートするときは、このクラスでrun()します。サンプルコードのtraining.Trainerが該当します。
引数はこちら。
- updater
- 使用するupdaterを指定します。(updater)
- stop_trigger(default: None)
- エポック数やイテレーション数などを(1000, 'iteration')や(1, 'epoch')のように指定します。(tuple)
- out(default: 'result')
- 出力先を指定します。(str)
Extension
for文を書かないとなると、途中経過の保存などなど、自分で処理が書けないと思うかもしれません。しかし、Extensionを使えばそうした心配が解消されます。<条件>のときに<動作>するを指定できるのがExtensionです。それを、Trainer.extend(extension)のようにextend関数を使って追加していきます。ちなみに、chainer.training.extensionsにはよく使うExtensionが用意されています。サンプルコードではこのchainer.training.extensionsのExtensionのみで書かれています。ここではサンプルコードで使われているextensionsを紹介していきます。*1
Evaluator
モデルの評価をするExtensionです。Updaterとほぼ同じような引数を持ちます。詳しくはこちら。
dump_graph()
グラフを保存するExtensionです。出力はDOT言語です。詳しくはこちら。
snapshot
trainerを保存するExtensionです。デフォルトでは毎エポック呼び出されます。詳しくはこちら。
LogReport
ログを保存するExtensionです。json形式で出力されます。詳しくはこちら。
PrintReport
ログを出力します。出力するログはデフォルトの値がないのでサンプルコードのように何かしら指定して使いましょう。詳しくはこちら。
ProgressBar
学習の進行状況を出力します。こういう「あると便利だけどサボりがち」なコードが1行で書けると便利ですね。詳しくはこちら。
最後に
大幅なアップデートで、最初は面倒だなと思いましたが、便利な機能が多く、使いこなしたら楽できそうです。しばらくDocsとにらめっこですが、頑張って使いこなしていこうと思います。
*1:自分でExtensionを作るときはchainer.training.make_extension()を使います