Tamaki Kojima([email protected])
Pytorch 1.0 support
This is alternative implementation of "Synchronized Multi-GPU Batch Normalization" which computes global stats across gpus instead of locally computed. SyncBN are getting important for those input image is large, and must use multi-gpu to increase the minibatch-size for the training.
The code was inspired by Pytorch-Encoding and Inplace-ABN
- Unlike Pytorch-Encoding, you don't need custom
nn.DataParallel
- Unlike Inplace-ABN, you can just replace your
nn.BatchNorm2d
to this module implementation, since it will not mark for inplace operation - You can plug into arbitrary module written in PyTorch to enable Synchronized BatchNorm
- Backward computation is rewritten and tested against behavior of
nn.BatchNorm2d
For PyTorch, please refer to https://pytorch.org/
NOTE : The code is tested only with PyTorch v1.0.0, CUDA10/CuDNN7.4.2 on ubuntu18.04
It utilize Pytorch JIT mechanism to compile seamlessly, using ninja. Please install ninja-build before use.
sudo apt-get install ninja-build
Also install all dependencies for python. For pip, run:
pip install -U -r requirements.txt
There is no need to build. just run and JIT will take care. JIT and cpp extensions are supported after PyTorch0.4, however it is highly recommended to use PyTorch > 1.0 due to huge design changes.
Please refer to test.py
for testing the difference between nn.BatchNorm2d
and modules.nn.BatchNorm2d
import torch
from modules import nn as NN
num_gpu = torch.cuda.device_count()
model = nn.Sequential(
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
nn.ReLU(inplace=True),
nn.Conv2d(3, 3, 1, 1, bias=False),
NN.BatchNorm2d(3),
).cuda()
model = nn.DataParallel(model, device_ids=range(num_gpu))
x = torch.rand(num_gpu, 3, 2, 2).cuda()
z = model(x)
-
gather all from workers to master and compute where
and
and then above global stats to be shared to all gpus, update running_mean and running_var by moving average using global stats.
-
forward batchnorm using global stats by
and then
-
Compute below sums on each gpu
and
then gather them at master node to sum up global, and normalize with N where N is total number of elements for each channels. Global sums are then shared among all gpus.
-
compute gradients using global stats
where
and
and finally,
Note that in the implementation, normalization with N is performed at step (2) and above equation and implementation is not exactly the same, but mathematically is same.
You can go deeper on above explanation at Kevin Zakka's Blog