2015-12-02 23 views
11

Sto cercando di eseguire il debug di un'architettura NN un po 'complicata e non canonica. Calcolare l'inoltro in avanti va bene e mi dà i risultati attesi, ma quando provo a ottimizzare l'utilizzo di Adam o di uno qualsiasi degli ottimizzatori standard, anche dopo un'iterazione con un tasso di apprendimento molto piccolo ottengo nans ovunque. Sto cercando di localizzarli e mi stavo chiedendo se c'è un modo per catturare la prima occorrenza di una nan e rilevare in quale op è sorta? Ho provato tf.add_check_numerics_ops() ma sembra che non stia facendo nulla, o forse lo sto usando in modo errato.Debugging nans nel pass backward

risposta

18

Il debug dei NaN può essere complicato, soprattutto se si dispone di una rete di grandi dimensioni. tf.add_check_numerics_ops() aggiunge operazioni al grafico che asseriscono che ciascun tensore a virgola mobile nel grafico non contiene alcun valore NaN, ma non esegue questi controlli per impostazione predefinita. Invece si restituisce un op che è possibile eseguire periodicamente, o ad ogni passo, come segue:

train_op = ... 
check_op = tf.add_check_numerics_ops() 

sess = tf.Session() 
sess.run([train_op, check_op]) # Runs training and checks for NaNs 
+0

Il problema è che una volta eseguito il treno_op, i nans si propagano attraverso la rete e quindi è inutile trovare la causa. Quello che mi piacerebbe fare è eseguire i passaggi avanti e indietro, e non appena viene generata una nan, viene lanciata un'eccezione dall'opzione offendente. –

+6

Se esegui 'train_op' e' check_op' insieme, dovresti ricevere un errore che riporta il primo nodo che ha un NaN - puoi catturare il 'tf.InvalidArgumentError' che viene sollevato ed estrarre l'op dal suo'. op' proprietà. Con un handle per l'op, è possibile accedere alla sua proprietà 'op.inputs [0]' per vedere quale tensore aveva valori NaN. – mrry

+0

Ok grazie questo lo farà! –

2

Forse si potrebbe aggiungere ops stampa a sospettare valori ops di stampa, qualcosa di simile

print_ops = [] 
for op in ops: 
    print_ops.append(tf.Print(op, [op], 
        message='%s :' % op.name, summarize=10)) 
print_op = tf.group(*print_ops) 
sess.run([train_op, print_op]) 

Per aggiungere a tutte le operazioni, è possibile fare un ciclo lungo le linee di add_check_numerics_ops.