GPU overhead: why our MNIST model trains faster on CPU

Machine learning and GPUs feel inseparable — every tutorial, every cloud provider, every “getting started” guide points you toward GPU instances. The intuition is straightforward: data goes in, thousands of cores process it in parallel, training goes fast. But this assumes all those cores are actually being used — and for small workloads, they aren’t.

We trained our MNIST model on a local workstation with a Quadro RTX 5000 GPU and an Intel Core i9-10885H CPU, and then on CPU only. The CPU was faster — 4.4s vs 6.5s per epoch. This article will explain why, and show how we got training down to 0.14s by reducing overhead instead of upgrading hardware.

Let’s start by looking at what happens during the training.

Breaking down one training step

Each training step runs four phases — forward pass, loss, backward pass, weight update:

# Forward pass
z1 = xb @ w1 + b1                                # matrix multiply + bias (hidden layer)
a1 = np.maximum(0, z1)                            # ReLU activation
z2 = a1 @ w2 + b2                                # matrix multiply + bias (output layer)
exp_z = np.exp(z2 - z2.max(axis=1, keepdims=True))
probs = exp_z / exp_z.sum(axis=1, keepdims=True)  # softmax → probabilities

# Loss gradient
dz2 = probs.copy()
dz2[np.arange(bs), yb] -= 1                       # how far off were the predictions
dz2 /= bs

# Backward pass — compute gradients
dw2 = a1.T @ dz2                                  # gradient for W2
da1 = dz2 @ w2.T                                  # gradient flowing backward
dz1 = da1 * (z1 > 0)                              # ReLU gradient
dw1 = xb.T @ dz1                                  # gradient for W1

# Update weights
w1 -= lr * dw1
b1 -= lr * dz1.sum(axis=0)
w2 -= lr * dw2
b2 -= lr * dz2.sum(axis=0)

Where does the time go? To find out, we wrapped each operation with time.perf_counter():

t = time.perf_counter()
z1 = xb @ w1 + b1
timings["fwd: X @ W1 + b1"] += time.perf_counter() - t

t = time.perf_counter()
a1 = np.maximum(0, z1)
timings["fwd: ReLU"] += time.perf_counter() - t
# ... and so on for each operation

When running on GPU, we had to add tf.test.experimental.sync_devices() between operations — without this, GPU operations queue asynchronously and the timing only captures the dispatch, not the actual execution.

Here are the results for a single training step (batch_size=32, averaged over 1,000 steps, time in microseconds μs — millionths of a second, lower is better):

OperationCPU (NumPy)GPU (TF)
fwd: X @ W1 + b1 (images × hidden weights + bias)162 μs541 μs
fwd: ReLU (zero out negatives)9 μs173 μs
fwd: a1 @ W2 + b2 (hidden × output weights + bias)19 μs478 μs
fwd: softmax + loss (probabilities + error)30 μs780 μs
bwd: gradients (how much each weight contributed to error)188 μs1,271 μs
update: W -= lr*dW (adjust weights to reduce error)345 μs1,204 μs
TOTAL751 μs4,447 μs

What stands out first is that the CPU is faster on every single operation — 751 μs total vs 4,447 μs. The operation that should benefit most from GPU parallelization is the matrix multiply X @ W1 — it’s thousands of independent dot products that could run on thousands of cores simultaneously, exactly the kind of work GPUs are designed for. Yet in the table above, it’s still slower on GPU: 162 μs on CPU vs 541 μs on GPU. The CPU is fast here because NumPy calls directly into BLAS (Basic Linear Algebra Subprograms) — highly optimized C/Fortran routines that use CPU-specific SIMD instructions (AVX2 processes 8 floats per cycle, FMA fuses multiply+add into one instruction). The GPU is slower not because its math is slower, but because each operation pays CUDA overhead (kernel launch, memory sync, context switching). When this overhead takes more time than the actual computation, the GPU ends up slower despite being faster at math.

Let’s look at xb @ w1 + b1 and a1 @ w2 + b2 — the matrix multiplies from each layer — and why GPUs are designed to make them fast. Here’s the forward pass from our NumPy implementation:

class HiddenLayer:
    def forward(self, x):
        self.z = self.W @ x + self.b        # matrix multiply + bias
        self.out = np.maximum(0, self.z)     # ReLU activation
        return self.out

