2016-03-01 16 views
11

Ho la seguente situazione:tensorflow: Ripristino variabili da da più punti di controllo

  • Ho 2 modelli scritti in 2 script separati:

  • Modello A è costituito da variabili a1, a2 e a3 , ed è scritto in A.py

  • Modello B costituito da variabili b1, b2, e b3, ed è scritto in B.py

In ognuno di A.py e B.py, ho un tf.train.Saver che salva il posto di blocco di tutte le variabili locali, e chiamiamola rispettivamente i file di checkpoint ckptA e ckptB .

Ora desidero creare un modello C che utilizzi a1 e b1. Posso fare in modo che lo stesso identico nome della variabile per a1 sia usato sia in A che in C usando var_scope (e lo stesso per b1).

La domanda è: come potrei caricare a1 e b1 da ckptA e ckptB nel modello di C? Ad esempio, sarebbe il seguente lavoro?

saver.restore(session, ckptA_location) 
saver.restore(session, ckptB_location) 

Si verificherebbe un errore se si tenta di ripristinare la stessa sessione due volte? Si lamenterebbe che non ci sono "slot" allocati per le variabili extra (b2, b3, a2, a3), o ripristinerebbe semplicemente le variabili che può, e si lamenterà solo se ci sono altre variabili in C non inizializzate?

Sto provando a scrivere del codice per testarlo ora, ma mi piacerebbe vedere un approccio canonico a questo problema, perché si incontra spesso questo quando si tenta di riutilizzare dei pesi pre-addestrati.

Grazie!

risposta

16

Si otterrà un tf.errors.NotFoundError se si è tentato di utilizzare un risparmiatore (per impostazione predefinita che rappresenta tutte le sei variabili) per ripristinare da un checkpoint che non contiene tutte le variabili rappresentate dal risparmiatore. (Nota comunque che sei libero di chiamare più volte Saver.restore() nella stessa sessione, per qualsiasi sottoinsieme delle variabili, purché tutte le variabili richieste siano presenti nel file corrispondente.)

L'approccio canonico è definire due istanze separate tf.train.Saver che coprono ogni sottoinsieme di variabili che è interamente contenuto in un singolo punto di controllo. Per esempio:

saver_a = tf.train.Saver([a1]) 
saver_b = tf.train.Saver([b1]) 

saver_a.restore(session, ckptA_location) 
saver_b.restore(session, ckptB_location) 

A seconda di come il codice è costruito, se si dispone di puntatori a tf.Variable oggetti chiamati a1 e b1 in ambito locale, si può smettere di leggere qui.

D'altra parte, se le variabili a1 e b1 sono definite in file separati, potrebbe essere necessario fare qualcosa di creativo per recuperare i puntatori a tali variabili. Anche se non è l'ideale, quello che la gente di solito fanno è quello di utilizzare un prefisso comune, per esempio nel modo seguente (assumendo che i nomi delle variabili sono "a1:0" e "b1:0" rispettivamente):

saver_a = tf.train.Saver([v for v in tf.all_variables() if v.name == "a1:0"]) 
saver_b = tf.train.Saver([v for v in tf.all_variables() if v.name == "b1:0"]) 

Una nota finale: non si deve fare sforzi eroici per garantire che le variabili abbiano gli stessi nomi in A e C. È possibile passare un dizionario nome- Variable come primo argomento al costruttore tf.train.Saver e quindi rimappare i nomi nel file di checkpoint sugli oggetti Variable nel codice. Ciò è utile se A.py e B.py hanno variabili con nome simile o se in C.py si desidera organizzare il codice modello da tali file in un tf.name_scope().

+0

i primi frammenti di codice si riferiscono a cose che dovrebbero essere scritte in C.py corrette? ed è scritto verso la fine della definizione del grafico dove a1 a2 a3 e b1 b2 b3 sono definiti in C.py? La mia domanda originale era che se solo a1 e b1 fossero definiti in C e nient'altro? –

+0

Scuse - ha aggiornato la risposta a questo caso. Se ridefinite le variabili in C.py, le cose sono molto più semplici! – mrry

+0

Un altro seguito: quando eseguiamo questa istruzione "saver_a.restore (session, ckptA_location)" saver_a viene istanziato con la singola variabile [a1] e nient'altro, tuttavia ckptA contiene valori per tutti a1, a2, a3. Stai dicendo che questo non sarebbe un problema, in quanto il risparmiatore cerca solo a1 nel ckpt e ripristina a1 nel modello C, ignorando a2 e a3? E un'ultima domanda: l'a1 in C sarà identificato come a1 nel ckptA fintanto che entrambi a1 sono istanziati (in C e A) con lo stesso nome giusto? –

Problemi correlati