본문 바로가기
Study/Machine learning

[Machine learning] 쉽게 설명하는 Cross Validation 교차검증

by 후이 (hui) 2020. 7. 26.
728x90
반응형

 

    index 

  1. 교차검증이란?
  2. 교차검증을 사용하는 이유 
  3. 코드 및 결과  
  4. 추가 질문들 (Stratified K-fold 교차검증)

 

1. 교차 검증 (cross validation)

: 모델의 학습 과정에서 학습 / 검증데이터를 나눌때 단순히 1번 나누는게 아니라 

K번 나누고 각각의 학습 모델의 성능을 비교하여 평균 값으로 

 

0) 우선 데이터를 왜 나누나? 

학습 모델의 한명의 학생이라하고, 우리는 대량의 문제집(데이터) 로 학생을 학습시킨다.

이때 효율적인 학습을 위해 문제집(데이터) 학습 분량을 나누는데 

 

학습 데이터 (Training set) - 문제집의 문제은행

검증 데이터 (Validation set) - 문제집에 속한 기출 모의고사 (성능 검증 / 학습에는 활용되지 않음 - 자세한 설명은 뒤에) 

시험 데이터(Test set) - 실제 시험 

 

여기서 질문 Q. Validation set 은 과연 학습에 영향을 미칠까? 

답 : 학습에 참조하긴 한다. 하지만, 학습 가중치 업데이트에는 영향을 미치지 못한다

 

학습 모델은 가중치를 업데이트 하면서 학습을 한다. 

Training set 배치 마다 학습을 하며, 해당 가중치에 대해서 업데이트가 되는데 

validation set 은 한 에폭을 돈 뒤에 학습이 잘되었는지 체크만 해주는 역할이기 때문에, 가중치 업데이트에는 영향을 미치지 않는다. 

 

정리하자면 : 가중치 업데이트에는 영향 x / 학습 과정에서 참조 o / 성능 평가에 활용 o

 

+) 다시 문제집 비유로 돌아오자면, 수능전 9평 6평을 치긴치는데 오답풀이는 안하는?

그때 당시 내 학습 수준을 확인하는 딱 그정도 느낌? 이라고 이해하면 좋을거같다.

 

** validation set 에 대한 더 자세한 설명은 아래의 링크를 따라서

https://untitledtblog.tistory.com/158

 

[머신 러닝] 과적합 (Overfitting)과 Validation Dataset의 개념

1. 과소적합 (underfitting)과 과적합 (overfitting) 머신 러닝의 궁극적인 목표는 training dataset을 이용하여 학습한 모델을 가지고 test dataset를 예측하는 것이다. 이 때 test dataset은 학습 과정에서 참조..

untitledtblog.tistory.com

 

 

1) 교차 검증은 그럼 뭔가? 

Training set과 Validation 을 여러번 나눈 뒤 모델의 학습을 검증하는 방식이다. 

 

 

1. 데이터를 K 등분한다. (이미지의 경우 K=5)

2. 1/5 를 검증데이터로, 나머지 4/5 를 학습 데이터로 

3. 1/5 를 검증데이터를 바꾸며 성능 평가 

--> 총 5개의 성능 결과가 나올 것이다. 이 5개의 평균을 해당 학습 모델의 성능이라 한다. 

 

 

2. 교차 검증의 효과 및 사용 이유 

1) 모든 데이터 셋을 평가에 활용하기 때문에 데이터셋이 부족할 때 적용하는 방법

별도로 Validation set으로 빼두었던 데이터도 다시 학습에 재활용되기 때문에, 전체 데이터가 학습/검증으로 한번에 나누기 작은 경우 위와 같이 여러번 데이터를 나누고 각 교차검증마다의 모델 성능을 비교하는 방식으로 학습을 진행하면 된다. 

 

2) K개의 성능 결과를 통합하여 하나의 결과를 도출하기 때문에 보다 일반화된 모델 성능 평가 가능 

