GPU overhead: why our MNIST model trains faster on CPU
Machine learning and GPUs feel inseparable — every tutorial, every cloud provider, every “getting started” guide points you toward GPU instances. The intuition is straightforward: data goes in, thousands of cores process it in parallel, training goes fast. But this assumes all those cores are actually being used — and for small workloads, they aren’t.
We trained our MNIST model on a local workstation with a Quadro RTX 5000 GPU and an Intel Core i9-10885H CPU, and then on CPU only. The CPU was faster — 4.4s vs 6.5s per epoch. This article will explain why, and show how we got training down to 0.14s by reducing overhead instead of upgrading hardware.
Let’s start by looking at what happens during the training.
Breaking down one training step
Each training step runs four phases — forward pass, loss, backward pass, weight update:
# Forward pass
z1 = xb @ w1 + b1 # matrix multiply + bias (hidden layer)
a1 = np.maximum(0, z1) # ReLU activation
z2 = a1 @ w2 + b2 # matrix multiply + bias (output layer)
exp_z = np.exp(z2 - z2.max(axis=1, keepdims=True))
probs = exp_z / exp_z.sum(axis=1, keepdims=True) # softmax → probabilities
# Loss gradient
dz2 = probs.copy()
dz2[np.arange(bs), yb] -= 1 # how far off were the predictions
dz2 /= bs
# Backward pass — compute gradients
dw2 = a1.T @ dz2 # gradient for W2
da1 = dz2 @ w2.T # gradient flowing backward
dz1 = da1 * (z1 > 0) # ReLU gradient
dw1 = xb.T @ dz1 # gradient for W1
# Update weights
w1 -= lr * dw1
b1 -= lr * dz1.sum(axis=0)
w2 -= lr * dw2
b2 -= lr * dz2.sum(axis=0)
Where does the time go? To find out, we wrapped each operation with time.perf_counter():
t = time.perf_counter()
z1 = xb @ w1 + b1
timings["fwd: X @ W1 + b1"] += time.perf_counter() - t
t = time.perf_counter()
a1 = np.maximum(0, z1)
timings["fwd: ReLU"] += time.perf_counter() - t
# ... and so on for each operation
When running on GPU, we had to add
tf.test.experimental.sync_devices()between operations — without this, GPU operations queue asynchronously and the timing only captures the dispatch, not the actual execution.
Here are the results for a single training step (batch_size=32, averaged over 1,000 steps, time in microseconds μs — millionths of a second, lower is better):
| Operation | CPU (NumPy) | GPU (TF) |
|---|---|---|
| fwd: X @ W1 + b1 (images × hidden weights + bias) | 162 μs | 541 μs |
| fwd: ReLU (zero out negatives) | 9 μs | 173 μs |
| fwd: a1 @ W2 + b2 (hidden × output weights + bias) | 19 μs | 478 μs |
| fwd: softmax + loss (probabilities + error) | 30 μs | 780 μs |
| bwd: gradients (how much each weight contributed to error) | 188 μs | 1,271 μs |
| update: W -= lr*dW (adjust weights to reduce error) | 345 μs | 1,204 μs |
| TOTAL | 751 μs | 4,447 μs |
What stands out first is that the CPU is faster on every single operation — 751 μs total vs 4,447 μs. The operation that should benefit most from GPU parallelization is the matrix multiply X @ W1 — it’s thousands of independent dot products that could run on thousands of cores simultaneously, exactly the kind of work GPUs are designed for.
Yet in the table above, it’s still slower on GPU: 162 μs on CPU vs 541 μs on GPU. The CPU is fast here because NumPy calls directly into BLAS (Basic Linear Algebra Subprograms) — highly optimized C/Fortran routines that use CPU-specific SIMD instructions (AVX2 processes 8 floats per cycle, FMA fuses multiply+add into one instruction). The GPU is slower not because its math is slower, but because each operation pays CUDA overhead (kernel launch, memory sync, context switching).
When this overhead takes more time than the actual computation, the GPU ends up slower despite being faster at math.
Let’s look at xb @ w1 + b1 and a1 @ w2 + b2 — the matrix multiplies from each layer — and why GPUs are designed to make them fast. Here’s the forward pass from our NumPy implementation:
class HiddenLayer:
def forward(self, x):
self.z = self.W @ x + self.b # matrix multiply + bias
self.out = np.maximum(0, self.z) # ReLU activation
return self.out
class OutputLayer:
def forward(self, x):
self.z = self.W @ x + self.b # matrix multiply + bias
exp = np.exp(self.z - np.max(self.z))
self.probs = exp / np.sum(exp) # softmax → probabilities
return self.probs
Each @ is a matrix multiply. For a single image, self.W @ x multiplies the weight matrix by the image’s pixels. Each of the 128 neurons has 784 weights — one per input pixel.
Each neuron computes a dot product of its 784 weights with the 784 input values (image pixels) — that’s 784 multiplications + 783 additions per neuron. The result is 128 activations, which comes from 128 neurons in the layer, each producing one value from its dot product. That’s 128 independent dot products for one image:
x (784,) W (128 × 784) result (128,)
┌─────────────┐ ┌──────────────────┐ ┌──────────────┐
│ p1 p2 … p784│ @ │ n1: w1 w2 … w784│ = │ a1 │
└─────────────┘ │ n2: w1 w2 … w784│ │ a2 │
one image, 784 pixels │ ... │ │ ... │
│ n128:w1 w2 … w784│ │ a128 │
└──────────────────┘ └──────────────┘
128 neurons 128 activations
But frameworks like Keras don’t process images one by one — they stack the entire batch into a single matrix and multiply it all at once. With batch size 32, X is 32 rows × 784 columns, and the same weight matrix produces 32 × 128 = 4,096 activations in one operation:
X (32 × 784) W1 (784 × 128) result (32 × 128)
┌─────────────────┐ ┌──────────────────┐ ┌──────────────────┐
│ img1: p1 … p784│ │ n1 n2 … n128 │ │ img1: a1 … a128 │
│ img2: p1 … p784│ @ │ w w … w │ = │ img2: a1 … a128 │
│ ... │ │ ... ... ... │ │ ... │
│ img32: p1 … p784│ │ w w … w │ │ img32:a1 … a128 │
└─────────────────┘ └──────────────────┘ └──────────────────┘
32 images, 784 pixels each 784 weights per neuron 32 × 128 activations
The weights are shared — it’s as if you had 32 copies of the same model running on 32 images simultaneously, except it’s done in one operation instead of 32 separate ones. One kernel launch, one read of the weight matrix from memory, and all 4,096 results computed in parallel.
Now, the interesting question is how many of those 4,096 dot products can the GPU compute at the same time? Since every one of them is completely independent — image 5’s activation for neuron 20 doesn’t depend on any other result — the GPU can assign each one to a separate thread.
The RTX 5000 has 48 Streaming Multiprocessors (SMs), each with 64 CUDA cores — 3,072 cores total. Each core runs at ~1.8 GHz and can do a multiply-add per cycle, giving us the theoretical peak: 3,072 cores × 1.8 billion cycles/sec × 2 ops/cycle = ~11 TFLOPS (trillion floating-point operations per second).
Our matrix multiply produces 4,096 dot products, each requiring 784 multiply-adds: 4,096 × 784 = ~3.2 million operations total. The GPU can do 11 trillion operations per second, but we’re only asking for 3.2 million — that’s 0.3 microseconds of actual compute, using just 0.00003% of the GPU’s capacity. The GPU finishes the math in 0.3 μs, then waits ~540 μs for the next kernel launch. It’s working for 0.06% of the time.
Where the time actually goes
Training a neural network on a GPU involves more than matrix multiplication. Every training step is a sequence of operations, and for small models, the actual math is the smallest part:
1. Kernel launch overhead. Every GPU operation — a matrix multiply, an activation function, a loss computation — is a kernel that the CPU must schedule and launch on the GPU. Each launch has a fixed overhead of roughly 5–15 microseconds. For a large matrix multiply that takes milliseconds, this is negligible. For our tiny 32×784 × 784×128 multiply that completes in microseconds, the launch overhead can exceed the computation itself. Our forward pass alone involves a matrix multiply, a bias add, a ReLU, another matrix multiply, another bias add, and a softmax — at least 6 kernel launches before we even start backpropagation.
2. Memory transfer latency. Data must travel from CPU memory to GPU memory (and gradients must travel back). This transfer has a fixed latency — the time to set up the DMA transfer, cross the PCIe bus, and signal completion. For batch size 32 with 784-dimensional inputs, we’re transferring about 100 KB per batch. The PCIe bus can move 32 GB/s, so the raw transfer is ~3 microseconds — but the setup overhead is 10–20x that.
3. Python and framework overhead. Keras/TensorFlow add their own layer of indirection. Each operation passes through Python, the TensorFlow runtime, XLA compilation (a compiler that optimizes the computation graph on first run), memory allocation, and synchronization. For large operations, this overhead is invisible. For small operations, it is the bottleneck.
The good news is that we can measure exactly how much time each type of overhead takes.
Profiling the CUDA overhead
Our per-operation timing showed the GPU takes more time, and we calculated that the actual math takes just 0.3 μs — the GPU should be idle most of the time. So what exactly is it spending the other 540 μs on?
To find out, we used NVIDIA’s nsys (Nsight Systems) profiler, which intercepts every CUDA API call:
nsys profile -o keras-gpu-profile python mnist-keras.py
nsys stats --force-export=true keras-gpu-profile.nsys-rep
CUDA is NVIDIA’s software layer between your code and the GPU hardware. When TensorFlow wants to do a matrix multiply, it doesn’t talk to the GPU directly — it calls CUDA functions like “allocate memory,” “copy this data,” “launch this kernel.” Each of these calls goes through the driver and has its own overhead. The nsys profiler records every one of them, so we can see exactly where the time goes.
This time, instead of timing individual operations, we profiled an entire epoch of model.fit() — all 1,500 steps. The epoch took 6.5s on GPU. Here’s where that time went:
| Operation | Overhead type | Calls | Total Time |
|---|---|---|---|
cuCtxSetCurrent (context switching) | kernel launch | 98,168 | 0.87s |
cuEventRecord (timing/sync events) | kernel launch | 22,786 | 0.57s |
cuMemcpyDtoHAsync (GPU→CPU copies) | memory transfer | 6,144 | 0.56s |
cuMemcpyHtoDAsync (CPU→GPU copies) | memory transfer | 5,678 | 0.27s |
cuGraphLaunch (execute compiled graph) | kernel launch | 1,875 | 0.16s |
cuLaunchKernel | kernel launch | 1,080 | 0.02s |
| Total CUDA overhead | ~2.7s |
None of these CUDA rows are the actual matrix multiply — they’re all management overhead around it. The actual compute (matrix multiplies, ReLUs, softmax) happens on the GPU after cuGraphLaunch dispatches the work, but it’s so fast for our tiny matrices that it doesn’t even show up as a significant item.
Putting it all together, here’s where the full 6.5s epoch went:
| Time | |
|---|---|
| CUDA overhead (table above) | ~2.7s |
| Python/framework overhead (not captured by nsys) | ~3.8s |
| Actual GPU math | negligible |
| Total epoch | 6.5s |
Essentially all of the 6.5s is overhead. Keras is already very optimized — it uses CUDA Graphs (cuGraphLaunch — 1,875 calls, one per batch) to pre-compile the entire forward+backward pass and replay it without per-op dispatch. But even with that optimization, the GPU never gets a chance to earn back the overhead through faster compute.
For comparison, Keras on CPU runs the same epoch in 4.4s — faster than GPU’s 6.5s, because there’s no CUDA overhead at all. XLA compiles to native code and the CPU does the math directly in its own memory.
It’s important to stress that there’s no surprise when it comes to the actual math — the GPU is significantly faster, as expected. We timed the core operation — a single matrix multiply of (32, 784) @ (784, 128), one batch through the first layer:
| Time per matrix multiply | |
|---|---|
| GPU | 109 μs |
| CPU | 272 μs |
The GPU is 2.5x faster — but note that even these 109 μs already include the CUDA overhead for launching that single operation. The pure math would be ~0.3 μs (as we calculated from TFLOPS earlier); the other ~108 μs is overhead for that one kernel launch. On CPU, the 272 μs is all math — no overhead layer in between.
A full training step involves multiple matrix multiplies plus activations, loss, gradients, and all the CUDA management around each of them. When we time a compiled train_step with all overhead included, the gap disappears entirely:
| Per step | |
|---|---|
| GPU | 1.27 ms (~0.1ms compute + ~1.2ms overhead) |
| CPU | 1.30 ms (all compute, no overhead) |
The GPU does the math faster but spends the rest of the time on overhead, ending up at roughly the same speed per step. And in the full model.fit() pipeline — with data loading, metrics, and callbacks — the CPU actually wins: 4.4s vs 6.5s per epoch.
One interesting note: if you run nvidia-smi while training, you might see GPU utilization at 90–100%. That looks like the GPU is fully busy — so why is it slower than CPU?
Because nvidia-smi utilization reports the percentage of time at least one kernel is running on the GPU — not how many cores are active. If tiny kernels are being launched back-to-back with no idle gaps, it shows ~100% utilization even though the vast majority of cores are idle at any given moment. As we saw earlier, our model uses just 0.00003% of the GPU’s capacity.
Making a single training run faster
Now that we understand the overhead, how do we reduce it? There are several angles of attack: reduce the number of steps (larger batch size), eliminate CUDA overhead (run on CPU), eliminate framework overhead (skip TensorFlow), compile the entire step into one operation (JAX JIT), or naively move to GPU (CuPy — spoiler: it’s worse). We tried all five.
1. Larger batch size
If the overhead is per-step, the obvious fix is: fewer steps. A larger batch size means more samples per step, so the same epoch requires fewer steps — and each step pays the overhead tax only once regardless of batch size. With 48,000 training samples (60K minus 20% validation) and batch size 32, one epoch is 48,000 / 32 = 1,500 steps — each step being one full forward + backward + update pass on one batch. With batch size 4096, it’s just 11 steps.
Larger batches also give the GPU more work per step — as we saw earlier, batch size 32 keeps only ~3 of 48 SMs active, while batch size 2048 keeps ~42 busy. So larger batch sizes help in two ways: fewer overhead-paying steps per epoch, and more GPU cores actually working per step.
However, you can’t increase the batch size indiscriminately — it has a direct impact on model accuracy. Larger batches produce smoother but less frequent gradient updates, which can lead to worse generalization. The right batch size is something you need to experiment with for your specific model.
2. Keras on CPU — skip CUDA overhead
As we saw, just disabling the GPU (tf.config.set_visible_devices([], 'GPU')) eliminates 2.7s of CUDA overhead. XLA compiles to native code and the CPU does the math directly.
3. Pure NumPy — skip framework overhead too
Going further, we can skip TensorFlow entirely. A pure NumPy training loop calls directly into BLAS routines with no framework dispatch overhead between operations.
4. JAX JIT — compile the entire step
JAX’s jit compiler traces the entire training step and compiles it into a single optimized native function. Instead of Python dispatching each operation one by one, the compiled function runs all of them in one fused call with near-zero overhead.
5. CuPy — what if we just move NumPy to GPU?
We also tried the naive approach: swap import numpy as np for import cupy as cp and run the same code on GPU. Since our implementation processes samples one at a time in a Python loop, each tiny operation (a 128-element vector add, a 10-element softmax) becomes a separate GPU kernel launch. The result: 443 seconds for 5 epochs — over 6x slower than NumPy on CPU. Simply moving code to GPU without rethinking the access pattern makes things worse, not better.
Putting it all together
We benchmarked approaches 1–4 across batch sizes (CuPy was too slow to include). Each cell is the time for one epoch:
| Batch size | Keras GPU | Keras CPU | Pure NumPy | JAX (CPU, JIT) |
|---|---|---|---|---|
| 32 | 6.26s | 11.31s | 2.71s | 2.37s |
| 128 | 2.74s | 2.21s | 4.70s | 1.03s |
| 512 | 2.64s | 1.80s | 1.30s | 0.64s |
| 1024 | 2.47s | 1.43s | 0.87s | 0.63s |
| 2048 | 2.32s | 1.42s | 0.84s | 0.48s |
| 4096 | 2.32s | 1.39s | 1.03s | 0.14s |
Reading across each row shows the effect of switching approach (less overhead per step). Reading down each column shows the effect of larger batch sizes (fewer steps per epoch). The two effects compound.
Keras GPU bottoms out at ~2.3s regardless of batch size — the CUDA overhead has a fixed floor that larger batches can’t eliminate.
Keras CPU goes from 11.3s down to 1.4s — no CUDA overhead, but TensorFlow’s framework overhead sets its own floor.
Pure NumPy hits 0.84s at batch size 2048 — no framework, just BLAS calls. It slows back down at 4096 due to memory pressure.
JAX reaches 0.14s at batch size 4096 — 45x faster than Keras on GPU. JIT compilation fuses the entire step into one native call with near-zero overhead.
The lesson: the path to faster training isn’t bigger hardware — it’s less overhead. Each approach in the table removes a layer of overhead, and larger batch sizes reduce how many times you pay whatever overhead remains.
Using parallelization for hyperparameter sweeps
The approaches above make a single training run faster. But when you’re doing hyperparameter sweeps — trying 5 learning rates, or 4 architectures — each run is completely independent. They don’t share weights, gradients, or state.
JAX’s vmap (vectorized map) can train all 5 models simultaneously in a single forward and backward pass.
This is different from Keras’s batch processing, which runs many images through one model.
vmap runs the same images through many models — each with its own weights and learning rate — in one fused operation. Under the hood, XLA compiles this into a single kernel: one matrix multiply handles all 5 forward passes, one handles all 5 backward passes. No Python loop, no per-model overhead.
The tradeoff: all models must have the same architecture (same parameter shapes), and you need to rewrite the training loop in JAX. Here’s the full code:
import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, random
def init_params(key):
# Same model as before: 784→128→10, random weights
k1, k2 = random.split(key)
w1 = random.normal(k1, (784, 128)) * jnp.sqrt(2.0 / 784)
b1 = jnp.zeros(128)
w2 = random.normal(k2, (128, 10)) * jnp.sqrt(2.0 / 128)
b2 = jnp.zeros(10)
return (w1, b1, w2, b2)
def loss_fn(params, x, y):
# Forward pass + cross-entropy loss — same math as our NumPy version
w1, b1, w2, b2 = params
h = jnp.maximum(0, x @ w1 + b1) # hidden layer + ReLU
logits = h @ w2 + b2 # output layer
log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))
return -jnp.mean(log_probs[jnp.arange(y.shape[0]), y])
def sgd_step(params, x, y, lr):
# One training step: compute gradients, update weights
grads = grad(loss_fn)(params, x, y) # JAX auto-differentiates loss_fn
return tuple(p - lr * g for p, g in zip(params, grads))
# 5 learning rates, 5 sets of weights, trained simultaneously
lr_array = jnp.array([0.001, 0.01, 0.1, 1.0, 10.0])
batched_step = jit(vmap(sgd_step, in_axes=(0, None, None, 0)))
batched_params = batched_step(batched_params, x_batch, y_batch, lr_array)
For simpler cases, Python’s multiprocessing.Pool achieves the same goal — 5 separate processes, each training one config:
import multiprocessing as mp
def train_one_config(args):
# Train one model with one learning rate — runs in its own process
lr, X_train, y_train = args
model = keras.Sequential([
keras.layers.Dense(128, activation="relu", input_shape=(784,)),
keras.layers.Dense(10, activation="softmax"),
])
model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr),
loss="sparse_categorical_crossentropy")
history = model.fit(X_train, y_train, epochs=10, batch_size=32,
validation_split=0.2, verbose=0)
return lr, history.history
with mp.Pool(5) as pool:
results = pool.map(train_one_config,
[(lr, X_train, y_train) for lr in [0.001, 0.01, 0.1, 1.0, 10.0]])
What to take away
- For small models, the CPU is faster than the GPU. Our MNIST model (784→128→10, ~101K parameters) trained in 4.4s per epoch on CPU vs 6.5s on GPU. The GPU’s math is faster (109 μs vs 272 μs per matrix multiply), but CUDA overhead — kernel launches, memory copies, context switching — adds 2.7s per epoch that the CPU simply doesn’t have.
- The GPU is barely working. Our matrix multiply uses 0.00003% of the RTX 5000’s capacity. At batch_size=32, only ~3 of 48 SMs are active. The GPU finishes the math in 0.3 μs, then waits ~540 μs for the next kernel launch.
nvidia-smiutilization is misleading. It reports time-with-any-kernel-running, not core occupancy. Our model shows ~100% utilization while using 0.00003% of the GPU’s compute.- Less overhead beats bigger hardware. By switching from Keras GPU to JAX JIT on CPU with batch_size=4096, we went from 6.5s to 0.14s per epoch — a 45x speedup on the same machine, no GPU needed.
- Don’t reach for a GPU until your workload needs one. For models with fewer than ~500K parameters and batch sizes under 256, a fast CPU will be both cheaper and faster.
Stay up to date
Get notified when I publish new deep dives.