2016-05-07 16 views
11

È possibile rinominare l'ambito variabile di un determinato modello in tensorflow?Rinomina ambito variabile del modello salvato in TensorFlow

Per esempio, ho creato un modello di regressione logistica per le cifre MNIST, sulla base del tutorial:

with tf.variable_scope('my-first-scope'): 
    NUM_IMAGE_PIXELS = 784 
    NUM_CLASS_BINS = 10 
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS]) 
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS]) 

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS])) 
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS])) 

    y = tf.nn.softmax(tf.matmul(x,W) + b) 
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
    saver = tf.train.Saver([W, b]) 

... # some training happens 

saver.save(sess, 'my-model') 

Ora voglio ricaricare il modello salvato nel 'my-first-scope' portata variabile e quindi salvare di nuovo tutto per una nuova file e sotto un nuovo ambito variabile di 'my-second-scope'.

risposta

7

È possibile utilizzare tf.contrib.framework.list_variables e tf.contrib.framework.load_variable come segue per raggiungere il tuo obiettivo:

with tf.Graph().as_default(), tf.Session().as_default() as sess: 
    with tf.variable_scope('my-first-scope'): 
    NUM_IMAGE_PIXELS = 784 
    NUM_CLASS_BINS = 10 
    x = tf.placeholder(tf.float32, shape=[None, NUM_IMAGE_PIXELS]) 
    y_ = tf.placeholder(tf.float32, shape=[None, NUM_CLASS_BINS]) 

    W = tf.Variable(tf.zeros([NUM_IMAGE_PIXELS,NUM_CLASS_BINS])) 
    b = tf.Variable(tf.zeros([NUM_CLASS_BINS])) 

    y = tf.nn.softmax(tf.matmul(x,W) + b) 
    cross_entropy = tf.reduce_mean(-tf.reduce_sum(y_ * tf.log(y), reduction_indices=[1])) 
    saver = tf.train.Saver([W, b]) 
    sess.run(tf.global_variables_initializer()) 
    saver.save(sess, 'my-model') 

vars = tf.contrib.framework.list_variables('.') 
with tf.Graph().as_default(), tf.Session().as_default() as sess: 

    new_vars = [] 
    for name, shape in vars: 
    v = tf.contrib.framework.load_variable('.', name) 
    new_vars.append(tf.Variable(v, name=name.replace('my-first-scope', 'my-second-scope'))) 

    saver = tf.train.Saver(new_vars) 
    sess.run(tf.global_variables_initializer()) 
    saver.save(sess, 'my-new-model') 
+0

Ciò richiede che sia stato costruito il grafico e tutto utilizzando il nome dell'ambito precedente, poiché per ripristinare il checkpoint è necessario aver definito il grafico. Se si dispone solo del file di checkpoint, è possibile sostituire un nome di ambito al suo interno? – npit

18

in base alla risposta del keveman, ho creato uno script python, che è possibile eseguire per rinominare le variabili di qualsiasi posto di blocco tensorflow:

https://gist.github.com/batzner/7c24802dd9c5e15870b4b56e22135c96

È possibile sostituire sottostringhe nei nomi di variabile e aggiungere un prefisso a tutti i nomi. Chiamare lo script con

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir 

con gli argomenti opzionali

--replace_from=substr --replace_to=substr --add_prefix=abc --dry_run 

Ecco la funzione centrale della sceneggiatura:

def rename(checkpoint_dir, replace_from, replace_to, add_prefix, dry_run=False): 
    checkpoint = tf.train.get_checkpoint_state(checkpoint_dir) 
    with tf.Session() as sess: 
     for var_name, _ in tf.contrib.framework.list_variables(checkpoint_dir): 
      # Load the variable 
      var = tf.contrib.framework.load_variable(checkpoint_dir, var_name) 

      # Set the new name 
      new_name = var_name 
      if None not in [replace_from, replace_to]: 
       new_name = new_name.replace(replace_from, replace_to) 
      if add_prefix: 
       new_name = add_prefix + new_name 

      if dry_run: 
       print('%s would be renamed to %s.' % (var_name, new_name)) 
      else: 
       print('Renaming %s to %s.' % (var_name, new_name)) 
       # Rename the variable 
       var = tf.Variable(var, name=new_name) 

     if not dry_run: 
      # Save the variables 
      saver = tf.train.Saver() 
      sess.run(tf.global_variables_initializer()) 
      saver.save(sess, checkpoint.model_checkpoint_path) 

Esempio:

python tensorflow_rename_variables.py --checkpoint_dir=path/to/dir --replace_from=scope1 --replace_to=scope1/model --add_prefix=abc/ 

rinominerà la variabile scope1/Variable1a abc/scope1/model/Variable1.

+0

Ricevo questo errore con lo script: ValueError: Impossibile trovare il file "checkpoint" oi checkpoint nella directory specificata ./fi – ryuzakinho

+1

@ryuzakinho, è necessario specificare una directory che contenga un file 'checkpoint'. Vedi https://www.tensorflow.org/programmers_guide/variables#checkpoint_files per maggiori informazioni. –

+1

In realtà, ho avuto write_state = False per qualche motivo. Quindi, non stava creando un file di checkpoint. – ryuzakinho