NTTドコモR&Dの技術ブログです。

マルチラベル分類モデル/多クラス分類モデルの閾値を最適な値に調整してみよう

こんにちは、ドコモの何と言います。
業務では自然言語処理に関する研究に取り組んでいます。
本記事は、ドコモアドベントカレンダー8日目の記事になります。
本記事では、BERTを使ったマルチラベル分類の閾値の調整に関する技術について紹介します。
マルチラベル分類とは、一つの文章に対して、複数のラベルを付与するような問題のことを指します。
一般的なマルチラベル分類では、まず、ラベル付けを行いたい文章を分類モデルに入力し、それぞれのラベルに対する確率値を計算します。
次に、モデルで計算した各ラベルの確率値に対して、あらかじめ決めておいた閾値を超えた場合にそのラベルを付与し、下回った場合には付与しないという2値分類を全てのラベルで行います。
この方法は、多くの場合でうまくいくのですが、次のような条件が重なるとうまくいきません。

  • 分類したいラベルの中に、互いに近い意味のラベルがいくつか存在する場合
  • 近い意味のラベル間に不均衡が生じている場合
    ※「ラベル間に不均衡が生じている」とは、あるラベルを付与されたデータ数は多く、他のあるラベルを付与されたデータ数は少ないという様なデータ内のラベル量に差がある状況を指します。

このような状況では、データ量が多いラベルの確率値が高くなり、データ量が少ないラベルの確率値が低くなってしまうという現象が起こります。
その結果、データ量の少ないラベルがデータ量の多いラベルに引っ張られ、上手く分類できないという問題が発生してしまいます。
解決方法として、次の3つの方法がよく用いられます。

  1. 近い意味のラベル同士を1つにまとめて新しいラベルにします。
  2. データ量が多いラベルが付与されているデータを削減します。
  3. データ量が少ないラベルのデータを新たに取得します。

1はこの中で最もよく用いられる解決方法だと思います。
ニュース記事に対するラベル付けを例にすると、「政治」と「経済」がお互いに近いラベルの場合、2つのラベルを結合して「社会」の様なもう少し上位のラベルを作成するイメージです。
この方法は、「社会」という少し荒い粒度の分類で十分だという場合には非常に有効な方法です。
2と3は不均衡データに対して行われる代表的な方法です。
3は可能であるなら行った方が良いですが、実際には新たなデータ取得が困難であったり、コスト的に出来ない場合が存在します。
2は入門書籍にもよく載っている方法で、データ量が多いラベルが付与されているデータを捨ててしまう事で不均衡を解消します。
この方法の欠点としては、データ量が多いラベルが付与されたデータを捨てて学習する必要があるため、そのデータに付与されている別のラベルにも影響が出てしまい、結果として他のラベルの分類性能にも影響が出る可能性があることです。
上で書いた3つの解決方法は、いずれも学習時に行う方法ですが、学習が終わったモデルに対してパラメータを調整する方法もあります。
はじめにも書きましたが、マルチラベル分類では、あらかじめ決めておいた閾値によってラベルをつけるかどうかの決定を行います。
その値を最適な値に調整する事でうまくラベル付けができる様になる場合があるのです。
そこで、この記事ではマルチラベル分類における閾値の調整方法をご紹介したいと思います。

事前準備

必要なライブラリをインストールします。

pip install numpy
pip install matplotlib

# datasetsライブラリ
pip install datasets

# 機械学習ライブラリ
pip install -U scikit-learn

# tensorflowで学習します。
pip install tensorflow==1.15.0

データセットについて

今回、go_emotionsデータセットを利用します。
go_emotionsデータセットには、27個の感情カテゴリまたはニュートラルのラベルが付けられた、慎重に精選された58,000件のRedditコメントが含まれています。生データだけでなく、あらかじめ定義されたトレーニング/検証/テストの分割を含む、データセットの小さく簡略化されたバージョンも含まれています。
下記のコードでgo_emotionsデータセットの小さく簡略化されたバージョンをダウンロードできます。

from datasets import load_dataset

dataset = load_dataset('go_emotions', 'simplified')

datasetの構成を確認します。

