BERTで対話破綻検知

対話破綻検知チャレンジは人と対話システムとの雑談対話に対して○、△、×の3値分類を行うコンペティションで今まで3回行われている。

データが公開刺されいるので、今回はこのタスクについてBERTをファインチューニングしてみる。

対話破綻検知チャレンジ(Dialog Breakdown Detection Challenge; DBDC)のタスクとBERTへの入力

DBDCでの破綻検知タスク

  • 対話システムの最終発話、そこに至るまでの人と対話システムの対話履歴がそれぞれ文字列として与えられ、最終発話に関して、破綻(×)、おそらく破綻(△)、破綻ではない(○)を判定する。
  • 対話履歴は最大で20ターンある。
  • データセットには人手で判定がアノテーションされている。
  • DBDC ~ DBDC3まで少しずつデータが増やされている。今回はDBDC, DBDC2のデータを用いる。
  • DBDCのdev, eval対話データとDBDC2のdevの対話データを学習に用いた(ただしそのうちの25%は検証データセット)。DBDC2の評価セットでDCM, DIT, IRSそれぞれの対話システムごとテストを行う。
  • DBDCデータが合計100対話(各21ターン)、DBDC2-devデータが合計150対話(各21ターン)で、学習に用いるのは(100+150)*0.75=188対話。

BERTへの入力

BERTへの最終発話と対話履歴の与え方は自明ではない。

今回は、BertのIsNext学習となるべく類似するようにと考え、以下の方針でいくこととした。

  • BERTでは各ターンの発話を再帰的に入力することはできないので、一度の入力で最終発話を対話履歴をすべて入力するものとする。
  • "[CLS]<ターン1発話>スペース<ターン2発話>....[SEP]<判定対象発話>[SEP]"をトークン数制限(512)まで詰め込む。
  • SegmentIDは前半の文脈発話では0、判定対象発話では1とする。
  • [CLS]の次に3次元に変換する線形変換をかませて3値分類する。

_truncate_seq_pair 関数の改変
run_classifier.pyで512より長いトークン系列を切り捨てる際には、tokena, tokenbで長いほうの後ろからトークンを捨てて行って収まるようにしてある。今回、文脈発話の後ろが捨てられると重要な部分が抜ける恐れがあるため、前方から捨てるように変更を加えた。

KLダイバージェンスの最適化

DBDCでは評価尺度としてラベル系統と分布距離系統の2つが採用されている。
今回は、分布距離系統を重視するためsoftmax_cross_entorpyの計算時に(正解ラベルではなく)ラベル分布を与えることでlossを計算する(soft target)。

といっても、tf.nn.softmax_cross_entropy_with_logits_v2のlabels引数に分布を与えるだけなので。難しいことはない。

\displaystyle{
loss = -\sum_{i}\sum_{j}p_{ij} \log{\hat{p}_{ij}} + \sum_{i}\sum_{j}p_{ij} \log{p_{ij}}
}

ハットがついているpはモデルの出力、ハットのないpは教師データの確率分布、iがデータを走り、jがクラスを走る。
tf.nn.softmax_cross_entropy_with_logits_v2が計算するのは前半のsum。 後半のsumは勾配に関与しないが分布が一致したときにlossが0になるためにつけている。これは結局のところpと\hat{p}のKLDだ。

関連部分のコードを抜き出すと次のようになる。

    # 学習済みモデルの出力(output_weights)をクラス分類の出力に線形変換
    logits = tf.matmul(output_layer, output_weights, transpose_b=True)
    logits = tf.nn.bias_add(logits, output_bias)
    
    # 後の判定処理のためにロジットを確率に変換
    probabilities = tf.nn.softmax(logits, axis=-1)
    
    # ロジットと正解分布(labels)からロスを計算
    per_example_loss = tf.nn.softmax_cross_entropy_with_logits_v2(labels, logits, axis=-1)
    per_example_loss += tf.reduce_sum(tf.log(tf.pow(labels, labels)), axis=-1)
    loss = tf.reduce_mean(per_example_loss)

データの準備と学習

  1. "dbdc_corpus.py -m fetch"を実行(データをダウンロードしてtsvに加工)。
  2. run_dbdc_classifier.pyを実行

run_dbdc_classifier.pyの引数は下記

python src/run_dbdc_classifier.py \
    --task_name=DBDC \
    --do_train=true \
    --do_eval=true \
    --data_dir=data/dbdc \
    --model_file=model/wiki-ja-mod.model \
    --vocab_file=model/wiki-ja-mod.vocab \
    --init_checkpoint=model/model.ckpt-1400000 \
    --max_seq_length=512 \
    --train_batch_size=4 \
    --learning_rate=2e-5 \
    --num_train_epochs=10 \
    --output_dir=model/dbdc_1

(model/wiki-ja-mod.model(.vocav)は後処理付きのsentencepieceだが、model/wiki-ja.model(.vocab)とほぼ同一)

