Batch Normalization
Batch normalization is one of the most important techniques in deep learning. It normalizes the activations of each layer to have zero mean and unit variance during training. The result: faster training, higher learning rates, reduced sensitivity to initialization, and mild regularization.
The Problem It Solves
Without normalization, as the network trains, the distribution of each layer’s inputs shifts with every weight update — the next layer constantly has to adapt to a moving target. This “internal covariate shift” slows learning and requires careful tuning of learning rates and weight initialization.
How Batch Normalization Works
For each mini-batch of activations:
Given a batch of activations: {x₁, x₂, ..., xₘ}
1. Compute batch statistics: μ_B = (1/m) Σ xᵢ (batch mean) σ²_B = (1/m) Σ (xᵢ - μ_B)² (batch variance)
2. Normalize: x̂ᵢ = (xᵢ - μ_B) / √(σ²_B + ε) (ε = 1e-5 for numerical stability)
3. Scale and shift (learnable parameters γ, β): yᵢ = γ × x̂ᵢ + β
γ and β allow the network to undo normalization if needed — the network can learnthe optimal activation scale for each feature.BatchNorm in PyTorch
import torch.nn as nn
# For fully connected layers (1D BatchNorm)model = nn.Sequential( nn.Linear(256, 128), nn.BatchNorm1d(128), # Normalize over the batch dimension nn.ReLU(), nn.Linear(128, 64), nn.BatchNorm1d(64), nn.ReLU(), nn.Linear(64, 10))
# For convolutional layers (2D BatchNorm)conv_block = nn.Sequential( nn.Conv2d(in_channels=32, out_channels=64, kernel_size=3, padding=1), nn.BatchNorm2d(64), # Normalize each channel separately nn.ReLU())Training vs. Inference
During training, BatchNorm uses the current mini-batch statistics (mean, variance).
During inference, it uses running statistics accumulated during training — a moving average of batch means and variances:
# Training: uses batch statisticsmodel.train()output = model(x)
# Inference: uses running statistics (frozen)model.eval()with torch.no_grad(): output = model(x)Running statistics are updated during training as:
running_mean = momentum × running_mean + (1 - momentum) × batch_meanDefault momentum=0.1 (PyTorch) — each batch contributes 10% to the running estimate.
Where to Place BatchNorm
The original paper proposed: Conv → BatchNorm → Activation
Modern practice often uses: Conv → Activation → BatchNorm (pre-activation ResNets)
Both work; the original ordering is more common:
# Standard ResNet block with BatchNormclass ResBlock(nn.Module): def __init__(self, channels): super().__init__() self.block = nn.Sequential( nn.Conv2d(channels, channels, 3, padding=1), nn.BatchNorm2d(channels), nn.ReLU(), nn.Conv2d(channels, channels, 3, padding=1), nn.BatchNorm2d(channels) ) self.relu = nn.ReLU()
def forward(self, x): return self.relu(x + self.block(x)) # Residual connectionBatchNorm Alternatives
| Normalization | Normalizes Over | Best For |
|---|---|---|
| BatchNorm | Batch dimension | CNNs, large batch sizes |
| LayerNorm | Feature dimension (per sample) | Transformers, NLP |
| GroupNorm | Groups of channels | Small batches, detection |
| InstanceNorm | Spatial dimensions (per channel per sample) | Style transfer |
# LayerNorm (used in Transformers)layer_norm = nn.LayerNorm(normalized_shape=512) # Normalize the last 512 dimensions
# GroupNorm (batch-size independent)group_norm = nn.GroupNorm(num_groups=8, num_channels=64) # Divide 64 channels into 8 groupsKey Benefits and Caveats
Benefits:
- Allows 5–10× higher learning rates
- Reduces dependence on careful weight initialization
- Acts as a regularizer (often reduces need for dropout)
- Stabilizes training of very deep networks
Caveats:
- Doesn’t work well with very small batch sizes (< 8–16) → use GroupNorm instead
- Adds complexity at inference (running stats must be correctly maintained)
- Can interact poorly with dropout when both are used aggressively
BatchNorm is standard in ResNets, EfficientNets, and most CNN architectures. LayerNorm has taken over in Transformers. Understanding which normalization to use and where is one of the key practical skills in building deep learning systems.