Navigation

GPU Memory Management

Maximize GPU memory usage and fix CUDA out-of-memory errors. Optimize memory allocation, use gradient checkpointing, mixed precision, and memory-efficient techniques for large models.

Understanding GPU Memory

GPU memory is one of the most precious resources in deep learning. Running out of memory (OOM errors) is frustrating, but understanding how memory is allocated helps you optimize usage.

Memory Breakdown

Your GPU memory is used by:

  1. Model Parameters (10-40% typically)

    • Weights and biases
    • Relatively fixed size
  2. Gradients (10-40%)

    • Same size as parameters during training
    • Freed after optimizer step
  3. Activations (20-60%)

    • Intermediate layer outputs
    • Grows with batch size
    • Biggest memory consumer
  4. Optimizer States (varies)

    • Adam: 2x parameter size
    • SGD with momentum: 1x parameter size
    • SGD: minimal overhead
  5. CUDA Context (~500MB-2GB)

    • PyTorch/framework overhead
    • Largely fixed

Quick Diagnostics

Check Current Memory Usage

import torch

def print_gpu_memory():
    """Print current GPU memory usage"""
    if torch.cuda.is_available():
        allocated = torch.cuda.memory_allocated() / 1024**3
        reserved = torch.cuda.memory_reserved() / 1024**3
        max_allocated = torch.cuda.max_memory_allocated() / 1024**3

        print(f"Allocated: {allocated:.2f} GB")
        print(f"Reserved:  {reserved:.2f} GB")
        print(f"Max Allocated: {max_allocated:.2f} GB")

        # Get total GPU memory
        total = torch.cuda.get_device_properties(0).total_memory / 1024**3
        print(f"Total GPU Memory: {total:.2f} GB")
        print(f"Utilization: {allocated/total*100:.1f}%")

print_gpu_memory()
import tensorflow as tf

def print_gpu_memory():
    """Print current GPU memory usage"""
    gpus = tf.config.list_physical_devices('GPU')
    if gpus:
        # Get memory info for first GPU
        gpu = gpus[0]
        memory_info = tf.config.experimental.get_memory_info(gpu.name.replace('/physical_device:', ''))

        allocated = memory_info['current'] / 1024**3
        peak = memory_info['peak'] / 1024**3

        print(f"Current Allocated: {allocated:.2f} GB")
        print(f"Peak Allocated: {peak:.2f} GB")

        # TensorFlow dynamically allocates memory, so total may not be immediately available
        print("Note: TensorFlow uses dynamic memory allocation")

print_gpu_memory()

Monitor During Training

# Add to your training loop
for epoch in range(epochs):
    for batch_idx, (data, target) in enumerate(train_loader):
        # Your training code here
        loss.backward()
        optimizer.step()

        # Log memory every 100 batches
        if batch_idx % 100 == 0:
            print_gpu_memory()

Out of Memory (OOM) Solutions

1. Reduce Batch Size (Easiest)

# Start large and reduce until it fits
batch_sizes = [128, 64, 32, 16, 8]

for bs in batch_sizes:
    try:
        train_loader = DataLoader(dataset, batch_size=bs)
        model.train()
        # Test one batch
        data, target = next(iter(train_loader))
        output = model(data.cuda())
        loss = criterion(output, target.cuda())
        loss.backward()
        print(f"Batch size {bs} works!")
        break
    except RuntimeError as e:
        if 'out of memory' in str(e):
            torch.cuda.empty_cache()
            print(f"Batch size {bs} - OOM")
            continue
        else:
            raise e

2. Gradient Accumulation

Simulate larger batch sizes without using more memory:

# Effective batch size = batch_size * accumulation_steps
accumulation_steps = 4
optimizer.zero_grad()

for i, (data, target) in enumerate(train_loader):
    data, target = data.cuda(), target.cuda()

    # Forward pass
    output = model(data)
    loss = criterion(output, target)

    # Normalize loss to account for accumulation
    loss = loss / accumulation_steps
    loss.backward()

    # Update weights every accumulation_steps
    if (i + 1) % accumulation_steps == 0:
        optimizer.step()
        optimizer.zero_grad()

