2015-11-17 14 views
47

Quando si lavora con il grafico globale predefinito, è possibile rimuovere i nodi dopo che sono stati aggiunti o, in alternativa, ripristinare il grafico predefinito da svuotare? Quando lavoro con TF in modo interattivo in IPython, mi ritrovo a dover riavviare il kernel ripetutamente. Mi piacerebbe essere in grado di sperimentare con i grafici più facilmente se possibile.Rimuovere i nodi dal grafico o ripristinare l'intero grafico predefinito

risposta

71

aggiornamento 11/2/2016

tf.reset_default_graph()

roba vecchia

C'è reset_default_graph, ma non fa parte della API pubblica (penso che dovrebbe essere, non qualcuno vuole file an issue su GitHub?)

La mia soluzione per ripristinare le cose è questa:

from tensorflow.python.framework import ops 
ops.reset_default_graph() 
sess = tf.InteractiveSession() 
30

Per impostazione predefinita, una sessione viene costruita attorno al grafico predefinito. Per evitare di lasciare nodi morti nella sessione, è necessario controllare il grafico predefinito o utilizzare un grafico esplicito.

  • Per cancellare il grafico di default, è possibile utilizzare la funzione di tf.reset_default_graph.

    tf.reset_default_graph() 
    sess = tf.InteractiveSession() 
    
  • È anche possibile costruire esplicitamente un grafico ed evitare di utilizzare quello predefinito. Se si utilizza un normale Session, sarà necessario creare completamente il grafico prima di costruire la sessione. Per InteractiveSession, si può semplicemente dichiarare il grafico e utilizzarlo come un contesto di dichiarare ulteriori cambiamenti:

    g = tf.Graph() 
    sess = tf.InteractiveSession(graph=g) 
    with g.asdefault(): 
        # Put variable declaration and other tf operation 
        # in the graph context 
        .... 
        b = tf.matmul(A, x) 
        .... 
    
    sess.run([b], ...) 
    

EDIT: Per le versioni recenti di tensorflow (1.0+), la funzione corretta è g.as_default.

+2

In tensorflow> = 1.0, è 'g.as_default()' –

+0

Ortografia: Nell'istruzione with hai dimenticato il carattere di sottolineatura in g.as_default() – user3750988

+0

così felice di aver trovato questa risposta. mi sta facendo impazzire ... –

3

Le celle di notebook IPython/Jupyter mantengono lo stato tra le corse di una cella.

Creare un grafico personalizzato:

def main(): 
    # Define your model 
    data = tf.placeholder(...) 
    model = ... 

with tf.Graph().as_default(): 
    main() 

ha funzionato una volta, il grafico viene ripulito.

Problemi correlati