[ R ] AUCROC and KS (H2O)

2019. 3. 16. 17:59분석 R/구현

html이나 rmd가 필요하다면 댓글에 글 남겨주세요.

Library Load

library(dplyr)
library(tidyverse)
library(h2o)
library(caret) 
library(riskr)

How to install riskr

Click Url

 

 

library(devtools)
install_local("C:/Users/lee/Desktop/riskr-master.zip")

KS PLOT Func

ks_plot <- function( output , name=None ){
  left0 <- output %>% dplyr::select(left , p1 ) %>% filter( left == 0 )
  left1 <- output %>% dplyr::select( left , p1 ) %>% filter( left == 1 )
  
  
  cdf1 <- ecdf(left0$p1) 
  cdf2 <- ecdf(left1$p1) 
  left00 <- left0$p1
  left11 <- left1$p1
  minMax <- seq(min(left00, left11), max(left00, left11), length.out=length(left00)) 
  
  ## k-s 통계량 값 구하기
  x0 <- minMax[which( abs(cdf1(minMax) - cdf2(minMax)) == max(abs(cdf1(minMax) - cdf2(minMax))) )] 
  y0 <- cdf1(x0) 
  y1 <- cdf2(x0) 
  
  ggplot(output, aes(x = p1, group = left , color = factor(left))) +
    stat_ecdf(size=1) +
    theme_bw(base_size = 14) +
    theme() +
    xlab("Predict") +
    ylab("ECDF") +
    geom_segment(aes(x = x0[1], y = y0[1], xend = x0[1], yend = y1[1]),
        linetype = "dashed", color = "red") +
    geom_point(aes(x = x0[1] , y= y0[1]), color="red", size=4) +
    geom_point(aes(x = x0[1] , y= y1[1]), color="red", size=4) +
    ggtitle(paste(name , "/ K-S Test ks = " , round(y0-y1,3) )) +
    theme(legend.title=element_blank())
}

Data Load

data <- read.csv("./HR.csv")
data$left <- factor(data$left)

index <- caret::createDataPartition(data$left, p = 0.7, list = FALSE)

train_data <- data[index, ]
test_data  <- data[-index, ]

index2 <- createDataPartition(train_data$left, p = 0.7, list = FALSE)

valid_data <- train_data[-index2, ]
train_data <- train_data[index2, ]

H2O

#library(h2o)
h2o.init(nthreads = -1)
##  Connection successful!
## 
## R is connected to the H2O cluster: 
##     H2O cluster uptime:         2 minutes 14 seconds 
##     H2O cluster timezone:       Asia/Seoul 
##     H2O data parsing timezone:  UTC 
##     H2O cluster version:        3.20.0.8 
##     H2O cluster version age:    5 months and 23 days !!! 
##     H2O cluster name:           H2O_started_from_R_lee_roj930 
##     H2O cluster total nodes:    1 
##     H2O cluster total memory:   3.51 GB 
##     H2O cluster total cores:    8 
##     H2O cluster allowed cores:  8 
##     H2O cluster healthy:        TRUE 
##     H2O Connection ip:          localhost 
##     H2O Connection port:        54321 
##     H2O Connection proxy:       NA 
##     H2O Internal Security:      FALSE 
##     H2O API Extensions:         Algos, AutoML, Core V3, Core V4 
##     R Version:                  R version 3.5.1 (2018-07-02)
## Warning in h2o.clusterInfo(): 
## Your H2O cluster version is too old (5 months and 23 days)!
## Please download and install the latest version from http://h2o.ai/download/
h2o.no_progress()

train_hf <- as.h2o(train_data)
valid_hf <- as.h2o(valid_data)
test_hf <- as.h2o(test_data)

Parameter

response <- "left"
features <- setdiff(names(train_hf), response)

hyper_params <- list(ntrees = c(100, 300),
                   max_depth = c(3, 5, 7),
                   mtries = c(-1, round(length(features)/3)))

search_criteria <- list(strategy = "RandomDiscrete", max_models = 5, seed = 2019)

RandomForest Random GridSearch

rf_grid <- h2o.grid("randomForest",
                      x = features, 
                      y = response,
                      grid_id = "rf_grid",
                      training_frame = train_hf,
                      validation_frame = valid_hf,
                      hyper_params = hyper_params,
                      search_criteria = search_criteria)

Get Grid Result

rf_gridperf <- h2o.getGrid(grid_id = "rf_grid",
                             sort_by = "auc",
                             decreasing = TRUE)

Get BEST Model

best_model <- h2o.getModel(rf_grid@model_ids[[1]])

Predict and Store

rf.pred <- predict(object = best_model, newdata = test_hf)

rf.pred <- as.data.frame(rf.pred)

rf.pred$left <- test_data$left 

K-S

target <- rf.pred$left
score <- rf.pred$p1

ks(target, score)
## [1] 0.9324294

K-S Plot Visualization

ks_plot(rf.pred , "Randomforest")

AUCROC

aucroc(target, score)
## [1] 0.9897863

AUCROC Visualization

gg_roc(target, score)

 



 


 

 

728x90