2016-01-14 12 views
6

Oggi ho provato a tracciare la matrice di confusione dal mio modello di classificazione.Matstotlib matshow con molte stringhe

Dopo la ricerca in alcune pagine, ho trovato che matshow da pyplot può aiutarmi.

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 
    cax = ax.matshow(cm) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     ax.set_xticklabels([''] + labels) 
     ax.set_yticklabels([''] + labels) 
    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.show() 

Funziona bene se ho alcune etichette

y_true = ['a', 'b', 'c', 'd', 'a', 'b', 'c', 'a', 'c', 'd', 'b', 'a', 'b', 'a'] 
y_pred = ['a', 'b', 'c', 'd', 'a', 'b', 'b', 'a', 'c', 'a', 'a', 'a', 'a', 'a'] 
labels = list(set(y_true)) 
cm = confusion_matrix(y_true, y_pred) 
plot_confusion_matrix(cm, labels=labels) 

enter image description here

Ma se ho molte etichette, alcune etichette non mostrano correttamente

y_true = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'] 
y_pred = ['a', 'b', 'c', 'd', 'e', 'f', 'g', 'h', 'i', 'j', 'k', 'l', 'm', 'n'] 
labels = list(set(y_true)) 
cm = confusion_matrix(y_true, y_pred) 
plot_confusion_matrix(cm, labels=labels) 

enter image description here

La mia domanda è: come posso visualizzare TUTTE le etichette nella trama di matshow? Ho provato qualcosa come fontdict ma ancora non funziona

risposta

3

È possibile controllare la frequenza delle zecche utilizzando il modulo matplotlib.ticker.

In questo caso, si desidera impostare un segno di spunta ogni multiplo di 1, in modo che possiamo utilizzare un MultipleLocator

Aggiungere queste due righe prima di chiamare plt.show():

ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 
ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 

e produrrà un barrare e contrassegnare per ogni lettera nel numero y_true e y_pred.

Ho cambiato anche il tuo matshow chiamata a fare uso della mappa di colori si specifica nella chiamata di funzione:

cax = ax.matshow(cm,cmap=cmap) 

enter image description here

Per completezza, la vostra intera funzione sarà simile a questa:

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 
import matplotlib.ticker as ticker 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call 
    cax = ax.matshow(cm,cmap=cmap) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     ax.set_xticklabels([''] + labels) 
     ax.set_yticklabels([''] + labels) 

    ax.xaxis.set_major_locator(ticker.MultipleLocator(1)) 
    ax.yaxis.set_major_locator(ticker.MultipleLocator(1)) 

    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.savefig('confusionmatrix.png') 
+0

ha funzionato. grazie Tom Un'altra domanda, sai come trarre valore dalla matrice di confusione nella trama di matshow? –

+0

probabilmente è meglio fare una nuova domanda con maggiori dettagli di quello che vuoi – tom

4

È possibile utilizzare il metodo xticks per specificare le etichette. La tua funzione sarà simile a questa (modificando la funzione dalla risposta sopra):

import matplotlib.pyplot as plt 
from sklearn.metrics import confusion_matrix 

def plot_confusion_matrix(cm, title='Confusion matrix', cmap=plt.cm.Blues, labels=None): 
    fig = plt.figure() 
    ax = fig.add_subplot(111) 

    # I also added cmap=cmap here, to make use of the 
    # colormap you specify in the function call 
    cax = ax.matshow(cm,cmap=cmap) 
    plt.title(title) 
    fig.colorbar(cax) 
    if labels: 
     plt.xticks(range(len(labels)), labels) 
     plt.yticks(range(len(labels)), labels) 

    plt.xlabel('Predicted') 
    plt.ylabel('True') 
    plt.savefig('confusionmatrix.png') 
Problemi correlati