2016-03-18 60 views
6

Ho un tensore logits con le dimensioni [batch_size, num_rows, num_coordinates] (vale a dire che ciascun logit nel batch è una matrice). Nel mio caso la dimensione del lotto è 2, ci sono 4 righe e 4 coordinate.Come selezionare le righe da un tensore 3D in TensorFlow?

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

Desidero selezionare la prima e la seconda riga del primo lotto e la seconda e la quarta riga del secondo lotto.

indices = tf.constant([[0, 1], [1, 3]]) 

Così l'output desiderato sarebbe

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0]], 
        [[15.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

Come faccio a fare questo usando tensorflow? Ho provato a utilizzare tf.gather(logits, indices) ma non ha restituito quello che mi aspettavo. Grazie!

risposta

7

Questo è possibile in TensorFlow, ma leggermente scomodo, poiché tf.gather() attualmente funziona solo con indici unidimensionali e seleziona solo sezioni dalla dimensione 0 di un tensore. Tuttavia, è ancora possibile per risolvere il problema in modo efficace, trasformando gli argomenti in modo che possano essere passati al tf.gather():

logits = ... # [2 x 4 x 4] tensor 
indices = tf.constant([[0, 1], [1, 3]]) 

# Use tf.shape() to make this work with dynamic shapes. 
batch_size = tf.shape(logits)[0] 
rows_per_batch = tf.shape(logits)[1] 
indices_per_batch = tf.shape(indices)[1] 

# Offset to add to each row in indices. We use `tf.expand_dims()` to make 
# this broadcast appropriately. 
offset = tf.expand_dims(tf.range(0, batch_size) * rows_per_batch, 1) 

# Convert indices and logits into appropriate form for `tf.gather()`. 
flattened_indices = tf.reshape(indices + offset, [-1]) 
flattened_logits = tf.reshape(logits, tf.concat(0, [[-1], tf.shape(logits)[2:]])) 

selected_rows = tf.gather(flattened_logits, flattened_indices) 

result = tf.reshape(selected_rows, 
        tf.concat(0, [tf.pack([batch_size, indices_per_batch]), 
            tf.shape(logits)[2:]])) 

Nota che, dal momento che questo utilizza tf.reshape() e non tf.transpose(), non ha bisogno di modificare i (potenzialmente grandi) dati nel tensore logits, quindi dovrebbe essere abbastanza efficiente. risposta

+0

Mentre la tua risposta è ottima, penso che oggi possa essere sostituita con 'tf.gather_nd', che probabilmente non era ancora disponibile al momento della scrittura (vedi la mia risposta) – kaufmanu

4

di mrry è grande, ma credo che con la funzione tf.gather_nd il problema può essere risolto con molti meno righe di codice (probabilmente questa funzione non era ancora disponibile al momento della scrittura di mrry):

logits = tf.constant([[[10.0, 10.0, 20.0, 20.0], 
         [11.0, 10.0, 10.0, 30.0], 
         [12.0, 10.0, 10.0, 20.0], 
         [13.0, 10.0, 10.0, 20.0]], 
        [[14.0, 11.0, 21.0, 31.0], 
         [15.0, 11.0, 11.0, 21.0], 
         [16.0, 11.0, 11.0, 21.0], 
         [17.0, 11.0, 11.0, 21.0]]]) 

indices = tf.constant([[[0, 0], [0, 1]], [[1, 1], [1, 3]]]) 

result = tf.gather_nd(logits, indices) 
with tf.Session() as sess: 
    print(sess.run(result)) 

Questo stamperà

[[[ 10. 10. 20. 20.] 
    [ 11. 10. 10. 30.]] 

[[ 15. 11. 11. 21.] 
    [ 17. 11. 11. 21.]]] 

tf.gather_nd dovrebbe essere disponibile a partire da v0.10. Controlla this github issue per ulteriori discussioni su questo.

+0

come hai cambiato gli indici in 3d da 2d (come chiesto in questione)? – Tulsi

+0

@Tulsi Non capisco la tua domanda. Non c'è menzione degli indici 3D nella domanda, o c'è? – kaufmanu

Problemi correlati