Catboost & Shap) 예시 코드

2022. 11. 16. 19:54분석 Python/chatbot

728x90

catboost와 shap 관련 예시 코드

catboost를 사용할 때, shap.Explanation을 잘 사용해야 함.

cat_features = list(train_x.select_dtypes("category"))
train_dataset = cb.Pool(train_x, train_y,cat_features=cat_features)
test__dataset = cb.Pool(test_x, test_y,cat_features=cat_features)
model = cb.CatBoostRegressor(loss_function='RMSE',random_state=1234, custom_metric=['RMSE', 'MAE', 'R2'])
grid = {'iterations': [100, 150, 200],
        'learning_rate': [0.03, 0.05,0.07,0.1],
        'depth': [6, 8,10],
        'l2_leaf_reg': [0.2, 0.5, 1, 3]}
model.randomized_search(grid, train_dataset,n_iter=30,verbose=False,partition_random_seed=0)

pd.DataFrame([model.get_params()])

sorted_feature_importance = model.feature_importances_.argsort()
plt.barh(np.array(list(train_x))[sorted_feature_importance], 
        model.feature_importances_[sorted_feature_importance], 
        color='turquoise')
plt.xlabel("CatBoost Feature Importance")


approximate = model.get_feature_importance(train_dataset, type="ShapValues", shap_calc_type="Approximate")[:, :-1]
shap.summary_plot(approximate, train_x,plot_type="bar")
approximate = model.get_feature_importance(train_dataset, type="ShapValues", shap_calc_type="Exact")[:, :-1]
shap.summary_plot(approximate, train_x,plot_type="bar")
approximate = model.get_feature_importance(train_dataset, type="ShapValues", shap_calc_type="Regular")[:, :-1]
shap.summary_plot(approximate, train_x,plot_type="bar")
shap.dependence_plot("[feature]", approximate, train_x)


explainer = shap.TreeExplainer(model)
shap_values = explainer.shap_values(cb.Pool(train_x, train_y, cat_features=cat_features))

shap.initjs()
sample_idxs = [0,1,2]
shap.force_plot(explainer.expected_value, shap_values[sample_idxs,], train_x.iloc[sample_idxs,:])

shap.summary_plot(shap_values, train_x, plot_type="bar")


shap_values_v2 = explainer(cb.Pool(train_x, train_y, cat_features=cat_features))
exp =shap.Explanation(values=shap_values_v2.values , base_values= shap_values_v2.base_values,data=train_x.values, feature_names=list(train_x))
shap.plots.scatter(exp[:, 1])
shap.plots.scatter(exp[:, 0])

from shap.plots import waterfall
waterfall(exp[1],show=True,max_display=20)
shap.plots.heatmap(exp[0:500],max_display=10)
728x90

'분석 Python > chatbot' 카테고리의 다른 글

windows7에서 챗봇 실행권한 줄 때는 icacls 활용  (0) 2017.12.26