본문 바로가기

Machine Learning/Business Analytics 1

Decision Tree

자료출처

 

1. 목표

  • feature들을 기반으로 결과를 분류하거나 예측
  • 결과는 규칙들의 집합

 

  • 아이가 오늘 놀지 안놀지에 대해서 예측하는 문제를 보자.
  • 맨 위를 보면 노는게 9, 안노는게 5로 총 14일의 데이터가 있다. 맨 위에서는 그 날의 바깥 날씨가 어떤지를 가지고 분기(split)를 하였다. overcast의 경우 play만 있기 때문에 완벽하게 분류되었다.
  • sunny로 분기된 다음에 습도에 따라 분기하였고 rain으로 분기된 다음엔 바람이 부는지 안부는지에 따라 분기하였다.

데이터 분석을 할 때 트리 기반 모델을 사용하는 이유 중 하나가 이런식으로 예측 결과물에 대해 사람의 언어로 설명이 가능하기 때문이다.

 

2. 용어 정리

  • root node : 부모 노드가 없고 자식 노드만 있는 노드
  • leaf node : 자식 노드가 없고 부모 노드만 있는 노드
  • parent node : 분기되기 전 노드
  • child node : 분기 이후 노드
  • split criterion : 노드를 분기하기 위해 사용된 특정한 변수의 값

추가적으로 우리가 알아볼 것은 Classification And Regression Tree의 약자인 CART라고 부르는 것이고 key idea는 2개가 있다.

  • Recursive Partitioning : 최대 homogeneity를 달성하기 위해 반복적으로 레코드를 두 부분으로 분기
  • Pruning the Tree : 오버피팅을 방지하기 위해서 주변 가지를 잘라내어 트리를 단순하게 하는 것

Recursive Partitioning의 핵심은 자식 노드들의 purity를 최대화 하기 위한 분기점을 선택하는 것이다.

 

그럼 purity를 어떻게 최대화 할까?

purity의 반대인 impurity 즉 불순도를 계산하는 방법을 알아보자. 불순도를 최소화 하는 것이 곧 purity를 최대화 하는 것이기 때문이다.

 

3. Impurity Measure

3.1 Gini Index

 

 

  • m은 클래스 개수를 의미한다. 여기서는 이진 분류이기 때문에 2가 된다.
  • $p_k$는 (해당 클래스의 샘플 개수/전체 샘플 개수)이다.
  • 주황색과 파란색으로 이루어진 rectangle의 지니 인덱스를 계산해보면 약 0.47이 된다. 
  • rectangle 안의 샘플이 모두 같은 클래스라면 지니 인덱스의 값은 0이고, 각각 절반씩 차지한다면 0.5가 된다, 따라서 지니 인덱스의 값이 작을 수록 purity가 높은 것이다.

 

 

 

  • 빨간색 경계선을 통해 2개의 rectangle로 나누었다. $R_i$는 (나눠진 rectangle의 샘플 수/전체 샘플 수) 이다.
  • 2개로 나누어진 rectangle에 대해서 지니 인덱스를 계산해보면 약 0.34이다.
  • 즉 빨간색 경계선을 통해 0.47에서 0.34로 지니 인덱스 값을 낮춰 purity를 높인 것이다. 이때 이 차이를 information gain이라 한다.

 

3.2 Deviance

 

 

  • $n_{ik}$ : k 클래스의 샘플 수
  • $p_{ik}$ : k 클래스의 샘플 수 / 전체 샘플 수
  • Deviance 값이 0이면 rectangle 안에 모든 같은 샘플만 있는 것이다. 지니 인덱스 처럼 Deviance도 낮은 것이 purity가 높은 것

 

 

 

  • 빨간색 경계선으로 rectangle을 두 부분으로 나누고 각각의 Deviance를 계산한 뒤 이를 더해준다.
  • 마찬가지로 경계선을 나눈 후와 나누기 전의 차이로 gain을 계산한다.

 

