{ "cells": [ { "cell_type": "markdown", "metadata": {}, "source": [ "# 결정트리 데모" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 라이브러리 import 및 설정" ] }, { "cell_type": "code", "execution_count": 1, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:24.311325Z", "start_time": "2020-09-21T07:58:24.007955Z" } }, "outputs": [], "source": [ "%reload_ext autoreload\n", "%autoreload 2\n", "%matplotlib inline" ] }, { "cell_type": "code", "execution_count": 2, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:24.916615Z", "start_time": "2020-09-21T07:58:24.313398Z" } }, "outputs": [], "source": [ "import graphviz\n", "from matplotlib import pyplot as plt\n", "from matplotlib import rcParams\n", "import numpy as np\n", "from pathlib import Path\n", "import pandas as pd\n", "from sklearn.tree import DecisionTreeClassifier, export_graphviz\n", "from sklearn.metrics import accuracy_score\n", "from sklearn.model_selection import KFold\n", "import warnings" ] }, { "cell_type": "code", "execution_count": 3, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:24.942305Z", "start_time": "2020-09-21T07:58:24.920530Z" } }, "outputs": [], "source": [ "rcParams['figure.figsize'] = (16, 8)\n", "plt.style.use('fivethirtyeight')\n", "pd.set_option('max_columns', 100)\n", "pd.set_option(\"display.precision\", 4)\n", "warnings.simplefilter('ignore')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 데이터 로드" ] }, { "cell_type": "code", "execution_count": 4, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:24.967204Z", "start_time": "2020-09-21T07:58:24.944918Z" } }, "outputs": [], "source": [ "data_dir = Path('../data/dacon-dku')\n", "sub_dir = Path('../build/sub')\n", "\n", "trn_file = data_dir / 'train.csv'\n", "tst_file = data_dir / 'test.csv'\n", "sample_file = data_dir / 'sample_submission.csv'\n", "\n", "target_col = 'class'\n", "seed = 42" ] }, { "cell_type": "code", "execution_count": 5, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:24.993118Z", "start_time": "2020-09-21T07:58:24.969489Z" } }, "outputs": [], "source": [ "algo_name = 'dt'\n", "feature_name = 'j1'\n", "model_name = f'{algo_name}_{feature_name}'\n", "\n", "sub_file = sub_dir / f'{model_name}.csv'" ] }, { "cell_type": "code", "execution_count": 6, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:26.103285Z", "start_time": "2020-09-21T07:58:24.995160Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(320000,) (320000, 18) (80000, 18)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
ugrizredshiftdered_udered_gdered_rdered_idered_znObservenDetectairmass_uairmass_gairmass_rairmass_iairmass_z
id
023.264020.336819.009517.672416.9396-8.1086e-0523.124320.257818.955117.632116.908918181.18981.19071.18901.18941.1902
115.052114.062013.452413.268413.16894.5061e-0314.966414.004513.411413.236313.1347111.25331.25781.24881.25101.2555
216.786415.825415.536315.393515.35004.7198e-0416.607615.686615.440015.321715.2961221.02251.02411.02101.02171.0233
325.660621.188720.221219.894919.63465.8143e-0625.353620.994720.087319.794719.5552431.20541.20611.20491.20511.2057
424.453420.699219.042418.324217.9826-3.3247e-0523.771420.433818.863018.190317.875913121.19391.19431.19371.19381.1941
\n", "
" ], "text/plain": [ " u g r i z redshift dered_u dered_g \\\n", "id \n", "0 23.2640 20.3368 19.0095 17.6724 16.9396 -8.1086e-05 23.1243 20.2578 \n", "1 15.0521 14.0620 13.4524 13.2684 13.1689 4.5061e-03 14.9664 14.0045 \n", "2 16.7864 15.8254 15.5363 15.3935 15.3500 4.7198e-04 16.6076 15.6866 \n", "3 25.6606 21.1887 20.2212 19.8949 19.6346 5.8143e-06 25.3536 20.9947 \n", "4 24.4534 20.6992 19.0424 18.3242 17.9826 -3.3247e-05 23.7714 20.4338 \n", "\n", " dered_r dered_i dered_z nObserve nDetect airmass_u airmass_g \\\n", "id \n", "0 18.9551 17.6321 16.9089 18 18 1.1898 1.1907 \n", "1 13.4114 13.2363 13.1347 1 1 1.2533 1.2578 \n", "2 15.4400 15.3217 15.2961 2 2 1.0225 1.0241 \n", "3 20.0873 19.7947 19.5552 4 3 1.2054 1.2061 \n", "4 18.8630 18.1903 17.8759 13 12 1.1939 1.1943 \n", "\n", " airmass_r airmass_i airmass_z \n", "id \n", "0 1.1890 1.1894 1.1902 \n", "1 1.2488 1.2510 1.2555 \n", "2 1.0210 1.0217 1.0233 \n", "3 1.2049 1.2051 1.2057 \n", "4 1.1937 1.1938 1.1941 " ] }, "execution_count": 6, "metadata": {}, "output_type": "execute_result" } ], "source": [ "trn = pd.read_csv(trn_file, index_col=0)\n", "tst = pd.read_csv(tst_file, index_col=0)\n", "y = trn[target_col]\n", "trn.drop(target_col, axis=1, inplace=True)\n", "print(y.shape, trn.shape, tst.shape)\n", "trn.head()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 결정트리 학습" ] }, { "cell_type": "code", "execution_count": 7, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:29.860475Z", "start_time": "2020-09-21T07:58:26.105574Z" } }, "outputs": [ { "data": { "text/plain": [ "DecisionTreeClassifier(ccp_alpha=0.0, class_weight=None, criterion='gini',\n", " max_depth=5, max_features=None, max_leaf_nodes=None,\n", " min_impurity_decrease=0.0, min_impurity_split=None,\n", " min_samples_leaf=10, min_samples_split=2,\n", " min_weight_fraction_leaf=0.0, presort='deprecated',\n", " random_state=42, splitter='best')" ] }, "execution_count": 7, "metadata": {}, "output_type": "execute_result" } ], "source": [ "clf = DecisionTreeClassifier(max_depth=5, min_samples_leaf=10, random_state=42)\n", "clf.fit(trn, y)" ] }, { "cell_type": "code", "execution_count": 8, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:29.944204Z", "start_time": "2020-09-21T07:58:29.862374Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "88.8669%\n" ] } ], "source": [ "print(f'{accuracy_score(y, clf.predict(trn)) * 100:.4f}%')" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 결정트리 시각화" ] }, { "cell_type": "code", "execution_count": 9, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:30.321220Z", "start_time": "2020-09-21T07:58:29.946555Z" } }, "outputs": [ { "data": { "image/svg+xml": [ "\n", "\n", "\n", "\n", "\n", "\n", "Tree\n", "\n", "\n", "\n", "0\n", "\n", "redshift ≤ 0.004\n", "gini = 0.6\n", "samples = 320000\n", "value = [119996, 42785, 157219]\n", "\n", "\n", "\n", "1\n", "\n", "redshift ≤ 0.002\n", "gini = 0.008\n", "samples = 119808\n", "value = [119331, 248, 229]\n", "\n", "\n", "\n", "0->1\n", "\n", "\n", "True\n", "\n", "\n", "\n", "30\n", "\n", "dered_z ≤ 17.267\n", "gini = 0.34\n", "samples = 200192\n", "value = [665, 42537, 156990]\n", "\n", "\n", "\n", "0->30\n", "\n", "\n", "False\n", "\n", "\n", "\n", "2\n", "\n", "redshift ≤ 0.001\n", "gini = 0.001\n", "samples = 118829\n", "value = [118771, 35, 23]\n", "\n", "\n", "\n", "1->2\n", "\n", "\n", "\n", "\n", "\n", "15\n", "\n", "u ≤ 18.607\n", "gini = 0.581\n", "samples = 979\n", "value = [560, 213, 206]\n", "\n", "\n", "\n", "1->15\n", "\n", "\n", "\n", "\n", "\n", "3\n", "\n", "z ≤ 12.636\n", "gini = 0.001\n", "samples = 118764\n", "value = [118716, 30, 18]\n", "\n", "\n", "\n", "2->3\n", "\n", "\n", "\n", "\n", "\n", "10\n", "\n", "u ≤ 17.819\n", "gini = 0.272\n", "samples = 65\n", "value = [55, 5, 5]\n", "\n", "\n", "\n", "2->10\n", "\n", "\n", "\n", "\n", "\n", "4\n", "\n", "redshift ≤ 0.0\n", "gini = 0.259\n", "samples = 41\n", "value = [35, 2, 4]\n", "\n", "\n", "\n", "3->4\n", "\n", "\n", "\n", "\n", "\n", "7\n", "\n", "g ≤ 14.614\n", "gini = 0.001\n", "samples = 118723\n", "value = [118681, 28, 14]\n", "\n", "\n", "\n", "3->7\n", "\n", "\n", "\n", "\n", "\n", "5\n", "\n", "gini = 0.0\n", "samples = 31\n", "value = [31, 0, 0]\n", "\n", "\n", "\n", "4->5\n", "\n", "\n", "\n", "\n", "\n", "6\n", "\n", "gini = 0.64\n", "samples = 10\n", "value = [4, 2, 4]\n", "\n", "\n", "\n", "4->6\n", "\n", "\n", "\n", "\n", "\n", "8\n", "\n", "gini = 0.096\n", "samples = 160\n", "value = [152, 6, 2]\n", "\n", "\n", "\n", "7->8\n", "\n", "\n", "\n", "\n", "\n", "9\n", "\n", "gini = 0.001\n", "samples = 118563\n", "value = [118529, 22, 12]\n", "\n", "\n", "\n", "7->9\n", "\n", "\n", "\n", "\n", "\n", "11\n", "\n", "gini = 0.639\n", "samples = 13\n", "value = [6, 4, 3]\n", "\n", "\n", "\n", "10->11\n", "\n", "\n", "\n", "\n", "\n", "12\n", "\n", "dered_u ≤ 20.504\n", "gini = 0.11\n", "samples = 52\n", "value = [49, 1, 2]\n", "\n", "\n", "\n", "10->12\n", "\n", "\n", "\n", "\n", "\n", "13\n", "\n", "gini = 0.053\n", "samples = 37\n", "value = [36, 1, 0]\n", "\n", "\n", "\n", "12->13\n", "\n", "\n", "\n", "\n", "\n", "14\n", "\n", "gini = 0.231\n", "samples = 15\n", "value = [13, 0, 2]\n", "\n", "\n", "\n", "12->14\n", "\n", "\n", "\n", "\n", "\n", "16\n", "\n", "dered_z ≤ 16.262\n", "gini = 0.656\n", "samples = 441\n", "value = [110, 170, 161]\n", "\n", "\n", "\n", "15->16\n", "\n", "\n", "\n", "\n", "\n", "23\n", "\n", "dered_i ≤ 20.975\n", "gini = 0.287\n", "samples = 538\n", "value = [450, 43, 45]\n", "\n", "\n", "\n", "15->23\n", "\n", "\n", "\n", "\n", "\n", "17\n", "\n", "dered_u ≤ 16.942\n", "gini = 0.645\n", "samples = 307\n", "value = [78, 90, 139]\n", "\n", "\n", "\n", "16->17\n", "\n", "\n", "\n", "\n", "\n", "20\n", "\n", "u ≤ 17.735\n", "gini = 0.56\n", "samples = 134\n", "value = [32, 80, 22]\n", "\n", "\n", "\n", "16->20\n", "\n", "\n", "\n", "\n", "\n", "18\n", "\n", "gini = 0.608\n", "samples = 209\n", "value = [32, 74, 103]\n", "\n", "\n", "\n", "17->18\n", "\n", "\n", "\n", "\n", "\n", "19\n", "\n", "gini = 0.618\n", "samples = 98\n", "value = [46, 16, 36]\n", "\n", "\n", "\n", "17->19\n", "\n", "\n", "\n", "\n", "\n", "21\n", "\n", "gini = 0.194\n", "samples = 47\n", "value = [4, 42, 1]\n", "\n", "\n", "\n", "20->21\n", "\n", "\n", "\n", "\n", "\n", "22\n", "\n", "gini = 0.647\n", "samples = 87\n", "value = [28, 38, 21]\n", "\n", "\n", "\n", "20->22\n", "\n", "\n", "\n", "\n", "\n", "24\n", "\n", "u ≤ 19.053\n", "gini = 0.211\n", "samples = 458\n", "value = [405, 26, 27]\n", "\n", "\n", "\n", "23->24\n", "\n", "\n", "\n", "\n", "\n", "27\n", "\n", "g ≤ 22.147\n", "gini = 0.588\n", "samples = 80\n", "value = [45, 17, 18]\n", "\n", "\n", "\n", "23->27\n", "\n", "\n", "\n", "\n", "\n", "25\n", "\n", "gini = 0.524\n", "samples = 54\n", "value = [34, 14, 6]\n", "\n", "\n", "\n", "24->25\n", "\n", "\n", "\n", "\n", "\n", "26\n", "\n", "gini = 0.153\n", "samples = 404\n", "value = [371, 12, 21]\n", "\n", "\n", "\n", "24->26\n", "\n", "\n", "\n", "\n", "\n", "28\n", "\n", "gini = 0.581\n", "samples = 31\n", "value = [3, 15, 13]\n", "\n", "\n", "\n", "27->28\n", "\n", "\n", "\n", "\n", "\n", "29\n", "\n", "gini = 0.253\n", "samples = 49\n", "value = [42, 2, 5]\n", "\n", "\n", "\n", "27->29\n", "\n", "\n", "\n", "\n", "\n", "31\n", "\n", "dered_u ≤ 19.041\n", "gini = 0.242\n", "samples = 160119\n", "value = [96, 22466, 137557]\n", "\n", "\n", "\n", "30->31\n", "\n", "\n", "\n", "\n", "\n", "46\n", "\n", "dered_u ≤ 18.984\n", "gini = 0.514\n", "samples = 40073\n", "value = [569, 20071, 19433]\n", "\n", "\n", "\n", "30->46\n", "\n", "\n", "\n", "\n", "\n", "32\n", "\n", "z ≤ 16.652\n", "gini = 0.324\n", "samples = 88040\n", "value = [59, 17805, 70176]\n", "\n", "\n", "\n", "31->32\n", "\n", "\n", "\n", "\n", "\n", "39\n", "\n", "z ≤ 17.029\n", "gini = 0.122\n", "samples = 72079\n", "value = [37, 4661, 67381]\n", "\n", "\n", "\n", "31->39\n", "\n", "\n", "\n", "\n", "\n", "33\n", "\n", "dered_u ≤ 18.279\n", "gini = 0.239\n", "samples = 64062\n", "value = [26, 8853, 55183]\n", "\n", "\n", "\n", "32->33\n", "\n", "\n", "\n", "\n", "\n", "36\n", "\n", "dered_u ≤ 18.555\n", "gini = 0.47\n", "samples = 23978\n", "value = [33, 8952, 14993]\n", "\n", "\n", "\n", "32->36\n", "\n", "\n", "\n", "\n", "\n", "34\n", "\n", "gini = 0.317\n", "samples = 32203\n", "value = [16, 6331, 25856]\n", "\n", "\n", "\n", "33->34\n", "\n", "\n", "\n", "\n", "\n", "35\n", "\n", "gini = 0.146\n", "samples = 31859\n", "value = [10, 2522, 29327]\n", "\n", "\n", "\n", "33->35\n", "\n", "\n", "\n", "\n", "\n", "37\n", "\n", "gini = 0.478\n", "samples = 5402\n", "value = [27, 3309, 2066]\n", "\n", "\n", "\n", "36->37\n", "\n", "\n", "\n", "\n", "\n", "38\n", "\n", "gini = 0.423\n", "samples = 18576\n", "value = [6, 5643, 12927]\n", "\n", "\n", "\n", "36->38\n", "\n", "\n", "\n", "\n", "\n", "40\n", "\n", "dered_z ≤ 16.761\n", "gini = 0.075\n", "samples = 51770\n", "value = [22, 2004, 49744]\n", "\n", "\n", "\n", "39->40\n", "\n", "\n", "\n", "\n", "\n", "43\n", "\n", "dered_u ≤ 19.392\n", "gini = 0.229\n", "samples = 20309\n", "value = [15, 2657, 17637]\n", "\n", "\n", "\n", "39->43\n", "\n", "\n", "\n", "\n", "\n", "41\n", "\n", "gini = 0.048\n", "samples = 32506\n", "value = [18, 786, 31702]\n", "\n", "\n", "\n", "40->41\n", "\n", "\n", "\n", "\n", "\n", "42\n", "\n", "gini = 0.119\n", "samples = 19264\n", "value = [4, 1218, 18042]\n", "\n", "\n", "\n", "40->42\n", "\n", "\n", "\n", "\n", "\n", "44\n", "\n", "gini = 0.331\n", "samples = 9118\n", "value = [4, 1905, 7209]\n", "\n", "\n", "\n", "43->44\n", "\n", "\n", "\n", "\n", "\n", "45\n", "\n", "gini = 0.127\n", "samples = 11191\n", "value = [11, 752, 10428]\n", "\n", "\n", "\n", "43->45\n", "\n", "\n", "\n", "\n", "\n", "47\n", "\n", "dered_r ≤ 17.697\n", "gini = 0.363\n", "samples = 4356\n", "value = [34, 3334, 988]\n", "\n", "\n", "\n", "46->47\n", "\n", "\n", "\n", "\n", "\n", "54\n", "\n", "dered_z ≤ 17.608\n", "gini = 0.513\n", "samples = 35717\n", "value = [535, 16737, 18445]\n", "\n", "\n", "\n", "46->54\n", "\n", "\n", "\n", "\n", "\n", "48\n", "\n", "dered_u ≤ 18.773\n", "gini = 0.442\n", "samples = 2432\n", "value = [5, 1636, 791]\n", "\n", "\n", "\n", "47->48\n", "\n", "\n", "\n", "\n", "\n", "51\n", "\n", "redshift ≤ 0.425\n", "gini = 0.21\n", "samples = 1924\n", "value = [29, 1698, 197]\n", "\n", "\n", "\n", "47->51\n", "\n", "\n", "\n", "\n", "\n", "49\n", "\n", "gini = 0.302\n", "samples = 1297\n", "value = [5, 1058, 234]\n", "\n", "\n", "\n", "48->49\n", "\n", "\n", "\n", "\n", "\n", "50\n", "\n", "gini = 0.5\n", "samples = 1135\n", "value = [0, 578, 557]\n", "\n", "\n", "\n", "48->50\n", "\n", "\n", "\n", "\n", "\n", "52\n", "\n", "gini = 0.197\n", "samples = 1908\n", "value = [13, 1698, 197]\n", "\n", "\n", "\n", "51->52\n", "\n", "\n", "\n", "\n", "\n", "53\n", "\n", "gini = 0.0\n", "samples = 16\n", "value = [16, 0, 0]\n", "\n", "\n", "\n", "51->53\n", "\n", "\n", "\n", "\n", "\n", "55\n", "\n", "dered_u ≤ 19.213\n", "gini = 0.416\n", "samples = 9153\n", "value = [10, 2685, 6458]\n", "\n", "\n", "\n", "54->55\n", "\n", "\n", "\n", "\n", "\n", "58\n", "\n", "u ≤ 20.239\n", "gini = 0.516\n", "samples = 26564\n", "value = [525, 14052, 11987]\n", "\n", "\n", "\n", "54->58\n", "\n", "\n", "\n", "\n", "\n", "56\n", "\n", "gini = 0.498\n", "samples = 3019\n", "value = [2, 1392, 1625]\n", "\n", "\n", "\n", "55->56\n", "\n", "\n", "\n", "\n", "\n", "57\n", "\n", "gini = 0.335\n", "samples = 6134\n", "value = [8, 1293, 4833]\n", "\n", "\n", "\n", "55->57\n", "\n", "\n", "\n", "\n", "\n", "59\n", "\n", "gini = 0.396\n", "samples = 9233\n", "value = [46, 6752, 2435]\n", "\n", "\n", "\n", "58->59\n", "\n", "\n", "\n", "\n", "\n", "60\n", "\n", "gini = 0.518\n", "samples = 17331\n", "value = [479, 7300, 9552]\n", "\n", "\n", "\n", "58->60\n", "\n", "\n", "\n", "\n", "\n" ], "text/plain": [ "" ] }, "execution_count": 9, "metadata": {}, "output_type": "execute_result" } ], "source": [ "dot_data = export_graphviz(clf, out_file=None,\n", " feature_names=trn.columns,\n", " filled=True,\n", " rounded=True,\n", " special_characters=True) \n", "graph = graphviz.Source(dot_data) \n", "graph " ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 시험데이터 예측" ] }, { "cell_type": "code", "execution_count": 10, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:30.388047Z", "start_time": "2020-09-21T07:58:30.324422Z" } }, "outputs": [ { "name": "stdout", "output_type": "stream", "text": [ "(80000, 1)\n" ] }, { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
class
id
3200000
3200010
3200020
3200030
3200040
\n", "
" ], "text/plain": [ " class\n", "id \n", "320000 0\n", "320001 0\n", "320002 0\n", "320003 0\n", "320004 0" ] }, "execution_count": 10, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sub = pd.read_csv(sample_file, index_col=0)\n", "print(sub.shape)\n", "sub.head()" ] }, { "cell_type": "code", "execution_count": 11, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:30.444461Z", "start_time": "2020-09-21T07:58:30.391820Z" } }, "outputs": [ { "data": { "text/html": [ "
\n", "\n", "\n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", " \n", "
class
id
3200002
3200010
3200022
3200030
3200042
\n", "
" ], "text/plain": [ " class\n", "id \n", "320000 2\n", "320001 0\n", "320002 2\n", "320003 0\n", "320004 2" ] }, "execution_count": 11, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sub[target_col] = clf.predict(tst)\n", "sub.head()" ] }, { "cell_type": "code", "execution_count": 12, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:30.480636Z", "start_time": "2020-09-21T07:58:30.447567Z" } }, "outputs": [ { "data": { "text/plain": [ "2 45299\n", "0 29882\n", "1 4819\n", "Name: class, dtype: int64" ] }, "execution_count": 12, "metadata": {}, "output_type": "execute_result" } ], "source": [ "sub[target_col].value_counts()" ] }, { "cell_type": "markdown", "metadata": {}, "source": [ "## 제출파일 저장" ] }, { "cell_type": "code", "execution_count": 13, "metadata": { "ExecuteTime": { "end_time": "2020-09-21T07:58:30.616584Z", "start_time": "2020-09-21T07:58:30.482921Z" } }, "outputs": [], "source": [ "sub.to_csv(sub_file)" ] }, { "cell_type": "code", "execution_count": null, "metadata": {}, "outputs": [], "source": [] } ], "metadata": { "kernelspec": { "display_name": "Python 3", "language": "python", "name": "python3" }, "language_info": { "codemirror_mode": { "name": "ipython", "version": 3 }, "file_extension": ".py", "mimetype": "text/x-python", "name": "python", "nbconvert_exporter": "python", "pygments_lexer": "ipython3", "version": "3.7.6" }, "toc": { "base_numbering": 1, "nav_menu": {}, "number_sections": true, "sideBar": true, "skip_h1_title": true, "title_cell": "Table of Contents", "title_sidebar": "Contents", "toc_cell": false, "toc_position": {}, "toc_section_display": true, "toc_window_display": false } }, "nbformat": 4, "nbformat_minor": 4 }