알코올 도수, 당도, pH를 이용해 화이트와인과 레드와인을 구분하는 task를 풀어보자.
info() : 데이터프레임의 각 열의 데이터 타입과 누락된 데이터가 있는지 확인하는 데 유용.
describe() : 열에 대한 간략한 통계를 출력. 최소, 최대, 평균값 등을 볼 수 있다.
여기서 알코올 도수와 당도, pH 값의 스케일이 다르니 이전에 했던 것처럼 사이킷런의 StandardScaler클래스를 사용해 특성을 표준화해야한다.
점수가 높지 않다. 훈련 세트와 테스트 세트의 점수가 모두 낮으니 모델이 과소적합된 것 같다. 이 모델을 설명하기 위해 로지스틱 회귀가 학습한 계수와 절편을 출력해 보자.
이렇게 보면 우리가 이 모델이 왜 저런 계수 값으로 학습했는지 정확히 이해하기 어렵다. 아마도 알코올 도수와 당도가 높을수록 화이트 와인일 가능성이 높고, pH가 높을수록 레드 와인일 가능성이 높은 것 같지만 정확히 이숫자가 어떤 의미인지 설명하기는 어렵다. 대부분 머신러닝 모델은 이렇게 학습의 결과를 설명하기 어렵다.
결정 트리
결정 트리 모델은 이유를 설명하기 쉽다. 마치 스무고개와 비슷하다. 사이킷런의 DecisionTreeClassifier클래스를 사용해 결정 트리 모델을 훈련해 보자.
훈련 세트의 점수에 비해 테스트 세트는 조금 낮다. 과대적합된 모델이라고 볼 수 있다. 그런데 이 모델을 그림으로 어떻게 표현할 수 있을까? 사이킷런은 plot_tree()함수를 사용해 결정 트리를 이해하기 쉬운 트리 그림으로 출력해 준다.
맨 위의 노드를 루트 노드(root node), 맨 아래 끝에 달린 노드를 리프 노드(leaf node)라고 한다.
너무 복잡하니 plot_tree() 함수에서 트리의 깊이를 제한해서 출력해 보자. max_depth매개변수를 1로 주면 루트 노드를 제외하고 하나의 노드를 더 확장하여 그린다. 또 filled 매개변수에서 클래스에 맞게 노드의 색을 칠할 수 있다. feature_names 매개변수에는 특성의 이름을 전달할 수 있다. 이렇게 하면 노드가 어떤 특성으로 나뉘는지 좀 더 잘 이해할 수 있을 것이다.
루트 노드는 sugar이 -0.239 이하인지 질문을 한다. 만약 어떤 샘플의 당도가 -0.239와 같거나 작으면 왼쪽 가지로 간다. 그렇지 않으면 오른쪽 가지로 이동한다. 즉 왼쪽이 Yes, 오른쪽이 No이다. 루트 노드의 총 샘플 수는 5197개이다. 이 중에서 음성 클래스(레드 와인)은 1258개이고, 양성 클래스(화이트 와인)은 3939개이다. 이 값이 value에 나타나 있다.
이어서 왼쪽 노드를 보자. 이 노드는 당도가 더 낮은지를 물어본다. 당도가 -0.802와 같거나 낮다면 다시 왼쪽 가지로, 그렇지 않으면 오른쪽 가지로 이동한다. 이 노드에서 음성 클래스와 샹성 클래스의 샘플 새우는 각각 1177개와 1745개이다. 루트 노드보다 양성 클래스, 즉 화이트 와인의 비율이 크게 줄어들었다. -> 화이트 와인이 오른쪽 노드로 많이 갔기 때문.
오른쪽 노드는 음성 클래스가 8개 양성 클래스가 2194개로 대부분의 화이트 와인 샘플이 이 노드로 이도앻ㅆ다. 노드의 바탕 색깔을 유심히 보자. 루트 노드보다 이 노드가 더 진하고, 왼쪽 노드는 더 연해졌다. plot_tree()함수에서 filled=True로 지정하면 클래스마다 색깔을 부여하고, 어떤 클래스의 비율이 높아지면 점점 진한 색으로 표시하여 직관적으로 알아볼 수 있다.
결정 트리에서 예측하는 방법은 간단하다. 리프 노드에서 가장 많은 클래스가 예측 클래스가 된다. 만약 이 결정 트리의 성장을 여기서 멈춘다면 왼쪽 노드에 도달한 샘플과 오른쪽 노드에 도달한 샘플은 모두 양성 클래스로 예측된다. 두 노드 모두 양성 클래스의 개수가 많기 때문이다.
불순도
노드 상자 안에 gini는 지니 불순도를 의미한다. DecisionTreeClassifier클래스의 criterion매개변수의 기본값이 gini이다. 앞의 그린 트리에서 루트 노드는 어떻게 당도 -0.239를 기준으로 왼쪽과 오른쪽 노드로 나누었을가? criterion매개변수에 지정한 지니 불순도를 사용한 것이다. 그럼 지니 불순도는 어떻게 계산하는지 알아보자.
지니 불순도 = 1 - (음성 클래스 비율**2 + 양성 클래스 비율**2)
다중 클래스 문제라면 클래스가 더 많겠지만 계산하는 방법은 동일하다. 그럼 이전 트리 그림에 있던 루트 노드의 지니 불순도를 계산해 보자. 루트 노드는 총 5197개의 샘플이 있고 그중에 1258개가 음성 클래스, 3939개가 양성 클래스이다. 따라서 다음과 같이 지니 불순도를 계산할 수 있다.
1 - ((1258 / 5197)**2 + (3939 / 5197)**2 ) = 0.367
만약 노드에 하나의 클래스만 있다면 지니 불순도는 0이 되어 가장 작다. 이런 노드를 순수 노드라고 한다.
결정 트리 모델은 부모 노드와 자식 노드의 불순도 차이가 가능한 크도록 트리를 성장시킨다. 부모 노드와 자식 노드의 불순도 차이를 계산하는 방법을 알아보자. 먼저 자식 노드의 불순도를 샘플 개수에 비례하여 모두 더한다. 그다음 부모 노드의 불순도에서 빼면 된다.
부모의 불순도 - (왼쪽 노드 샘플 수 / 부모의 샘플 수) * 왼쪽 노드 불순도 - (오른쪽 노드 샘플 수 / 부모의 샘플 수) * 오른쪽 노드 불순도 = 0.367 - (2922 / 5197) * 0.481 - (2275 / 5197) * 0.069 = 0.066
이렇게 부모와 자식 노드 사이의 불순도 차이를 정보 이득이라고 부른다.이제 결정 트리의 노드를 어떻게 나누는지 이해했다. 이 알고리즘은 정보 이득이 최대가 되도록 데이터를 나누는 것이다. 이때 지니 불순도를 기준으로 사용한다. 그런데 사이킷런에 또 다른 불순도 기준이 있다.
DecisionTreeClassifier클래스에서 criterion = 'entropy'를 지정하여 엔트로피 불순도를 사용할 수 있다. 엔트로피 불순도도 노드의 클래스 비율을 사용하지만 지니 불순도처럼 제곱이 아니라 밑이 2인 로그를 사용하여 곱한다. 보통 기본값인 지니 불순도와 엔트로피 불순도가 만든 결과의 차이는 크지 않다.
가지치기
열매를 잘 맺기 위해 과수원에서 가지치기를 하는 것처럼 결정 트리도 가지치기를 해야한다. 그렇지 않으면 무작정 끝까지 자라나는 트리가 만들어진다. 훈련 세트에는 아주 잘 맞겠지만 테스트셋에서는 점수가 안 좋을 것이다. 이를 일반화가 안 되었다고 말한다.
결정 트리에서 가지치기를 하는 가장 간단한 방법은 자라날 수 있는 트리의 최대 깊이를 지정하는 것이다. DecisionTreeClassitier 클래스의 max_depth매개별수를 3으로 지정하여 모델을 만들어 보자. 이렇게 하면 루트 노드 아래로 최대 3개의 노드까지만 성장할 수 있다.
왼쪽에서 세번째에 있는 노드만 음성 클래스가 더 많은 것으로 보인다. 이 노드에 도착해야만 레드와인으로 예측한다. 그럼 루트 노드부터 이 노드까지 도달하려면 당도는 -0.239보다 작고 -0.802보다 커야한다. 그리고 알코올 도수는 0.454보다 작아야한다. 그러면 레드 와인으로 분류된다.
그런데 -0.802라는 음수로 된 당도를 어떻게 설명해야 할까? 앞서 불순도를 기준으로 샘플을 나눈다고 했다. 불순도는 클래스별 비율을 가지고 계산했다. 샘플을 어떤 클래스 비율로 나누는지 계산할 때 특성값의 스케일은 계산에 영향을 미치지 않기 때문에 표준화 전처리를 할 필요가 없다.
결과를 보면 같은 트리지만, 특성값을 표준점수로 바꾸지 않아서 이해하기가 훨씬 쉽다.
마지막으로 결정 트리는 어떤 특성이 가장 유용한지 나타내는 특성 중요도를 계산해준다. 이 트리의 루트 노드와 깊이 1에서 당도를 사용했기 때문에 아마도 당도(sugar)가 가장 유용한 특성 중 하나일 것 같다. 특성 중요도는 결정 트리 모델의 feature_importances_ 속성에 저장되어 있다. 특성 중요도를 활용하면 결정 트리 모델의 특성 선택에 활용할 수 있다. 이것이 결정 트리 알고리즘의 또 다른 장점 중 하나이다.
'AI > 혼공파 머신러닝+딥러닝' 카테고리의 다른 글
[DL 07-1] 인공 신경망 (0) | 2024.05.26 |
---|---|
[ML 06-1] 군집 알고리즘 (0) | 2024.05.24 |
[ML 04-2] 확률적 경사 하강법 (0) | 2024.05.09 |
[ML 04-1] 로지스틱 회귀 (0) | 2024.04.04 |
[ML 03-3] 특성공학과 규제 - 릿지(Ridge), 라쏘(Lasso) (1) | 2024.03.31 |