DatasetDict({
    train: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 43410
    })
    validation: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5426
    })
    test: Dataset({
        features: ['text', 'labels', 'id'],
        num_rows: 5427
    })
})

datasetの中身は下記のような感じです。

{'text': ['To make her feel threatened',
  'Dirty Southern Wankers',
  "OmG pEyToN iSn'T gOoD eNoUgH tO hElP uS iN tHe PlAyOfFs! Dumbass Broncos fans circa December 2015.",
  'Yes I heard abt the f bombs! That has to be why. Thanks for your reply:) until then hubby and I will anxiously wait 😝',
  'We need more boards and to create a bit more space for [NAME]. Then we’ll be good.',
  'Damn youtube and outrage drama is super lucrative for reddit',
  'It might be linked to the trust factor of your friend.'],
 'labels': [[14], [3], [26], [15], [8, 20], [0], [27]],
 'id': ['ed7ypvh',  'ed0bdzj', 'edvnz26', 'ee3b6wu', 'ef4qmod', 'ed8wbdn', 'eczgv1o']}

上記のように、1文あたりに1個以上のラベルがついたデータセットとなっています。
ラベルはidが0から27の全部で28個ありました。
ラベル詳細:

# 各ラベルidの内容
emotions = {"0": "admiration", "1": "amusement", "2": "anger", "3": "annoyance", "4": "approval", "5": "caring", "6": "confusion", "7": "curiosity", "8": "desire", "9": "disappointment", "10": "disapproval", "11": "disgust", "12": "embarrassment", "13": "excitement", "14": "fear", "15": "gratitude", "16": "grief", "17": "joy", "18": "love", "19": "nervousness", "20": "optimism", "21": "pride", "22": "realization", "23": "relief", "24": "remorse", "25": "sadness", "26": "surprise", "27": "neutral"}

評価指標

マルチラベル分類問題でよく使われる評価指標は、Precision, Recall, F1-scoreです。
それぞれの計算式:

# 適合率
precision = tp / (tp + fp) 

# 再現率
recall = tp / (tp + fn)

# precisionとrecallの調和平均を計算します。
f1_score = 2 * (precision * recall) / (precision + recall)

結果を比較する時、precisionとrecall両方とも重視する場合は、f1_scoreで判断を行います。f1_score値は大きいほうがいいです。
今回の実験も上記の三つの指標を使って評価を行います。

precision_recall_curveについて

precision_recall_curveはsklearn.metrics中の関数のひとつです。
今回、この関数を使って閾値を調整します。
この関数を使って、異なる閾値(確率)のprecisionとrecallを計算できます。
具体例

import numpy as np
from sklearn.metrics import precision_recall_curve

y_trues = np.array([0, 0, 1, 1, 1, 0, 1])
y_scores = np.array([0.6, 0.7, 0.51, 0.8, 0.7, 0.56, 0.67])
precisions, recalls, thresholds = precision_recall_curve(y_trues, y_scores)

実行結果

>>> precisions
array([0.57142857, 0.5, 0.6, 0.75, 0.66666667, 1., 1.])
>>> recalls
array([1., 0.75, 0.75, 0.75, 0.5, 0.25, 0.])
>>> thresholds
array([0.59, 0.6 , 0.7 , 0.8 , 0.9 ])

この結果は例えばthresholdが0.59の時、precisionは0.5714、recallは1.0ということを表します。

  • threshold = 0.59の時、precision = 0.57 and recall = 1.0
  • threshold = 0.8の時、precision = 0.75 and recall = 0.75

precisionとrecall両方とも重視する指標として、f1_scoreを利用します。

  • thresholdが0.59の時、precisionは0.57でrecallが1.0なのでf1_scoreは0.73になります。
  • thresholdが0.8の時、precisionは0.75でrecallが0.75なのでf1_scoreは0.75になります。

threshold=0.8の時のf1_score = 0.75のほうが大きいので、閾値を0.8に設定するほうが良いです。

Bertで学習

この記事では閾値の調整の仕方を紹介したいので、google-researchのbertのソースコードを使って簡単に学習を行いました。
run_classifier.pyでは、multi-label分類問題に利用できないため、それをベースにして改造しました。
下記の順番で実験を行いました。

  • STEP1:データを作成します。
