The Exploding Gradient Problem: Causes, Symptoms, and Gradient Clipping

Why gradients can grow uncontrollably in deep networks, how to spot NaN losses caused by it, and how gradient clipping fixes it in practice.

The Exploding Gradient Problem: Causes, Symptoms, and Gradient Clipping

If a training run’s loss suddenly jumps to an enormous number and then becomes NaN within a handful of steps, after training seemingly normally before that point, you’re very likely looking at exploding gradients — the mirror-image failure of vanishing gradients, where the same chain-rule multiplication that can shrink gradients toward zero can, under different conditions, cause them to grow uncontrollably large instead.


The Mechanism: Multiplying Many Large Numbers Together

Just as vanishing gradients arise from repeatedly multiplying derivatives smaller than 1, exploding gradients arise from repeatedly multiplying derivatives larger than 1 across many layers — the product grows exponentially rather than shrinking.

# Simulating gradient magnitude after passing through 10 layers
# where each layer's local derivative happens to be greater than 1
gradient = 1.0
for layer in range(10):
local_derivative = 1.5 # greater than 1
gradient *= local_derivative
print(gradient) # 1.5^10 ≈ 57.7 -- and this compounds further with more layers

With more layers, or a larger local derivative, this can quickly reach astronomically large values — at which point the resulting weight update is so large it destroys the network’s learned weights in a single step, often producing inf or NaN values, connecting directly to the overflow issues covered in Numerical Computation.


Common Causes

Poor weight initialization. Weights initialized with too large a variance directly cause both forward-pass activations and backward-pass gradients to grow layer over layer, covered in Weight Initialization.

Too high a learning rate. Even with healthy gradients, an excessively large learning rate can cause weight updates so large that the next forward pass produces dramatically larger activations, which then produce larger gradients, compounding into an unstable feedback loop — covered in Learning Rate.

Recurrent architectures processing long sequences. RNNs, covered in Recurrent Neural Networks, effectively apply the same weight matrix repeatedly across every time step in a sequence — a long sequence means many repeated multiplications by the same matrix, making exploding (and vanishing) gradients a particularly acute, well-documented problem for this architecture family specifically.


Diagnosing Exploding Gradients

The clearest signal is watching the gradient norm (the overall magnitude of all gradients combined) over training steps.

import torch
total_norm = 0
for param in model.parameters():
if param.grad is not None:
param_norm = param.grad.data.norm(2)
total_norm += param_norm.item() ** 2
total_norm = total_norm ** 0.5
print(f"Gradient norm: {total_norm}")
# A gradient norm that suddenly spikes to a very large value,
# right before loss becomes NaN, is the clear signature of this problem

Tracking this value across training (via TensorBoard or a similar tool) turns a confusing “training randomly broke” incident into a specific, visible, and diagnosable pattern — a gradient norm that stays reasonably bounded for many steps and then suddenly spikes right before the loss diverges.


The Standard Fix: Gradient Clipping

Gradient clipping caps the gradient norm at a maximum threshold before applying the weight update — if the actual gradient exceeds this threshold, it’s scaled down proportionally, preserving its direction but limiting its magnitude.

import torch.nn as nn
loss.backward()
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0) # clip before the optimizer step
optimizer.step()

This single line is one of the most common, effective, and cheap fixes for training instability, particularly in recurrent architectures and transformers, where it’s applied almost universally as a standard precaution rather than something reached for only after a problem is observed.


Value Clipping vs. Norm Clipping

An alternative (less commonly recommended) approach clips each individual gradient value to a fixed range, rather than scaling the overall gradient vector proportionally.

# Norm clipping (preferred): scales the whole gradient vector, preserving direction
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)
# Value clipping (less common): clips each individual value independently
torch.nn.utils.clip_grad_value_(model.parameters(), clip_value=0.5)

Norm clipping is generally preferred because it preserves the direction of the gradient (still a valid, if scaled-down, descent direction), while value clipping can distort the gradient’s direction by clipping some components more aggressively than others.


Prevention Beyond Clipping

Gradient clipping treats the symptom effectively, but combining it with root-cause prevention gives the most stable training:

PreventionAddresses
Proper weight initializationPrevents gradients from starting too large
Batch normalizationKeeps activations (and indirectly gradients) in a stable range throughout training
A conservative learning rateReduces the size of each individual weight update
Architectures with better gradient flow (LSTM/GRU for sequences)Specifically designed to mitigate this issue in recurrent contexts, covered in LSTM and GRU

Distinguishing Exploding Gradients From Other Causes of NaN Loss

Not every NaN loss traces back to exploding gradients specifically — it’s worth ruling out the alternatives covered in Numerical Computation before concluding gradients are the culprit. A NaN appearing on the very first training step, before any real learning has happened, more often points to a data problem (a missing value that wasn’t handled, a division by zero in a custom loss function, or an unstable computation like an un-shifted softmax) rather than exploding gradients, which typically build up gradually over several training steps before finally producing NaN. Checking whether the gradient norm was trending upward over several steps before the failure, versus jumping straight to NaN from a normal-looking first step, is a fast, useful way to distinguish which category of problem you’re actually dealing with.

Summary

SymptomCauseFix
Loss suddenly jumps to a huge value or NaNGradients compounding to very large magnitude across layersGradient clipping (immediate fix)
Recurring instability across many training runsPoor initialization, too-high learning rate, or architecture-specific issues (RNNs)Address root cause: initialization, learning rate, or architecture choice

Exploding gradients are the more dramatic, more immediately visible sibling of vanishing gradients — but they’re just as directly traceable to the chain rule’s repeated multiplication, and just as reliably fixed with well-established, standard techniques.