Navigation

Learning Rate Optimization

Find optimal learning rates for deep learning models. Use learning rate finders, schedulers, and warmup strategies to achieve faster convergence and more stable training in PyTorch and TensorFlow.

Why Learning Rate Matters

The learning rate is arguably the most important hyperparameter in deep learning. Set it wrong, and your model either:

  • Too High: Diverges, produces NaN losses, or oscillates wildly
  • Too Low: Trains painfully slow or gets stuck in poor local minima
  • Just Right: Converges quickly to good solutions

Quick Start: Finding Your Learning Rate

The Learning Rate Range Test

The most reliable method to find a good learning rate:

import torch
import matplotlib.pyplot as plt

def find_lr(model, train_loader, optimizer, criterion,
            start_lr=1e-7, end_lr=10, num_iter=100):
    """
    Perform learning rate range test
    """
    lrs = []
    losses = []

    lr_mult = (end_lr / start_lr) ** (1 / num_iter)
    lr = start_lr
    optimizer.param_groups[0]['lr'] = lr

    best_loss = float('inf')
    batch_iter = iter(train_loader)

    for i in range(num_iter):
        try:
            data, target = next(batch_iter)
        except StopIteration:
            batch_iter = iter(train_loader)
            data, target = next(batch_iter)

        data, target = data.cuda(), target.cuda()

        optimizer.zero_grad()
        output = model(data)
        loss = criterion(output, target)

        # Stop if loss explodes
        if loss.item() > 4 * best_loss or torch.isnan(loss):
            break

        if loss.item() < best_loss:
            best_loss = loss.item()

        lrs.append(lr)
        losses.append(loss.item())

        loss.backward()
        optimizer.step()

        # Update learning rate
        lr *= lr_mult
        optimizer.param_groups[0]['lr'] = lr

    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Range Test')
    plt.grid(True)
    plt.savefig('lr_range_test.png')

    return lrs, losses

# Usage
# lrs, losses = find_lr(model, train_loader, optimizer, criterion)
import tensorflow as tf
import matplotlib.pyplot as plt
import numpy as np

def find_lr(model, train_dataset, start_lr=1e-7, end_lr=10, num_iter=100):
    """
    Perform learning rate range test
    """
    lrs = []
    losses = []

    lr_mult = (end_lr / start_lr) ** (1 / num_iter)
    lr = start_lr

    optimizer = tf.keras.optimizers.Adam(learning_rate=lr)
    loss_fn = tf.keras.losses.SparseCategoricalCrossentropy(from_logits=True)

    best_loss = float('inf')
    train_iter = iter(train_dataset)

    for i in range(num_iter):
        try:
            data, target = next(train_iter)
        except StopIteration:
            train_iter = iter(train_dataset)
            data, target = next(train_iter)

        with tf.GradientTape() as tape:
            output = model(data, training=True)
            loss = loss_fn(target, output)

        # Stop if loss explodes
        if loss.numpy() > 4 * best_loss or np.isnan(loss.numpy()):
            break

        if loss.numpy() < best_loss:
            best_loss = loss.numpy()

        lrs.append(lr)
        losses.append(loss.numpy())

        grads = tape.gradient(loss, model.trainable_variables)
        optimizer.apply_gradients(zip(grads, model.trainable_variables))

        # Update learning rate
        lr *= lr_mult
        optimizer.learning_rate.assign(lr)

    # Plot results
    plt.figure(figsize=(10, 6))
    plt.plot(lrs, losses)
    plt.xscale('log')
    plt.xlabel('Learning Rate')
    plt.ylabel('Loss')
    plt.title('Learning Rate Range Test')
    plt.grid(True)
    plt.savefig('lr_range_test.png')

    return lrs, losses

# Usage
# lrs, losses = find_lr(model, train_dataset)

How to interpret:

  1. Look for the steepest downward slope in the loss curve
  2. Pick a learning rate from the middle of that slope
  3. Usually 10x smaller than where loss starts to increase

Learning Rate Schedules

Smoothly decreases learning rate following a cosine curve:

from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.AdamW(model.parameters(), lr=1e-3)
scheduler = CosineAnnealingLR(optimizer, T_max=100, eta_min=1e-6)

for epoch in range(100):
    train(model, train_loader, optimizer)
    scheduler.step()
import tensorflow as tf

# Create cosine decay schedule
lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-3,
    decay_steps=100 * steps_per_epoch,  # 100 epochs
    alpha=1e-6  # Minimum learning rate
)

optimizer = tf.keras.optimizers.AdamW(learning_rate=lr_schedule)

# Or use callback for epoch-based scheduling
cosine_callback = tf.keras.callbacks.LearningRateScheduler(
    lambda epoch: 1e-3 * 0.5 * (1 + np.cos(np.pi * epoch / 100))
)

model.fit(train_dataset, epochs=100, callbacks=[cosine_callback])

Pros:

  • Smooth decay prevents sudden performance drops
  • Works well for most architectures
  • Can help escape local minima with warm restarts

