2016-03-07 8 views
5

Sto addestrando una cella LSTM su lotti di sequenze che hanno lunghezze diverse. Il tf.nn.rnn ha il parametro molto conveniente sequence_length, ma dopo averlo chiamato, non so come selezionare le righe di output corrispondenti all'ultimo passaggio temporale di ciascun elemento nel batch.Come scegliere gli ultimi valori di uscita validi da tensorflow RNN

Il mio codice è sostanzialmente la seguente:

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size) 
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths) 

lstm_outputs è una lista con l'uscita LSTM ad ogni passo temporale. Tuttavia, ogni elemento del mio batch ha una lunghezza diversa, quindi vorrei creare un tensore contenente l'ultima uscita LSTM valida per ciascun articolo nel mio batch.

Se potessi utilizzare l'indicizzazione NumPy, vorrei solo fare qualcosa di simile:

all_outputs = tf.pack(lstm_outputs) 
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :] 

ma si scopre che per il momento iniziare tensorflow non lo supporta (io sono a conoscenza del feature request) .

Quindi, come posso ottenere questi valori?

risposta

5

Una soluzione più accettabile è stata pubblicata da danijar sulla pagina di richiesta di funzionalità che ho collegato nella domanda. Non ha bisogno di valutare i tensori, che è un grande vantaggio.

Ho capito che funziona con tensorflow 0.8. Ecco il codice:

def extract_last_relevant(outputs, length): 
    """ 
    Args: 
     outputs: [Tensor(batch_size, output_neurons)]: A list containing the output 
      activations of each in the batch for each time step as returned by 
      tensorflow.models.rnn.rnn. 
     length: Tensor(batch_size): The used sequence length of each example in the 
      batch with all later time steps being zeros. Should be of type tf.int32. 

    Returns: 
     Tensor(batch_size, output_neurons): The last relevant output activation for 
      each example in the batch. 
    """ 
    output = tf.transpose(tf.pack(outputs), perm=[1, 0, 2]) 
    # Query shape. 
    batch_size = tf.shape(output)[0] 
    max_length = int(output.get_shape()[1]) 
    num_neurons = int(output.get_shape()[2]) 
    # Index into flattened array as a workaround. 
    index = tf.range(0, batch_size) * max_length + (length - 1) 
    flat = tf.reshape(output, [-1, num_neurons]) 
    relevant = tf.gather(flat, index) 
    return relevant 
2

Non è la soluzione migliore ma è possibile valutare le uscite, quindi utilizzare solo l'indicizzazione numpy per ottenere i risultati e creare una variabile tensoriale? Potrebbe funzionare come un intervallo di interruzione fino a quando tensorflow ottiene questa caratteristica. per esempio.

all_outputs = session.run(lstm_outputs, feed_dict={'your inputs'}) 
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :] 
use_this_as_an_input_to_new_tensorflow_op = tf.constant(last_outputs) 
+0

Sì, sarebbe sicuramente non essere la soluzione più bella. Ma al momento non riesco a vedere in nessun altro modo. – erickrf

+1

Ci sono modi migliori ora? – Zhao

1

se siete interessati solo l'ultima uscita valida è possibile recuperare attraverso lo stato restituito da tf.nn.rnn() considerando che è sempre una tupla (c, h) dove c è l'ultimo stato e h è l'ultimo risultato. Quando lo stato è un LSTMStateTuple è possibile utilizzare il seguente frammento (che lavorano in tensorflow 0,12):

lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size) 
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths) 
last_output = state[1] 
+0

Questo è molto più semplice della risposta accettata a partire da tensorflow 0.12 e successive – erickrf

Problemi correlati