class OutputLayer:
    def forward(self, x):
        self.z = self.W @ x + self.b        # matrix multiply + bias
        exp = np.exp(self.z - np.max(self.z))
        self.probs = exp / np.sum(exp)       # softmax → probabilities
        return self.probs

Each @ is a matrix multiply. For a single image, self.W @ x multiplies the weight matrix by the image’s pixels. Each of the 128 neurons has 784 weights — one per input pixel. Each neuron computes a dot product of its 784 weights with the 784 input values (image pixels) — that’s 784 multiplications + 783 additions per neuron. The result is 128 activations, which comes from 128 neurons in the layer, each producing one value from its dot product. That’s 128 independent dot products for one image:

      x (784,)                    W (128 × 784)              result (128,)
 ┌─────────────┐           ┌──────────────────┐       ┌──────────────┐
 │ p1 p2 … p784│     @     │ n1:  w1 w2 … w784│   =   │ a1           │
 └─────────────┘           │ n2:  w1 w2 … w784│       │ a2           │
  one image, 784 pixels     │ ...               │       │ ...          │
                            │ n128:w1 w2 … w784│       │ a128         │
                            └──────────────────┘       └──────────────┘
                             128 neurons                 128 activations

But frameworks like Keras don’t process images one by one — they stack the entire batch into a single matrix and multiply it all at once. With batch size 32, X is 32 rows × 784 columns, and the same weight matrix produces 32 × 128 = 4,096 activations in one operation:

      X (32 × 784)                W1 (784 × 128)            result (32 × 128)
 ┌─────────────────┐         ┌──────────────────┐       ┌──────────────────┐
 │ img1:  p1 … p784│         │  n1   n2  … n128 │       │ img1: a1  … a128 │
 │ img2:  p1 … p784│    @    │  w    w   …  w   │   =   │ img2: a1  … a128 │
 │ ...             │         │  ...  ...    ... │       │ ...              │
 │ img32: p1 … p784│         │  w    w   …  w   │       │ img32:a1  … a128 │
 └─────────────────┘         └──────────────────┘       └──────────────────┘
  32 images, 784 pixels each   784 weights per neuron     32 × 128 activations

The weights are shared — it’s as if you had 32 copies of the same model running on 32 images simultaneously, except it’s done in one operation instead of 32 separate ones. One kernel launch, one read of the weight matrix from memory, and all 4,096 results computed in parallel.

Now, the interesting question is how many of those 4,096 dot products can the GPU compute at the same time? Since every one of them is completely independent — image 5’s activation for neuron 20 doesn’t depend on any other result — the GPU can assign each one to a separate thread.

The RTX 5000 has 48 Streaming Multiprocessors (SMs), each with 64 CUDA cores — 3,072 cores total. Each core runs at ~1.8 GHz and can do a multiply-add per cycle, giving us the theoretical peak: 3,072 cores × 1.8 billion cycles/sec × 2 ops/cycle = ~11 TFLOPS (trillion floating-point operations per second).

Our matrix multiply produces 4,096 dot products, each requiring 784 multiply-adds: 4,096 × 784 = ~3.2 million operations total. The GPU can do 11 trillion operations per second, but we’re only asking for 3.2 million — that’s 0.3 microseconds of actual compute, using just 0.00003% of the GPU’s capacity. The GPU finishes the math in 0.3 μs, then waits ~540 μs for the next kernel launch. It’s working for 0.06% of the time.

Where the time actually goes

Training a neural network on a GPU involves more than matrix multiplication. Every training step is a sequence of operations, and for small models, the actual math is the smallest part:

1. Kernel launch overhead. Every GPU operation — a matrix multiply, an activation function, a loss computation — is a kernel that the CPU must schedule and launch on the GPU. Each launch has a fixed overhead of roughly 5–15 microseconds. For a large matrix multiply that takes milliseconds, this is negligible. For our tiny 32×784 × 784×128 multiply that completes in microseconds, the launch overhead can exceed the computation itself. Our forward pass alone involves a matrix multiply, a bias add, a ReLU, another matrix multiply, another bias add, and a softmax — at least 6 kernel launches before we even start backpropagation.

