混淆矩陣為常見的分析方法,能夠分析模型在每個類別上的表現狀況。

要利用混淆矩陣分析,首先要載入相關的函式

from sklearn.metrics import classification_report

import itertools
import matplotlib.pyplot as plt
from sklearn.metrics import confusion_matrix
import numpy as np

 

 

將繪制圖表的部份程式碼整理程一個函式,方便調用:

def plot_confusion_matrix(cm, classes,
                          normalize=False,
                          title='Confusion matrix',
                          cmap=plt.cm.Blues):
    """
    This function prints and plots the confusion matrix.
    Normalization can be applied by setting `normalize=True`.
    """
    if normalize:
        cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis]
        print("Normalized confusion matrix")
    else:
        print('Confusion matrix, without normalization')

    print(cm)

    plt.imshow(cm, interpolation='nearest', cmap=cmap)
    plt.title(title)
    plt.colorbar()
    tick_marks = np.arange(len(classes))
    plt.xticks(tick_marks, classes, rotation=45)
    plt.yticks(tick_marks, classes)

    fmt = '.2f' if normalize else 'd'
    thresh = cm.max() / 2.
    for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])):
        plt.text(j, i, format(cm[i, j], fmt),
                 horizontalalignment="center",
                 color="white" if cm[i, j] > thresh else "black")

    plt.ylabel('True label')
    plt.xlabel('Predicted label')
    plt.tight_layout()

 

 

最終m表示要測試的模型檔案,X_test表示測試資料集的輸入向量,可以得到y_predict為模型的預測結果。

透過convert_to_labels 將one-hot-encode的資料轉成label ( 詳細的程式碼可以參考

target_names 為各個label所代表的含意,為方便對應會在此定義。

最終利用剛剛定義的plot_confusion_matrix方法進行繪製。

y_predict = m.predict(X_test, batch_size=None, verbose=0, steps=None)

y_pred = convert_to_labels(y_predict)
y_true = convert_to_labels(y_test)
target_names = [ 'Hunger','Sleepy' ,'Diaper','Painful']
print ("month = " + str(month))
print(classification_report(y_true, y_pred, target_names=target_names))
print ("**************************************************************")

plt.figure()
cnf_matrix = confusion_matrix(y_true, y_pred)
plot_confusion_matrix(cnf_matrix, classes=target_names,normalize=True,
                    title="month = " + str(month) + ' confusion matrix')

plt.show()

 

顯示結果如下圖:

arrow
arrow
    文章標籤
    confusion sklearn python
    全站熱搜

    Lung-Yu,Tsai 發表在 痞客邦 留言(2) 人氣()