2016-03-24 16 views

risposta

54

In TensorFlow, i pesi addestrati sono rappresentati dagli oggetti tf.Variable. Se hai creato — ad es. chiamato v — da soli, è possibile ottenere il suo valore come un array NumPy chiamando sess.run(v) (dove sess è un tf.Session).

Se attualmente non si dispone di un puntatore allo tf.Variable, è possibile ottenere un elenco delle variabili trainabili nel grafico corrente chiamando tf.trainable_variables(). Questa funzione restituisce un elenco di tutti gli oggetti tf.Variable trainable nel grafico corrente e puoi selezionare quello che desideri facendo corrispondere la proprietà v.name. Ad esempio:

# Desired variable is called "tower_2/filter:0". 
var = [v for v in tf.trainable_variables() if v.name == "tower_2/filter:0"][0] 
+0

Grazie @mrry, se carico modello riguardava da qualsiasi supporto modello zoo da tensorflow posso accedere ai parametri addestrabili con stessa funzione ho provato ma tornato matrice vuota. Qualsiasi risposta per favore –

+3

Dipende dal meccanismo utilizzato per caricare il modello. Se usi il più recente 'tf.train.import_meta_graph()' allora 'tf.trainable_variables()' dovrebbe funzionare. Se usi la funzione 'tf.import_graph_def()' di livello inferiore, dovresti passare il nome della variabile nell'argomento facoltativo 'return_elements', e verrà restituito un tensore (che potrai quindi passare a' sess.run ( – mrry

+0

Grazie mille –

Problemi correlati