LightGBM 실습

라이브러리 import 및 설정

%reload_ext autoreload
%autoreload 2
%matplotlib inline
import lightgbm as lgb
from matplotlib import pyplot as plt
from matplotlib import rcParams
import numpy as np
from pathlib import Path
import pandas as pd
from sklearn.metrics import accuracy_score
from sklearn.model_selection import train_test_split
import seaborn as sns
import warnings
rcParams['figure.figsize'] = (16, 8)
plt.style.use('fivethirtyeight')
pd.set_option('max_columns', 100)
pd.set_option("display.precision", 4)
warnings.simplefilter('ignore')

학습데이터 로드

03-pandas-eda.ipynb에서 생성한 feature.csv 피처파일 사용

data_dir = Path('../data/dacon-dku')
feature_dir = Path('../build/feature')
sub_dir = Path('../build/sub')

trn_file = data_dir / 'train.csv'
tst_file = data_dir / 'test.csv'
sample_file = data_dir / 'sample_submission.csv'

target_col = 'class'
seed = 42
algo_name = 'lgb'
feature_name = 'feature'
model_name = f'{algo_name}_{feature_name}'

feature_file = feature_dir / f'{feature_name}.csv'
sub_file = sub_dir / f'{model_name}.csv'
df = pd.read_csv(feature_file, index_col=0)
print(df.shape)
df.head()
(400000, 20)
z redshift dered_u dered_g dered_r dered_i dered_z nObserve airmass_u class d_dered_u d_dered_g d_dered_r d_dered_i d_dered_z d_dered_ig d_dered_zg d_dered_rz d_dered_iz d_obs_det
id
0 16.9396 -8.1086e-05 23.1243 20.2578 18.9551 17.6321 16.9089 2.9444 1.1898 0.0 -0.1397 -0.0790 -0.0544 -0.0403 -0.0307 -2.6257 -3.3488 2.0462 0.7232 -15.0556
1 13.1689 4.5061e-03 14.9664 14.0045 13.4114 13.2363 13.1347 0.6931 1.2533 1.0 -0.0857 -0.0574 -0.0410 -0.0322 -0.0343 -0.7683 -0.8698 0.2767 0.1016 -0.3069
2 15.3500 4.7198e-04 16.6076 15.6866 15.4400 15.3217 15.2961 1.0986 1.0225 0.0 -0.1787 -0.1388 -0.0963 -0.0718 -0.0540 -0.3649 -0.3905 0.1440 0.0257 -0.9014
3 19.6346 5.8143e-06 25.3536 20.9947 20.0873 19.7947 19.5552 1.6094 1.2054 0.0 -0.3070 -0.1941 -0.1339 -0.1003 -0.0795 -1.2000 -1.4395 0.5321 0.2395 -1.3906
4 17.9826 -3.3247e-05 23.7714 20.4338 18.8630 18.1903 17.8759 2.6391 1.1939 0.0 -0.6820 -0.2653 -0.1794 -0.1339 -0.1067 -2.2436 -2.5579 0.9871 0.3144 -9.3609
y = df[target_col].values[:320000]
df.drop(target_col, axis=1, inplace=True)
trn = df.iloc[:320000]
tst = df.iloc[320000:]
feature_name = df.columns.tolist()
print(y.shape, trn.shape, tst.shape)
(320000,) (320000, 19) (80000, 19)

학습/검증 데이터 구분

X_trn, X_val, y_trn, y_val = train_test_split(trn, y, test_size=.2, random_state=seed)
print(X_trn.shape, X_val.shape, y_trn.shape, y_val.shape)
(256000, 19) (64000, 19) (256000,) (64000,)

LightGBM 모델 학습

clf = lgb.LGBMClassifier(objective='multiclass',
                         n_estimators=1000,
                         num_leaves=64,
                         learning_rate=0.1,
                         min_child_samples=10,
                         subsample=.5,
                         subsample_freq=1,
                         colsample_bytree=.8,
                         random_state=seed,
                         n_jobs=-1)
clf.fit(X_trn, y_trn,
        eval_set=[(X_val, y_val)],
        eval_metric='multiclass',
        early_stopping_rounds=10)
