{
"cells": [
{
"cell_type": "markdown",
"metadata": {},
"source": [
"# 결정트리 데모"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 라이브러리 import 및 설정"
]
},
{
"cell_type": "code",
"execution_count": 1,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:24.311325Z",
"start_time": "2020-09-21T07:58:24.007955Z"
}
},
"outputs": [],
"source": [
"%reload_ext autoreload\n",
"%autoreload 2\n",
"%matplotlib inline"
]
},
{
"cell_type": "code",
"execution_count": 2,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:24.916615Z",
"start_time": "2020-09-21T07:58:24.313398Z"
}
},
"outputs": [],
"source": [
"import graphviz\n",
"from matplotlib import pyplot as plt\n",
"from matplotlib import rcParams\n",
"import numpy as np\n",
"from pathlib import Path\n",
"import pandas as pd\n",
"from sklearn.tree import DecisionTreeClassifier, export_graphviz\n",
"from sklearn.metrics import accuracy_score\n",
"from sklearn.model_selection import KFold\n",
"import warnings"
]
},
{
"cell_type": "code",
"execution_count": 3,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:24.942305Z",
"start_time": "2020-09-21T07:58:24.920530Z"
}
},
"outputs": [],
"source": [
"rcParams['figure.figsize'] = (16, 8)\n",
"plt.style.use('fivethirtyeight')\n",
"pd.set_option('max_columns', 100)\n",
"pd.set_option(\"display.precision\", 4)\n",
"warnings.simplefilter('ignore')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 데이터 로드"
]
},
{
"cell_type": "code",
"execution_count": 4,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:24.967204Z",
"start_time": "2020-09-21T07:58:24.944918Z"
}
},
"outputs": [],
"source": [
"data_dir = Path('../data/dacon-dku')\n",
"sub_dir = Path('../build/sub')\n",
"\n",
"trn_file = data_dir / 'train.csv'\n",
"tst_file = data_dir / 'test.csv'\n",
"sample_file = data_dir / 'sample_submission.csv'\n",
"\n",
"target_col = 'class'\n",
"seed = 42"
]
},
{
"cell_type": "code",
"execution_count": 5,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:24.993118Z",
"start_time": "2020-09-21T07:58:24.969489Z"
}
},
"outputs": [],
"source": [
"algo_name = 'dt'\n",
"feature_name = 'j1'\n",
"model_name = f'{algo_name}_{feature_name}'\n",
"\n",
"sub_file = sub_dir / f'{model_name}.csv'"
]
},