Ho un tensore logits
con le dimensioni [batch_size, num_rows, num_coordinates]
(vale a dire che ciascun logit nel batch è una matrice). Nel mio caso la dimensione del lotto è 2, ci sono 4 righe e 4 coordinate.Come selezionare le righe da un tensore 3D in TensorFlow?
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0],
[12.0, 10.0, 10.0, 20.0],
[13.0, 10.0, 10.0, 20.0]],
[[14.0, 11.0, 21.0, 31.0],
[15.0, 11.0, 11.0, 21.0],
[16.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
Desidero selezionare la prima e la seconda riga del primo lotto e la seconda e la quarta riga del secondo lotto.
indices = tf.constant([[0, 1], [1, 3]])
Così l'output desiderato sarebbe
logits = tf.constant([[[10.0, 10.0, 20.0, 20.0],
[11.0, 10.0, 10.0, 30.0]],
[[15.0, 11.0, 11.0, 21.0],
[17.0, 11.0, 11.0, 21.0]]])
Come faccio a fare questo usando tensorflow? Ho provato a utilizzare tf.gather(logits, indices)
ma non ha restituito quello che mi aspettavo. Grazie!
Mentre la tua risposta è ottima, penso che oggi possa essere sostituita con 'tf.gather_nd', che probabilmente non era ancora disponibile al momento della scrittura (vedi la mia risposta) – kaufmanu