-
[딥러닝] imbalanced data 학습Deep learning 2022. 9. 21. 23:38
Imbalanced data 학습 방법
기계학습 알고리즘들은 각 클래스의 비율이 비슷한 상황을 가정하기 때문에, 클래스가 불균형한 dataset의 경우 전체적인 데이터에 대해 제대로 학습하지 못하고 큰 비중을 차지하는 클래스에 편향되어 학습한다. 그 결과 정확도는 높으나 정작 원하는 항목에 대해서는 분류해내지 못하는 클래스 불균형 현상이 발생된다. 예를 들면 병원에서 암 진단검사를 받는 환자, 네트워크의 침입탐지, 은행에서의 사기 탐지와 같은 dataset에서 클래스 불균형 현상이 발생한다.
클래스의 불균형을 해소하기 위한 기법으로는 Data Resampling, K-fold cross validation, Weight balancing등이 있다.
- Resample the training set
불균형한 데이터셋에서 균형을 잡는 접근 방식은 under-sampling, over-sampling이다.
Under-sampling은 양이 많은 클래스의 크기를 줄여 dataset의 균형을 유지한다. 이 방법은 데이터의 양이 충분히 많을 때 사용한다. 극도로 적은 클래스의 모든 샘플을 유지하고 양이 많은 클래스의 동일한 수의 샘플을 무작위로 선택하는 방법으로 균형 잡힌 새 dataset을 구성할 수 있다. 반면에 데이터의 양이 부족한 경우엔 Over-sampling을 사용한다. 풍부한 샘플을 제거하는 대신 희귀 샘플의 크기를 늘려 dataset의 균형을 맞추는 방법이다.
이 프로젝트에서는 위 두가지 방법을 조합하여 dataset을 구성하였다.
- K-fold cross validation
불균형 문제를 해결하기 위해 over-sampling을 사용할 경우 cross-validation을 적절하게 적용해야 한다. Over-sampling은 희귀 샘플의 분포 함수를 기반으로 새로운 random data를 생성하기 위해 bootstrap sampling을 적용한다. 만약 cross-validation이 over-sampling 이후에 적용하면 이는 우리의 모델을 특정 artificial bootstrapping 결과에 맞추는 결과가 된다. cross-validation은 over-sampling을 하기 전에 반드시 완료해야 한다.
- Weight balancing
Weight balancing은 train 데이터에서 각 loss를 계산할 때 특정 클래스에 대해서는 더 큰 loss를 계산해 주는 것이다. 희귀 클래스에는 더 큰 정확도가 필요하므로 해당 클래스에 더 큰 loss를 취해주는 것이다. 또 다른 방법으로는 클래스의 비율에 대해 가중치를 두는 방법이다. 예를 들면 두 클래스의 비율이 1:9라면 가중치를 9:1로 줌으로써 전체 클래스의 loss에 동일하게 기여하도록 한다.
imbalanced data 평가 방법
- Evaluation Metrics
Imbalanced data를 사용해서 model을 생성할 경우 evaluation metrics를 부적절하게 해석할 수 있으므로 매우 위험하다. 1과 0중 하나로 분류해야하는 binary classification 문제에서 만약 0이 데이터의 99%를 차지한다면, 모델의 정확도는 99%를 보여주겠지만 이 정확도의 신뢰도는 낮다. 이러한 경우, 아래와 같은 evaluation metrics를 사용하는게 좋다.
- Precision/Specificity: how many selected instances are relevant.
- Recall/Sensitivity: how many relevant instances are selected.
- F1 score: harmonic mean of precision and recall.
- MCC: correlation coefficient between the observed and predicted binary classifications.
- AUC: relation between true-positive rate and false positive rate.
Binary classification 에서 true positive - false positive와 precision - recall 은 threshold를 조정함에 따라 그 수치가 달라지며, 모델이 갖을 수 있는 두 값을 그래프로 그린 것이 ROC curve와 PR curve 이다. 두 curve에 대한 예시와 설명은 아래와 같다.
- ROC curve
모델이 가질 수 있는 true positive와 false positive의 값. ‘Γ’모양이 optimal한 모델의 그래프. 모든 가능한 threshold별 False Positive Rate, True Positive Rate을 한 눈에 알아볼 수 있다.
- Precision-Recall Curve
모델이 가질 수 있는 precision, recall의 값. ‘ㄱ’모양이 optimal 한 모델의 그래프. 이 그래프에서는 precision과 recall에 따른 threshold값까지는 파악할 수 없으나, 모델이 precision 100% 일 때 recall값과, recall 100% 일 때 precision 값에 대하여 파악할 수 있다.
'Deep learning' 카테고리의 다른 글