p_val = clf.predict(X_val)
p_tst = clf.predict(tst)
[1]	valid_0's multi_logloss: 0.888817
Training until validation scores don't improve for 10 rounds
[2]	valid_0's multi_logloss: 0.808716
[3]	valid_0's multi_logloss: 0.73642
[4]	valid_0's multi_logloss: 0.67493
[5]	valid_0's multi_logloss: 0.622153
[6]	valid_0's multi_logloss: 0.580522
[7]	valid_0's multi_logloss: 0.541283
[8]	valid_0's multi_logloss: 0.509037
[9]	valid_0's multi_logloss: 0.482079
[10]	valid_0's multi_logloss: 0.452677
[11]	valid_0's multi_logloss: 0.427118
[12]	valid_0's multi_logloss: 0.403961
[13]	valid_0's multi_logloss: 0.383394
[14]	valid_0's multi_logloss: 0.364428
[15]	valid_0's multi_logloss: 0.348016
[16]	valid_0's multi_logloss: 0.334241
[17]	valid_0's multi_logloss: 0.321063
[18]	valid_0's multi_logloss: 0.308737
[19]	valid_0's multi_logloss: 0.297626
[20]	valid_0's multi_logloss: 0.287811
[21]	valid_0's multi_logloss: 0.27825
[22]	valid_0's multi_logloss: 0.26961
[23]	valid_0's multi_logloss: 0.261913
[24]	valid_0's multi_logloss: 0.255251
[25]	valid_0's multi_logloss: 0.249147
[26]	valid_0's multi_logloss: 0.244276
[27]	valid_0's multi_logloss: 0.238584
[28]	valid_0's multi_logloss: 0.233369
[29]	valid_0's multi_logloss: 0.228661
[30]	valid_0's multi_logloss: 0.224218
[31]	valid_0's multi_logloss: 0.221098
[32]	valid_0's multi_logloss: 0.217668
[33]	valid_0's multi_logloss: 0.214128
[34]	valid_0's multi_logloss: 0.210878
[35]	valid_0's multi_logloss: 0.207931
[36]	valid_0's multi_logloss: 0.20538
[37]	valid_0's multi_logloss: 0.20288
[38]	valid_0's multi_logloss: 0.200445
[39]	valid_0's multi_logloss: 0.198375
[40]	valid_0's multi_logloss: 0.196419
[41]	valid_0's multi_logloss: 0.194399
[42]	valid_0's multi_logloss: 0.192703
[43]	valid_0's multi_logloss: 0.191124
[44]	valid_0's multi_logloss: 0.189626
[45]	valid_0's multi_logloss: 0.188421
[46]	valid_0's multi_logloss: 0.187082
[47]	valid_0's multi_logloss: 0.185866
[48]	valid_0's multi_logloss: 0.184847
[49]	valid_0's multi_logloss: 0.183792
[50]	valid_0's multi_logloss: 0.182771
[51]	valid_0's multi_logloss: 0.181812
[52]	valid_0's multi_logloss: 0.181021
[53]	valid_0's multi_logloss: 0.180188
[54]	valid_0's multi_logloss: 0.179476
[55]	valid_0's multi_logloss: 0.178834
[56]	valid_0's multi_logloss: 0.178237
[57]	valid_0's multi_logloss: 0.177583
[58]	valid_0's multi_logloss: 0.176999
[59]	valid_0's multi_logloss: 0.176388
[60]	valid_0's multi_logloss: 0.175812
[61]	valid_0's multi_logloss: 0.175346
[62]	valid_0's multi_logloss: 0.174852
[63]	valid_0's multi_logloss: 0.174447
[64]	valid_0's multi_logloss: 0.173964
[65]	valid_0's multi_logloss: 0.173533
[66]	valid_0's multi_logloss: 0.173169
[67]	valid_0's multi_logloss: 0.172809
[68]	valid_0's multi_logloss: 0.172418
[69]	valid_0's multi_logloss: 0.172074
[70]	valid_0's multi_logloss: 0.171756
[71]	valid_0's multi_logloss: 0.171481
[72]	valid_0's multi_logloss: 0.171192
[73]	valid_0's multi_logloss: 0.170923
[74]	valid_0's multi_logloss: 0.170707
[75]	valid_0's multi_logloss: 0.17042
[76]	valid_0's multi_logloss: 0.170871
[77]	valid_0's multi_logloss: 0.170594
[78]	valid_0's multi_logloss: 0.170273
[79]	valid_0's multi_logloss: 0.169802
[80]	valid_0's multi_logloss: 0.169603
[81]	valid_0's multi_logloss: 0.169404
[82]	valid_0's multi_logloss: 0.16919
[83]	valid_0's multi_logloss: 0.169024
[84]	valid_0's multi_logloss: 0.1688
[85]	valid_0's multi_logloss: 0.170769
[86]	valid_0's multi_logloss: 0.169504
[87]	valid_0's multi_logloss: 0.169375
[88]	valid_0's multi_logloss: 0.169265
[89]	valid_0's multi_logloss: 0.169105
[90]	valid_0's multi_logloss: 0.168939
[91]	valid_0's multi_logloss: 0.168824
[92]	valid_0's multi_logloss: 0.168713
[93]	valid_0's multi_logloss: 0.168551
[94]	valid_0's multi_logloss: 0.168442
[95]	valid_0's multi_logloss: 0.168354
[96]	valid_0's multi_logloss: 0.16821
[97]	valid_0's multi_logloss: 0.168084
[98]	valid_0's multi_logloss: 0.167978
[99]	valid_0's multi_logloss: 0.16785
[100]	valid_0's multi_logloss: 0.167742
[101]	valid_0's multi_logloss: 0.167693
[102]	valid_0's multi_logloss: 0.167673
[103]	valid_0's multi_logloss: 0.167475
[104]	valid_0's multi_logloss: 0.167424
[105]	valid_0's multi_logloss: 0.16737
[106]	valid_0's multi_logloss: 0.167374
[107]	valid_0's multi_logloss: 0.167552
[108]	valid_0's multi_logloss: 0.16726
[109]	valid_0's multi_logloss: 0.167194
[110]	valid_0's multi_logloss: 0.167149
[111]	valid_0's multi_logloss: 0.167097
[112]	valid_0's multi_logloss: 0.167004
[113]	valid_0's multi_logloss: 0.166939
[114]	valid_0's multi_logloss: 0.166868
[115]	valid_0's multi_logloss: 0.166838
[116]	valid_0's multi_logloss: 0.16676
[117]	valid_0's multi_logloss: 0.166722
[118]	valid_0's multi_logloss: 0.166703
[119]	valid_0's multi_logloss: 0.166664
[120]	valid_0's multi_logloss: 0.166539
[121]	valid_0's multi_logloss: 0.166496
[122]	valid_0's multi_logloss: 0.166463
[123]	valid_0's multi_logloss: 0.166434
[124]	valid_0's multi_logloss: 0.166445
[125]	valid_0's multi_logloss: 0.166406
[126]	valid_0's multi_logloss: 0.166361
[127]	valid_0's multi_logloss: 0.166329
[128]	valid_0's multi_logloss: 0.166508
[129]	valid_0's multi_logloss: 0.166394
[130]	valid_0's multi_logloss: 0.166367
[131]	valid_0's multi_logloss: 0.166314
[132]	valid_0's multi_logloss: 0.166249
[133]	valid_0's multi_logloss: 0.166234
[134]	valid_0's multi_logloss: 0.166162
[135]	valid_0's multi_logloss: 0.16609
[136]	valid_0's multi_logloss: 0.166088
[137]	valid_0's multi_logloss: 0.16606
[138]	valid_0's multi_logloss: 0.166025
[139]	valid_0's multi_logloss: 0.166032
[140]	valid_0's multi_logloss: 0.165969
[141]	valid_0's multi_logloss: 0.165969
[142]	valid_0's multi_logloss: 0.165949
[143]	valid_0's multi_logloss: 0.165998
[144]	valid_0's multi_logloss: 0.165946
[145]	valid_0's multi_logloss: 0.165898
[146]	valid_0's multi_logloss: 0.165873
[147]	valid_0's multi_logloss: 0.165844
[148]	valid_0's multi_logloss: 0.165806
[149]	valid_0's multi_logloss: 0.165815
[150]	valid_0's multi_logloss: 0.16581
[151]	valid_0's multi_logloss: 0.165789
[152]	valid_0's multi_logloss: 0.165771
[153]	valid_0's multi_logloss: 0.165726
[154]	valid_0's multi_logloss: 0.165692
[155]	valid_0's multi_logloss: 0.165629
[156]	valid_0's multi_logloss: 0.165621
[157]	valid_0's multi_logloss: 0.165612
[158]	valid_0's multi_logloss: 0.165554
[159]	valid_0's multi_logloss: 0.16551
[160]	valid_0's multi_logloss: 0.165483
[161]	valid_0's multi_logloss: 0.165447
[162]	valid_0's multi_logloss: 0.165442
[163]	valid_0's multi_logloss: 0.165427
[164]	valid_0's multi_logloss: 0.165416
[165]	valid_0's multi_logloss: 0.165384
[166]	valid_0's multi_logloss: 0.165401
[167]	valid_0's multi_logloss: 0.165404
[168]	valid_0's multi_logloss: 0.165413
[169]	valid_0's multi_logloss: 0.165382
[170]	valid_0's multi_logloss: 0.165358
[171]	valid_0's multi_logloss: 0.165325
[172]	valid_0's multi_logloss: 0.165303
[173]	valid_0's multi_logloss: 0.165287
[174]	valid_0's multi_logloss: 0.165236
[175]	valid_0's multi_logloss: 0.165245
[176]	valid_0's multi_logloss: 0.165261
[177]	valid_0's multi_logloss: 0.165269
[178]	valid_0's multi_logloss: 0.165265
[179]	valid_0's multi_logloss: 0.165246
[180]	valid_0's multi_logloss: 0.165202
[181]	valid_0's multi_logloss: 0.16519
[182]	valid_0's multi_logloss: 0.165189
[183]	valid_0's multi_logloss: 0.165184
[184]	valid_0's multi_logloss: 0.165182
[185]	valid_0's multi_logloss: 0.165178
[186]	valid_0's multi_logloss: 0.165086
[187]	valid_0's multi_logloss: 0.165076
[188]	valid_0's multi_logloss: 0.165062
[189]	valid_0's multi_logloss: 0.165028
[190]	valid_0's multi_logloss: 0.165012
[191]	valid_0's multi_logloss: 0.165017
[192]	valid_0's multi_logloss: 0.165027
[193]	valid_0's multi_logloss: 0.165022
[194]	valid_0's multi_logloss: 0.165048
[195]	valid_0's multi_logloss: 0.165034
[196]	valid_0's multi_logloss: 0.165014
[197]	valid_0's multi_logloss: 0.165018
[198]	valid_0's multi_logloss: 0.165015
[199]	valid_0's multi_logloss: 0.164988
[200]	valid_0's multi_logloss: 0.164967
[201]	valid_0's multi_logloss: 0.164949
[202]	valid_0's multi_logloss: 0.164952
[203]	valid_0's multi_logloss: 0.164935
[204]	valid_0's multi_logloss: 0.164923
[205]	valid_0's multi_logloss: 0.164958
[206]	valid_0's multi_logloss: 0.164947
[207]	valid_0's multi_logloss: 0.164956
[208]	valid_0's multi_logloss: 0.164933
[209]	valid_0's multi_logloss: 0.164924
[210]	valid_0's multi_logloss: 0.164938
[211]	valid_0's multi_logloss: 0.164912
[212]	valid_0's multi_logloss: 0.164944
[213]	valid_0's multi_logloss: 0.164909
[214]	valid_0's multi_logloss: 0.164904
[215]	valid_0's multi_logloss: 0.164936
[216]	valid_0's multi_logloss: 0.164909
[217]	valid_0's multi_logloss: 0.164912
[218]	valid_0's multi_logloss: 0.164897
[219]	valid_0's multi_logloss: 0.16492
[220]	valid_0's multi_logloss: 0.164907
[221]	valid_0's multi_logloss: 0.164894
[222]	valid_0's multi_logloss: 0.164931
[223]	valid_0's multi_logloss: 0.164911
[224]	valid_0's multi_logloss: 0.164979
[225]	valid_0's multi_logloss: 0.164981
[226]	valid_0's multi_logloss: 0.164956
[227]	valid_0's multi_logloss: 0.164955
[228]	valid_0's multi_logloss: 0.164969
[229]	valid_0's multi_logloss: 0.164976
[230]	valid_0's multi_logloss: 0.164977
[231]	valid_0's multi_logloss: 0.164981
Early stopping, best iteration is:
[221]	valid_0's multi_logloss: 0.164894
print(f'{accuracy_score(y_val, p_val) * 100:.4f}%')
93.1250%

피처 중요도 시각화

imp = pd.DataFrame({'feature': trn.columns, 'importance': clf.feature_importances_})
imp = imp.sort_values('importance').set_index('feature')
imp.plot(kind='barh')
<matplotlib.axes._subplots.AxesSubplot at 0x7ff3b07ba6d0>
../_images/07-lightgbm_17_1.png

제출 파일 생성

sub = pd.read_csv(sample_file, index_col=0)
print(sub.shape)
sub.head()
(80000, 1)
class
id
320000 0
320001 0
320002 0
320003 0
320004 0
sub[target_col] = p_tst
sub.head()
class
id
320000 2.0
320001 0.0
320002 2.0
320003 0.0
320004 2.0
sub[target_col].value_counts()
2.0    41006
0.0    29976
1.0     9018
Name: class, dtype: int64
sub.to_csv(sub_file)