2016-01-24 16 views
15

Sto tentando di ripristinare il modello TensorFlow. Ho seguito questo esempio: http://nasdag.github.io/blog/2016/01/19/classifying-bees-with-google-tensorflow/Ripristino del modello TensorFlow

Alla fine del codice nell'esempio ho aggiunto queste righe: sono stati creati

saver = tf.train.Saver() 
save_path = saver.save(sess, "model.ckpt") 
print("Model saved in file: %s" % save_path) 

due file: punto di controllo e model.ckpt.

In un nuovo file di python (tomas_bees_predict.py), ho questo codice:

import tensorflow as tf 

saver = tf.train.Saver() 

with tf.Session() as sess: 
    # Restore variables from disk. 
    saver.restore(sess, "model.ckpt") 
    print("Model restored.") 

Tuttavia quando eseguo il codice, ottengo questo errore:

Traceback (most recent call last): 
    File "tomas_bees_predict.py", line 3, in <module> 
    saver = tf.train.Saver() 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 705, in __init__ 
raise ValueError("No variables to save") 

ValueError: Nessuna variabile per salvare

C'è un modo per leggere il file mode.ckpt e vedere quali variabili vengono salvate? O forse qualcuno può aiutare a salvare il modello e ripristinarlo in base all'esempio sopra descritto?

EDIT 1:

penso provato a fare funzionare lo stesso codice in modo da ricreare la struttura del modello e mi è stato sempre l'errore. Penso che potrebbe essere correlato al fatto che il codice qui descritto non utilizza variabili denominate: http://nasdag.github.io/blog/2016/01/19/classifying-bees-with-google-tensorflow/

def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    return tf.Variable(initial) 

def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    return tf.Variable(initial) 

così ho fatto questo esperimento. Ho scritto due versioni del codice (con e senza variabili nominative) per salvare il modello e il codice per ripristinare il modello.

tensor_save_named_vars.py:

import tensorflow as tf 

# Create some variables. 
v1 = tf.Variable(1, name="v1") 
v2 = tf.Variable(2, name="v2") 

# Add an op to initialize the variables. 
init_op = tf.initialize_all_variables() 

# Add ops to save and restore all the variables. 
saver = tf.train.Saver() 

# Later, launch the model, initialize the variables, do some work, save the 
# variables to disk. 
with tf.Session() as sess: 
    sess.run(init_op) 
    print "v1 = ", v1.eval() 
    print "v2 = ", v2.eval() 
    # Save the variables to disk. 
    save_path = saver.save(sess, "/tmp/model.ckpt") 
    print "Model saved in file: ", save_path 

tensor_save_not_named_vars.py:

import tensorflow as tf 

# Create some variables. 
v1 = tf.Variable(1) 
v2 = tf.Variable(2) 

# Add an op to initialize the variables. 
init_op = tf.initialize_all_variables() 

# Add ops to save and restore all the variables. 
saver = tf.train.Saver() 

# Later, launch the model, initialize the variables, do some work, save the 
# variables to disk. 
with tf.Session() as sess: 
    sess.run(init_op) 
    print "v1 = ", v1.eval() 
    print "v2 = ", v2.eval() 
    # Save the variables to disk. 
    save_path = saver.save(sess, "/tmp/model.ckpt") 
    print "Model saved in file: ", save_path 

tensor_restore.py:

import tensorflow as tf 

# Create some variables. 
v1 = tf.Variable(0, name="v1") 
v2 = tf.Variable(0, name="v2") 

# Add ops to save and restore all the variables. 
saver = tf.train.Saver() 

# Later, launch the model, use the saver to restore variables from disk, and 
# do some work with the model. 
with tf.Session() as sess: 
    # Restore variables from disk. 
    saver.restore(sess, "/tmp/model.ckpt") 
    print "Model restored." 
    print "v1 = ", v1.eval() 
    print "v2 = ", v2.eval() 

Ecco che cosa ottengo quando eseguo questo codice:

$ python tensor_save_named_vars.py 

I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 4 
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 4 
v1 = 1 
v2 = 2 
Model saved in file: /tmp/model.ckpt 

$ python tensor_restore.py 

I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 4 
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 4 
Model restored. 
v1 = 1 
v2 = 2 

$ python tensor_save_not_named_vars.py 

I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 4 
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 4 
v1 = 1 
v2 = 2 
Model saved in file: /tmp/model.ckpt 

