All You Need is a Good Init

Posted on Mon 09 May 2016 in misc

ICLR 2016 has a really interesting paper "All You Need is a Good Init". In this post I will try to repeat the results of the authors and will do that in Torch.

General idea

So, what is the motivation? Batch Normalization helps, but it slows down the training process (the authors claim it's 30%). Can we do better without additional overhead?

Yes, we can spend more time for smart weights initialization (not much more), but get benefits in training speed, ability to use bigger learning rates and better results.

Pre-initialize network with orthonormal matrices as in Saxe et al.(2014)
for each layer L do
    while |Var(B_L) - 1.0| >= Tol_var and (T_i) < T_max) do
        do Forward pass with a mini-batch
        calculate Var(B_L)
        W_L = W_L /sqrt(Var(B_L))
    end while
end for

What I found important for the implementation here:

  • L is a Convolutional or Fully Connected layer
  • For each layer we use a new minibatch.
  • We compute variance for the whole data in minibatch: Var(B_L)! (at first I thought that we compute variance feature-wise) ## Torch implementation
require 'nn'

nn.Sequential.lsuvInit = function (self, get_batch, tol_var, t_max)
   local tol_var = tol_var or 0.1
   local t_max   = t_max or 10

   for _,m in ipairs(self:listModules()) do
      if m.weight ~= nil then
         local t_i = 1
         while true do
            local input = get_batch()
            local var = torch.var(m.output)
            if torch.abs(var - 1.0) < tol_var or t_i > t_max then
            t_i = t_i + 1

Usage (from MNIST example):

require 'nn'

-- add nninit.orthogonal to all convolutional and fully connected layers
model:add(nn.SpatialConvolutionMM(1, 32, 5, 5):init('weight', nninit.orthogonal, {gain = 'relu'}))
model:add(nn.Linear(200, #classes):init('weight', nninit.orthogonal, {gain = 'relu'}))

--do LSUV after orthogonal init above
if opt.lsuv then

MNIST example

I used the following bash command to run the experiment (-f for full mnist dataset: 60 000 for training and 10 000 for testing):

th mnist-example.lua --lsuv -r lr
epoch with lsuv (lr=0.1) with lsuv (lr=0.05) without lsuv (lr=0.001) with lsuv (lr=0.001)
1 97.77% 96.69% 83.39% 78.28%
2 98.45% 97.94% 89.25% 87.75%
3 98.63% 98.37% 91.23% 91.19%
4 98.74% 98.57 92.46% 92.82%
5 98.88% 98.72% 93.23% 93.81%
6 98.97% 98.75% 93.88% 94.53%
7 99.03% 98.86% 94.44% 95.06%
8 99.01% 98.86% 94.81% 95.4%
9 99.01% 98.9% 95.03% 95.87%
10 98.96% 98.91 95.29% 96.15%

I did not wait for 100 epochs as the authors of the original paper did. At first, I thought that we can use bigger learning rates when we use LSUV, but then I realised that MNIST nolsuv case does not use BN, so, this is not true. And MNIST results just show us that training works and the accuracy rates are pretty comparable. Let's have a look at CIFAR-10 experiment.

CIFAR example

I did not check the limit of the accuracy we can achieve, but just checked if the training is comparable in general. And it is. Test dataset accuracy is on the pic.


Thanks for the debugging and help to @ikostrikov

If you want to ask me a question, you can find me here