2017年12月29日 星期五

Batch Normalization in TensorFlow

Overview

    This article will not illustrate the basic concept of Batch Normalization.  Instead, this article will focus on the implementation detail of Batch Normalization in TensorFlow.

Introduction

    Batch Normalization can make the convergence of neural networks easier, and sometimes can even improve the accuracies.  The formula is shown below:






where










   The idea is simple and elegant.   However, when it comes to implementation, it becomes a little tricky due to the average parts (the average of the means and the average of the variances):  To calculate such average parts over all training set seems infeasible.
     A common way to solve it is to invoke the Moving Average algorithm.  In this case, you cannot merely wrap the Batch Normalization in a function and return its output tensor, you should also return the update operations of the Moving Average calculation.  And session run such update operations after you backpropagate the neural network in each step.  Therefore, you should expect that there should be somehow an update operation after each training step (or, at least in an implicit way), no matter what Batch Normalization API you use.
     Tensorflow provides two Batch Normalization API, tf.nn.batch_normalization() and tf.layers.batch_normalization().  The following section will use these different API to implement a function called BatchNormalization() and compare their performance.

tf.nn.batch_normalization:

    This is the low-level API for Batch Normalization:  You should not only session run the update operations each training step, but also need to calculate the average of mean and variance as well as to create variables such as gamma and betta all by your own.  The code is shown as follows (ref: [1], [2]):
def BatchNormalization(isTraining_, currentStep_, inputTensor_, isConvLayer_, layerName_="BatchNorm"):
        with tf.variable_scope(layerName_):
                currentBatchMean = None
                currentBatchVariance = None
                outputChannels = None
                if isConvLayer_:
                        currentBatchMean, currentBatchVariance = tf.nn.moments(inputTensor_, [0, 1, 2])
                else:
                        currentBatchMean, currentBatchVariance = tf.nn.moments(inputTensor_, [0])

                averageCalculator = tf.train.ExponentialMovingAverage(decay=0.99,
                                                                      num_updates=currentStep_)
                updateVariablesOperation = averageCalculator.apply( [currentBatchMean, currentBatchVariance] )

                totalMean = tf.cond(isTraining_,
                                    lambda: currentBatchMean, lambda: averageCalculator.average(currentBatchMean) )

                totalVariance = tf.cond(isTraining_,
                                        lambda: currentBatchVariance, lambda: averageCalculator.average(currentBatchVariance) )

                outputChannels = int(inputTensor_.shape[-1])
                gamma = tf.Variable( tf.ones([outputChannels]) )
                betta = tf.Variable( tf.zeros([outputChannels]) )
                epsilon = 1e-5
                outputTensor = tf.nn.batch_normalization(inputTensor_, mean=totalMean, variance=totalVariance, offset=betta,
                                                         scale=gamma, variance_epsilon=epsilon)
                return outputTensor, updateVariablesOperation
Note that as remarked in ref [1], it suggests that one should assign the current training step to num_updates in the construction of tf.train.ExponentialMovingAverage() to "prevent from averaging across non-existing iterations".  I think it just means that the variable is randomly initialized in the first step and should be scaled down its importance.  If you don't assign the num_updates, according to the TensorFlow documents, it will calculate the mean simply by:
totalMean = (1 - decay)*currentBatchMean + decay*totalMean

And if you assign the current training step to num_updates, the mean will be calculated as follows:
decay = min(decay, (1 + step)/(10+step) )
totalMean = (1 - decay)*currentBatchMean + decay*totalMean

Usage:

You can build your net as follows:
class AlexnetBatchNorm(SubnetBase):
        def __init__(self, isTraining_, trainingStep_, input_, ...):
                self.isTraining = isTraining_
                self.trainingStep = trainingStep_
                self.input = input_
                ...

        def Build(self):
                net = ConvLayer(self.input, 3, 8, stride_=1, padding_='SAME', layerName_='conv1')
                net, updateVariablesOp1 = BatchNormalization(self.isTraining, self.trainingStep, net, isConvLayer_=True)
                net = tf.nn.relu(net)

                net = ConvLayer(net, 3, 16, stride_=1, padding_='SAME', layerName_='conv2')
                net, updateVariablesOp2 = BatchNormalization(self.isTraining, self.trainingStep, net, isConvLayer_=True)
                net = tf.nn.relu(net)

                ...

                updateOperations = tf.group(updateVariablesOp1, updateVariablesOp2, ...)
                return net, updateOperations
