The Exploding Gradient Problem: Causes, Symptoms, and Gradient Clipping
If a training run’s loss suddenly jumps to an enormous number and then becomes NaN within a handful of steps, after training seemingly normally before that point, you’re very likely looking at exploding gradients — the mirror-image failure of vanishing gradients, where the same chain-rule multiplication that can shrink gradients toward zero can, under different conditions, cause them to grow uncontrollably large instead.
The Mechanism: Multiplying Many Large Numbers Together
Just as vanishing gradients arise from repeatedly multiplying derivatives smaller than 1, exploding gradients arise from repeatedly multiplying derivatives larger than 1 across many layers — the product grows exponentially rather than shrinking.
# Simulating gradient magnitude after passing through 10 layers# where each layer's local derivative happens to be greater than 1gradient = 1.0for layer in range(10): local_derivative = 1.5 # greater than 1 gradient *= local_derivative
print(gradient) # 1.5^10 ≈ 57.7 -- and this compounds further with more layersWith more layers, or a larger local derivative, this can quickly reach astronomically large values — at which point the resulting weight update is so large it destroys the network’s learned weights in a single step, often producing inf or NaN values, connecting directly to the overflow issues covered in Numerical Computation.
Common Causes
Poor weight initialization. Weights initialized with too large a variance directly cause both forward-pass activations and backward-pass gradients to grow layer over layer, covered in Weight Initialization.
Too high a learning rate. Even with healthy gradients, an excessively large learning rate can cause weight updates so large that the next forward pass produces dramatically larger activations, which then produce larger gradients, compounding into an unstable feedback loop — covered in Learning Rate.
Recurrent architectures processing long sequences. RNNs, covered in Recurrent Neural Networks, effectively apply the same weight matrix repeatedly across every time step in a sequence — a long sequence means many repeated multiplications by the same matrix, making exploding (and vanishing) gradients a particularly acute, well-documented problem for this architecture family specifically.
Diagnosing Exploding Gradients
The clearest signal is watching the gradient norm (the overall magnitude of all gradients combined) over training steps.
import torch
total_norm = 0for param in model.parameters(): if param.grad is not None: param_norm = param.grad.data.norm(2) total_norm += param_norm.item() ** 2total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm}")# A gradient norm that suddenly spikes to a very large value,# right before loss becomes NaN, is the clear signature of this problemTracking this value across training (via TensorBoard or a similar tool) turns a confusing “training randomly broke” incident into a specific, visible, and diagnosable pattern — a gradient norm that stays reasonably bounded for many steps and then suddenly spikes right before the loss diverges.
The Standard Fix: Gradient Clipping
Gradient clipping caps the gradient norm at a maximum threshold before applying the weight update — if the actual gradient exceeds this threshold, it’s scaled down proportionally, preserving its direction but limiting its magnitude.
import torch.nn as nn
loss.backward()torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip before the optimizer stepoptimizer.step()This single line is one of the most common, effective, and cheap fixes for training instability, particularly in recurrent architectures and transformers, where it’s applied almost universally as a standard precaution rather than something reached for only after a problem is observed.
Value Clipping vs. Norm Clipping
An alternative (less commonly recommended) approach clips each individual gradient value to a fixed range, rather than scaling the overall gradient vector proportionally.
# Norm clipping (preferred): scales the whole gradient vector, preserving directiontorch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Value clipping (less common): clips each individual value independentlytorch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)Norm clipping is generally preferred because it preserves the direction of the gradient (still a valid, if scaled-down, descent direction), while value clipping can distort the gradient’s direction by clipping some components more aggressively than others.
Prevention Beyond Clipping
Gradient clipping treats the symptom effectively, but combining it with root-cause prevention gives the most stable training:
| Prevention | Addresses |
|---|---|
| Proper weight initialization | Prevents gradients from starting too large |
| Batch normalization | Keeps activations (and indirectly gradients) in a stable range throughout training |
| A conservative learning rate | Reduces the size of each individual weight update |
| Architectures with better gradient flow (LSTM/GRU for sequences) | Specifically designed to mitigate this issue in recurrent contexts, covered in LSTM and GRU |
Distinguishing Exploding Gradients From Other Causes of NaN Loss
Not every NaN loss traces back to exploding gradients specifically — it’s worth ruling out the alternatives covered in Numerical Computation before concluding gradients are the culprit. A NaN appearing on the very first training step, before any real learning has happened, more often points to a data problem (a missing value that wasn’t handled, a division by zero in a custom loss function, or an unstable computation like an un-shifted softmax) rather than exploding gradients, which typically build up gradually over several training steps before finally producing NaN. Checking whether the gradient norm was trending upward over several steps before the failure, versus jumping straight to NaN from a normal-looking first step, is a fast, useful way to distinguish which category of problem you’re actually dealing with.
Summary
| Symptom | Cause | Fix |
|---|---|---|
| Loss suddenly jumps to a huge value or NaN | Gradients compounding to very large magnitude across layers | Gradient clipping (immediate fix) |
| Recurring instability across many training runs | Poor initialization, too-high learning rate, or architecture-specific issues (RNNs) | Address root cause: initialization, learning rate, or architecture choice |
Exploding gradients are the more dramatic, more immediately visible sibling of vanishing gradients — but they’re just as directly traceable to the chain rule’s repeated multiplication, and just as reliably fixed with well-established, standard techniques.