2015-11-14 31 views
54

Ho seguito le esercitazioni di mnist fornite ed è stato in grado di addestrare un modello e valutarne l'accuratezza. Tuttavia, le esercitazioni non mostrano come fare previsioni date un modello. Non mi interessa la precisione, voglio solo usare il modello per predire un nuovo esempio e nell'output vedere tutti i risultati (etichette), ciascuno con il punteggio assegnato (ordinato o meno).Creazione di previsioni con un modello TensorFlow

+0

Ho creato un archivio in cui è possibile disegnare numeri e testare il modello con i propri dati. https://github.com/EddieOne/mnist-live-test Non viene fornito con le istruzioni. Ho comunque realizzato un video con una panoramica di alto livello. https://www.youtube.com/watch?v=pudJU-cDkMo – Eddie

risposta

53

Nella "Deep MNIST for Experts" esempio, vedere questa linea:

ora possiamo implementare il nostro modello di regressione. Prende solo una linea! Noi moltiplichiamo le immagini di input vettorizzate x per la matrice di peso W, aggiungiamo la polarizzazione b e calcoliamo le probabilità di softmax assegnate a ogni classe.

y = tf.nn.softmax(tf.matmul(x,W) + b) 

Basta tirare sul nodo y e sarete avere quello che vuoi.

feed_dict = {x: [your_image]} 
classification = tf.run(y, feed_dict) 
print classification 

questo vale per quasi ogni modello si crea - avrai calcolato le probabilità di previsione come uno degli ultimi passi prima di calcolare la perdita.

+1

Quando si verifica questo suggerimento sull'esempio convnet (con 'y_conv = tf.nn.softmax (tf.matmul (h_fc1_drop, W_fc2) + b_fc2)' I get 'Argomento non valido: è necessario alimentare un valore per il tensore segnaposto 'Placeholder_2' con dtype float', per l'esempio semplice di softmax funziona bene.Qualsiasi idea perché sia ​​così? –

+3

Posso rispondere al mio commento: l'esempio di convnet ha un variabile addizionale in feed_dict, mi sono perso per aggiungere quello. In questo caso il feed_dict dovrebbe apparire così: 'feed_dict = {x: [your_image], keep_prob: 1.0}' –

+0

L'output del tuo codice sarà qualcosa di simile a [False Vero Falso ..., Vero Falso Vero], ma voglio convertirlo in [3 1 3 ..., 1 5 1], quali etichette di classe errate invece di False. Come possiamo ottenere quell'etichetta che è errata classificato invece di falso? –

12

Come suggerito da @dga, è necessario eseguire la nuova istanza dei dati tramite il modello già previsto.

Ecco un esempio:

presuppongono che è andato anche se il primo tutorial e calcolato l'accuratezza del modello (il modello è questo: y = tf.nn.softmax(tf.matmul(x, W) + b)). Ora prendi il tuo modello e applica il nuovo punto dati ad esso. Nel seguente codice, calcolo il vettore, ottenendo la posizione del valore massimo. Mostra l'immagine e stampa quella posizione massima.

from matplotlib import pyplot as plt 
from random import randint 
num = randint(0, mnist.test.images.shape[0]) 
img = mnist.test.images[num] 

classification = sess.run(tf.argmax(y, 1), feed_dict={x: [img]}) 
plt.imshow(img.reshape(28, 28), cmap=plt.cm.binary) 
plt.show() 
print 'NN predicted', classification[0] 
Problemi correlati