3. Mixed Precision Training

Reduces memory usage by ~40-50%:

from torch.cuda.amp import autocast, GradScaler

model = model.cuda()
scaler = GradScaler()

for epoch in range(epochs):
    for data, target in train_loader:
        data, target = data.cuda(), target.cuda()
        optimizer.zero_grad()

        # Run forward pass in mixed precision
        with autocast():
            output = model(data)
            loss = criterion(output, target)

        # Backward pass with scaling
        scaler.scale(loss).backward()
        scaler.step(optimizer)
        scaler.update()
import tensorflow as tf

# Enable mixed precision globally
policy = tf.keras.mixed_precision.Policy('mixed_float16')
tf.keras.mixed_precision.set_global_policy(policy)

# Build and compile model
model = create_model()
optimizer = tf.keras.optimizers.Adam()

# Wrap optimizer with loss scaling (automatic in TensorFlow)
model.compile(
    optimizer=optimizer,
    loss='sparse_categorical_crossentropy',
    metrics=['accuracy']
)

# Train normally - mixed precision handled automatically
for epoch in range(epochs):
    model.fit(train_dataset, epochs=1)

Memory savings:

  • FP16 uses 2 bytes vs FP32’s 4 bytes
  • Can often double batch size
  • Minimal accuracy impact for most models

4. Gradient Checkpointing

Trade compute for memory by recomputing activations:

from torch.utils.checkpoint import checkpoint
import torch.nn as nn

class MyModel(nn.Module):
    def __init__(self):
        super().__init__()
        self.layer1 = nn.Linear(1000, 1000)
        self.layer2 = nn.Linear(1000, 1000)
        self.layer3 = nn.Linear(1000, 10)

    def forward(self, x):
        # Use checkpointing for memory-intensive layers
        x = checkpoint(self.layer1, x, use_reentrant=False)
        x = checkpoint(self.layer2, x, use_reentrant=False)
        x = self.layer3(x)
        return x
import tensorflow as tf

# TensorFlow has gradient checkpointing via recompute_grad
from tensorflow.python.ops import gradient_checkpoint

class MyModel(tf.keras.Model):
    def __init__(self):
        super().__init__()
        self.layer1 = tf.keras.layers.Dense(1000)
        self.layer2 = tf.keras.layers.Dense(1000)
        self.layer3 = tf.keras.layers.Dense(10)

    @tf.recompute_grad
    def call(self, x, training=False):
        # Decorator enables gradient checkpointing
        x = self.layer1(x)
        x = self.layer2(x)
        x = self.layer3(x)
        return x

# Or use gradient_checkpointing in transformers
# from transformers import TFAutoModel
# model = TFAutoModel.from_pretrained("model_name", gradient_checkpointing=True)

Trade-off:

  • Saves ~40-50% memory on activations
  • Increases training time by ~20-30%
  • Best for very deep networks

5. Clear Unused Variables

# Free memory explicitly
del large_tensor
torch.cuda.empty_cache()

# Don't keep unnecessary computation graphs
with torch.no_grad():
    # Operations here won't build computation graph
    validation_output = model(val_data)

# Detach tensors you don't need gradients for
prediction = output.detach()

6. Optimize Model Size

# Use smaller models
model = torchvision.models.resnet18()  # Instead of resnet152

# Reduce hidden dimensions
model = Transformer(
    d_model=512,      # Instead of 1024
    nhead=8,          # Instead of 16
    num_layers=6      # Instead of 12
)

# Use depthwise separable convolutions
from torch.nn import Conv2d

# Standard conv
conv = Conv2d(256, 256, kernel_size=3, padding=1)

# Depthwise separable (fewer parameters)
depthwise = Conv2d(256, 256, kernel_size=3, padding=1, groups=256)
pointwise = Conv2d(256, 256, kernel_size=1)

