2016-02-23 15 views
18

Sto usando scikit-learn per la classificazione di documenti di testo (da 22000) a 100 classi. Uso il metodo della matrice di confusione di scikit-learn per calcolare la matrice di confusione.Come posso tracciare una matrice di confusione?

model1 = LogisticRegression() 
model1 = model1.fit(matrix, labels) 
pred = model1.predict(test_matrix) 
cm=metrics.confusion_matrix(test_labels,pred) 
print(cm) 
plt.imshow(cm, cmap='binary') 

Ecco come il mio matrice di confusione si presenta come:

[[3962 325 0 ..., 0 0 0] 
[ 250 2765 0 ..., 0 0 0] 
[ 2 8 17 ..., 0 0 0] 
..., 
[ 1 6 0 ..., 5 0 0] 
[ 1 1 0 ..., 0 0 0] 
[ 9 0 0 ..., 0 0 9]] 

Tuttavia, non ricevono una trama trasparente o leggibili. C'è un modo migliore per farlo?

risposta

13

@ amillerrhodes di dà risposta perfetta a How to plot confusion matrix with string axis rather than integer in python.

confusion matrix example

Ecco il codice che genera l'immagine qui sopra

import numpy as np 
import matplotlib.pyplot as plt 

conf_arr = [[33,2,0,0,0,0,0,0,0,1,3], 
      [3,31,0,0,0,0,0,0,0,0,0], 
      [0,4,41,0,0,0,0,0,0,0,1], 
      [0,1,0,30,0,6,0,0,0,0,1], 
      [0,0,0,0,38,10,0,0,0,0,0], 
      [0,0,0,3,1,39,0,0,0,0,4], 
      [0,2,2,0,4,1,31,0,0,0,2], 
      [0,1,0,0,0,0,0,36,0,2,0], 
      [0,0,0,0,0,0,1,5,37,5,1], 
      [3,0,0,0,0,0,0,0,0,39,0], 
      [0,0,0,0,0,0,0,0,0,0,38]] 

norm_conf = [] 
for i in conf_arr: 
    a = 0 
    tmp_arr = [] 
    a = sum(i, 0) 
    for j in i: 
     tmp_arr.append(float(j)/float(a)) 
    norm_conf.append(tmp_arr) 

fig = plt.figure() 
plt.clf() 
ax = fig.add_subplot(111) 
ax.set_aspect(1) 
res = ax.imshow(np.array(norm_conf), cmap=plt.cm.jet, 
       interpolation='nearest') 

width, height = conf_arr.shape 

for x in xrange(width): 
    for y in xrange(height): 
     ax.annotate(str(conf_arr[x][y]), xy=(y, x), 
        horizontalalignment='center', 
        verticalalignment='center') 

cb = fig.colorbar(res) 
alphabet = 'ABCDEFGHIJKLMNOPQRSTUVWXYZ' 
plt.xticks(range(width), alphabet[:width]) 
plt.yticks(range(height), alphabet[:height]) 
plt.savefig('confusion_matrix.png', format='png') 

Speranza che aiuta.

43

enter image description here

è possibile utilizzare al posto di plt.matshow()plt.imshow() oppure è possibile utilizzare il modulo di Seaborn heatmap per tracciare la matrice di confusione

import seaborn as sn 
import pandas as pd 
import matplotlib.pyplot as plt 
array = [[33,2,0,0,0,0,0,0,0,1,3], 
     [3,31,0,0,0,0,0,0,0,0,0], 
     [0,4,41,0,0,0,0,0,0,0,1], 
     [0,1,0,30,0,6,0,0,0,0,1], 
     [0,0,0,0,38,10,0,0,0,0,0], 
     [0,0,0,3,1,39,0,0,0,0,4], 
     [0,2,2,0,4,1,31,0,0,0,2], 
     [0,1,0,0,0,0,0,36,0,2,0], 
     [0,0,0,0,0,0,1,5,37,5,1], 
     [3,0,0,0,0,0,0,0,0,39,0], 
     [0,0,0,0,0,0,0,0,0,0,38]] 
df_cm = pd.DataFrame(array, index = [i for i in "ABCDEFGHIJK"], 
        columns = [i for i in "ABCDEFGHIJK"]) 
plt.figure(figsize = (10,7)) 
sn.heatmap(df_cm, annot=True) 
14

risposta s' @bninopaul non è del tutto per i principianti

qui è il codice è possibile "copiare ed eseguire"

import seaborn as sn 
import pandas as pd 
import matplotlib.pyplot as plt 

array = [[13,1,1,0,2,0], 
    [3,9,6,0,1,0], 
    [0,0,16,2,0,0], 
    [0,0,0,13,0,0], 
    [0,0,0,0,15,0], 
    [0,0,1,0,0,15]]   
df_cm = pd.DataFrame(array, range(6), 
        range(6)) 
#plt.figure(figsize = (10,7)) 
sn.set(font_scale=1.4)#for label size 
sn.heatmap(df_cm, annot=True,annot_kws={"size": 16})# font size 

result

Problemi correlati