2. One Cycle Policy

Increases then decreases learning rate in one cycle:

from torch.optim.lr_scheduler import OneCycleLR

optimizer = torch.optim.SGD(model.parameters(), lr=0.1, momentum=0.9)
scheduler = OneCycleLR(
    optimizer,
    max_lr=0.1,
    epochs=100,
    steps_per_epoch=len(train_loader)
)

for epoch in range(100):
    for batch in train_loader:
        train_step(batch)
        scheduler.step()  # Call after each batch!
import tensorflow as tf

# Calculate total steps
total_steps = 100 * steps_per_epoch

# One cycle schedule: warmup -> peak -> decay
def one_cycle_schedule(step):
    if step < total_steps * 0.3:  # Warmup phase
        return 0.1 * (step / (total_steps * 0.3))
    else:  # Decay phase
        progress = (step - total_steps * 0.3) / (total_steps * 0.7)
        return 0.1 * (1 - progress)

lr_schedule = tf.keras.optimizers.schedules.LearningRateSchedule()
lr_schedule.__call__ = one_cycle_schedule

optimizer = tf.keras.optimizers.SGD(learning_rate=lr_schedule, momentum=0.9)

# Or use callback
class OneCycleScheduler(tf.keras.callbacks.Callback):
    def __init__(self, max_lr, total_steps):
        self.max_lr = max_lr
        self.total_steps = total_steps
        self.step = 0

    def on_batch_begin(self, batch, logs=None):
        lr = one_cycle_schedule(self.step)
        tf.keras.backend.set_value(self.model.optimizer.lr, lr)
        self.step += 1

Best for:

  • Fast convergence (often beats other schedules)
  • Limited training time/budget
  • When you know total training iterations upfront

3. Reduce on Plateau

Decreases LR when metrics stop improving:

from torch.optim.lr_scheduler import ReduceLROnPlateau

optimizer = torch.optim.Adam(model.parameters(), lr=1e-3)
scheduler = ReduceLROnPlateau(
    optimizer,
    mode='min',
    factor=0.5,      # Reduce by half
    patience=5,      # Wait 5 epochs
    min_lr=1e-6
)

for epoch in range(100):
    train(model, train_loader, optimizer)
    val_loss = validate(model, val_loader)
    scheduler.step(val_loss)  # Pass validation loss
import tensorflow as tf

optimizer = tf.keras.optimizers.Adam(learning_rate=1e-3)

# Use ReduceLROnPlateau callback
reduce_lr = tf.keras.callbacks.ReduceLROnPlateau(
    monitor='val_loss',
    factor=0.5,        # Reduce by half
    patience=5,        # Wait 5 epochs
    min_lr=1e-6,
    verbose=1
)

model.compile(optimizer=optimizer, loss='sparse_categorical_crossentropy')
model.fit(
    train_dataset,
    validation_data=val_dataset,
    epochs=100,
    callbacks=[reduce_lr]
)

Best for:

  • Unknown optimal training length
  • When validation loss is your primary metric
  • Conservative training approaches

Optimizer-Specific Tips

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,           # Good default
    weight_decay=0.01, # Regularization
    betas=(0.9, 0.999)
)
optimizer = tf.keras.optimizers.AdamW(
    learning_rate=1e-3,    # Good default
    weight_decay=0.01,     # Regularization
    beta_1=0.9,
    beta_2=0.999
)

Typical LR ranges:

  • Transformers: 1e-4 to 5e-4
  • CNNs: 1e-3 to 3e-3
  • Small models: 1e-3 to 1e-2

SGD with Momentum

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,            # Usually 10-100x higher than Adam
    momentum=0.9,
    weight_decay=1e-4,
    nesterov=True      # Often helps
)
optimizer = tf.keras.optimizers.SGD(
    learning_rate=0.1,  # Usually 10-100x higher than Adam
    momentum=0.9,
    nesterov=True       # Often helps
)

# Note: TensorFlow handles weight decay separately
# Add to loss or use kernel_regularizer in layers

Typical LR ranges:

  • ResNets: 0.1 (with decay)
  • Small CNNs: 0.01 to 0.1

Warmup Strategy

Gradually increase learning rate at start of training:

import math

def get_lr_with_warmup(current_step, warmup_steps, max_lr, total_steps):
    """Calculate learning rate with warmup and cosine decay"""
    if current_step < warmup_steps:
        # Linear warmup
        return max_lr * current_step / warmup_steps
    else:
        # Cosine decay
        progress = (current_step - warmup_steps) / (total_steps - warmup_steps)
        return max_lr * 0.5 * (1 + math.cos(math.pi * progress))

# Usage in training loop
for step in range(total_steps):
    lr = get_lr_with_warmup(step, warmup_steps=1000,
                           max_lr=1e-3, total_steps=100000)
    for param_group in optimizer.param_groups:
        param_group['lr'] = lr
import tensorflow as tf
import numpy as np