# 学習用のデータ
train_texts = [_ for _ in dataset['train']['text']]
train_ids = [_ for _ in dataset['train']['id']]
train_labels = [_ for _ in dataset['train']['labels']]

new_train_labels = []
for labels in train_labels:
    cur_labels = ["0"] * 28
    for label in labels:
        cur_labels[label] = "1"

    new_train_labels.append(cur_labels)

# テスト用のデータ
test_texts = [_ for _ in dataset['test']['text']]
test_ids = [_ for _ in dataset['test']['id']]
test_labels = [_ for _ in dataset['test']['labels']]

new_test_labels = []
for labels in test_labels:
    cur_labels = ["0"] * 28
    for label in labels:
        cur_labels[label] = "1"

    new_test_labels.append(cur_labels)

# 検証用のデータ
dev_texts = [_ for _ in dataset['validation']['text']]
dev_ids = [_ for _ in dataset['validation']['id']]
dev_labels = [_ for _ in dataset['validation']['labels']]

new_dev_labels = []
for labels in dev_labels:
    cur_labels = ["0"] * 28
    for label in labels:
        cur_labels[label] = "1"

    new_dev_labels.append(cur_labels)

# データをファイルに保存します。
with open("train.tsv", "w", encoding="utf-8") as file:
    for index, item in enumerate(train_ids):
        file.write(item + "\t" + train_texts[index] + "\t" + "\t".join(new_train_labels[index]) + "\n")

with open("test.tsv", "w", encoding="utf-8") as file:
    for index, item in enumerate(test_ids):
        file.write(item + "\t" + test_texts[index] + "\t" + "\t".join(new_test_labels[index]) + "\n")

with open("dev.tsv", "w", encoding="utf-8") as file:
    for index, item in enumerate(dev_ids):
        file.write(item + "\t" + dev_texts[index] + "\t" + "\t".join(new_dev_labels[index]) + "\n")
  • STEP2: run_classifier.pyからコードをコピーして、一個新しいファイルrun_multi_label_classifier.pyを作ります。
  • STEP3: run_multi_label_classifier.pyの中身を修正します。
    まず、一個新しいProcessorを追加します。
class GoEmotionsProcessor(DataProcessor):

    def get_train_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "train.tsv")), "train")

    def get_dev_examples(self, data_dir):
        """See base class."""
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "dev.tsv")), "dev")

    def get_test_examples(self, data_dir):
        return self._create_examples(
            self._read_tsv(os.path.join(data_dir, "test.tsv")), "test")

    def get_labels(self):
        """See base class."""
        return list(range(28))

    def _create_examples(self, lines, set_type):
        """Creates examples for the training and dev sets."""
        examples = []
        for (i, line) in enumerate(lines):
            guid = line[0]
            text_a = line[1]

            if set_type == "test":
                labels = [0] * 28
            else:
                labels = [int(_) for _ in line[2:]]

            examples.append(
                InputExample(guid=guid, text_a=text_a, label=labels))

        return examples

次は、convert_single_example関数の中身の修正です。

# 元のコード
# label_id = label_map[example.label]

# 下記のように修正します。
label_ids = [int(_) for _ in example.label]

...... 

feature = InputFeatures(
        input_ids=input_ids,
        input_mask=input_mask,
        segment_ids=segment_ids,
        label_id=labels_ids,
        is_real_example=True)
return feature

次は、file_based_input_fn_builder関数の中身の修正です。

name_to_features = {
        "input_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "input_mask": tf.FixedLenFeature([seq_length], tf.int64),
        "segment_ids": tf.FixedLenFeature([seq_length], tf.int64),
        "label_ids": tf.FixedLenFeature([28], tf.int64), # ラベル数を追加します。
        "is_real_example": tf.FixedLenFeature([], tf.int64),
    }

その後、create_model関数の中身の修正です。

