2016-04-03 36 views
5

Dire che ho un tensore di dimensioni BxWxHxD. Voglio elaborare il tensore in modo tale da avere un nuovo tensore BxWxHxD in cui viene mantenuto solo l'elemento massimo in ogni slice WxH e tutti gli altri valori sono zero. In altre parole, penso che il modo migliore per ottenere questo risultato sia in qualche modo prendere un argmax 2D attraverso le fette WxH, risultando in tensori di indice BxD per le righe e le colonne che possono poi essere convertite in un tensore BxWxHxD unico caldo essere usato come maschera Come faccio a fare questo lavoro?Tensorflow multidimensionale argmax

risposta

1

È possibile utilizzare la seguente funzione come punto di partenza. Calcola gli indici dell'elemento massimo per ciascun batch e per ciascun canale. La matrice risultante è nel formato (dimensione batch, 2, numero di canali).

def argmax_2d(tensor): 

    # input format: BxHxWxD 
    assert rank(tensor) == 4 

    # flatten the Tensor along the height and width axes 
    flat_tensor = tf.reshape(tensor, (tf.shape(tensor)[0], -1, tf.shape(tensor)[3])) 

    # argmax of the flat tensor 
    argmax = tf.cast(tf.argmax(flat_tensor, axis=1), tf.int32) 

    # convert indexes into 2D coordinates 
    argmax_x = argmax // tf.shape(tensor)[2] 
    argmax_y = argmax % tf.shape(tensor)[2] 

    # stack and return 2D coordinates 
    return tf.stack((argmax_x, argmax_y), axis=1) 

def rank(tensor): 

    # return the rank of a Tensor 
    return len(tensor.get_shape())