Qualcuno potrebbe chiarire se lo stato iniziale dell'RNN in TF viene ripristinato per i mini-lotti successivi, oppure viene utilizzato l'ultimo stato del mini-lotto precedente come indicato in Ilya Sutskever et al., ICLR 2015?È stato ripristinato lo stato iniziale RNN per i mini-lotti successivi?
risposta
Le operazioni tf.nn.dynamic_rnn()
o tf.nn.rnn()
consentono di specificare lo stato iniziale dell'RNN utilizzando il parametro initial_state
. Se non si specifica questo parametro, gli stati nascosti verranno inizializzati su zero vettori all'inizio di ciascun batch di addestramento.
In TensorFlow, è possibile avvolgere i tensori in tf.Variable()
per mantenere i loro valori nel grafico tra più sessioni di sessione. Assicurati di contrassegnarli come non addestrabili perché gli ottimizzatori regolano tutte le variabili trainabili per impostazione predefinita.
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell = tf.nn.rnn_cell.GRUCell(256)
state = tf.Variable(cell.zero_states(batch_size, tf.float32), trainable=False)
output, new_state = tf.nn.dynamic_rnn(cell, data, initial_state=state)
with tf.control_dependencies([state.assign(new_state)]):
output = tf.identity(output)
sess = tf.Session()
sess.run(tf.initialize_all_variables())
sess.run(output, {data: ...})
Non ho testato questo codice ma dovrebbe darvi un suggerimento nella giusta direzione. C'è anche un tf.nn.state_saving_rnn()
a cui puoi fornire un oggetto salvaschermo, ma non l'ho ancora usato.
Oltre alla risposta di danijar, ecco il codice per un LSTM, il cui stato è una tupla (state_is_tuple=True
). Supporta anche più livelli.
Definiamo due funzioni: una per ottenere le variabili di stato con uno stato iniziale zero e una funzione per restituire un'operazione, che possiamo passare a session.run
per aggiornare le variabili di stato con l'ultimo stato nascosto dell'LSTM. risposta
def get_state_variables(batch_size, cell):
# For each layer, get the initial state and make a variable out of it
# to enable updating its value.
state_variables = []
for state_c, state_h in cell.zero_state(batch_size, tf.float32):
state_variables.append(tf.contrib.rnn.LSTMStateTuple(
tf.Variable(state_c, trainable=False),
tf.Variable(state_h, trainable=False)))
# Return as a tuple, so that it can be fed to dynamic_rnn as an initial state
return tuple(state_variables)
def get_state_update_op(state_variables, new_states):
# Add an operation to update the train states with the last state tensors
update_ops = []
for state_variable, new_state in zip(state_variables, new_states):
# Assign the new state to the state variables on this layer
update_ops.extend([state_variable[0].assign(new_state[0]),
state_variable[1].assign(new_state[1])])
# Return a tuple in order to combine all update_ops into a single operation.
# The tuple's actual value should not be used.
return tf.tuple(update_ops)
simili di danijar, possiamo usare che per aggiornare lo stato del LSTM dopo ogni batch:
data = tf.placeholder(tf.float32, (batch_size, max_length, frame_size))
cell_layer = tf.contrib.rnn.GRUCell(256)
cell = tf.contrib.rnn.MultiRNNCell([cell_layer] * num_layers)
# For each layer, get the initial state. states will be a tuple of LSTMStateTuples.
states = get_state_variables(batch_size, cell)
# Unroll the LSTM
outputs, new_states = tf.nn.dynamic_rnn(cell, data, initial_state=states)
# Add an operation to update the train states with the last state tensors.
update_op = get_state_update_op(states, new_states)
sess = tf.Session()
sess.run(tf.global_variables_initializer())
sess.run([outputs, update_op], {data: ...})
La differenza principale è che state_is_tuple=True
rende lo stato del LSTM un LSTMStateTuple contenente due variabili (stato cellule e stato nascosto) invece di una singola variabile. L'utilizzo di più livelli rende quindi lo stato di LSTM una tupla di LSTMStateTuples, uno per livello.
- 1. come impostare lo stato iniziale in redux
- 2. Redux React crea lo stato iniziale dall'API
- 3. Quale stato viene perso/ripristinato dopo glUseProgram?
- 4. Esegui codice non appena il modulo è stato ripristinato
- 5. Come fa 'ottenere' effettivamente/ottenere/lo stato iniziale in Haskell?
- 6. Dove recuperare lo stato iniziale quando si usa flummox (flusso)?
- 7. Lo stato esistente dei pacchetti è stato scartato
- 8. Java. System.exit (stato int). Valore per lo stato di uscita
- 9. Perché lo stato ColorAnimation per lo stato selezionato non è persistente dopo l'attivazione dello stato di MouseOver?
- 10. Stato iniziale per un processo di flusso di dati
- 11. Qual è lo stato corrente di PocoCapsule?
- 12. lo spostamento interno non è stato corretto
- 13. Qual è lo stato di Spring.Net?
- 14. Qual è lo stato di CAT.NET?
- 15. Qual è lo stato di JMX 2.0?
- 16. Qual è lo stato di PHPDoc?
- 17. Lo stato di qualsiasi classe standard dopo lo spostamento è stato specificato?
- 18. Qual è lo stato della lingua Javascript?
- 19. Qual è lo stato corrente del biocode?
- 20. I controlli di Youtube (riproduzione/pausa) non funzionano dopo che il lettore è stato ripristinato/l'attività è ripresa
- 21. Qual è il modo migliore, usando lo schema di progettazione "Stato", per cambiare stato?
- 22. Come si pulisce lo stato di Redux?
- 23. Controllare lo stato di AutoResetEvent
- 24. : Ottieni lo stato del provider
- 25. Distribuito lo stato dell'arte Haskell nel 2011?
- 26. Rcpp imposta lo stato RNG su uno stato precedente
- 27. Come resettare lo stato dell'animazione?
- 28. React + Flux: acquisizione dello stato iniziale in un negozio
- 29. Stato ordine Magento Stato vs Stato
- 30. Lo stato di GMaps v3
Nota come lo fai tu crei num_layers _identical_ cells che non è quello che vuoi fare probabilmente –