オーバーサンプリングの問題点 - smote VS GICaPS
今回は、不均衡データでのクラス分類の方法および問題点について投稿します。
不均衡データと主要な分析手法
不均衡データとは、クラス分類したい目的変数に大きな偏りがあるデータを指します。例としてよく挙げられるのが、ガンの予測です。100人がガン検診を受診した場合、実際にガンである患者が5人のみだったとします。この結果を予測する場合、全患者が陰性であると予測すると正解率は95%と高確率である一方、陽性患者を誰一人当てることができていません。マイノリティ側(この例では、陽性)の予測精度に関心がある場合、この結果は好ましくありません。また、機械学習モデルを用いる上でも、マジョリティ側の予測精度を向上させたほうが、全体の精度が向上するため、マイノリティ側の予測精度がおざなりになるという問題があります。
そこで、不均衡データを均衡データに変換してしまおうという考えが提案されています。主に、アンダーサンプリングとオーバーサンプリングという手法です。
アンダーサンプリングとは、マイノリティ側のデータ数と一致するようにマジョリティ側のデータを削減する方法です。上記のガンの陽性・陰性の予測例では、マイノリティ側(陽性)は5人ですので、マジョリティ側(陰性)も5人に合わせ全体を10人としデータを再構築します。アンダーサンプリングの問題は、マイノリティ側のデータ数が非常に少ない場合、全体のデータ数が激減してしまうことです。
反対に、オーバーサンプリングはマジョリティ側のデータ数にマイノリティ側のデータ数を合わせる方法です。つまり、マジョリティ側(陰性)は95人ですので、マイノリティ側(陽性)を95人になるようにし、全体で190のデータ数を構築します。アンダーサンプリングでは、データを削減するという方法だったため、全データが実際のデータでしたが、オーバーサンプリングの場合、データを増やすため擬似データを使用することになります。不均衡の傾向が顕著な場合、擬似データの割合が大きくなってしまいます。
どちらの手法にも課題はありますが、疑似データをより正確に生成することが可能となれば、オーバーサンプリングは今後不均衡データを扱う際に主流となる可能性が高いのではないかと考えています。そのため、本記事では、現状のオーバーサンプリングの手法の考察と最新の研究内容について紹介したいと思います。
smote : Synthetic Minority Over-sampling Technique
smoteとは、現在オーバーサンプリングを行う上で主要な手法です。というのも、pythonのモジュールにあるため、簡単に実装することができるためです。私も最近になって、初めて不均衡データというものを知り実装してみましたが非常に簡単でした。一方、モデルの予測精度は向上しませんでした。ビッグデータを扱う上では、一つ一つのデータというのをじっくり見る機会というのは少ないと思います。そのため、簡単に実装できるモジュールを安易に利用するのではなく、その原理と問題点を理解することが求められます。
smoteの原理
smoteの原理は公式のページを参照して頂くことが最も間違いありません。今回は、公式ページを参照して簡単に紹介します。
newサンプルを生成する仕組みは上図に示すように非常に単純です。2つのマイノリティデータを選択し、2点間の線上に乱数λ(0<λ<1)を振るだけです。このアルゴリズムをマイノリティデータがマジョリティデータと一致するまで繰り返すのです。
smoteの実装
実際に簡単な例を作成し実装します。必要なモジュールをimportします。
import pandas as pd import numpy as np import seaborn as sns import matplotlib.pyplot as plt sns.set_context('talk') import random from imblearn.over_sampling import SMOTE # smoteのモジュールとなります
不均衡データは、クラス1が300個、クラス0が10個となるように作成しました。また、説明変数は2つ用意しました。
# クラス0のデータを作成 class0_num = 10 class0_a0 = [random.randint(200, 500) for i in range(class0_num)] class0_a1 = [random.randint(200, 500) for i in range(class0_num)] class0 = [0 for i in range(class0_num)]
# クラス1のデータを作成 class1_num = 300 class1_a0 = [random.randint(0, 500) for i in range(class1_num)] class1_a1 = [random.randint(0, 500) for i in range(class1_num)] class1 = [1 for i in range(class1_num)] # クラス0と1を結合 class1_a0.extend(class0_a0) class1_a1.extend(class0_a1) class1.extend(class0)
作成したデータは下図のようになります。
lt.scatter(class1_a0, class1_a1, color='olive') plt.scatter(class0_a0, class0_a1, color='tomato') plt.show()
赤のデータがクラス0のマイノリティを表しています。今回のデータでは、10:300なので、クラス0のデータを290個増やします。
smote = SMOTE()
origin = pd.DataFrame(data = [class1_a0, class1_a1]).T
origin_smote, class_smote = smote.fit_resample(origin, class1) # smoteの引数は説明変数、目的変数となります
上記コードを実行するだけで、origin_smote, class_smoteとして均衡データが作成されます。実際に散布図で確認します。
index_0 = [i for i, x in enumerate(class_smote) if x == 0] index_1 = [i for i, x in enumerate(class_smote) if x == 0] plt.scatter(class1_a0, class1_a1, color='olive') plt.scatter(origin_smote.loc[index_0][0], origin_smote.loc[index_0][1], color= 'navy') plt.scatter(origin.loc[class1_num:][0], origin[class1_num:][1], color='crimson') plt.title('all data') plt.show() plt.scatter(origin_smote.loc[index_0][0], origin_smote.loc[index_0][1], color= 'navy') plt.scatter(origin.loc[class1_num:][0], origin[class1_num:][1], color='crimson') plt.title('class0 data') plt.show()
青で示されるプロットがsmoteにより生成された擬似データとなります。all dataで生成データの傾向を観察すると、クラス0のデータの範囲内で分散しているように見えます。しかし、クラス0のみで観察した場合、少しいびつなデータとなっていると感じるのではないでしょうか。というのもドーナッツ型のようなデータ補完がされているためです。2点の赤いプロットを結ぶ線上にしかデータは生成されないため、originalデータの位置よって、いびつなデータとなる可能性があります。
同じ条件で乱数を振り、複数回データを生成した結果を載せます。
何度か繰り返しましたがどの場合においても、いびつな傾向が感じられました。これは、優先的に決定される2点の赤プロットが近しい2点であるためです。そのため、対角線上にデータを生成するのではなく、外枠にデータが生成される傾向となります。
smoteの問題点
データの外枠から補完されることの問題点は、説明変数間の相関が変化(低下)してしまうことにあります。実際にdata3でのオリジナルデータとsmoteデータの相関係数を計算すると大きく低下していることとがわかります。
print('original dataの相関係数') print(origin.loc[class1_num:].corr()) print() print('smote dataの相関係数') print(origin_smote.loc[class1_num:].corr())
[out] origin dataの相関係数 0 1 0 1.00000 -0.25939 1 -0.25939 1.00000 smote dataの相関係数 0 1 0 1.000000 -0.090797 1 -0.090797 1.000000
このように、一見上手くデータを補完することができているように見えたとしてもマイノリティ側のデータ傾向は大きく変わっている可能性があります。
説明変数の次元数と相関係数の関係
上記までの例では、smoteを簡単に理解するため説明変数は2次元としていました。しかし、実際に扱うデータは高次元であるため、その影響を確認します。
data3に説明変数と一つ追加し3次元とします。
# a2という説明変数を追加 class0_a2 = [random.randint(200, 500) for i in range(class0_num)] class1_a2 = [random.randint(0, 500) for i in range(class1_num)] class1_a2.extend(class0_a2) origin_3 = pd.DataFrame(data = [class1_a0, class1_a1, class1_a2]).T # smoteを実行 origin3_smote, class_smote = smote.fit_resample(origin_3, class1) # 相関行列を出力 print('original dataの相関係数') print(origin_3.loc[class1_num:].corr().round(3)) print() print('smote dataの相関係数') print(origin3_smote.loc[class1_num:].corr().round(3))
[out] original dataの相関係数 0 1 2 0 1.000 -0.259 -0.080 1 -0.259 1.000 -0.008 2 -0.080 -0.008 1.000 smote dataの相関係数 0 1 2 0 1.000 -0.413 -0.049 1 -0.413 1.000 0.145 2 -0.049 0.145 1.000
2次元データでは説明変数a0, a1の相関係数はsmoteにより低下していましたが、3次元データの場合大きく増加しました。
実際に両データを比較します。
左が説明変数が2次元である場合、右が説明変数が3次元である場合のsmoteの実行結果です。2つの図からわかるように、たしかに3次元の場合がより外枠だけでなく、中心も満遍なく補完されています。これは3次元となったことで、smoteにより結ばれる2点の選択肢が増えたためです。
その傾向を確かめるため、説明変数を10次元とし検証します。
# +7次元分のデータを作成 origin_p = origin_3 for p in range(7): class0_ap = [random.randint(200, 500) for i in range(class0_num)] class1_ap = [random.randint(0, 500) for i in range(class1_num)] class1_ap.extend(class0_ap) class1_ap_df = pd.DataFrame({p+3:class1_ap}) origin_p = pd.concat([origin_p, class1_ap_df], axis = 1) originp_smote, class_smote = smote.fit_resample(origin_p, class1)
print('original dataの相関係数') print(origin_p.loc[class1_num:].corr().round(3)) print() print('smote dataの相関係数') print(originp_smote.loc[class1_num:].corr().round(3))
[out] original dataの相関係数 0 1 2 3 4 5 6 7 8 9 0 1.000 -0.259 -0.080 0.265 -0.071 0.233 -0.123 -0.529 0.030 0.187 1 -0.259 1.000 -0.008 0.387 0.379 0.300 -0.210 0.095 -0.318 -0.273 2 -0.080 -0.008 1.000 -0.382 -0.370 0.558 -0.271 -0.508 -0.501 0.493 3 0.265 0.387 -0.382 1.000 0.302 0.247 0.175 -0.032 0.209 -0.108 4 -0.071 0.379 -0.370 0.302 1.000 -0.424 0.715 0.036 -0.011 -0.677 5 0.233 0.300 0.558 0.247 -0.424 1.000 -0.491 -0.600 -0.599 0.385 6 -0.123 -0.210 -0.271 0.175 0.715 -0.491 1.000 0.066 0.039 -0.340 7 -0.529 0.095 -0.508 -0.032 0.036 -0.600 0.066 1.000 0.391 -0.183 8 0.030 -0.318 -0.501 0.209 -0.011 -0.599 0.039 0.391 1.000 -0.143 9 0.187 -0.273 0.493 -0.108 -0.677 0.385 -0.340 -0.183 -0.143 1.000 smote dataの相関係数 0 1 2 3 4 5 6 7 8 9 0 1.000 -0.375 0.026 0.136 -0.160 0.179 -0.135 -0.569 0.105 0.233 1 -0.375 1.000 -0.013 0.406 0.329 0.333 -0.196 0.123 -0.368 -0.263 2 0.026 -0.013 1.000 -0.415 -0.510 0.633 -0.428 -0.561 -0.517 0.619 3 0.136 0.406 -0.415 1.000 0.355 0.156 0.229 0.059 0.218 -0.212 4 -0.160 0.329 -0.510 0.355 1.000 -0.502 0.775 0.182 0.033 -0.754 5 0.179 0.333 0.633 0.156 -0.502 1.000 -0.585 -0.620 -0.619 0.444 6 -0.135 -0.196 -0.428 0.229 0.775 -0.585 1.000 0.193 0.101 -0.475 7 -0.569 0.123 -0.561 0.059 0.182 -0.620 0.193 1.000 0.409 -0.263 8 0.105 -0.368 -0.517 0.218 0.033 -0.619 0.101 0.409 1.000 -0.149 9 0.233 -0.263 0.619 -0.212 -0.754 0.444 -0.475 -0.263 -0.149 1.000
相関行列から、10次元でのa0とa1のsmote dataの相関係数は3次元だった場合と比較し低下していますが、2次元smote dataと比較すると相関係数は増加しています。この結果から、説明変数の次元を増やした場合、相関係数の振れ幅がoriginal dataの相関係数に収束していくと考えられます。これは、下図に示すように次元が増えることでマイノリティクラスの領域内が次第に埋められていくためだと考えます。
オーバーサンプリングの研究動向
上記までのようにsmoteはシンプルなアルゴリズムである反面、補完データに偏りやマイノリティクラス内での不均衡が生じ得ます。そこで、「A Method for Handling Multi-class Imbalanced Data byGeometry based Information Sampling and Class Prioritized Synthetic Data Generation (GICaPS)」の論文で、データ生成がより自然(均等)となるようなアルゴリズムについて述べられていますので紹介します。
GICaPS (Geometry based Information Sampling and Class Prioritized Synthetic Data Generation)
GICaPSとは、データの幾何学情報(分布)に基づきサンプリングを行う手法です。アンダーサンプリングとオーバーサンプリングを組み合わせた手法が論文内で提案されています。しかし、今回はオーバーサンプリングに着目しそのアルゴリズムを理解します。本論文でsmoteの問題点として、クラス間の干渉の影響が考慮されていないということを挙げています。これは、smoteのアルゴリズムからもわかるようにマイノリティ側のデータ配置のみしか考慮していません。つまり、マジョリティとマイノリティの境界と思われる周辺に擬似データが生成されると機械学習モデルの精度を低下させる危険があります。 簡単な例を下に示します。
この例では、クラス1は0〜300の範囲、クラス0では200〜500としデータを生成しました。そのため、originalデータから確認できるようにクラス0と1ではデータ傾向は大きく異なります。しかし、smoteにより擬似データを生成するとその境界は考慮されていないため、クラスが重なる領域にoriginalでは1個のみでしたが、35個のデータが生成されてしまいました。これを回避する手法として、no man’s landという概念を導入しています。
no man’s landを用いたオーバーサンプリング
no man’s landとは、翻訳すると誰のものでもない土地を意味します。今回の場合では、class0, 1どちらにも属さない領域となります。no man’s landの領域の決定方法は論文内で下図のように説明されています。
図中の(a)が最もシンプルな例です。k近傍法で切り取られた空間Vで、マイノリティクラスの2点(Xm, Xv)を結んだ際に、マジョリティクラスQが存在する場合、長方形(高さ:線分に最も近いQの線分との垂直距離、幅:線分間にQが存在する幅)がNo man's landを表します。
論文で紹介されているNo man's landを求めるより一般化した方法を紹介します。下図では、マジョリティクラスの任意の点を線分ab上に投射した際の長さを求める方法が述べられています。
実装し、線分abとno man's landのlを求めます。
# a, b点を決定し、ベクトルabを求める a = np.array([[class0_a0[1]], [class0_a1[1]]]) b = np.array([[class0_a0[0]], [class0_a1[0]]]) ab = b - a # t1, t2を決定し、at1, at2ベクトルを求める t1 = np.array([[class1_a0[157]], [class1_a1[157]]]) t2 = np.array([[class1_a0[46]], [class1_a1[46]]]) at1 = t1 - a at2 = t2 - a # 論文の式から、pt1, pt2を求める pt1 = (np.dot(ab.T, at1) * ab) / np.dot(ab.T, ab) pt2 = (np.dot(ab.T, at2) * ab) / np.dot(ab.T, ab) apt1 = pt1 + a apt2 = pt2 + a # 散布図を作成 plt.scatter(class1_a0, class1_a1, color='olive', label = 'class1') plt.scatter(origin.loc[class1_num:][0], origin[class1_num:][1], color='crimson', label = 'class0') plt.title('smote data') plt.legend(loc="upper left", bbox_to_anchor=(1.02, 1.0,), borderaxespad=0) plt.scatter(a[0], a[1], color = 'fuchsia') plt.text(a[0], a[1], 'a') plt.scatter(b[0], b[1], color = 'fuchsia') plt.text(b[0], b[1], 'b') plt.scatter(t1[0], t1[1], color = 'lime') plt.text(t1[0], t1[1], 't1') plt.scatter(t2[0], t2[1], color = 'lime') plt.text(t2[0], t2[1], 't2')
plt.scatter(a[0], a[1], color = 'fuchsia') plt.text(a[0], a[1], 'a') plt.scatter(b[0], b[1], color = 'fuchsia') plt.text(b[0], b[1], 'b') plt.scatter(t1[0], t1[1], color = 'lime') plt.text(t1[0], t1[1], 't1') plt.scatter(t2[0], t2[1], color = 'lime') plt.text(t2[0], t2[1], 't2') plt.annotate('', xy = (apt2[0], apt2[1]), xytext = (a[0], a[1]), arrowprops = dict(width = 0.5, headwidth = 7, color = 'black')) plt.annotate('', xy = (apt1[0], apt1[1]), xytext = (a[0], a[1]), arrowprops = dict(width = 0.5, headwidth = 7, color = 'black')) plt.text(apt1[0], apt1[1], 'pt1') plt.text(apt2[0]-50, apt2[1], 'pt2') plt.xlim([0, 500]) plt.ylim([0, 500]) plt.title('pt1 and pt2')
上記図から、pt1およびpt2を求めることができました。しかし、今回選択したt2は線分ab上で垂直に交わりません。no man' landの対象となり得るのは線分ab上で垂直に交わるt1のようなプロットです。これを満たす条件として、下記2つを考えます。
1.線分aptの長さが線分abより短い
2.線分aptと線分abが同方向を向く
また、点aに対しマジョリティクラス全てを点tとみなすわけではありません。論文の通り、Vmという空間を定義し、Vm内のプロットに対しno man's landを決定していきます。上記例では、t2を含めVmを定義した場合、データ数が膨大となりオーバーサンプリング能力と計算量が見合わなくなってしまいます。そのため、KNN(K-近傍法)を用いて適切な範囲のVmを設定する必要があります。
これらを踏まえた上でGICAPsを実装します。
GICaPSの実装
GICaPSを実装するため、まずオーバーサンプリングの対象となるデータ(origin)を用意します。上記までの例で使用したデータを今回も用います。
実装したプログラムが下記となります。
まず、k=5のKNNを用い全マイノリティデータに空間Vmを定義しました。次に、k=5のデータの中に含まれるマイノリティクラスのデータを除外しました。(no man's landに関係するのはマジョリティデータのみであるためです)
そして残ったデータで線分ab上で垂直に交わるデータがtとなります。
no man's landのlは、tのうち線分atの距離が最大となるtmaxと最低となるtminを距離マトリックスから抽出し線分ptminと線分ptmaxの差とします。
この作業を全てのマイノリティクラスのデータに対し適用することで、全てのマイノリティクラス間(線分ab上)のlを求めることができます。線分abに対する線分lの大きさの比がその線分ab間でオーバーサンプリングするデータ数の割合となります。
このようにしてGICaPSでは、クラス間のデータ位置を考慮しオーバーサンプリングすることが可能となります。
H = 300 # オーバーサンプリングするデータ数 p = 1 # no man's landの範囲調整パラメータ # クラスごとのインデックスを取得 index_0 = [i for i, label in enumerate(origin['class_label']) if label == 0] index_1 = [i for i, label in enumerate(origin['class_label']) if label == 1] # class0のデータを抽出 origin_0 = origin[origin['class_label'] == 0] origin_0np = np.array(origin_0.drop('class_label', axis = 1)) # class1のデータを抽出 origin_1 = origin[origin['class_label'] == 1] origin_1np = np.array(origin_1.drop('class_label', axis = 1)) # オーバーサンプリングの対象となるデータ origin_list = np.array(origin.drop('class_label', axis = 1)) # 全てのデータ間の距離マトリックスを作成 distance_list = distance.pdist(origin_list) distance_matrix = squareform(distance_list) # knn空間(k=5)を作成 knn_list = [] # class0でknn5個を抽出したもの S_list = [] # knnの最大距離、最小距離の要素を抽出 for i in range(len(index_0)): # i = a 少数クラスのデータ数分繰り返す idx_6 = np.argpartition(distance_matrix[index_0[i]], 6)[:6] # 行で距離が短い下位6つの要素のインデックスを取得 idx_5 = idx_6[~(idx_6 == i)] # matrixの対角線は0なので除外する knn_list.append(idx_5) # knn内のクラス1のみを抽出 knn_list_majority = [] # for knn_list_row in knn_list: knn_majority = [index for i, index in enumerate(knn_list_row) if index not in index_0] # 少数派を除外 knn_list_majority.append(knn_majority) Sa_list = [] for j in range(len(index_0)): # j = b a = origin_0np[i].reshape(2, 1) b = origin_0np[j].reshape(2, 1) ab = b - a ab_length = np.linalg.norm(ab) # tmax, tmin を求める distance_row = distance_matrix[index_0[i]] knn_distance_row = distance_row[knn_list_majority[i]] if len(knn_distance_row) > 0: # knnにクラス1が含まれている場合 idx_list = np.argpartition(knn_distance_row , kth = len(knn_distance_row)-1) ab_element = [] # 線分ab上にあるtを抽出 for idx in idx_list: knn_idx = knn_list_majority[i][idx] t = origin_list[knn_idx].reshape(2, 1) at = t - a at_internal = np.dot(at.T, ab) at_length = np.linalg.norm(at) theata = np.arccos(at_internal/(at_length * ab_length)) if (0 < theata < math.pi /2) and at_length / ab_length < 1: # 線分atと線分abのなす角が0以上90度以下(abと同じ向き)かつ絶対値が1以下(abの線分内)の場合 ab_element.append(knn_idx) if len(distance_matrix[i][ab_element]) > 1: # knnにクラス1が2つ以上含まれている場合 (1つしかないとl=0となりnomanslandなくなる) idx_max = np.argpartition(distance_matrix[i][ab_element], kth = len(distance_matrix[i][ab_element])-1)[-1] # knnの範囲内で距離が最も大きいmajority_indexを取得 idx_min = np.argpartition(distance_matrix[i][ab_element], kth = len(distance_matrix[i][ab_element])-1)[0] knn_idx_max = ab_element[idx_max] knn_idx_min = ab_element[idx_min] tmax = origin_list[knn_idx_max].reshape(2, 1) tmin = origin_list[knn_idx_min].reshape(2, 1) atmax = tmax - a atmin = tmin -a ptmax = (np.dot(ab.T, atmax) * ab) / np.dot(ab.T, ab) ptmin = (np.dot(ab.T, atmin) * ab) / np.dot(ab.T, ab) aptmax = ptmax + a aptmin = ptmin + a l = np.linalg.norm(aptmax - aptmin) # 線分ab上のno man's land の区間 S = p * (ab_length - l) # 線分ab上のオーバーサンプリング対象区間の大きさ Sa_list.append(S) # aに対する全てのbでのSをSa_listに追加 else: S = p * ab_length #tmax, tminが存在しない場合は線分ab自体がSとなる Sa_list.append(S) # Sa_listは任意のaに対するSを格納 else: S = p * ab_length Sa_list.append(S) S_list.append(Sa_list) # S_lsitは全てのa, bのSを格納 Sa_sum_list = [] # ΣVmSを算出 for i in S_list: Sa_sum_list.append(sum(i)) Nm_list = [] # 任意のaでいくつオーバーサンプリングするか算出 S_sum = sum(Sa_sum_list) # ΣxmΣVmSを算出 for Sa_sum in Sa_sum_list: Nm = H * Sa_sum / S_sum Nm_list.append(Nm) Nv_list = [] # 任意の線分abでいくつオーバーサンプリングするか計算 for i in range(10): Nv_list_a = [] for j in range(10): Nv = (Nm_list[i] * S_list[i][j]) / Sa_sum_list[i] Nv_list_a.append(Nv) Nv_list.append(Nv_list_a) X_new = [] # 新しいデータを生成 for i in range(10): for j in range(10): Nv = Nv_list[i][j] N = round(Nv) a = np.array([[class0_a0[i]], [class0_a1[i]]]) b = np.array([[class0_a0[j]], [class0_a1[j]]]) ab = b - a for y in range(N): r = np.random.random(a.shape[0]) rm = np.array([[r[0]], [r[1]]]) X = a + y * (ab / Nv) + rm X_new.append(X)
GICaPSとsmoteの比較
マイノリティデータとマジョリティデータのオリジナルデータを下図に示します。
オリジナルデータに対しsmoteとGICaPSでそれぞれオーバーサンプリングした結果が下図になります。
GICaPSとsmoteを重ねた図が下図になります。
上図より、GICaPSでは、クラス間のデータ位置が考慮されているため、マジョリティクラスが存在する範囲のオーバーサンプリング数はsmoteより少ないことがわかります。また、smoteでは、マイノリティクラスデータを結ぶ線分上かつマイノリティデータ間でKNNが適用されるため、オーバーサンプリングした結果が不均一となっていました。しかし、GICaPSでは、全てのマイノリティデータ間でオーバーサンプリングを行うため不均一性が和らいでいることが確認できます。
以上より、GICaPSはオーバーサンプリング手法として有効であると考えます。しかし、予測に疑似データを使用する危険性を無くせたわけではありません。オリジナルデータおよびオーバーサンプリング後のデータを確認することが必要だと考えます。