All You Need is a Good Init

Originally published in 2016

ICLR 2016 has a really interesting paper "All You Need is a Good Init". In this post I will try to repeat the results of the authors and will do that in Torch.

General idea

So, what is the motivation? Batch Normalization helps, but it slows down the training process (the authors claim it's 30%). Can we do better without additional overhead?

Yes, we can spend more time for smart weights initialization (not much more), but get benefits in training speed, ability to use bigger learning rates and better results.

pseudo
Pre-initialize network with orthonormal matrices as in Saxe et al.(2014)
for each layer L do
    while |Var(B_L) - 1.0| >= Tol_var and (T_i) < T_max) do
        do Forward pass with a mini-batch
        calculate Var(B_L)
        W_L = W_L /sqrt(Var(B_L))
    end while
end for

What I found important for the implementation here:

require nn

nn.Sequential.lsuvInit = function (self, get_batch, tol_var, t_max)
   local tol_var = tol_var or 0.1
   local t_max   = t_max or 10

   for _,m in ipairs(self:listModules()) do
      if m.weight ~= nil then
         local t_i = 1
         while true do
            local input = get_batch()
            self:forward(input)
            local var = torch.var(m.output)
            if torch.abs(var - 1.0) < tol_var or t_i > t_max then
               break
            end
            m.weight:div(math.sqrt(var))
            t_i = t_i + 1
         end
      end
   end
end

Usage (from MNIST example):

require 'nn'
-- add nninit.orthogonal to all convolutional and fully connected layers
model:add(nn.SpatialConvolutionMM(1, 32, 5, 5):init('weight', nninit.orthogonal, {gain = 'relu'}))
model:add(nn.ReLU())
...
model:add(nn.Linear(200, #classes):init('weight', nninit.orthogonal, {gain = 'relu'}))

--do LSUV after orthogonal init above
if opt.lsuv then
  model:lsuvInit(get_batch)
end

MNIST example

I used the following bash command to run the experiment (-f for full mnist dataset: 60 000 for training and 10 000 for testing):

th mnist-example.lua --lsuv -r lr
epoch with lsuv (lr=0.1) with lsuv (lr=0.05) without lsuv (lr=0.001) with lsuv (lr=0.001)
1 97.77% 96.69% 83.39% 78.28%
2 98.45% 97.94% 89.25% 87.75%
3 98.63% 98.37% 91.23% 91.19%
4 98.74% 98.57 92.46% 92.82%
5 98.88% 98.72% 93.23% 93.81%
6 98.97% 98.75% 93.88% 94.53%
7 99.03% 98.86% 94.44% 95.06%
8 99.01% 98.86% 94.81% 95.4%
9 99.01% 98.9% 95.03% 95.87%
10 98.96% 98.91 95.29% 96.15%

I did not wait for 100 epochs as the authors of the original paper did. At first, I thought that we can use bigger learning rates when we use LSUV, but then I realised that MNIST nolsuv case does not use BN, so, this is not true. And MNIST results just show us that training works and the accuracy rates are pretty comparable. Let's have a look at CIFAR-10 experiment.

CIFAR example

I did not check the limit of the accuracy we can achieve, but just checked if the training is comparable in general. And it is. Test dataset accuracy is on the pic.

References

Thanks for the debugging and help to @ikostrikov

If you want to ask me a question, you can find me here