
{
"cell_type": "code",
"execution_count": 6,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:26.103285Z",
"start_time": "2020-09-21T07:58:24.995160Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(320000,) (320000, 18) (80000, 18)\n"
]
},
{
"data": {
"text/html": [
"
\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" u | \n",
" g | \n",
" r | \n",
" i | \n",
" z | \n",
" redshift | \n",
" dered_u | \n",
" dered_g | \n",
" dered_r | \n",
" dered_i | \n",
" dered_z | \n",
" nObserve | \n",
" nDetect | \n",
" airmass_u | \n",
" airmass_g | \n",
" airmass_r | \n",
" airmass_i | \n",
" airmass_z | \n",
"
\n",
" \n",
" id | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 0 | \n",
" 23.2640 | \n",
" 20.3368 | \n",
" 19.0095 | \n",
" 17.6724 | \n",
" 16.9396 | \n",
" -8.1086e-05 | \n",
" 23.1243 | \n",
" 20.2578 | \n",
" 18.9551 | \n",
" 17.6321 | \n",
" 16.9089 | \n",
" 18 | \n",
" 18 | \n",
" 1.1898 | \n",
" 1.1907 | \n",
" 1.1890 | \n",
" 1.1894 | \n",
" 1.1902 | \n",
"
\n",
" \n",
" 1 | \n",
" 15.0521 | \n",
" 14.0620 | \n",
" 13.4524 | \n",
" 13.2684 | \n",
" 13.1689 | \n",
" 4.5061e-03 | \n",
" 14.9664 | \n",
" 14.0045 | \n",
" 13.4114 | \n",
" 13.2363 | \n",
" 13.1347 | \n",
" 1 | \n",
" 1 | \n",
" 1.2533 | \n",
" 1.2578 | \n",
" 1.2488 | \n",
" 1.2510 | \n",
" 1.2555 | \n",
"
\n",
" \n",
" 2 | \n",
" 16.7864 | \n",
" 15.8254 | \n",
" 15.5363 | \n",
" 15.3935 | \n",
" 15.3500 | \n",
" 4.7198e-04 | \n",
" 16.6076 | \n",
" 15.6866 | \n",
" 15.4400 | \n",
" 15.3217 | \n",
" 15.2961 | \n",
" 2 | \n",
" 2 | \n",
" 1.0225 | \n",
" 1.0241 | \n",
" 1.0210 | \n",
" 1.0217 | \n",
" 1.0233 | \n",
"
\n",
" \n",
" 3 | \n",
" 25.6606 | \n",
" 21.1887 | \n",
" 20.2212 | \n",
" 19.8949 | \n",
" 19.6346 | \n",
" 5.8143e-06 | \n",
" 25.3536 | \n",
" 20.9947 | \n",
" 20.0873 | \n",
" 19.7947 | \n",
" 19.5552 | \n",
" 4 | \n",
" 3 | \n",
" 1.2054 | \n",
" 1.2061 | \n",
" 1.2049 | \n",
" 1.2051 | \n",
" 1.2057 | \n",
"
\n",
" \n",
" 4 | \n",
" 24.4534 | \n",
" 20.6992 | \n",
" 19.0424 | \n",
" 18.3242 | \n",
" 17.9826 | \n",
" -3.3247e-05 | \n",
" 23.7714 | \n",
" 20.4338 | \n",
" 18.8630 | \n",
" 18.1903 | \n",
" 17.8759 | \n",
" 13 | \n",
" 12 | \n",
" 1.1939 | \n",
" 1.1943 | \n",
" 1.1937 | \n",
" 1.1938 | \n",
" 1.1941 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" u g r i z redshift dered_u dered_g \\\n",
"id \n",
"0 23.2640 20.3368 19.0095 17.6724 16.9396 -8.1086e-05 23.1243 20.2578 \n",
"1 15.0521 14.0620 13.4524 13.2684 13.1689 4.5061e-03 14.9664 14.0045 \n",
"2 16.7864 15.8254 15.5363 15.3935 15.3500 4.7198e-04 16.6076 15.6866 \n",
"3 25.6606 21.1887 20.2212 19.8949 19.6346 5.8143e-06 25.3536 20.9947 \n",
"4 24.4534 20.6992 19.0424 18.3242 17.9826 -3.3247e-05 23.7714 20.4338 \n",
"\n",
" dered_r dered_i dered_z nObserve nDetect airmass_u airmass_g \\\n",
"id \n",
"0 18.9551 17.6321 16.9089 18 18 1.1898 1.1907 \n",
"1 13.4114 13.2363 13.1347 1 1 1.2533 1.2578 \n",
"2 15.4400 15.3217 15.2961 2 2 1.0225 1.0241 \n",
"3 20.0873 19.7947 19.5552 4 3 1.2054 1.2061 \n",
"4 18.8630 18.1903 17.8759 13 12 1.1939 1.1943 \n",
"\n",
" airmass_r airmass_i airmass_z \n",
"id \n",
"0 1.1890 1.1894 1.1902 \n",
"1 1.2488 1.2510 1.2555 \n",
"2 1.0210 1.0217 1.0233 \n",
"3 1.2049 1.2051 1.2057 \n",
"4 1.1937 1.1938 1.1941 "
]
},
"execution_count": 6,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"trn = pd.read_csv(trn_file, index_col=0)\n",
"tst = pd.read_csv(tst_file, index_col=0)\n",
"y = trn[target_col]\n",
"trn.drop(target_col, axis=1, inplace=True)\n",
"print(y.shape, trn.shape, tst.shape)\n",
"trn.head()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 결정트리 학습"
]
},
{
"cell_type": "code",
"execution_count": 7,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:29.860475Z",
"start_time": "2020-09-21T07:58:26.105574Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n",
" max_depth=5, max_features=None, max_leaf_nodes=None,\n",
" min_impurity_decrease=0.0, min_impurity_split=None,\n",
" min_samples_leaf=10, min_samples_split=2,\n",
" min_weight_fraction_leaf=0.0, presort='deprecated',\n",
" random_state=42, splitter='best')"
]
},
"execution_count": 7,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"clf = DecisionTreeClassifier(max_depth=5, min_samples_leaf=10, random_state=42)\n",
"clf.fit(trn, y)"
]
},
{
"cell_type": "code",
"execution_count": 8,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:29.944204Z",
"start_time": "2020-09-21T07:58:29.862374Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"88.8669%\n"
]
}
],
"source": [
"print(f'{accuracy_score(y, clf.predict(trn)) * 100:.4f}%')"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 결정트리 시각화"
]
},