$ python tensor_restore.py 
I tensorflow/core/common_runtime/local_device.cc:40] Local device intra op parallelism threads: 4 
I tensorflow/core/common_runtime/direct_session.cc:58] Direct session inter op parallelism threads: 4 
W tensorflow/core/common_runtime/executor.cc:1076] 0x7ff953881e40 Compute status: Not found: Tensor name "v2" not found in checkpoint files /tmp/model.ckpt 
    [[Node: save/restore_slice_1 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_1/tensor_name, save/restore_slice_1/shape_and_slice)]] 
W tensorflow/core/common_runtime/executor.cc:1076] 0x7ff953881e40 Compute status: Not found: Tensor name "v1" not found in checkpoint files /tmp/model.ckpt 
    [[Node: save/restore_slice = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice/tensor_name, save/restore_slice/shape_and_slice)]] 
Traceback (most recent call last): 
    File "tensor_restore.py", line 14, in <module> 
    saver.restore(sess, "/tmp/model.ckpt") 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 891, in restore 
    sess.run([self._restore_op_name], {self._filename_tensor_name: save_path}) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 368, in run 
    results = self._do_run(target_list, unique_fetch_targets, feed_dict_string) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/client/session.py", line 444, in _do_run 
    e.code) 
tensorflow.python.framework.errors.NotFoundError: Tensor name "v2" not found in checkpoint files /tmp/model.ckpt 
    [[Node: save/restore_slice_1 = RestoreSlice[dt=DT_INT32, preferred_shard=-1, _device="/job:localhost/replica:0/task:0/cpu:0"](_recv_save/Const_0, save/restore_slice_1/tensor_name, save/restore_slice_1/shape_and_slice)]] 
Caused by op u'save/restore_slice_1', defined at: 
    File "tensor_restore.py", line 8, in <module> 
    saver = tf.train.Saver() 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 713, in __init__ 
    restore_sequentially=restore_sequentially) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 432, in build 
    filename_tensor, vars_to_save, restore_sequentially, reshape) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 191, in _AddRestoreOps 
    values = self.restore_op(filename_tensor, vs, preferred_shard) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/training/saver.py", line 106, in restore_op 
    preferred_shard=preferred_shard) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/io_ops.py", line 189, in _restore_slice 
    preferred_shard, name=name) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/gen_io_ops.py", line 271, in _restore_slice 
    preferred_shard=preferred_shard, name=name) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/ops/op_def_library.py", line 664, in apply_op 
    op_def=op_def) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1834, in create_op 
    original_op=self._default_original_op, op_def=op_def) 
    File "/usr/local/lib/python2.7/site-packages/tensorflow/python/framework/ops.py", line 1043, in __init__ 
    self._traceback = _extract_stack() 

Quindi, forse, il codice originale (vedere il link esterno sopra) potrebbe essere modificato per qualcosa di simile:

def weight_variable(shape): 
    initial = tf.truncated_normal(shape, stddev=0.1) 
    weight_var = tf.Variable(initial, name="weight_var") 
    return weight_var 

def bias_variable(shape): 
    initial = tf.constant(0.1, shape=shape) 
    bias_var = tf.Variable(initial, name="bias_var") 
    return bias_var 

Ma allora la domanda Ho: è il ripristino delle variabili weight_var e bias_var sufficienti per implementare la previsione? Ho fatto l'allenamento sulla potente macchina con GPU e vorrei copiare il modello sul computer meno potente senza GPU per eseguire previsioni.

+0

Possibile duplicato di [tensorflow: Come ripristinare un modello salvato in precedenza (python)] (http: // stackoverflow.com/questions/33759623/tensorflow-how-to-restore-a-previously-saved-model-python) – mrry

+0

È simile, ma non è un duplicato. – Tomas

risposta

12

C'è una domanda simile qui: Tensorflow: how to save/restore a model? TLDR; è necessario ricreare struttura del modello con stessa sequenza di comandi tensorflow API prima di utilizzare oggetto Saver per ripristinare i pesi

Questo è ottimale, seguire Github issue #696 per il progresso sul rendere questo più facile

+1

Per la cronaca: il problema è stato chiuso più di un anno fa e sembra che l'archiviazione della struttura del modello sia ora supportata. – bluenote10

+0

"è necessario ricreare la struttura del modello utilizzando la stessa sequenza di comandi API TensorFlow prima di utilizzare l'oggetto Saver per ripristinare i pesi" puoi spiegarlo ulteriormente? – Chaine

+3

@Chaine c'è un'opzione migliore, utilizzare MetaGraph o il modello salvato –

1

assicurarsi che la dichiarazione di tf.train. Saver() è in con tf.Session() come sess

0

Questo problema deve essere causato dalle varianti dell'ambito del nome quando si crea una doppia rete.

mettere il comando:

tf.reset_default_graph()

prima di creare la rete

Problemi correlati