where ConvLayer() is a simple wrapper for convolution layer. The isTraining_, trainingStep_, input_ is the placeholders that will be assigned while you session run.
     Finally, session run the update operation for each training step:
while step < MAX_TRAINING_STEPS:
        session.run( trainOp,
                     feed_dict={self.net.isTraining : True,
                                self.net.trainingStep : step,
                                self.net.input : x,
                                ...})

        session.run( self.updateNetOp,
                     feed_dict={self.net.isTraining : False,
                                self.net.trainingStep : step,
                                self.net.input : x,
                                        ...})

You can refer to our git hub (file: Train.py, src/subnet/AlexBatchNorm.py, src/layers/BasicLayers.py) for more detail.

Performance:

     The following figure shows the training and validation curve of models that applied Batch Normalization (the green and the pink curves) or not (the blue and the orange curves):
One can see that the model with Batch Normalization converges very fast and even improves the result by a small percentage.

Recover:

    In the above implementation, we have used tf.train.ExponentialMovingAverage to calculate the average of mean and variance.  However, its documents suggested that while you try to recover the graph from checkpoints, you should do something like:
variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
In the Batch Normalization, however, it seems that we can recover the network as usual (probably the tf.nn.batch_normalization() has already done it?).  If I try to recover the network as the above suggestion, I got following error:
Therefore, one should just recover the network just as follows:
modelLoader = tf.train.Saver()
modelLoader.restore(session, PATH_TO_MODEL_CHECKPOINT)
    One more proof of this is to re-train the model and see if its loss starts from the same value as the pre-train model:
As shown above, the blue curve is the validation of the pre-train model.  The red curve is the model that read from the last step of the pre-train model and re-train it again.  You can see that they are perfectly matched.  Therefore, we can judge that the means and variances of variables are perfectly recovered.


tf.layers.batch_normalization:

    TensorFlow also provides high-level API for Batch Normalization.  However, its weird behavior makes us decide not to use it finally.
    The wrapper function of Batch Normalization that applies this API is simply:
def BatchNormalization(isTraining_, inputTensor_, layerName_=None):
        return tf.layers.batch_normalization(inputTensor_, training=isTraining_, name=layerName_)
    In this implementation, you can see that we don't need to input whether the last layer is convolution and the function just return the output tensor.  However, this does not mean that you don't need to update the network.  The update operation is stored in the tf.collection and you should pull it out and session run it after each training step.  The documentation suggests that you can also claim the dependencies of the update operation and the training operation so that while you session run the training operation, it'll automatically run the update operation for you:
updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learningRate)
with tf.control_dependencies(updateOps):
                self.trainOp = optimizer.minimize(lossOp)

while step < MAX_TRAINING_STEPS:
        session.run( self.trainOp,
                     feed_dict={self.net.isTraining : True,
                                self.net.trainingStep : step,
                                self.net.input : x,
                                ...})

Performance:

Following is the comparison of the two implementations.
The two upper curves are the train & val that apply the tf.layers.batch_normalization API.  And the two lower curve is the train & val that apply the tf.nn.batch_normalization API.  One can see that: Both of them converge to the same limit.  However, the tf.layers.batch_normalization API had gone through a small hump during the 45~50 training epochs.  It's this strange behavior makes us finally use the tf.nn.batch_normalization API.

Conclusion

     This article concentrates on the implementation detail of the Batch Normalization in Tensorflow.  Two approaches have been compared.  Moreover, this article also proves that one does not need to recover the tf.train.ExponentialMovingAverage variables manually (as the documentation suggested) while one recover the networks.

沒有留言:

張貼留言