2. Memory transfer latency. Data must travel from CPU memory to GPU memory (and gradients must travel back). This transfer has a fixed latency — the time to set up the DMA transfer, cross the PCIe bus, and signal completion. For batch size 32 with 784-dimensional inputs, we’re transferring about 100 KB per batch. The PCIe bus can move 32 GB/s, so the raw transfer is ~3 microseconds — but the setup overhead is 10–20x that.

3. Python and framework overhead. Keras/TensorFlow add their own layer of indirection. Each operation passes through Python, the TensorFlow runtime, XLA compilation (a compiler that optimizes the computation graph on first run), memory allocation, and synchronization. For large operations, this overhead is invisible. For small operations, it is the bottleneck.

The good news is that we can measure exactly how much time each type of overhead takes.

Profiling the CUDA overhead

Our per-operation timing showed the GPU takes more time, and we calculated that the actual math takes just 0.3 μs — the GPU should be idle most of the time. So what exactly is it spending the other 540 μs on?

To find out, we used NVIDIA’s nsys (Nsight Systems) profiler, which intercepts every CUDA API call:

nsys profile -o keras-gpu-profile python mnist-keras.py
nsys stats --force-export=true keras-gpu-profile.nsys-rep

CUDA is NVIDIA’s software layer between your code and the GPU hardware. When TensorFlow wants to do a matrix multiply, it doesn’t talk to the GPU directly — it calls CUDA functions like “allocate memory,” “copy this data,” “launch this kernel.” Each of these calls goes through the driver and has its own overhead. The nsys profiler records every one of them, so we can see exactly where the time goes.

This time, instead of timing individual operations, we profiled an entire epoch of model.fit() — all 1,500 steps. The epoch took 6.5s on GPU. Here’s where that time went:

OperationOverhead typeCallsTotal Time
cuCtxSetCurrent (context switching)kernel launch98,1680.87s
cuEventRecord (timing/sync events)kernel launch22,7860.57s
cuMemcpyDtoHAsync (GPU→CPU copies)memory transfer6,1440.56s
cuMemcpyHtoDAsync (CPU→GPU copies)memory transfer5,6780.27s
cuGraphLaunch (execute compiled graph)kernel launch1,8750.16s
cuLaunchKernelkernel launch1,0800.02s
Total CUDA overhead~2.7s

None of these CUDA rows are the actual matrix multiply — they’re all management overhead around it. The actual compute (matrix multiplies, ReLUs, softmax) happens on the GPU after cuGraphLaunch dispatches the work, but it’s so fast for our tiny matrices that it doesn’t even show up as a significant item.

Putting it all together, here’s where the full 6.5s epoch went:

Time
CUDA overhead (table above)~2.7s
Python/framework overhead (not captured by nsys)~3.8s
Actual GPU mathnegligible
Total epoch6.5s

Essentially all of the 6.5s is overhead. Keras is already very optimized — it uses CUDA Graphs (cuGraphLaunch — 1,875 calls, one per batch) to pre-compile the entire forward+backward pass and replay it without per-op dispatch. But even with that optimization, the GPU never gets a chance to earn back the overhead through faster compute.

For comparison, Keras on CPU runs the same epoch in 4.4s — faster than GPU’s 6.5s, because there’s no CUDA overhead at all. XLA compiles to native code and the CPU does the math directly in its own memory.

It’s important to stress that there’s no surprise when it comes to the actual math — the GPU is significantly faster, as expected. We timed the core operation — a single matrix multiply of (32, 784) @ (784, 128), one batch through the first layer:

Time per matrix multiply
GPU109 μs
CPU272 μs

The GPU is 2.5x faster — but note that even these 109 μs already include the CUDA overhead for launching that single operation. The pure math would be ~0.3 μs (as we calculated from TFLOPS earlier); the other ~108 μs is overhead for that one kernel launch. On CPU, the 272 μs is all math — no overhead layer in between.

A full training step involves multiple matrix multiplies plus activations, loss, gradients, and all the CUDA management around each of them. When we time a compiled train_step with all overhead included, the gap disappears entirely:

Per step
GPU1.27 ms (~0.1ms compute + ~1.2ms overhead)
CPU1.30 ms (all compute, no overhead)

The GPU does the math faster but spends the rest of the time on overhead, ending up at roughly the same speed per step. And in the full model.fit() pipeline — with data loading, metrics, and callbacks — the CPU actually wins: 4.4s vs 6.5s per epoch.