4. Recursive Partitioning

 

  • recursive partitioning이 어떻게 이루어지는지 알아보자.
  • split을 할 때는 항상 축에 수직인 경계선을 결정해야 한다.

 

 

  • 먼저 Lot size의 값을 기준으로 나눈다. 처음에는 가장 위에 보이는 14.0과 14.8의 평균인 14.4를 기준으로 잡고 split을 수행한다.
  • 그런 다음 이렇게 나누었을 때의 지니 인덱스를 계산한다.
  • split 전에는 지니 값이 0.5이었고 나눈 뒤엔 0.48이 되었으니 gain은 0.02가 된다.
  • Lot size가 몇 개 중복된 경우가 보이는데 모두 unique한 값이라고 가정한다면 24개의 샘플이 있기 때문에 위와 같은 과정을 총 23번 수행하여 가장 gain이 큰 경우를 찾는다.

위와 같은 과정을 분기된 자식 노드에 대해서도 gain이 없을 때 까지 splt을 수행하면 최종적으로 다음과 같이 분기된다.

 

 

  • 이렇게 완벽하게 분할된 트리를 full tree라고 부른다. 즉 각 영역에는 오직 하나의 클래스만 존재한다.
  • 새로운 데이터가 들어오게 되면 full tree가 아닌 경우에는 각 영역의 클래스 비율을 통해서 어떤 클래스인지 예측을 수행한다. 예를들어 어떤 영역에 클래스 1의 샘플이 3개, 클래스 2의 샘플이 2개라면 클래스 1일 확률을 0.6이라고 예측하는 것이다.
  • 이 예측 기준은 cutoff 값에 따라 달라진다. 

 

5. Pruning

직감적으로 full tree의 검증 성능은 안좋을 것이라고 알 수 있다. 당연히 train 데이터에 과적합된 모형이기 때문이다. 그래서 수행하는 것이 pruning인데, 과적합되기 전에 학습을 중지하는 것이 pre-pruning이고 full tree 상태에서 분기된 leaf 노드들을 다시 하나의 노드로 합치는 것을 post-pruning이라고 한다.

 

그림과 같은 경우에는 트리의 depth가 5인 지점이 가장 최선일 것이다. depth가 깊어질 수록 full tree에 가까워진다.

우리는 일반적으로 pre-pruning을 수행하여 모델을 학습하는데 여기서는 post 방식을 개념적으로만 이해해보자.

 

pruning의 기준은 Cost complexity이다.

$CC(T) = Error(T) + {\alpha}*L(T)$

 

  • CC(T) : tree의 cost complexity
  • Error(T) : 검증 데이터에 대한 잘못 분류된 샘플의 비율
  • alpha : penalty factor로써 하이퍼 파라미터이다.
  • L(T) : leaf node의 개수

2가지 예시를 살펴보자

 

 

두 트리의 리프 노드의 개수가 같다면 당연히 검증 에러가 낮은 1번 트리를 선택한다.

 

 

두 트리의 검증 에러가 같다면 리프노드가 적은, 즉 상대적으로 simple한 트리 A를 선택한다.

 

post-pruning을 할 일은 실제로 거의 없으므로 이정도로만 알아보고 regression tree로 넘어가자.

 

6. Regression Tree

 

  • regression은 각 영역에 대해 target 값의 평균으로 예측을 수행한다.
  • 분류와 다른 점은 불순도를 계산하는 방식인데 sum of squared error를 사용한다.

 

 

예를들어 위와 같이 15를 기준으로 split하였고, split 전의 sse와 split 이후의 sse를 계산하여 최종적인 gain을 얻을 수 있다.

'Machine Learning > Business Analytics 1' 카테고리의 다른 글

Logistic Regression : Interpretation  (0) 2023.12.20
Logistic Regression : Learning  (0) 2023.12.19
Logistic Regression : Formulation  (0) 2023.12.19
Evaluating Regression Models  (0) 2023.12.08
Multiple Linear Regression  (0) 2023.12.08