2016-02-02 20 views
7

Stavo cercando di creare un grafico autoincrementante in TensorFlow. Ho pensato che l'op assign potrebbe essere adatto a questo, ma non ho trovato documentazione per questo.Assegna op in TensorFlow: qual è il valore di ritorno?

ho pensato che questo op restituisce il suo valore — come nei linguaggi C-like — e ha scritto il seguente codice:

import tensorflow as tf 

counter = tf.Variable(0, name="counter") 

one = tf.constant(1) 
ten = tf.constant(10) 

new_counter = tf.add(counter, one) 
assign = tf.assign(counter, new_counter) 
result = tf.add(assign, ten) 

init_op = tf.initialize_all_variables() 

with tf.Session() as sess: 

    sess.run(init_op) 

    for _ in range(3): 

    print sess.run(result) 

e questo codice funziona.

La domanda è: è questo il comportamento previsto? Perché l'op di assegnazione non è documentata qui: https://www.tensorflow.org/versions/0.6.0/api_docs/index.html

È un'opzione non consigliata?

risposta

13

L'operatore tf.assign() è il meccanismo sottostante che implementa il metodo Variable.assign(). È necessario un tensore mutabile (con tipo tf.*_ref) e un nuovo valore, e restituisce un tensore mutabile che è stato aggiornato con il nuovo valore. Il valore restituito viene fornito per semplificare l'ordine di un compito prima di una lettura successiva, ma questa funzione non è ben documentata. Un esempio si spera illustrare:

v = tf.Variable(0) 
new_v = v.assign(10) 
output = v + 5 # `v` is evaluated before or after the assignment. 

sess.run(v.initializer) 

result, _ = sess.run([output, new_v.op]) 
print result # ==> 10 or 15, depending on the order of execution. 

v = tf.Variable(0) 
new_v = v.assign(10) 
output = new_v + 5 # `new_v` is evaluated after the assignment. 

sess.run(v.initializer) 

result = sess.run([output]) 
print result # ==> 15 

Nel esempio di codice le dipendenze dataflow far rispettare l'ordine di esecuzione [read counter] -> new_counter = tf.add(...) -> tf.assign(...) -> [read output of assign] -> result = tf.add(...), il che significa che la semantica è ambigua. Tuttavia,, i passaggi di lettura-modifica-scrittura per aggiornare il contatore sono alquanto inefficienti e possono comportare un comportamento imprevisto quando sono in esecuzione più passaggi contemporaneamente. Ad esempio, più thread che accedono alla stessa variabile potrebbero osservare il contatore che si sposta all'indietro (nel caso in cui un valore precedente sia stato riscritto dopo un nuovo valore).

mi consiglia di utilizzare Variable.assign_add() per aggiornare il contatore, come segue:

counter = tf.Variable(0, name="counter") 

one = tf.constant(1) 
ten = tf.constant(10) 

# assign_add ensures that the counter always moves forward. 
updated_counter = counter.assign_add(one, use_locking=True) 

result = tf.add(updated_counter, ten) 
# ... 
+3

Eseguiamo il primo snippet di codice e abbiamo solo osservato che l'output può essere solo 5 o 15. –

4

tf.assign() è ben documented in the latest versions ed è usato frequentemente nei progetti.

Questa operazione restituisce "ref" dopo l'assegnazione. Questo rende più semplice per le operazioni a catena che devono utilizzare il valore di reset.

In parole semplici, ci vuole il tensore originale e un nuovo tensore. Aggiorna il valore originale del tuo tensore con un nuovo valore e restituisce il riferimento del tuo tensore originale.

Problemi correlati