class WarmupCosineDecay(tf.keras.optimizers.schedules.LearningRateSchedule):
    def __init__(self, warmup_steps, max_lr, total_steps):
        self.warmup_steps = warmup_steps
        self.max_lr = max_lr
        self.total_steps = total_steps

    def __call__(self, step):
        step = tf.cast(step, tf.float32)
        warmup_steps = tf.cast(self.warmup_steps, tf.float32)
        total_steps = tf.cast(self.total_steps, tf.float32)

        # Linear warmup
        warmup_lr = self.max_lr * step / warmup_steps

        # Cosine decay
        progress = (step - warmup_steps) / (total_steps - warmup_steps)
        cosine_lr = self.max_lr * 0.5 * (1 + tf.cos(np.pi * progress))

        return tf.cond(step < warmup_steps, lambda: warmup_lr, lambda: cosine_lr)

# Usage
lr_schedule = WarmupCosineDecay(
    warmup_steps=1000,
    max_lr=1e-3,
    total_steps=100000
)
optimizer = tf.keras.optimizers.Adam(learning_rate=lr_schedule)

Why warmup helps:

  • Prevents early training instability
  • Allows batch normalization statistics to stabilize
  • Essential for large batch training
  • Recommended warmup: 1-5% of total training steps

Common Issues & Solutions

Loss Goes to NaN

# Check for:
1. Learning rate too high → reduce by 10x
2. Gradient explosion → add gradient clipping
3. Bad initialization → use proper init (Xavier/He)

# Add gradient clipping
torch.nn.utils.clip_grad_norm_(model.parameters(), max_norm=1.0)

Training Too Slow

# Try:
1. Increase learning rate (use LR range test)
2. Use AdamW instead of SGD
3. Add learning rate warmup
4. Try OneCycleLR scheduler

Validation Loss Increases

# Solutions:
1. Reduce learning rate
2. Add/increase weight decay
3. Use learning rate decay schedule
4. Add dropout or other regularization

Real-World Example Configurations

Vision Transformers (ViT)

from torch.optim.lr_scheduler import CosineAnnealingLR

optimizer = torch.optim.AdamW(
    model.parameters(),
    lr=1e-3,
    weight_decay=0.05,
    betas=(0.9, 0.999)
)

scheduler = CosineAnnealingLR(
    optimizer,
    T_max=epochs,
    eta_min=1e-6
)

# With warmup
warmup_epochs = 5
import tensorflow as tf

lr_schedule = tf.keras.optimizers.schedules.CosineDecay(
    initial_learning_rate=1e-3,
    decay_steps=epochs * steps_per_epoch,
    alpha=1e-6
)

optimizer = tf.keras.optimizers.AdamW(
    learning_rate=lr_schedule,
    weight_decay=0.05,
    beta_1=0.9,
    beta_2=0.999
)

# With warmup: use WarmupCosineDecay class from above

ResNet/CNN

optimizer = torch.optim.SGD(
    model.parameters(),
    lr=0.1,
    momentum=0.9,
    weight_decay=1e-4
)

# Step decay every 30 epochs
scheduler = torch.optim.lr_scheduler.StepLR(
    optimizer,
    step_size=30,
    gamma=0.1
)
optimizer = tf.keras.optimizers.SGD(
    learning_rate=0.1,
    momentum=0.9
)

# Step decay callback
def lr_schedule(epoch):
    return 0.1 * (0.1 ** (epoch // 30))

lr_callback = tf.keras.callbacks.LearningRateScheduler(lr_schedule)

model.compile(optimizer=optimizer, loss='categorical_crossentropy')
model.fit(train_dataset, epochs=100, callbacks=[lr_callback])

Fine-tuning Pretrained Models

# Different learning rates for different layers
optimizer = torch.optim.AdamW([
    {'params': model.backbone.parameters(), 'lr': 1e-5},  # Pretrained
    {'params': model.head.parameters(), 'lr': 1e-3}       # New layers
], weight_decay=0.01)
# Different learning rates for different layers
optimizer = tf.keras.optimizers.AdamW(weight_decay=0.01)

# Set different learning rates by layer
for layer in model.layers[:-2]:  # Backbone layers
    layer.learning_rate = 1e-5
for layer in model.layers[-2:]:  # Head layers
    layer.learning_rate = 1e-3

# Or use multiple optimizers (advanced)
backbone_optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-5, weight_decay=0.01)
head_optimizer = tf.keras.optimizers.AdamW(learning_rate=1e-3, weight_decay=0.01)

Key Takeaways

  • Always run a learning rate range test for new models/datasets
  • Start with proven configurations for your architecture type
  • Use warmup for the first 1-5% of training
  • OneCycleLR often converges fastest
  • AdamW is a safe default optimizer
  • Monitor training curves - they tell you if LR is wrong
  • When fine-tuning, use 10-100x lower learning rates

:::tip Save learning rate in your training logs! It’s essential for reproducing results and debugging issues. :::