One interesting note: if you run nvidia-smi while training, you might see GPU utilization at 90–100%. That looks like the GPU is fully busy — so why is it slower than CPU?

Because nvidia-smi utilization reports the percentage of time at least one kernel is running on the GPU — not how many cores are active. If tiny kernels are being launched back-to-back with no idle gaps, it shows ~100% utilization even though the vast majority of cores are idle at any given moment. As we saw earlier, our model uses just 0.00003% of the GPU’s capacity.

Making a single training run faster

Now that we understand the overhead, how do we reduce it? There are several angles of attack: reduce the number of steps (larger batch size), eliminate CUDA overhead (run on CPU), eliminate framework overhead (skip TensorFlow), compile the entire step into one operation (JAX JIT), or naively move to GPU (CuPy — spoiler: it’s worse). We tried all five.

1. Larger batch size

If the overhead is per-step, the obvious fix is: fewer steps. A larger batch size means more samples per step, so the same epoch requires fewer steps — and each step pays the overhead tax only once regardless of batch size. With 48,000 training samples (60K minus 20% validation) and batch size 32, one epoch is 48,000 / 32 = 1,500 steps — each step being one full forward + backward + update pass on one batch. With batch size 4096, it’s just 11 steps.

Larger batches also give the GPU more work per step — as we saw earlier, batch size 32 keeps only ~3 of 48 SMs active, while batch size 2048 keeps ~42 busy. So larger batch sizes help in two ways: fewer overhead-paying steps per epoch, and more GPU cores actually working per step.

However, you can’t increase the batch size indiscriminately — it has a direct impact on model accuracy. Larger batches produce smoother but less frequent gradient updates, which can lead to worse generalization. The right batch size is something you need to experiment with for your specific model.

2. Keras on CPU — skip CUDA overhead

As we saw, just disabling the GPU (tf.config.set_visible_devices([], 'GPU')) eliminates 2.7s of CUDA overhead. XLA compiles to native code and the CPU does the math directly.

3. Pure NumPy — skip framework overhead too

Going further, we can skip TensorFlow entirely. A pure NumPy training loop calls directly into BLAS routines with no framework dispatch overhead between operations.

4. JAX JIT — compile the entire step

JAX’s jit compiler traces the entire training step and compiles it into a single optimized native function. Instead of Python dispatching each operation one by one, the compiled function runs all of them in one fused call with near-zero overhead.

5. CuPy — what if we just move NumPy to GPU?

We also tried the naive approach: swap import numpy as np for import cupy as cp and run the same code on GPU. Since our implementation processes samples one at a time in a Python loop, each tiny operation (a 128-element vector add, a 10-element softmax) becomes a separate GPU kernel launch. The result: 443 seconds for 5 epochs — over 6x slower than NumPy on CPU. Simply moving code to GPU without rethinking the access pattern makes things worse, not better.

Putting it all together

We benchmarked approaches 1–4 across batch sizes (CuPy was too slow to include). Each cell is the time for one epoch:

Batch sizeKeras GPUKeras CPUPure NumPyJAX (CPU, JIT)
326.26s11.31s2.71s2.37s
1282.74s2.21s4.70s1.03s
5122.64s1.80s1.30s0.64s
10242.47s1.43s0.87s0.63s
20482.32s1.42s0.84s0.48s
40962.32s1.39s1.03s0.14s

Reading across each row shows the effect of switching approach (less overhead per step). Reading down each column shows the effect of larger batch sizes (fewer steps per epoch). The two effects compound.

Keras GPU bottoms out at ~2.3s regardless of batch size — the CUDA overhead has a fixed floor that larger batches can’t eliminate.

Keras CPU goes from 11.3s down to 1.4s — no CUDA overhead, but TensorFlow’s framework overhead sets its own floor.

Pure NumPy hits 0.84s at batch size 2048 — no framework, just BLAS calls. It slows back down at 4096 due to memory pressure.

JAX reaches 0.14s at batch size 4096 — 45x faster than Keras on GPU. JIT compilation fuses the entire step into one native call with near-zero overhead.

The lesson: the path to faster training isn’t bigger hardware — it’s less overhead. Each approach in the table removes a layer of overhead, and larger batch sizes reduce how many times you pay whatever overhead remains.