スクリプトは以下のリポジトリに保存
https://github.com/iki-taichi/bert-japanese

主な関連スクリプト:

  • src/dbdc_corpus.py
  • src/run_dbdc_classifier.py

結果

同一条件で3回学習した結果、検証データに関して以下のようにラベル精度が得られた。

dbdc_1 dbdc_2 dbdc_3
3クラス分類精度 0.5881 0.5793 0.5910

dbdc_3の学習曲線:

f:id:lang-int:20190413005151p:plain
loss curve dbdc_3

検証データに関して、最も高い精度だったdbdc_3をテストセットで評価し、DBDC2のNTTCSrun2, HCUrun3と比較した。

DBDC2のサイトより引用したNTTCSrun2は3つの対話システムに関して一貫して高い性能を示し、学習したデータもほぼ同一であるため比較対象とした。 また同じところから引用したHCUrun3はRNNをベースにした深層モデルのアンサンブルであり、同じ深層モデルの目安として参照した。

各指標で最も良い値を太字にする。 なお、ラベル系統(-label)は数値が高いほどよく、分布距離系統(-distribution)はすうちが小さいほど元の分布に近く良い。

DCM-label

pretrained bert NTTCSrun2 HCUrun3
Accuracy 0.523 0.565 0.504
Precision (X) 0.556 (69/124) 0.523 0.520
Recall(X) 0.388 (69/178) 0.584 0.292
F-measure (X) 0.457 0.552 0.374
Precision (T+X) 0.837 (267/319) 0.875 0.910
Recall(T+X) 0.744 (267/359) 0.624 0.396
F-measure (T+X) 0.788 0.728 0.551

DCM-distribution

pretrained bert NTTCSrun2 HCUrun3
JS divergence (O,T,X) 0.083 0.085 0.100
JS divergence (O,T+X) 0.053 0.057 0.072
JS divergence (O+T,X) 0.052 0.054 0.061
Mean squared error (O,T,X) 0.046 0.044 0.055
Mean squared error (O,T+X) 0.055 0.056 0.074
Mean squared error (O+T,X) 0.055 0.054 0.067

DIT-label

pretrained bert NTTCSrun2 HCUrun3
Accuracy 0.62 0.655 0.624
Precision(X) 0.695 (216/311) 0.632 0.655
Recall(X) 0.818 (216/264) 0.943 0.818
F-measure(X) 0.751 0.757 0.727
Precision(T+X) 0.895 (374/418) 0.900 0.904
Recall(T+X) 0.908 (374/412) 0.891 0.842
F-measure(T+X) 0.901 0.910 0.872

DIT-distribution

pretrained bert NTTCSrun2 HCUrun3
JS divergence (O,T,X) 0.043 0.046 0.052
JS divergence (O,T+X) 0.025 0.030 0.033
JS divergence (O+T,X) 0.027 0.030 0.035
Mean squared error (O,T,X) 0.024 0.025 0.029
Mean squared error (O,T+X) 0.025 0.029 0.032
Mean squared error (O+T,X) 0.031 0.034 0.041

IRS-label

pretrained bert NTTCSrun2 HCUrun3
Accuracy 0.595 0.584 0.505
Precision(X) 0.635 (169/266) 0.554 0.534
Recall(X) 0.732 (169/231) 0.801 0.580
F-measure(X) 0.680 0.655 0.556
Precision(T+X) 0.793 (302/381) 0.791 0.757
Recall(T+X) 0.846 (302/357) 0.773 0.602
F-measure(T+X) 0.818 0.782 0.671

IRS-distribution

pretrained bert NTTCSrun2 HCUrun3
JS divergence (O,T,X) 0.091 0.101 0.118
JS divergence (O,T+X) 0.059 0.068 0.081
JS divergence (O+T,X) 0.059 0.070 0.082
Mean squared error (O,T,X) 0.050 0.054 0.066
Mean squared error (O,T+X) 0.061 0.065 0.080
Mean squared error (O+T,X) 0.062 0.074 0.091

感想

分布距離系統ではbertとNTTCSrun2の性能は互角か、bertのほうが若干良い傾向がでている。しかし、ラベル系統ではNTTCSrun2がDCM, DITのAccuracyで上回っている。

Bertのラベル系統の性能はHCUrun3より良いが劇的に変化しているわけではない。

分布距離系統が良いのはsoft targetの影響もあると考えられる(つまり、bertじゃなくてもsoft targetで分布距離を学習させたら改善する可能性がある)。

総合すると、対話のデータではないWikipediaの事前学習と対話のデータ(DBDC-dev, DBDC-eval, DBDC2-dev)だけでもそこそこの性能が出るが、分析に基づいた分類器を超すことはできなかったと言えそう。

もう少し対話に近いデータで事前学習する、マルチターン(あるいは3文以上)を考慮した構造にすることで、もう少し改善するかもしれない。