2015-11-13 34 views
8

Come posso scegliere di eseguire una parte del grafico in base a una condizione?Esecuzione condizionale in TensorFlow

Ho una parte della mia rete che deve essere eseguita solo se viene fornito un valore di segnaposto in feed_dict. Viene eseguito un percorso alternativo se il valore non viene fornito. Come faccio a implementare questo utilizzando tensorflow?

Qui ci sono le porzioni rilevanti del mio codice:

sess.run(accuracy, feed_dict={inputs: mnist.test.images, outputs: mnist.test.labels}) 

N = tf.shape(outputs) 
    cost = 0 
    if N > 0: 
     y_N = tf.slice(h_c, [0, 0], N) 
     cross_entropy = tf.nn.softmax_cross_entropy_with_logits(y_N, outputs, name='xentropy') 
     cost = tf.reduce_mean(cross_entropy, name='xentropy_mean') 

Nel codice di cui sopra, sto cercando qualcosa da usare al posto di if N > 0:

risposta

7

Hrm. E 'possibile che ciò che si vuole è tf.control_flow_ops.cond() https://github.com/tensorflow/tensorflow/blob/master/tensorflow/python/ops/control_flow_ops.py#L597

Ma non è esportato nel namespace tf, e sto rispondendo senza controllare come garantito stabile questa interfaccia è, ma è utilizzato in modelli rilasciati , quindi vai per questo. :)

Tuttavia: poiché in realtà sai in anticipo quale percorso desideri quando costruisci il feed_dict, potresti anche adottare un approccio diverso per invocare un percorso separato attraverso il tuo modello. Il metodo standard per farlo è quello di, per esempio, impostare il codice come:

def model(input, n_greater_than): 
    ... cleverness ... 
    if n_greater_than: 
    ... other cleverness... 
    return tf.reduce_mean(input) 


out1 = model(input, True) 
out2 = model(input, False) 

E poi tirare l'out1 o nodi OUT2 seconda di ciò che si sa quando si sta per eseguire il calcolo e impostare la feed_dict. Ricorda che per impostazione predefinita, se il modello fa riferimento alle stesse variabili (creale all'esterno di, il modello() func), in pratica avrai due percorsi separati.

potete vedere un esempio di questo nell'esempio mnist convoluzionale: https://github.com/tensorflow/tensorflow/blob/master/tensorflow/models/image/mnist/convolutional.py#L165

io sono un fan di farlo in questo modo, senza introdurre le dipendenze del flusso di controllo, se potete.

0

Ecco un semplice esempio, che può iniziare. Esegue diverse parti del grafico in base alla forma del tensore:

import tensorflow as tf 

a = tf.Variable([[3.0, 3.0], [3.0, 3.0]]) 
b = tf.Variable([[1.0, 1.0], [2.0, 2.0]]) 
l = tf.shape(a) 

add_op, sub_op = tf.add(a, b), tf.sub(a, b) 

sess = tf.Session() 
init = tf.initialize_all_variables() 
sess.run(init) 
t = sess.run(l) 

print sess.run(sub_op if t[0] == 3 else add_op) 

sess.close() 

Change 3 a 2 per vedere come sarà detratta tensore. Come vedi, ho avviato i nodi per add e sub e shape, quindi nel grafico controllo la forma ed eseguo la parte specifica.