Sto addestrando una cella LSTM su lotti di sequenze che hanno lunghezze diverse. Il tf.nn.rnn
ha il parametro molto conveniente sequence_length
, ma dopo averlo chiamato, non so come selezionare le righe di output corrispondenti all'ultimo passaggio temporale di ciascun elemento nel batch.Come scegliere gli ultimi valori di uscita validi da tensorflow RNN
Il mio codice è sostanzialmente la seguente:
lstm_cell = tf.nn.rnn_cell.LSTMCell(num_lstm_units, input_size)
lstm_outputs, state = tf.nn.rnn(lstm_cell, input_list, dtype=tf.float32, sequence_length=sequence_lengths)
lstm_outputs
è una lista con l'uscita LSTM ad ogni passo temporale. Tuttavia, ogni elemento del mio batch ha una lunghezza diversa, quindi vorrei creare un tensore contenente l'ultima uscita LSTM valida per ciascun articolo nel mio batch.
Se potessi utilizzare l'indicizzazione NumPy, vorrei solo fare qualcosa di simile:
all_outputs = tf.pack(lstm_outputs)
last_outputs = all_outputs[sequence_lengths, tf.range(batch_size), :]
ma si scopre che per il momento iniziare tensorflow non lo supporta (io sono a conoscenza del feature request) .
Quindi, come posso ottenere questi valori?
Sì, sarebbe sicuramente non essere la soluzione più bella. Ma al momento non riesco a vedere in nessun altro modo. – erickrf
Ci sono modi migliori ora? – Zhao