Sentencepieceの水増しをBERTで試してみる
前回、事前学習済みのbertモデルbert-japaneseを使って対話破綻検知チャレンジ(Dialog Breakdown Detection Challenge)コーパスでのファインチューニングを行った。
結果は決して悪いものではなかったが、当時(DBDC2)のトップモデルの性能には届かなかった。
bert-japaneseは文字列のトークン化にsentencepieceを採用しているので、sentencepieceのSampleEncodeAsPiecesを用いたデータの水増しが可能である。
今回は、dbdcのファインチューニングでトークン化に関する水増しを試してみる。
参考
変更点
モデル、コーパス、学習条件は前回のDBDCの学習と同一。
トークン化の際に水増しを行うため、学習スクリプトに以下の変更を施した。 (変更後の学習スクリプトrun_dbdc_classifier.py)
tokenizerのtokenizeでEncodeAsPieces, SampleEncodeAsPiecesを使い分け
SentencePieceTokenizerにenabled_samplingというboolの変数を持たせ、この真偽で2つのエンコード方法を使い分ける。
このようにして学習時はSampleし、他の時はベストのトークン化を行う。
SampleEncodeAsPiecesのn_best_sizeは-1, alpha=0.1を用いる。なお、この値はいくつかの文をサンプルにかけて適度に揺らぐことを確認ている。
トークン化前後でenabled_samplingを変更
bertのスクリプトで事例のトークン化が行われているのは、convert_single_example関数内である。
さらに、file_based_convert_examples_to_featuresで各事例に対してこの関数が呼び出される。
学習時にfile_based_convert_examples_to_featuresが実行される直前に
- examplesをエポックの数だけコピー
- Tokenizerのenabled_samplingをTrueに変更
とすることでトークン化をサンプリングモードに変え、トークン化を揺らがし、file_based_convert_examples_to_featuresが実行された後で普通のトークン化に戻す。
一応この部分のコードは下記:
if FLAGS.do_train: train_file = os.path.join(FLAGS.output_dir, "train.tf_record") # Augmentation on tokenization train_examples *= int(FLAGS.num_train_epochs) tokenizer.enabled_sampling = True file_based_convert_examples_to_features( train_examples, label_list, FLAGS.max_seq_length, tokenizer, train_file) tokenizer.enabled_sampling = False tf.logging.info("***** Running training *****") # ... コードが続く ...
この方法だと、tfrecordにサンプル数 x エポック数のデータを書き込むため多少ディスク容量を食うが、dbdcのデータは1エポック分が3.5MB程度だったので問題はないと思われる。
トークン化の水増しをして学習した結果
ロスの曲線
学習曲線にはそれほど差がみられない。
評価結果
前回同様参考としてDBDC2のNTTCSrun2の結果を引用した(DBDC2)。
BERT水増しなし-10は前回の記事の結果。
今回は、水増しありで10エポックと、水増しあり/なしで20エポックの学習を行った。
ばらつきの参考として、水増しあり20エポックは2回学習し、その両方の結果を示した。
NTTCS run2 | BERT 水増しなし -10 | BERT 水増しあり -10 | BERT 水増しなし -20 | BERT 水増しあり -20(1) | BERT 水増しあり -20(2) | |
---|---|---|---|---|---|---|
エポック数 | - | 10 | 10 | 20 | 20 | 20 |
(loss-valid) | - | 0.201 | 0.197 | 0.207 | 0.199 | 0.202 |
(Acc.-valid) | - | 0.591 | 0.611 | 0.579 | 0.585 | 0.598 |
Acc.-DCM | 0.565 | 0.523 | 0.507 | 0.529 | 0.545 | 0.544 |
Acc.-DIT | 0.655 | 0.62 | 0.618 | 0.62 | 0.655 | 0.627 |
Acc.-IRS | 0.584 | 0.595 | 0.611 | 0.593 | 0.607 | 0.607 |
JS-Div. (O,T,X)-DCM | 0.085 | 0.083 | 0.0829 | 0.0841 | 0.0783 | 0.0815 |
JS-Div. (O,T,X)-DIT | 0.046 | 0.043 | 0.0418 | 0.0455 | 0.0428 | 0.0445 |
JS-Div. (O,T,X)-IRS | 0.101 | 0.091 | 0.0899 | 0.0922 | 0.0867 | 0.09 |
(Acc.は破綻3ラベルの正解率、JS-Div.は3ラベル分布のJSダイバージェンス、DCM, DIT, IRSは対話システムの種類を表す)
水増しあり20の2回の結果には、DITのAccuracyと分布距離系統評価値にばらつきがみられる。
水増しあり10エポックは水増しなし10エポックよりもラベル系統指標における劣化が見られるが、20エポックまで学習すると水増しなしの方はほとんど変化がないのに対して水増しありの方は、平均的には精度が向上している。
十分に学習させれば、トークン化に関する水増しには一定の効果があると見える。
ただし、どのBERTの結果を見てもまだラベル系統はNTTCSrun2の精度には追いつけていないが。
感想
トークン化時の水増しの効果は確かにありそうだ。
Sentencepieceの使用時はトークン化の水増しを試してみる価値はあるのではないだろうか(そのときは、水増しをしていない時より学習に時間がかかるようになるため、少し長めに学習エポックを設定するとよい)。