2016-06-10 24 views
8

Si prega di scusare l'ampiezza di questa domanda. Forse una volta che ne so di più forse potrei chiederlo in modo più specifico.strategia di ottimizzazione del codice tensorflow

Ho un pezzo di codice tensorflow sensibile alle prestazioni. Dal punto di vista di qualcuno che sa poco della programmazione gpu, mi piacerebbe sapere quali guide o strategie sarebbero un "buon punto di partenza" per ottimizzare il mio codice. (GPU singola)

Forse anche una lettura di quanto tempo è stato speso per ciascun PO tensorflow sarebbe bello ...

Ho una conoscenza vaga che

  • Alcune operazioni andare più veloce quando viene assegnato a una CPU anziché una gpu, ma non è chiaro quale sia
  • C'è una parte del software google chiamato "EEG" che ho letto su un documento
    che potrebbe essere un giorno aperto.

Ci possono essere anche altri fattori comuni in gioco che non sono a conoscenza di ..

+0

problema simile [qui] (http://stackoverflow.com/questions/36439483/how-to-get-the-time-consumed-to-execute-each- nodo-in-tensorflow-grafico). Fondamentalmente si passano specifiche opzioni a 'sess.run()' e si usa un oggetto TimeLine –

+0

Ahh un oggetto TimeLine ... questo è grosso modo quello che stavo andando per – user3391229

risposta

17

ho voluto dare una risposta più completa su come utilizzare il Timeline oggetto per ottenere il tempo di esecuzione per ogni nodo del grafo:

  • si utilizza un classico sess.run() ma specificando argomenti options e run_metadata
  • quindi si crea un oggetto Timeline con i dati run_metadata.step_stats

Ecco nel codice di esempio:

import tensorflow as tf 
from tensorflow.python.client import timeline 

x = tf.random_normal([1000, 1000]) 
y = tf.random_normal([1000, 1000]) 
res = tf.matmul(x, y) 

# Run the graph with full trace option 
with tf.Session() as sess: 
    run_options = tf.RunOptions(trace_level=tf.RunOptions.FULL_TRACE) 
    run_metadata = tf.RunMetadata() 
    sess.run(res, options=run_options, run_metadata=run_metadata) 

    # Create the Timeline object, and write it to a json 
    tl = timeline.Timeline(run_metadata.step_stats) 
    ctf = tl.generate_chrome_trace_format() 
    with open('timeline.json', 'w') as f: 
     f.write(ctf) 

È quindi possibile aprire Google Chrome, vai alla pagina chrome://tracing e caricare il file timeline.json. Si dovrebbe qualcosa di simile:

timeline

+0

Yup, questo è davvero molto utile. Assicurati di non impostare 'FULL_TRACE' ogni volta che chiami' sess.run' o rallenterai il tuo allenamento. Di solito lo chiamo ogni 100-1k passi. – Nova

Problemi correlati