하나의 학습/ 검증 데이터로 이루어진 모델은 해당 학습데이터에만 과적합되었을 가능성이 높다. 

하지만 여러차례 나누는 교차검증 방식을 통해 전체 데이터 전 범위를 학습하고, 검증 데이터로 성능을 평가함으로써

보다 일반화된 모델을 생성할 수 있다. 

 

3. 코드 구현 

from sklearn.tree import DecisionTreeClassifier
from sklearn.model_selection import cross_val_score , cross_validate
from sklearn.datasets import load_iris

iris_data = load_iris()
dt_clf = DecisionTreeClassifier(random_state=156)

data = iris_data.data
label = iris_data.target

# 성능 지표는 정확도(accuracy) , 교차 검증 세트는 3개 
scores = cross_val_score(dt_clf , data , label , scoring='accuracy',cv=3)
print('교차 검증별 정확도:',np.round(scores, 4))
print('평균 검증 정확도:', np.round(np.mean(scores), 4))

 

scikit learn 에서 제공하는 cross_val_score을 활용하면 

교차 검증을 통해 모델 학습을 진행하고 성능을 평가한다. 저기서 cv =3 은 전체 데이터를 3등분하여 교차검증 하는 것. 

 

 

4. 예외적인 상황 

데이터 클래스가 불균형한 경우

 --> "계층별 k-겹 교차 검증(Stratified k-fold cross validation) "

가령 금융거래 사기 분류 모델에서 전체 데이터중 정상 거래 건수는 95% 사기인 거래건수는 5% 라면, 

앞서 설명한 일반적인 교차 검증으로 데이터를 분할했을 때, 사기 거래 건수가 고루 분할 되지 못하고 한 분할에 몰릴 수 있다. 

이때 데이터 클래스 별 분포를 고려해서 데이터 폴드 세트를 만드는 방법이 계층별 k-겹 교차 검증 이다. 

각 데이터 폴드마다 정상 거래건수, 사기 거래 건수가 고루 들어갈 수 있도록 데이터 클래스별 분포를 고려한 분할 방식.

데이터 클래스 별 데이터가 아주 불균형한 상황에서는 이 방법을 써야한다. 

 

 

 코드로 살펴보면 

1) 일반적인 Kfold 를 사용한 경우 

kfold = KFold(n_splits=3)
# kfold.split(X)는 폴드 세트를 5번 반복할 때마다 달라지는 학습/테스트 용 데이터 로우 인덱스 번호 반환. 
n_iter =0
for train_index, test_index  in kfold.split(iris_df):
    n_iter += 1
    label_train= iris_df['label'].iloc[train_index]
    label_test= iris_df['label'].iloc[test_index]
    print('## 교차 검증: {0}'.format(n_iter))
    print('학습 레이블 데이터 분포:\n', label_train.value_counts())
    print('검증 레이블 데이터 분포:\n', label_test.value_counts())

2) stratifiedKfold 를 사용한 경우

from sklearn.model_selection import StratifiedKFold

skf = StratifiedKFold(n_splits=3)
n_iter=0

for train_index, test_index in skf.split(iris_df, iris_df['label']):
    n_iter += 1
    label_train= iris_df['label'].iloc[train_index]
    label_test= iris_df['label'].iloc[test_index]
    print('## 교차 검증: {0}'.format(n_iter))
    print('학습 레이블 데이터 분포:\n', label_train.value_counts())
    print('검증 레이블 데이터 분포:\n', label_test.value_counts())


--> 1) 일반적인 Kfold 를 사용한 경우  결과
데이터 폴드별 레이블 개수가 다 다르다. 

--> 2) stratifiedKfold 를 사용한 경우 결과
데이터 폴드별 레이블 개수가 동일하다. 
세가지 레이블에 상응하는 데이터를 고루 학습하고 있음을 확인

 

 

 

 

출처 : 파이선 머신러닝 완벽가이드

728x90
반응형

댓글