2016-05-06 4 views
11

Ho bisogno di un flusso di controllo condizionale nel mio grafico. Se pred è True, il grafico dovrebbe chiamare un op che aggiorna una variabile e quindi la restituisce, altrimenti restituisce la variabile invariata. Una versione semplificata è:Confuso dal comportamento di `tf.cond`

pred = tf.constant(True) 
x = tf.Variable([1]) 
assign_x_2 = tf.assign(x, [2]) 
def update_x_2(): 
    with tf.control_dependencies([assign_x_2]): 
    return tf.identity(x) 
y = tf.cond(pred, update_x_2, lambda: tf.identity(x)) 
with tf.Session() as session: 
    session.run(tf.initialize_all_variables()) 
    print(y.eval()) 

Tuttavia, ritengo che sia pred=True e pred=False conducono allo stesso risultato y=[2], che significa che l'op assegnazione viene chiamato anche quando update_x_2 non è selezionata per tf.cond. Come spiegarlo? E come risolvere questo problema?

risposta

17

TL; DR: Se si desidera tf.cond() eseguire un effetto collaterale (come una cessione) in uno dei rami, è necessario creare l'op che esegue l'effetto collaterale all'interno la funzione che si passa a tf.cond() .

Il comportamento di tf.cond() è un po 'non intuitivo. Poiché l'esecuzione in un grafico TensorFlow scorre in avanti nel grafico, tutte le operazioni a cui si fa riferimento nel ramo devono essere eseguite prima che il condizionale venga valutato. Ciò significa che sia il ramo vero che quello falso ricevono una dipendenza di controllo dall'opzione tf.assign(), e quindi , anche se pred is False.

La soluzione è creare l'op tf.assign() all'interno della funzione che definisce il ramo vero. Ad esempio, è possibile strutturare il codice come segue:

pred = tf.placeholder(tf.bool, shape=[]) 
x = tf.Variable([1]) 
def update_x_2(): 
    with tf.control_dependencies([tf.assign(x, [2])]): 
    return tf.identity(x) 
y = tf.cond(pred, update_x_2, lambda: tf.identity(x)) 
with tf.Session() as session: 
    session.run(tf.initialize_all_variables()) 
    print(y.eval(feed_dict={pred: False})) # ==> [1] 
    print(y.eval(feed_dict={pred: True})) # ==> [2] 
+0

Sì, questo è quello che mi confonde anche. Capisco che prima di eseguire 'tf.cond', il runtime si assicura che tutte le dipendenze siano eseguite. Le dipendenze delle operazioni in 'True' e' False' sono anche dipendenze da 'cond', quindi anche se un op in un ramo non può mai essere eseguito, tutte le dipendenze sono eseguite, non è corretto? –

+1

Sì, l'eliminazione del grafico considera tutte le potenziali dipendenze (di uno dei due rami) per l'esecuzione e inibisce solo la loro esecuzione se fossero definite all'interno di uno dei rami, perché "CondContext" [aggiunge una dipendenza di controllo sul pivot] (https: //github.com/tensorflow/tensorflow/blob/2b2f312cb07765c628d264abe326bfc286f462c1/tensorflow/python/ops/control_flow_ops.py#L1092) e tale dipendenza sarà un tensore morto (impedendo l'esecuzione dell'opzione) se non si trova nel ramo. – mrry

+0

Qual è stato il ragionamento in questo modo? Perché non potare il sottografo dietro il ramo non attivo? –