2016-06-12 40 views
5

Credo di avere difficoltà a capire come funzionano i grafici in tensorflow e come accedervi. La mia intuizione è che le linee sotto 'con grafico:' formeranno il grafico come una singola entità. Quindi, ho deciso di creare una classe che avrebbe costruito un grafico quando istanziato e avrebbe una funzione che avrebbe eseguito il grafico, come segue;Tensorflow: creazione di un grafico in una classe e esecuzione in ouside

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      prediction = ... 
      cost  = ... 
      optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(optimizer, feed_dict) 
      loss = sess.run(cost, feed_dict) 
      ... 
     return variables 

I passi successivi sono per creare un file principale che assemblare i parametri da passare alla classe, per costruire il grafico e poi di farlo funzionare;

#Main file 
... 
parameters_dict = { 'n_input': 28, 'learnRate': 0.001, ... } 

#Building graph 
G = Graph(parameters_dict) 
P = G.launchG(Input) 
... 

questo è molto elegante per me, ma non funziona del tutto (ovviamente). In effetti, sembra che le funzioni launchG non abbiano accesso ai nodi definiti nel grafico, che mi danno errore come;

---> 26 sess.run(optimizer, feed_dict) 

NameError: name 'optimizer' is not defined 

Forse è il mio pitone (e tensorflow) la comprensione che è troppo limitato, ma ero sotto la strana impressione che con il grafico (G) ha creato, in esecuzione la sessione con questo grafico come un argomento dovrebbe dare accesso ai nodi in esso, senza che io debba dare un accesso esplicito.

Qualsiasi illuminazione?

risposta

7

I nodi prediction, cost, e optimizer sono variabili locali creati nel metodo __init__, non è possibile accedere nel metodo launchG.

La soluzione più semplice sarebbe quella di dichiararli come attributi della vostra classe Graph:

class Graph(object): 

    #To build the graph when instantiated 
    def __init__(self, parameters): 
     self.graph = tf.Graph() 
     with self.graph.as_default(): 
      ... 
      self.prediction = ... 
      self.cost  = ... 
      self.optimizer = ... 
      ... 
    # To launch the graph 
    def launchG(self, inputs): 
     with tf.Session(graph=self.graph) as sess: 
      ... 
      sess.run(self.optimizer, feed_dict) 
      loss = sess.run(self.cost, feed_dict) 
      ... 
     return variables 

È anche possibile recuperare i nodi del grafo usando il loro nome esatto con graph.get_tensor_by_name e graph.get_operation_by_name.

Problemi correlati