# 元のコード:multi-class-classificationの場合は、softmaxを使います。
# probabilities = tf.nn.softmax(logits, axis=-1)
# log_probs = tf.nn.log_softmax(logits, axis=-1)
# one_hot_labels = tf.one_hot(labels, depth=num_labels, dtype=tf.float32)
# per_example_loss = -tf.reduce_sum(one_hot_labels * log_probs, axis=-1)

# 今回、multi-label-classificationなので、sigmoidに変更します。
probabilities = tf.nn.sigmoid(logits)
labels = tf.cast(labels, tf.float32)
per_example_loss = tf.nn.sigmoid_cross_entropy_with_logits(labels=labels, logits=logits)

return (loss, per_example_loss, logits, probabilities)

最後、main関数に感情分類のタスクを追加します。

processors = {
      "cola": ColaProcessor,
      "mnli": MnliProcessor,
      "mrpc": MrpcProcessor,
      "xnli": XnliProcessor,
      "go_emotions": GoEmotionsProcessor #追加したタスク
}
  • STEP4: bert modelをダウンロードします。今回BERT-Baseの方を利用します。
  • STEP5: 環境変数を設定します。
# base bert modelのパス
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

# DATA_DIRは学習データと検証データのパス
export DATA_DIR=/path/to/data
  • STEP6: 下記のコマンドで学習をスタートします。学習データモデルが/tmp/go_emotions_output/に出力されます。
python run_multi_label_classifier.py \
  --task_name=go_emotions \
  --do_train=true \
  --do_eval=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$BERT_BASE_DIR/bert_model.ckpt \
  --max_seq_length=128 \
  --train_batch_size=32 \
  --learning_rate=2e-5 \
  --num_train_epochs=3.0 \
  --output_dir=/tmp/go_emotions_output/
  • STEP7: 学習データモデルを使ってテストデータのラベルを予測します。テスト結果が/tmp/go_emotions_output/test_results.tsvに出力されます。
# base bert modelのパス
export BERT_BASE_DIR=/path/to/bert/uncased_L-12_H-768_A-12

# test dataのパス
export DATA_DIR=/path/to/test_data

# 学習済モデルのパス
export TRAINED_CLASSIFIER=/tmp/go_emotions_output/

python run_multi_label_classifier.py \
  --task_name=go_emotions \
  --do_predict=true \
  --data_dir=$DATA_DIR \
  --vocab_file=$BERT_BASE_DIR/vocab.txt \
  --bert_config_file=$BERT_BASE_DIR/bert_config.json \
  --init_checkpoint=$TRAINED_CLASSIFIER \
  --max_seq_length=128 \
  --output_dir=/tmp/go_emotions_output/

テストについて

予測したラベル結果(test_results.tsv)と実際ラベルの情報(test.tsv)を読み取ります。

# test data
with open("test.tsv", "r", encoding="utf-8") as file:
    test_data = file.readlines()

# 予測したラベル結果
with open("/tmp/go_emotions_output/test_results.tsv", "r", encoding="utf-8") as file:
    result = file.readlines()

# 実際のラベルのまとめ
true_labels = [[0] * len(test_data) for _ in range(28)]
for d_index, item in enumerate(test_data):
    items = item.replace("\n", "").split("\t")
    labels = items[2:]
    for l_index, label in enumerate(labels):
        if label == "1":
           true_labels[l_index][d_index] = 1

# 予測したラベル結果のまとめ
prob_scores = [[] * len(test_data) for _ in range(28)]
for item in result:
    item = item.replace("\n", "").split(",")
    labels = item[1:]
        for index, item in enumerate(labels):
          prob_scores[index].append(float(item))

下記の関数で各ラベルのprecision, recall, f1-scoreの算出ができます。

def calculate_precision_recall_and_f1(trues, probs, threshold):
    tp, fp, tn, fn = 0, 0, 0, 0

    for index, item in enumerate(trues):
        if probs[index] >= threshold:
            cur_pred_label = 1
        else:
            cur_pred_label = 0

        if item == 1 and cur_pred_label == 1:
            tp += 1
        elif item == 1 and cur_pred_label == 0:
            fn += 1
        elif item == 0 and cur_pred_label == 1:
            fp += 1
        else:
            tn += 1

    if (fp + tp) == 0:
        precision = 0.0
    else:
        precision = tp / (fp + tp)

    if (tp + fn) == 0:
        recall = 0.0
    else:
        recall = tp / (tp + fn)

    if (precision + recall) == 0:
        f1_score = 0.0
    else:
        f1_score = 2 * precision * recall / (precision + recall)

    return tp, fp, tn, fn, precision, recall, f1_score

