2016-05-02 62 views
5

Ho notato che ci sono già funzioni di normalizzazione batch nell'API per tensorflow. Una cosa che non capisco però, è come cambiare la procedura tra allenamento e test?Normalizzazione batch in tensorflow

Batch normalizzazione agisce in modo diverso durante la prova di quanto durante l'allenamento. Nello specifico si usa una media e una varianza fisse durante l'allenamento.

C'è qualche buon codice di esempio da qualche parte? Ho visto un po ', ma con le variabili di ambito ha ottenuto confuso

+0

considerare l'utilizzo di strati pre-definite dalla API di alto livello come ad esempio 'tf.contrib .layers'. – danijar

risposta

9

Hai ragione, il tf.nn.batch_normalization fornisce solo le funzionalità di base per l'attuazione di normalizzazione batch. Devi aggiungere la logica extra per tenere traccia dei mezzi e delle variazioni in movimento durante l'allenamento, e usare i mezzi e le varianze addestrati durante l'inferenza. Si può guardare a questo example per un'implementazione molto generale, ma una versione rapida che non utilizza gamma è qui:

beta = tf.Variable(tf.zeros(shape), name='beta') 
    moving_mean = tf.Variable(tf.zeros(shape), name='moving_mean', 
           trainable=False) 
    moving_variance = tf.Variable(tf.ones(shape), 
            name='moving_variance', 
            trainable=False) 
    control_inputs = [] 
    if is_training: 
    mean, variance = tf.nn.moments(image, [0, 1, 2]) 
    update_moving_mean = moving_averages.assign_moving_average(
     moving_mean, mean, self.decay) 
    update_moving_variance = moving_averages.assign_moving_average(
     moving_variance, variance, self.decay) 
    control_inputs = [update_moving_mean, update_moving_variance] 
    else: 
    mean = moving_mean 
    variance = moving_variance 
    with tf.control_dependencies(control_inputs): 
    return tf.nn.batch_normalization(
     image, mean=mean, variance=variance, offset=beta, 
     scale=None, variance_epsilon=0.001) 
+0

Grazie mille. Un'altra domanda veloce. Una versione con gamma è davvero più complicata? sembra che dovresti solo inizializzare un altro tf.Variabile per questo? Il resto del codice dovrebbe essere lo stesso, non dovrebbe? – user3358117

+0

Sì, puoi seguire l'implementazione più generale nel link che ho fornito per aggiungere 'gamma'. – keveman