Overview
This article will not illustrate the basic concept of
Batch Normalization. Instead, this article will focus on the implementation detail of
Batch Normalization in TensorFlow.
Introduction
Batch Normalization can make the convergence of neural networks easier, and sometimes can even improve the accuracies. The formula is shown below:
where
The idea is simple and elegant. However, when it comes to implementation, it becomes a little tricky due to the average parts (the average of the means and the average of the variances): To calculate such average parts over all training set seems infeasible.
A common way to solve it is to invoke the
Moving Average algorithm. In this case, you cannot merely wrap the Batch Normalization in a function and return its output tensor, you should also return the
update operations of the
Moving Average calculation. And
session run such
update operations after you
backpropagate the neural network in each step. Therefore, you should expect that there should be somehow an
update operation after each training step (or, at least in an implicit way), no matter what
Batch Normalization API you use.
Tensorflow provides two Batch Normalization API,
tf.nn.batch_normalization() and
tf.layers.batch_normalization(). The following section will use these different API to implement a function called
BatchNormalization() and compare their performance.
tf.nn.batch_normalization:
This is the low-level API for Batch Normalization: You should not only
session run the
update operations each training step, but also need to calculate the average of mean and variance as well as to create variables such as gamma and betta all by your own. The code is shown as follows (ref:
[1],
[2]):
def BatchNormalization(isTraining_, currentStep_, inputTensor_, isConvLayer_, layerName_="BatchNorm"):
with tf.variable_scope(layerName_):
currentBatchMean = None
currentBatchVariance = None
outputChannels = None
if isConvLayer_:
currentBatchMean, currentBatchVariance = tf.nn.moments(inputTensor_, [0, 1, 2])
else:
currentBatchMean, currentBatchVariance = tf.nn.moments(inputTensor_, [0])
averageCalculator = tf.train.ExponentialMovingAverage(decay=0.99,
num_updates=currentStep_)
updateVariablesOperation = averageCalculator.apply( [currentBatchMean, currentBatchVariance] )
totalMean = tf.cond(isTraining_,
lambda: currentBatchMean, lambda: averageCalculator.average(currentBatchMean) )
totalVariance = tf.cond(isTraining_,
lambda: currentBatchVariance, lambda: averageCalculator.average(currentBatchVariance) )
outputChannels = int(inputTensor_.shape[-1])
gamma = tf.Variable( tf.ones([outputChannels]) )
betta = tf.Variable( tf.zeros([outputChannels]) )
epsilon = 1e-5
outputTensor = tf.nn.batch_normalization(inputTensor_, mean=totalMean, variance=totalVariance, offset=betta,
scale=gamma, variance_epsilon=epsilon)
return outputTensor, updateVariablesOperation
Note that as remarked in ref
[1], it suggests that one should assign the
current training step to
num_updates in the construction of
tf.train.ExponentialMovingAverage() to
"prevent from averaging across non-existing iterations". I think it just means that the variable is randomly initialized in the first step and should be scaled down its importance. If you don't assign the num_updates, according to the TensorFlow documents, it will calculate the mean simply by:
totalMean = (1 - decay)*currentBatchMean + decay*totalMean
And if you assign the current training step to num_updates, the mean will be calculated as follows:
decay = min(decay, (1 + step)/(10+step) )
totalMean = (1 - decay)*currentBatchMean + decay*totalMean
Usage:
You can build your net as follows:
class AlexnetBatchNorm(SubnetBase):
def __init__(self, isTraining_, trainingStep_, input_, ...):
self.isTraining = isTraining_
self.trainingStep = trainingStep_
self.input = input_
...
def Build(self):
net = ConvLayer(self.input, 3, 8, stride_=1, padding_='SAME', layerName_='conv1')
net, updateVariablesOp1 = BatchNormalization(self.isTraining, self.trainingStep, net, isConvLayer_=True)
net = tf.nn.relu(net)
net = ConvLayer(net, 3, 16, stride_=1, padding_='SAME', layerName_='conv2')
net, updateVariablesOp2 = BatchNormalization(self.isTraining, self.trainingStep, net, isConvLayer_=True)
net = tf.nn.relu(net)
...
updateOperations = tf.group(updateVariablesOp1, updateVariablesOp2, ...)
return net, updateOperations
where
ConvLayer() is a simple wrapper for convolution layer. The
isTraining_,
trainingStep_,
input_ is the
placeholders that will be assigned while you
session run.
Finally, session run the update operation for each training step:
while step < MAX_TRAINING_STEPS:
session.run( trainOp,
feed_dict={self.net.isTraining : True,
self.net.trainingStep : step,
self.net.input : x,
...})
session.run( self.updateNetOp,
feed_dict={self.net.isTraining : False,
self.net.trainingStep : step,
self.net.input : x,
...})
You can refer to
our git hub (file:
Train.py,
src/subnet/AlexBatchNorm.py,
src/layers/BasicLayers.py) for more detail.
Performance:
The following figure shows the training and validation curve of models that applied Batch Normalization (the green and the pink curves) or not (the blue and the orange curves):
One can see that the model with
Batch Normalization converges very fast and even improves the result by a small percentage.
Recover:
In the above implementation, we have used
tf.train.ExponentialMovingAverage to calculate the average of mean and variance. However, its documents suggested that while you try to recover the graph from checkpoints, you should do something like:
variables_to_restore = ema.variables_to_restore()
saver = tf.train.Saver(variables_to_restore)
In the
Batch Normalization, however, it seems that we can recover the network as usual (probably the
tf.nn.batch_normalization() has already done it?). If I try to recover the network as the above suggestion, I got following error:
Therefore, one should just recover the network just as follows:
modelLoader = tf.train.Saver()
modelLoader.restore(session, PATH_TO_MODEL_CHECKPOINT)
One more proof of this is to re-train the model and see if its loss starts from the same value as the pre-train model:
As shown above, the blue curve is the validation of the pre-train model. The red curve is the model that read from the last step of the pre-train model and re-train it again. You can see that they are perfectly matched. Therefore, we can judge that the means and variances of variables are perfectly recovered.
tf.layers.batch_normalization:
TensorFlow also provides high-level API for Batch Normalization. However, its weird behavior makes us decide not to use it finally.
The wrapper function of
Batch Normalization that applies this API is simply:
def BatchNormalization(isTraining_, inputTensor_, layerName_=None):
return tf.layers.batch_normalization(inputTensor_, training=isTraining_, name=layerName_)
In this implementation, you can see that we don't need to input whether the last layer is convolution and the function just return the output tensor. However, this does not mean that you don't need to update the network. The
update operation is stored in the
tf.collection and you should pull it out and session run it after each training step. The documentation suggests that you can also claim the dependencies of the
update operation and the
training operation so that while you session run the
training operation, it'll automatically run the
update operation for you:
updateOps = tf.get_collection(tf.GraphKeys.UPDATE_OPS)
optimizer = tf.train.AdamOptimizer(learning_rate=self.learningRate)
with tf.control_dependencies(updateOps):
self.trainOp = optimizer.minimize(lossOp)
while step < MAX_TRAINING_STEPS:
session.run( self.trainOp,
feed_dict={self.net.isTraining : True,
self.net.trainingStep : step,
self.net.input : x,
...})
Performance:
Following is the comparison of the two implementations.
The two upper curves are the
train &
val that apply the
tf.layers.batch_normalization API. And the two lower curve is the
train &
val that apply the
tf.nn.batch_normalization API. One can see that: Both of them converge to the same limit. However, the
tf.layers.batch_normalization API had gone through a small hump during the 45~50 training epochs. It's this strange behavior makes us finally use the
tf.nn.batch_normalization API
.
Conclusion
This article concentrates on the implementation detail of the Batch Normalization in Tensorflow. Two approaches have been compared. Moreover, this article also proves that one does not need to recover the
tf.train.ExponentialMovingAverage variables manually (as the documentation suggested) while one recover the networks.