Memory Optimization Strategies

Strategy 1: Multi-GPU Training

Distribute model across GPUs:

# DataParallel (simple but not optimal)
model = nn.DataParallel(model)

# DistributedDataParallel (recommended)
import torch.distributed as dist
from torch.nn.parallel import DistributedDataParallel as DDP

# Initialize process group
dist.init_process_group(backend='nccl')

# Wrap model
model = model.cuda()
model = DDP(model, device_ids=[local_rank])

Strategy 2: CPU Offloading

Move some operations to CPU:

# Keep model on GPU, but compute some metrics on CPU
model.cuda()

for data, target in train_loader:
    data, target = data.cuda(), target.cuda()

    output = model(data)
    loss = criterion(output, target)
    loss.backward()
    optimizer.step()

    # Compute accuracy on CPU to save GPU memory
    pred = output.argmax(dim=1).cpu()
    target_cpu = target.cpu()
    accuracy = (pred == target_cpu).float().mean()

Strategy 3: In-Place Operations

Reduce memory by modifying tensors in-place:

# Instead of:
x = x + 1
x = torch.relu(x)

# Use in-place:
x += 1              # or x.add_(1)
x = torch.relu_(x)  # or nn.ReLU(inplace=True)

# In models:
self.relu = nn.ReLU(inplace=True)

:::caution In-place operations can cause issues with autograd. Use them carefully and test gradients. :::

Memory Profiling

PyTorch Memory Profiler

from torch.profiler import profile, ProfilerActivity

with profile(
    activities=[ProfilerActivity.CPU, ProfilerActivity.CUDA],
    profile_memory=True,
    record_shapes=True
) as prof:
    # Your training code
    model(data)
    loss.backward()

# Print memory stats
print(prof.key_averages().table(
    sort_by="cuda_memory_usage",
    row_limit=10
))

# Export for visualization
prof.export_chrome_trace("trace.json")

Memory Snapshot

import torch

# Start recording memory history
torch.cuda.memory._record_memory_history()

try:
    # Run your code
    model(data)
    loss.backward()
except RuntimeError as e:
    if 'out of memory' in str(e):
        # Dump memory snapshot for analysis
        torch.cuda.memory._dump_snapshot("oom_snapshot.pickle")
finally:
    torch.cuda.memory._record_memory_history(enabled=None)

Best Practices Checklist

  1. Start with these settings:

    torch.backends.cudnn.benchmark = True  # Faster training
    torch.backends.cuda.matmul.allow_tf32 = True  # Faster matmul
  2. Always use mixed precision (easy 2x improvement)

  3. Monitor memory throughout training (catch leaks early)

  4. Use gradient checkpointing for very deep models

  5. Clear cache between experiments:

    torch.cuda.empty_cache()
  6. Test batch size before long training runs

  7. Use torch.no_grad() during validation/inference:

    with torch.no_grad():
        val_loss = validate(model, val_loader)

Common OOM Causes & Fixes

ProblemSolution
OOM during first forward passReduce batch size or model size
OOM during backward passUse gradient checkpointing
OOM increases over timeMemory leak - check for tensors kept in lists
OOM only on large inputsImplement dynamic batching by input size
OOM after many epochsClear unused cache, check for accumulating metrics

Quick Reference: Memory Reduction Techniques

Ordered by effectiveness vs effort:

  1. Mixed Precision Training - 2x memory, 5 lines of code ⭐
  2. Reduce Batch Size - Variable memory, 1 line of code
  3. Gradient Accumulation - Simulate larger batches, 10 lines of code
  4. Gradient Checkpointing - 30-50% memory, 10-20% slower
  5. Model Optimization - Variable, requires architecture changes
  6. Multi-GPU Training - Linear scaling, requires multiple GPUs

:::tip Start here: Enable mixed precision and find the largest batch size that fits. This solves 90% of memory issues with minimal effort. :::