Using parallelization for hyperparameter sweeps

The approaches above make a single training run faster. But when you’re doing hyperparameter sweeps — trying 5 learning rates, or 4 architectures — each run is completely independent. They don’t share weights, gradients, or state.

JAX’s vmap (vectorized map) can train all 5 models simultaneously in a single forward and backward pass. This is different from Keras’s batch processing, which runs many images through one model. vmap runs the same images through many models — each with its own weights and learning rate — in one fused operation. Under the hood, XLA compiles this into a single kernel: one matrix multiply handles all 5 forward passes, one handles all 5 backward passes. No Python loop, no per-model overhead.

The tradeoff: all models must have the same architecture (same parameter shapes), and you need to rewrite the training loop in JAX. Here’s the full code:

import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, random

def init_params(key):
    # Same model as before: 784→128→10, random weights
    k1, k2 = random.split(key)
    w1 = random.normal(k1, (784, 128)) * jnp.sqrt(2.0 / 784)
    b1 = jnp.zeros(128)
    w2 = random.normal(k2, (128, 10)) * jnp.sqrt(2.0 / 128)
    b2 = jnp.zeros(10)
    return (w1, b1, w2, b2)

def loss_fn(params, x, y):
    # Forward pass + cross-entropy loss — same math as our NumPy version
    w1, b1, w2, b2 = params
    h = jnp.maximum(0, x @ w1 + b1)       # hidden layer + ReLU
    logits = h @ w2 + b2                    # output layer
    log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))
    return -jnp.mean(log_probs[jnp.arange(y.shape[0]), y])

def sgd_step(params, x, y, lr):
    # One training step: compute gradients, update weights
    grads = grad(loss_fn)(params, x, y)     # JAX auto-differentiates loss_fn
    return tuple(p - lr * g for p, g in zip(params, grads))

# 5 learning rates, 5 sets of weights, trained simultaneously
lr_array = jnp.array([0.001, 0.01, 0.1, 1.0, 10.0])
batched_step = jit(vmap(sgd_step, in_axes=(0, None, None, 0)))
batched_params = batched_step(batched_params, x_batch, y_batch, lr_array)

For simpler cases, Python’s multiprocessing.Pool achieves the same goal — 5 separate processes, each training one config:

import multiprocessing as mp

def train_one_config(args):
    # Train one model with one learning rate — runs in its own process
    lr, X_train, y_train = args
    model = keras.Sequential([
        keras.layers.Dense(128, activation="relu", input_shape=(784,)),
        keras.layers.Dense(10, activation="softmax"),
    ])
    model.compile(optimizer=keras.optimizers.SGD(learning_rate=lr),
                  loss="sparse_categorical_crossentropy")
    history = model.fit(X_train, y_train, epochs=10, batch_size=32,
                        validation_split=0.2, verbose=0)
    return lr, history.history

with mp.Pool(5) as pool:
    results = pool.map(train_one_config,
                       [(lr, X_train, y_train) for lr in [0.001, 0.01, 0.1, 1.0, 10.0]])

What to take away

  1. For small models, the CPU is faster than the GPU. Our MNIST model (784→128→10, ~101K parameters) trained in 4.4s per epoch on CPU vs 6.5s on GPU. The GPU’s math is faster (109 μs vs 272 μs per matrix multiply), but CUDA overhead — kernel launches, memory copies, context switching — adds 2.7s per epoch that the CPU simply doesn’t have.
  2. The GPU is barely working. Our matrix multiply uses 0.00003% of the RTX 5000’s capacity. At batch_size=32, only ~3 of 48 SMs are active. The GPU finishes the math in 0.3 μs, then waits ~540 μs for the next kernel launch.
  3. nvidia-smi utilization is misleading. It reports time-with-any-kernel-running, not core occupancy. Our model shows ~100% utilization while using 0.00003% of the GPU’s compute.
  4. Less overhead beats bigger hardware. By switching from Keras GPU to JAX JIT on CPU with batch_size=4096, we went from 6.5s to 0.14s per epoch — a 45x speedup on the same machine, no GPU needed.
  5. Don’t reach for a GPU until your workload needs one. For models with fewer than ~500K parameters and batch sizes under 256, a fast CPU will be both cheaper and faster.