2016-02-02 40 views
26

Questa domanda riguarda l'accesso ai singoli elementi di un tensore, ad esempio [[1,2,3]]. Ho bisogno di accedere all'elemento interno [1,2,3] (Questo può essere eseguito usando .eval() o sess.run()) ma richiede più tempo quando la dimensione del tensore è enorme)Python tensorflow: accesso a singoli elementi in un tensore

È lì qualche metodo per fare lo stesso più velocemente?

Grazie in anticipo.

risposta

0

Ho il sospetto che sia il resto del calcolo che richiede tempo, piuttosto che accedere a un elemento.

Anche il risultato potrebbe richiedere una copia da qualsiasi memoria sia archiviata, quindi se è presente sulla scheda grafica, prima sarà necessario ricopiarlo in RAM e quindi accedere al proprio elemento. In tal caso, puoi saltarlo aggiungendo un'operazione di tensorflow per prendere il primo elemento e restituirlo solo.

36

Esistono due modi principali per accedere ai sottoinsiemi degli elementi in un tensore, che dovrebbero funzionare per l'esempio.

  1. utilizzare l'operatore di indicizzazione (basato su tf.slice()) per estrarre una porzione contigua del tensore.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 
    
    output = input[0, :] 
    print sess.run(output) # ==> [1 2 3] 
    

    L'operatore di indicizzazione supporta molte delle stesse specifiche di slice di NumPy.

  2. Utilizzare tf.gather() op per selezionare una sezione non contigua dal tensore.

    input = tf.constant([[1, 2, 3], [4, 5, 6], [7, 8, 9]]) 
    
    output = tf.gather(input, 0) 
    print sess.run(output) # ==> [1 2 3] 
    
    output = tf.gather(input, [0, 2]) 
    print sess.run(output) # ==> [[1 2 3] [7 8 9]] 
    

    noti che tf.gather() permette di selezionare solo le fette intere nella dimensione 0th (intere righe nell'esempio di matrice), quindi potrebbe essere necessario tf.reshape() o tf.transpose() l'input per ottenere gli elementi appropriati.

+1

"... quindi potrebbe essere necessario tf.reshape() o tf.transpose() l'input per ottenere gli elementi appropriati." -> o usa 'tf.gather_nd'? –

1

Semplicemente non è possibile ottenere il valore dell'elemento 0a di [[1,2,3]] senza run() - ning o eval() - ing un'operazione che sarebbe ottenerlo. Perché prima di 'correre' o 'eval', hai solo una descrizione di come ottenere questo elemento interno (perché TF utilizza grafici/calcoli simbolici). Quindi, anche se dovessi usare tf.gather/tf.slice, dovresti comunque ottenere i valori tramite eval/run. Vedi la risposta di @ mrry.

Problemi correlati