본문 바로가기

Programming/Machine Learning

버섯데이터 분류

mushroom.csv
0.36MB

목표

- 버섯의 특징을 활용해 독/식용 버섯을 분류

- Decision tree 시각화 & 과대적합 속성 제어

- 특성선택(Feature selection) 해보기

import pandas as pd
from sklearn.model_selection import train_test_split
from sklearn.tree import DecisionTreeClassifier

# 1. 데이터를 로딩
# 2. 전체 컬럼,행 숫자 파악
# 3. 결측치 파악
# 4. 문제와 답 분리
# 5. 기술통계 -> 범주형데이터 : 갯수,최빈값,종류
# 6. label의 비율을 확인해보자.

data = pd.read_csv('data/mushroom.csv')
data.head()

data.shape

data.info()

X = data.loc[ : , 'cap-shape': ]
y = data.loc[ : , 'poisonous']

data.describe()

y.value_counts()

 

원 핫 인코딩

X_one_hot = pd.get_dummies(X)
X_one_hot.head()

 

라벨인코딩

X['habitat'].unique()

habitat_dic = {
    'u' : 2,
    'g' : 1,
    'm' : 3,
    'd' : 5,
    'p' : 4,
    'w' : 6,
    'l' : 7
}
X['habitat'].map(habitat_dic)

 

모델링

X_train,X_test,y_train,y_test = train_test_split(X_one_hot,
                                                y,
                                            test_size = 0.3)
tree_model = DecisionTreeClassifier()
tree_model.fit(X_train, y_train)
tree_model.score(X_test, y_test)

 

시각화 패키지 설치

# !pip install graphviz

from sklearn.tree import export_graphviz
export_graphviz(tree_model, out_file='tree.dot',
               class_names=['독','식용'],
               feature_names=X_one_hot.columns,
               impurity=False,
               filled=True)
               
import graphviz

with open('tree.dot', encoding='UTF8') as f:
    dot_graph = f.read()

display(graphviz.Source(dot_graph))

 

과대적합제어

# max_depth, max_leaf_nodes, min_samples_leaf
tree_model2 = DecisionTreeClassifier(max_depth=1)
tree_model2.fit(X_train,y_train)

export_graphviz(tree_model2, out_file='tree2.dot',
               class_names=['독','식용'],
               feature_names=X_one_hot.columns,
               impurity=False,
               filled=True)

with open('tree2.dot', encoding='UTF8') as f:
    dot_graph = f.read()

display(graphviz.Source(dot_graph))

tree_model2.score(X_test,y_test)

 

특성선택

- tree 모델의 특성중요도

fi = tree_model.feature_importances_
fi

importance_df = pd.DataFrame(fi, index=X_one_hot.columns)
importance_df.sort_values(by=0,ascending=False)

'Programming > Machine Learning' 카테고리의 다른 글

Linear Model - Regression  (0) 2020.02.17
타이타닉 생존자 예측 분석  (0) 2020.02.17
iris 품종분류  (0) 2020.02.17
BMI 학습하기  (0) 2020.02.16
서울시 구별 CCTV 현황 분석  (0) 2020.02.16