調整前のテスト結果

全てのラベルの閾値を0.5に設定、各ラベルの精度を計算します。

ttp, tfp, ttn, tfn = 0, 0, 0, 0
for index in range(28):
    tp, fp, tn, fn, precision, recall, f1_score = calculate_precision_recall_and_f1(true_labels[index], prob_scores[index], 0.5)
    print(tp, fp, tn, fn, precision, recall, f1_score)

    ttp += tp
    tfp += fp
    ttn += tn
    tfn += fn

precision = ttp / (tfp + ttp)
recall = ttp / (ttp + tfn)
f1_score = 2 * precision * recall / (precision + recall)

print("全体的な精度情報:")
print(ttp, tfp, ttn, tfn, precision, recall, f1_score)

結果は下記になります。

label precision recall f1-score
admiration 0.707 0.651 0.678
amusement 0.794 0.848 0.821
anger 0.640 0.369 0.468
annoyance 0.571 0.0125 0.024
approval 0.651 0.239 0.35
caring 0.75 0.156 0.258
confusion 0.631 0.268 0.376
curiosity 0.564 0.451 0.501
desire 0.72 0.217 0.333
disappointment 0.0 0.0 0.0
disapproval 0.551 0.202 0.296
disgust 0.889 0.260 0.403
embarrassment 0.0 0.0 0.0
excitement 0.947 0.175 0.295
fear 0.854 0.449 0.588
gratitude 0.951 0.889 0.919
grief 0.0 0.0 0.0
joy 0.748 0.516 0.610
love 0.817 0.824 0.820
nervousness 0.0 0.0 0.0
optimism 0.721 0.430 0.539
pride 0.0 0.0 0.0
realization 0.0 0.0 0.0
relief 0.0 0.0 0.0
remorse 0.582 0.571 0.577
sadness 0.724 0.404 0.519
surprise 0.691 0.397 0.505
neutral 0.741 0.534 0.621
total 0.740 0.448 0.558

調整後のテスト結果

precision_recall_curveを利用して、各ラベルのPrecision-Recall曲線を確認します。
例:ラベルamusementのPrecision-Recall曲線

import matplotlib.pyplot as plt
from sklearn.metrics import precision_recall_curve
    
precisions, recalls, thresholds = precision_recall_curve(true_labels[1], prob_scores[1])

fig, ax = plt.subplots()
ax.set_title('Precision-Recall Curve')
ax.set_ylabel('Precision')
ax.set_xlabel('Recall')
ax.plot(recalls, precisions)
plt.show()

このラベルの一番いい結果を得るため、上の図でf1_score(precisionとrecallの調和平均)が最大値に達する点(閾値)を見つける必要があります。
どうやって見つけるか、具体的なやり方は下記になります。

  • STEP1: precision_recall_curveでこのラベルのprecisions, recallsとthresholdsの情報を得ます。
  • STEP2: precisionsとrecallsの情報を利用して各f1_scoreを算出します。
  • STEP3: 算出されたf1_scoreの中の値が一番最大のindexを探し出します。
  • STEP4: thresholdsの中の同じindexの値はこのラベルの一番いい閾値です。

上記をコード化すると、下記のような感じです。

# ラベル単位で一番いい閾値を探し出します。
def find_out_the_best_threshold(trues, probs):
    import numpy as np
    from sklearn.metrics import precision_recall_curve

    y_trues = np.array(trues)
    y_scores = np.array(probs)

    precisions, recalls, thresholds = precision_recall_curve(y_trues, y_scores)
    min_length = min(min(len(precisions), len(recalls)), len(thresholds))

    # 全てのthresholdのf1-scoreを算出します。
    max_f1_score, max_f1_index = 0.0, 0
    for index, item in enumerate(precisions[:min_length]):
        if (item + recalls[index]) == 0:
            cur_f1_score = 0.0
        else:
            cur_f1_score = 2 * item * recalls[index] / (item + recalls[index])

        if cur_f1_score > max_f1_score:
            max_f1_score = cur_f1_score
            max_f1_index = index

    return thresholds[max_f1_index]

