LightGBM 실습¶
라이브러리 import 및 설정¶
%reload_ext autoreload
%autoreload 2
%matplotlib inline
import lightgbm as lgb
from matplotlib import pyplot as plt
from matplotlib import rcParams
import numpy as np
from pathlib import Path
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import seaborn as sns
import warnings
rcParams['figure.figsize'] = (16, 8)
plt.style.use('fivethirtyeight')
pd.set_option('max_columns', 100)
pd.set_option("display.precision", 4)
warnings.simplefilter('ignore')
학습데이터 로드¶
03-pandas-eda.ipynb에서 생성한 feature.csv
피처파일 사용
data_dir = Path('../data/dacon-dku')
feature_dir = Path('../build/feature')
sub_dir = Path('../build/sub')
trn_file = data_dir / 'train.csv'
tst_file = data_dir / 'test.csv'
sample_file = data_dir / 'sample_submission.csv'
target_col = 'class'
seed = 42
algo_name = 'lgb'
feature_name = 'feature'
model_name = f'{algo_name}_{feature_name}'
feature_file = feature_dir / f'{feature_name}.csv'
sub_file = sub_dir / f'{model_name}.csv'
df = pd.read_csv(feature_file, index_col=0)
print(df.shape)
df.head()
(400000, 20)
z | redshift | dered_u | dered_g | dered_r | dered_i | dered_z | nObserve | airmass_u | class | d_dered_u | d_dered_g | d_dered_r | d_dered_i | d_dered_z | d_dered_ig | d_dered_zg | d_dered_rz | d_dered_iz | d_obs_det | |
---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|---|
id | ||||||||||||||||||||
0 | 16.9396 | -8.1086e-05 | 23.1243 | 20.2578 | 18.9551 | 17.6321 | 16.9089 | 2.9444 | 1.1898 | 0.0 | -0.1397 | -0.0790 | -0.0544 | -0.0403 | -0.0307 | -2.6257 | -3.3488 | 2.0462 | 0.7232 | -15.0556 |
1 | 13.1689 | 4.5061e-03 | 14.9664 | 14.0045 | 13.4114 | 13.2363 | 13.1347 | 0.6931 | 1.2533 | 1.0 | -0.0857 | -0.0574 | -0.0410 | -0.0322 | -0.0343 | -0.7683 | -0.8698 | 0.2767 | 0.1016 | -0.3069 |
2 | 15.3500 | 4.7198e-04 | 16.6076 | 15.6866 | 15.4400 | 15.3217 | 15.2961 | 1.0986 | 1.0225 | 0.0 | -0.1787 | -0.1388 | -0.0963 | -0.0718 | -0.0540 | -0.3649 | -0.3905 | 0.1440 | 0.0257 | -0.9014 |
3 | 19.6346 | 5.8143e-06 | 25.3536 | 20.9947 | 20.0873 | 19.7947 | 19.5552 | 1.6094 | 1.2054 | 0.0 | -0.3070 | -0.1941 | -0.1339 | -0.1003 | -0.0795 | -1.2000 | -1.4395 | 0.5321 | 0.2395 | -1.3906 |
4 | 17.9826 | -3.3247e-05 | 23.7714 | 20.4338 | 18.8630 | 18.1903 | 17.8759 | 2.6391 | 1.1939 | 0.0 | -0.6820 | -0.2653 | -0.1794 | -0.1339 | -0.1067 | -2.2436 | -2.5579 | 0.9871 | 0.3144 | -9.3609 |
y = df[target_col].values[:320000]
df.drop(target_col, axis=1, inplace=True)
trn = df.iloc[:320000]
tst = df.iloc[320000:]
feature_name = df.columns.tolist()
print(y.shape, trn.shape, tst.shape)
(320000,) (320000, 19) (80000, 19)
학습/검증 데이터 구분¶
X_trn, X_val, y_trn, y_val = train_test_split(trn, y, test_size=.2, random_state=seed)
print(X_trn.shape, X_val.shape, y_trn.shape, y_val.shape)
(256000, 19) (64000, 19) (256000,) (64000,)
LightGBM 모델 학습¶
clf = lgb.LGBMClassifier(objective='multiclass',
n_estimators=1000,
num_leaves=64,
learning_rate=0.1,
min_child_samples=10,
subsample=.5,
subsample_freq=1,
colsample_bytree=.8,
random_state=seed,
n_jobs=-1)
clf.fit(X_trn, y_trn,
eval_set=[(X_val, y_val)],
eval_metric='multiclass',
early_stopping_rounds=10)
p_val = clf.predict(X_val)
p_tst = clf.predict(tst)
[1] valid_0's multi_logloss: 0.888817
Training until validation scores don't improve for 10 rounds
[2] valid_0's multi_logloss: 0.808716
[3] valid_0's multi_logloss: 0.73642
[4] valid_0's multi_logloss: 0.67493
[5] valid_0's multi_logloss: 0.622153
[6] valid_0's multi_logloss: 0.580522
[7] valid_0's multi_logloss: 0.541283
[8] valid_0's multi_logloss: 0.509037
[9] valid_0's multi_logloss: 0.482079
[10] valid_0's multi_logloss: 0.452677
[11] valid_0's multi_logloss: 0.427118
[12] valid_0's multi_logloss: 0.403961
[13] valid_0's multi_logloss: 0.383394
[14] valid_0's multi_logloss: 0.364428
[15] valid_0's multi_logloss: 0.348016
[16] valid_0's multi_logloss: 0.334241
[17] valid_0's multi_logloss: 0.321063
[18] valid_0's multi_logloss: 0.308737
[19] valid_0's multi_logloss: 0.297626
[20] valid_0's multi_logloss: 0.287811
[21] valid_0's multi_logloss: 0.27825
[22] valid_0's multi_logloss: 0.26961
[23] valid_0's multi_logloss: 0.261913
[24] valid_0's multi_logloss: 0.255251
[25] valid_0's multi_logloss: 0.249147
[26] valid_0's multi_logloss: 0.244276
[27] valid_0's multi_logloss: 0.238584
[28] valid_0's multi_logloss: 0.233369
[29] valid_0's multi_logloss: 0.228661
[30] valid_0's multi_logloss: 0.224218
[31] valid_0's multi_logloss: 0.221098
[32] valid_0's multi_logloss: 0.217668
[33] valid_0's multi_logloss: 0.214128
[34] valid_0's multi_logloss: 0.210878
[35] valid_0's multi_logloss: 0.207931
[36] valid_0's multi_logloss: 0.20538
[37] valid_0's multi_logloss: 0.20288
[38] valid_0's multi_logloss: 0.200445
[39] valid_0's multi_logloss: 0.198375
[40] valid_0's multi_logloss: 0.196419
[41] valid_0's multi_logloss: 0.194399
[42] valid_0's multi_logloss: 0.192703
[43] valid_0's multi_logloss: 0.191124
[44] valid_0's multi_logloss: 0.189626
[45] valid_0's multi_logloss: 0.188421
[46] valid_0's multi_logloss: 0.187082
[47] valid_0's multi_logloss: 0.185866
[48] valid_0's multi_logloss: 0.184847
[49] valid_0's multi_logloss: 0.183792
[50] valid_0's multi_logloss: 0.182771
[51] valid_0's multi_logloss: 0.181812
[52] valid_0's multi_logloss: 0.181021
[53] valid_0's multi_logloss: 0.180188
[54] valid_0's multi_logloss: 0.179476
[55] valid_0's multi_logloss: 0.178834
[56] valid_0's multi_logloss: 0.178237
[57] valid_0's multi_logloss: 0.177583
[58] valid_0's multi_logloss: 0.176999
[59] valid_0's multi_logloss: 0.176388
[60] valid_0's multi_logloss: 0.175812
[61] valid_0's multi_logloss: 0.175346
[62] valid_0's multi_logloss: 0.174852
[63] valid_0's multi_logloss: 0.174447
[64] valid_0's multi_logloss: 0.173964
[65] valid_0's multi_logloss: 0.173533
[66] valid_0's multi_logloss: 0.173169
[67] valid_0's multi_logloss: 0.172809
[68] valid_0's multi_logloss: 0.172418
[69] valid_0's multi_logloss: 0.172074
[70] valid_0's multi_logloss: 0.171756
[71] valid_0's multi_logloss: 0.171481
[72] valid_0's multi_logloss: 0.171192
[73] valid_0's multi_logloss: 0.170923
[74] valid_0's multi_logloss: 0.170707
[75] valid_0's multi_logloss: 0.17042
[76] valid_0's multi_logloss: 0.170871
[77] valid_0's multi_logloss: 0.170594
[78] valid_0's multi_logloss: 0.170273
[79] valid_0's multi_logloss: 0.169802
[80] valid_0's multi_logloss: 0.169603
[81] valid_0's multi_logloss: 0.169404
[82] valid_0's multi_logloss: 0.16919
[83] valid_0's multi_logloss: 0.169024
[84] valid_0's multi_logloss: 0.1688
[85] valid_0's multi_logloss: 0.170769
[86] valid_0's multi_logloss: 0.169504
[87] valid_0's multi_logloss: 0.169375
[88] valid_0's multi_logloss: 0.169265
[89] valid_0's multi_logloss: 0.169105
[90] valid_0's multi_logloss: 0.168939
[91] valid_0's multi_logloss: 0.168824
[92] valid_0's multi_logloss: 0.168713
[93] valid_0's multi_logloss: 0.168551
[94] valid_0's multi_logloss: 0.168442
[95] valid_0's multi_logloss: 0.168354
[96] valid_0's multi_logloss: 0.16821
[97] valid_0's multi_logloss: 0.168084
[98] valid_0's multi_logloss: 0.167978
[99] valid_0's multi_logloss: 0.16785
[100] valid_0's multi_logloss: 0.167742
[101] valid_0's multi_logloss: 0.167693
[102] valid_0's multi_logloss: 0.167673
[103] valid_0's multi_logloss: 0.167475
[104] valid_0's multi_logloss: 0.167424
[105] valid_0's multi_logloss: 0.16737
[106] valid_0's multi_logloss: 0.167374
[107] valid_0's multi_logloss: 0.167552
[108] valid_0's multi_logloss: 0.16726
[109] valid_0's multi_logloss: 0.167194
[110] valid_0's multi_logloss: 0.167149
[111] valid_0's multi_logloss: 0.167097
[112] valid_0's multi_logloss: 0.167004
[113] valid_0's multi_logloss: 0.166939
[114] valid_0's multi_logloss: 0.166868
[115] valid_0's multi_logloss: 0.166838
[116] valid_0's multi_logloss: 0.16676
[117] valid_0's multi_logloss: 0.166722
[118] valid_0's multi_logloss: 0.166703
[119] valid_0's multi_logloss: 0.166664
[120] valid_0's multi_logloss: 0.166539
[121] valid_0's multi_logloss: 0.166496
[122] valid_0's multi_logloss: 0.166463
[123] valid_0's multi_logloss: 0.166434
[124] valid_0's multi_logloss: 0.166445
[125] valid_0's multi_logloss: 0.166406
[126] valid_0's multi_logloss: 0.166361
[127] valid_0's multi_logloss: 0.166329
[128] valid_0's multi_logloss: 0.166508
[129] valid_0's multi_logloss: 0.166394
[130] valid_0's multi_logloss: 0.166367
[131] valid_0's multi_logloss: 0.166314
[132] valid_0's multi_logloss: 0.166249
[133] valid_0's multi_logloss: 0.166234
[134] valid_0's multi_logloss: 0.166162
[135] valid_0's multi_logloss: 0.16609
[136] valid_0's multi_logloss: 0.166088
[137] valid_0's multi_logloss: 0.16606
[138] valid_0's multi_logloss: 0.166025
[139] valid_0's multi_logloss: 0.166032
[140] valid_0's multi_logloss: 0.165969
[141] valid_0's multi_logloss: 0.165969
[142] valid_0's multi_logloss: 0.165949
[143] valid_0's multi_logloss: 0.165998
[144] valid_0's multi_logloss: 0.165946
[145] valid_0's multi_logloss: 0.165898
[146] valid_0's multi_logloss: 0.165873
[147] valid_0's multi_logloss: 0.165844
[148] valid_0's multi_logloss: 0.165806
[149] valid_0's multi_logloss: 0.165815
[150] valid_0's multi_logloss: 0.16581
[151] valid_0's multi_logloss: 0.165789
[152] valid_0's multi_logloss: 0.165771
[153] valid_0's multi_logloss: 0.165726
[154] valid_0's multi_logloss: 0.165692
[155] valid_0's multi_logloss: 0.165629
[156] valid_0's multi_logloss: 0.165621
[157] valid_0's multi_logloss: 0.165612
[158] valid_0's multi_logloss: 0.165554
[159] valid_0's multi_logloss: 0.16551
[160] valid_0's multi_logloss: 0.165483
[161] valid_0's multi_logloss: 0.165447
[162] valid_0's multi_logloss: 0.165442
[163] valid_0's multi_logloss: 0.165427
[164] valid_0's multi_logloss: 0.165416
[165] valid_0's multi_logloss: 0.165384
[166] valid_0's multi_logloss: 0.165401
[167] valid_0's multi_logloss: 0.165404
[168] valid_0's multi_logloss: 0.165413
[169] valid_0's multi_logloss: 0.165382
[170] valid_0's multi_logloss: 0.165358
[171] valid_0's multi_logloss: 0.165325
[172] valid_0's multi_logloss: 0.165303
[173] valid_0's multi_logloss: 0.165287
[174] valid_0's multi_logloss: 0.165236
[175] valid_0's multi_logloss: 0.165245
[176] valid_0's multi_logloss: 0.165261
[177] valid_0's multi_logloss: 0.165269
[178] valid_0's multi_logloss: 0.165265
[179] valid_0's multi_logloss: 0.165246
[180] valid_0's multi_logloss: 0.165202
[181] valid_0's multi_logloss: 0.16519
[182] valid_0's multi_logloss: 0.165189
[183] valid_0's multi_logloss: 0.165184
[184] valid_0's multi_logloss: 0.165182
[185] valid_0's multi_logloss: 0.165178
[186] valid_0's multi_logloss: 0.165086
[187] valid_0's multi_logloss: 0.165076
[188] valid_0's multi_logloss: 0.165062
[189] valid_0's multi_logloss: 0.165028
[190] valid_0's multi_logloss: 0.165012
[191] valid_0's multi_logloss: 0.165017
[192] valid_0's multi_logloss: 0.165027
[193] valid_0's multi_logloss: 0.165022
[194] valid_0's multi_logloss: 0.165048
[195] valid_0's multi_logloss: 0.165034
[196] valid_0's multi_logloss: 0.165014
[197] valid_0's multi_logloss: 0.165018
[198] valid_0's multi_logloss: 0.165015
[199] valid_0's multi_logloss: 0.164988
[200] valid_0's multi_logloss: 0.164967
[201] valid_0's multi_logloss: 0.164949
[202] valid_0's multi_logloss: 0.164952
[203] valid_0's multi_logloss: 0.164935
[204] valid_0's multi_logloss: 0.164923
[205] valid_0's multi_logloss: 0.164958
[206] valid_0's multi_logloss: 0.164947
[207] valid_0's multi_logloss: 0.164956
[208] valid_0's multi_logloss: 0.164933
[209] valid_0's multi_logloss: 0.164924
[210] valid_0's multi_logloss: 0.164938
[211] valid_0's multi_logloss: 0.164912
[212] valid_0's multi_logloss: 0.164944
[213] valid_0's multi_logloss: 0.164909
[214] valid_0's multi_logloss: 0.164904
[215] valid_0's multi_logloss: 0.164936
[216] valid_0's multi_logloss: 0.164909
[217] valid_0's multi_logloss: 0.164912
[218] valid_0's multi_logloss: 0.164897
[219] valid_0's multi_logloss: 0.16492
[220] valid_0's multi_logloss: 0.164907
[221] valid_0's multi_logloss: 0.164894
[222] valid_0's multi_logloss: 0.164931
[223] valid_0's multi_logloss: 0.164911
[224] valid_0's multi_logloss: 0.164979
[225] valid_0's multi_logloss: 0.164981
[226] valid_0's multi_logloss: 0.164956
[227] valid_0's multi_logloss: 0.164955
[228] valid_0's multi_logloss: 0.164969
[229] valid_0's multi_logloss: 0.164976
[230] valid_0's multi_logloss: 0.164977
[231] valid_0's multi_logloss: 0.164981
Early stopping, best iteration is:
[221] valid_0's multi_logloss: 0.164894
print(f'{accuracy_score(y_val, p_val) * 100:.4f}%')
93.1250%
피처 중요도 시각화¶
imp = pd.DataFrame({'feature': trn.columns, 'importance': clf.feature_importances_})
imp = imp.sort_values('importance').set_index('feature')
imp.plot(kind='barh')
<matplotlib.axes._subplots.AxesSubplot at 0x7ff3b07ba6d0>
제출 파일 생성¶
sub = pd.read_csv(sample_file, index_col=0)
print(sub.shape)
sub.head()
(80000, 1)
class | |
---|---|
id | |
320000 | 0 |
320001 | 0 |
320002 | 0 |
320003 | 0 |
320004 | 0 |
sub[target_col] = p_tst
sub.head()
class | |
---|---|
id | |
320000 | 2.0 |
320001 | 0.0 |
320002 | 2.0 |
320003 | 0.0 |
320004 | 2.0 |
sub[target_col].value_counts()
2.0 41006
0.0 29976
1.0 9018
Name: class, dtype: int64
sub.to_csv(sub_file)