混淆矩陣為常見的分析方法,能夠分析模型在每個類別上的表現狀況。
要利用混淆矩陣分析,首先要載入相關的函式
from sklearn.metrics import classification_report import itertools import matplotlib.pyplot as plt from sklearn.metrics import confusion_matrix import numpy as np
將繪制圖表的部份程式碼整理程一個函式,方便調用:
def plot_confusion_matrix(cm, classes, normalize=False, title='Confusion matrix', cmap=plt.cm.Blues): """ This function prints and plots the confusion matrix. Normalization can be applied by setting `normalize=True`. """ if normalize: cm = cm.astype('float') / cm.sum(axis=1)[:, np.newaxis] print("Normalized confusion matrix") else: print('Confusion matrix, without normalization') print(cm) plt.imshow(cm, interpolation='nearest', cmap=cmap) plt.title(title) plt.colorbar() tick_marks = np.arange(len(classes)) plt.xticks(tick_marks, classes, rotation=45) plt.yticks(tick_marks, classes) fmt = '.2f' if normalize else 'd' thresh = cm.max() / 2. for i, j in itertools.product(range(cm.shape[0]), range(cm.shape[1])): plt.text(j, i, format(cm[i, j], fmt), horizontalalignment="center", color="white" if cm[i, j] > thresh else "black") plt.ylabel('True label') plt.xlabel('Predicted label') plt.tight_layout()
最終m表示要測試的模型檔案,X_test表示測試資料集的輸入向量,可以得到y_predict為模型的預測結果。
透過convert_to_labels 將one-hot-encode的資料轉成label ( 詳細的程式碼可以參考)
target_names 為各個label所代表的含意,為方便對應會在此定義。
最終利用剛剛定義的plot_confusion_matrix方法進行繪製。
y_predict = m.predict(X_test, batch_size=None, verbose=0, steps=None) y_pred = convert_to_labels(y_predict) y_true = convert_to_labels(y_test) target_names = [ 'Hunger','Sleepy' ,'Diaper','Painful'] print ("month = " + str(month)) print(classification_report(y_true, y_pred, target_names=target_names)) print ("**************************************************************") plt.figure() cnf_matrix = confusion_matrix(y_true, y_pred) plot_confusion_matrix(cnf_matrix, classes=target_names,normalize=True, title="month = " + str(month) + ' confusion matrix') plt.show()
顯示結果如下圖:
文章標籤
全站熱搜