全てのラベルを自分の一番いい閾値に設定して、もう一回精度を計算します。

ftp, ffp, ftn, ffn = 0, 0, 0, 0
for index in range(28):
    threshold = find_out_the_best_threshold(true_labels[index], prob_scores[index])
    tp, fp, tn, fn, precision, recall, f1_score = calculate_precision_recall_and_f1(true_labels[index], prob_scores[index], threshold)
    print(tp, fp, tn, fn, precision, recall, f1_score)

    ftp += tp
    ffp += fp
    ftn += tn
    ffn += fn

precision = ftp / (ffp + ftp)
recall = ftp / (ftp + ffn)
f1_score = 2 * precision * recall / (precision + recall)
print("調整した後の精度情報:")
print(ftp, ffp, ftn, ffn, precision, recall, f1_score)

調整した後の結果です。

label threshold 調整前のprecision 調整後のprecision 調整前のrecall 調整後のrecall 調整前のf1-score 調整後のf1-score
admiration 0.364 0.707 0.680 0.651 0.728 0.678 0.703
amusement 0.483 0.794 0.791 0.848 0.860 0.821 0.824
anger 0.464 0.640 0.590 0.369 0.429 0.468 0.497
annoyance 0.267 0.571 0.326 0.0125 0.413 0.024 0.364
approval 0.232 0.651 0.432 0.239 0.430 0.35 0.431
caring 0.312 0.75 0.516 0.156 0.356 0.258 0.421
confusion 0.331 0.631 0.543 0.268 0.412 0.376 0.468
curiosity 0.200 0.564 0.460 0.451 0.820 0.501 0.589
desire 0.102 0.72 0.519 0.217 0.506 0.333 0.512
disappointment 0.200 0.0 0.381 0.0 0.265 0.0 0.313
disapproval 0.246 0.551 0.403 0.202 0.438 0.296 0.420
disgust 0.208 0.889 0.514 0.260 0.463 0.403 0.487
embarrassment 0.066 0.0 0.577 0.0 0.405 0.0 0.476
excitement 0.201 0.947 0.52 0.175 0.379 0.295 0.438
fear 0.278 0.854 0.651 0.449 0.692 0.588 0.671
gratitude 0.311 0.951 0.936 0.889 0.915 0.919 0.925
grief 0.015 0.0 0.019 0.0 0.167 0.0 0.034
joy 0.396 0.748 0.705 0.516 0.565 0.610 0.628
love 0.457 0.817 0.807 0.824 0.845 0.820 0.825
nervousness 0.070 0.0 0.5 0.0 0.217 0.0 0.303
optimism 0.333 0.721 0.712 0.430 0.532 0.539 0.609
pride 0.010 0.0 0.010 0.0 0.313 0.0 0.019
realization 0.11 0.0 0.256 0.0 0.159 0.0 0.200
relief 0.025 0.0 0.182 0.0 0.182 0.0 0.182
remorse 0.295 0.582 0.625 0.571 0.804 0.577 0.703
sadness 0.271 0.724 0.542 0.404 0.538 0.519 0.540
surprise 0.405 0.691 0.623 0.397 0.468 0.505 0.534
neutral 0.206 0.741 0.595 0.534 0.834 0.621 0.695
total 0.740 0.541 0.448 0.649 0.558 0.590

まとめ

今回は、分類モデルの精度向上のために、precision-recall curveの値から、分類の閾値の調整を行いました。

  • 一般的なモデル学習の精度改善方法として、学習データの調整などの方法ありますが、出力結果の確率情報を活用する方法でも、精度が向上することを確認できました。
  • しかしながら、今回の方法では、解決できない問題があります。それは、データの少ないラベルの精度向上です。学習データが少ないため、ラベルの確率値の調整が精度向上につながりませんでした。

参考・引用