2016-01-26 13 views
16

Ho una configurazione in cui ho bisogno di inizializzare un LSTM dopo l'inizializzazione principale che utilizza tf.initialize_all_variables(). Cioè Voglio chiamare tf.initialize_variables([var_list])Tensorflow: come ottenere tutte le variabili da rnn_cell.BasicLSTM & rnn_cell.MultiRNNCell

Is tutto per raccogliere tutte le variabili interne addestrabili per entrambi:

  • rnn_cell.BasicLSTM
  • rnn_cell.MultiRNNCell

modo che io possa inizializzare SOLO questi parametri?

Il motivo principale per cui voglio questo è perché non voglio re-inizializzare alcuni valori addestrati di prima.

risposta

17

Il modo più semplice per risolvere il problema è utilizzare l'ambito variabile. I nomi delle variabili all'interno di uno scope saranno preceduti dal suo nome. Ecco un breve frammento:

cell = rnn_cell.BasicLSTMCell(num_nodes) 

with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    # Retrieve just the LSTM variables. 
    lstm_variables = [v for v in tf.all_variables() 
        if v.name.startswith(vs.name)] 

# [..] 
# Initialize the LSTM variables. 
tf.initialize_variables(lstm_variables) 

Funzionerebbe allo stesso modo con MultiRNNCell.

EDIT: cambiato tf.trainable_variables per tf.all_variables()

+0

Questo è perfetto, grazie. Non ho capito che 'tf.trainable_variables()' rispetta lo scope, ma immagino che con il senno di poi abbia senso! – bge0

+1

Vorrei aggiungere che 'tf.all_variables()' invece di 'tf.trainable_variables()' sarebbe una scelta migliore. Principalmente perché ci sono cose come gli ottimizzatori che non hanno variabili addestrabili, che tuttavia dovrebbero comunque essere inizializzate. – bge0

+1

Grazie, hai ragione. Ho aggiornato il codice. –

11

È inoltre possibile utilizzare tf.get_collection():

cell = rnn_cell.BasicLSTMCell(num_nodes) 
with tf.variable_scope("LSTM") as vs: 
    # Execute the LSTM cell here in any way, for example: 
    for i in range(num_steps): 
    output[i], state = cell(input_data[i], state) 

    lstm_variables = tf.get_collection(tf.GraphKeys.VARIABLES, scope=vs.name) 

(in parte copiato da risposta di Rafal)

Si noti che l'ultima riga è equivalente alla comprensione lista in Il codice di Rafal

Fondamentalmente, tensorflow memorizza una raccolta globale di variabili, che può essere recuperata da tf.all_variables() o da tf.get_collection(tf.GraphKeys.VARIABLES). Se si specifica scope (nome ambito) nella funzione tf.get_collection(), si recuperano solo tensori (variabili in questo caso) nella raccolta i cui ambiti sono nell'ambito specificato.

MODIFICA: È anche possibile utilizzare tf.GraphKeys.TRAINABLE_VARIABLES per ottenere solo variabili trainabili. Ma poiché vanilla BasicLSTMCell non inizializza alcuna variabile non addestrabile, entrambi saranno funzionalmente equivalenti. Per un elenco completo delle raccolte di grafici predefinite, selezionare this out.

+0

Questo è il modo migliore della soluzione di Rafal :-) –

+1

Proprio come ho commentato sopra, forse dovresti usare meglio ' tf.get_collection (..., scope = vs.name + "/") 'perché potrebbe esserci un altro ambito denominato" LSTM2 "o così. – Albert

Problemi correlati