Sklearn SVM + OneVsRestClassifer Gridsearch

2019. 6. 15. 18:43분석 Python/Scikit Learn (싸이킷런)

 

SVM Pipe Line & Grid Search & N-fold

 

sklearn에서 poepline을 이용하면 쉽게 모델링과 GridSearch가 가능하다.

 

하지만 OneVsRestClassifer를 추가하려니 다음과 같은 문제가 있어서, 이것을 해결해서 공유한다.

OneVsOneClassifier 도 가능할 것이다.

 

 

svm = Pipeline([('scl', StandardScaler()) , ("clf", SVC(random_state=1 , probability = True))])
pipe_svc = Pipeline([('ova', OneVsRestClassifier(svm))])
cv = ShuffleSplit(n_splits=5, test_size=0.3, random_state=0)
param_range = [0.0001, 0.001, 0.01, 0.1, 1.0, 10.0, 100.0, 1000.0]
param_grid = [
    {'ova__estimator__clf__C': param_range, 'ova__estimator__clf__kernel': ['linear']},
    {'ova__estimator__clf__C': param_range, 'ova__estimator__clf__gamma': param_range, 'ova__estimator__clf__kernel': ['rbf']}]

gs = GridSearchCV(estimator=pipe_svc, param_grid=param_grid, 
                  scoring='accuracy', cv=cv , n_jobs=10)
                  
                  
gs.fit(X_train , y_train)

print(gs.best_score_)
print(gs.best_params_)

##prob
y_prob = gs.predict_proba(X_test)
##predict
y_pred = gs.predict(X_test)
y_score = gs.decision_function(X_test)

target_names = ["class 0", "class 1"]
print(classification_report(y_test, y_pred, target_names=target_names))

f1_score(y_test, y_pred, average='macro')  
f1_score(y_test, y_pred, average='micro')  
f1_score(y_test, y_pred, average='weighted')


def rocvis(true , prob , label ) :
    AUC = np.mean(true == np.round(prob.ravel()).astype(int))
    if type(true[0]) == str :
        from sklearn.preprocessing import LabelEncoder
        le = LabelEncoder()
        true = le.fit_transform(true)
    else :
        pass
    fpr, tpr, thresholds = roc_curve(true, prob)
    plt.plot(fpr, tpr, marker='.', label =  "AUC : {:.2f} , {}".format(AUC,label)   )
    
    
plt.style.use('ggplot')
fig , ax = plt.subplots(figsize= (20,10))
plt.plot([0, 1], [0, 1], linestyle='--')
rocvis(true = y_test , prob = y_prob[:,1] , label = "svm" )
plt.legend(fontsize = 20 , loc='center', shadow=True )
plt.title("Models Roc Curve" , fontsize= 25)
plt.savefig("./Model_Result.png")
plt.show()


728x90