2015-11-13 24 views
5

Esiste un modo per estrarre la diagonale di una matrice quadrata in TensorFlow? Cioè, per una matrice simile a questo:Ottieni la diagonale di una matrice in TensorFlow

[ 
[0, 1, 2], 
[3, 4, 5], 
[6, 7, 8] 
] 

Voglio andare a prendere gli elementi: [0, 4, 8]

In NumPy, questo è abbastanza straight-forward tramite np.diag:

In tensorflow, c'è un diag function, ma forma solo una nuova matrice con gli elementi specificati nell'argomento sulla diagonale, che non è quello che voglio.

Potrei immaginare come questo potrebbe essere fatto via a passo falso ... ma non vedo il progresso per i tensori in TensorFlow.

risposta

8

con tensorflow 0.8 la sua possibile estrarre il elementi diagonali con tf.diag_part() (vedi documentation

2

Questa soluzione è probabilmente una soluzione, ma funziona.

>> sess = tensorflow.InteractiveSession() 
>> x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]]) 
>> x.initializer.run() 
>> z = tensorflow.pack([x[i,i] for i in range(3)]) 
>> z.eval() 
array([1, 5, 9], dtype=int32) 
+0

Purtroppo questa operazione può essere tremendamente lento, non so perché, però. – Literal

2

Attualmente è possibile estrarre elementi diagonali con tf.diag_part. Ecco il loro esempio:

# 'input' is [[1, 0, 0, 0] 
       [0, 2, 0, 0] 
       [0, 0, 3, 0] 
       [0, 0, 0, 4]] 

tf.diag_part(input) ==> [1, 2, 3, 4] 

Vecchio risposta (quando diag_part) non era disponibile (ancora rilevante, se si vuole ottenere qualcosa che non è disponibile ora):

Dopo aver esaminato se la math operations e tensor transformations, non sembra che tale operazione esista. Anche se è possibile estrarre questi dati con moltiplicazioni di matrice, non sarebbe efficiente (ottenere diagonale è O(n)).

Hai tre approcci, a partire da facile a difficile.

  1. Valutare il tensore, estratto di diagonale con NumPy, costruire una variabile con TF
  2. Usa tf.pack in modo Anurag ha suggerito (anche estrarre il valore 3 usando tf.shape
  3. Scrivi la tua op in C++, ricostruire TF e l'uso in modo nativo.
+0

Suoni ragionevoli. Grazie! – theaNO

0

Utilizzare l'operazione gather.

x = tensorflow.Variable([[1,2,3],[4,5,6],[7,8,9]]) 
x_flat = tf.reshape(x, [-1]) # flatten the matrix 
x_diag = tf.gather(x, [0, 3, 6]) 
0

A seconda del contesto, una maschera può essere un bel modo per `annullare' elementi fuori diagonale della matrice, soprattutto se si pensa a ridurre comunque:

mask = tf.diag(tf.ones([n])) 
y = tf.mul(mask,y) 
cost = -tf.reduce_sum(y) 
2

Utilizzare il tf.diag_part()

with tf.Session() as sess: 
    x = tf.ones(shape=[3, 3]) 
    x_diag = tf.diag_part(x) 
    print(sess.run(x_diag)) 
Problemi correlati