
{
"cell_type": "code",
"execution_count": 9,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:30.321220Z",
"start_time": "2020-09-21T07:58:29.946555Z"
}
},
"outputs": [
{
"data": {
"image/svg+xml": [
"\n",
"\n",
"\n",
"\n",
"\n"
],
"text/plain": [
""
]
},
"execution_count": 9,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"dot_data = export_graphviz(clf, out_file=None,\n",
" feature_names=trn.columns,\n",
" filled=True,\n",
" rounded=True,\n",
" special_characters=True) \n",
"graph = graphviz.Source(dot_data) \n",
"graph "
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 시험데이터 예측"
]
},
{
"cell_type": "code",
"execution_count": 10,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:30.388047Z",
"start_time": "2020-09-21T07:58:30.324422Z"
}
},
"outputs": [
{
"name": "stdout",
"output_type": "stream",
"text": [
"(80000, 1)\n"
]
},
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" class | \n",
"
\n",
" \n",
" id | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 320000 | \n",
" 0 | \n",
"
\n",
" \n",
" 320001 | \n",
" 0 | \n",
"
\n",
" \n",
" 320002 | \n",
" 0 | \n",
"
\n",
" \n",
" 320003 | \n",
" 0 | \n",
"
\n",
" \n",
" 320004 | \n",
" 0 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" class\n",
"id \n",
"320000 0\n",
"320001 0\n",
"320002 0\n",
"320003 0\n",
"320004 0"
]
},
"execution_count": 10,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub = pd.read_csv(sample_file, index_col=0)\n",
"print(sub.shape)\n",
"sub.head()"
]
},
{
"cell_type": "code",
"execution_count": 11,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:30.444461Z",
"start_time": "2020-09-21T07:58:30.391820Z"
}
},
"outputs": [
{
"data": {
"text/html": [
"\n",
"\n",
"
\n",
" \n",
" \n",
" | \n",
" class | \n",
"
\n",
" \n",
" id | \n",
" | \n",
"
\n",
" \n",
" \n",
" \n",
" 320000 | \n",
" 2 | \n",
"
\n",
" \n",
" 320001 | \n",
" 0 | \n",
"
\n",
" \n",
" 320002 | \n",
" 2 | \n",
"
\n",
" \n",
" 320003 | \n",
" 0 | \n",
"
\n",
" \n",
" 320004 | \n",
" 2 | \n",
"
\n",
" \n",
"
\n",
"
"
],
"text/plain": [
" class\n",
"id \n",
"320000 2\n",
"320001 0\n",
"320002 2\n",
"320003 0\n",
"320004 2"
]
},
"execution_count": 11,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub[target_col] = clf.predict(tst)\n",
"sub.head()"
]
},
{
"cell_type": "code",
"execution_count": 12,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:30.480636Z",
"start_time": "2020-09-21T07:58:30.447567Z"
}
},
"outputs": [
{
"data": {
"text/plain": [
"2 45299\n",
"0 29882\n",
"1 4819\n",
"Name: class, dtype: int64"
]
},
"execution_count": 12,
"metadata": {},
"output_type": "execute_result"
}
],
"source": [
"sub[target_col].value_counts()"
]
},
{
"cell_type": "markdown",
"metadata": {},
"source": [
"## 제출파일 저장"
]
},
{
"cell_type": "code",
"execution_count": 13,
"metadata": {
"ExecuteTime": {
"end_time": "2020-09-21T07:58:30.616584Z",
"start_time": "2020-09-21T07:58:30.482921Z"
}
},
"outputs": [],
"source": [
"sub.to_csv(sub_file)"
]
},
{
"cell_type": "code",
"execution_count": null,
"metadata": {},
"outputs": [],
"source": []
}
],
"metadata": {
"kernelspec": {
"display_name": "Python 3",
"language": "python",
"name": "python3"
},
"language_info": {
"codemirror_mode": {
"name": "ipython",
"version": 3
},
"file_extension": ".py",
"mimetype": "text/x-python",
"name": "python",
"nbconvert_exporter": "python",
"pygments_lexer": "ipython3",
"version": "3.7.6"
},
"toc": {
"base_numbering": 1,
"nav_menu": {},
"number_sections": true,
"sideBar": true,
"skip_h1_title": true,
"title_cell": "Table of Contents",
"title_sidebar": "Contents",
"toc_cell": false,
"toc_position": {},
"toc_section_display": true,
"toc_window_display": false
}
},
"nbformat": 4,
"nbformat_minor": 4
}