2016-03-06 40 views
26

Diciamo che sono seguente codice:Come aggiungere se la condizione in un grafico TensorFlow?

x = tf.placeholder("float32", shape=[None, ins_size**2*3], name = "x_input") 
condition = tf.placeholder("int32", shape=[1, 1], name = "condition") 
W = tf.Variable(tf.zeros([ins_size**2*3,label_option]), name = "weights") 
b = tf.Variable(tf.zeros([label_option]), name = "bias") 

if condition > 0: 
    y = tf.nn.softmax(tf.matmul(x, W) + b) 
else: 
    y = tf.nn.softmax(tf.matmul(x, W) - b) 

Sarebbe il lavoro if dichiarazione nel calcolo (io non la penso così)? In caso contrario, come è possibile aggiungere un'istruzione if nel grafico di calcolo TensorFlow?

risposta

51

È corretto che l'istruzione if non funzioni qui, poiché la condizione viene valutata al momento della costruzione del grafico, mentre presumibilmente si desidera che la condizione dipenda dal valore inviato al segnaposto in fase di esecuzione. (In realtà, sarà sempre prendere il primo ramo, perché condition > 0 restituisce un Tensor, che è "truthy" in Python.)

Per supportare flusso controllo condizionale, tensorflow fornisce all'operatore tf.cond(), che valuta uno dei due rami, a seconda di una condizione booleana. Per mostrare come usarlo, io riscrivere il programma in modo che condition è un valore scalare tf.int32 per semplicità:

x = tf.placeholder(tf.float32, shape=[None, ins_size**2*3], name="x_input") 
condition = tf.placeholder(tf.int32, shape=[], name="condition") 
W = tf.Variable(tf.zeros([ins_size**2 * 3, label_option]), name="weights") 
b = tf.Variable(tf.zeros([label_option]), name="bias") 

y = tf.cond(condition > 0, lambda: tf.matmul(x, W) + b, lambda: tf.matmul(x, W) - b) 
+1

Grazie mille per la spiegazione in dettaglio! –

+1

@mrry Entrambi i rami sono eseguiti di default? Ho tf.cond (c, lambda x: train_op1, lambda x: train_op2) ed entrambi train_ops sono eseguiti ad ogni esecuzione di cond indipendentemente dal valore di c. Sto facendo qualcosa di sbagliato? –

+5

@PiotrDabkowski Questo è un comportamento a volte sorprendente di 'tf.cond()', che viene toccato [nei documenti] (https://www.tensorflow.org/api_docs/python/tf/cond). In breve, è necessario creare gli ops che si desidera eseguire condizionatamente * all'interno * dei rispettivi lambda. Tutto ciò che crei al di fuori dei lambda, ma si riferisce a entrambi i rami, verrà eseguito in entrambi i casi. – mrry

Problemi correlati