Why a bigger GPU is not always faster

We ran our MNIST notebook on three machines: a local workstation with a Quadro RTX 5000 GPU, a GCP VM with an NVIDIA L4 GPU (newer, 50% more memory, nearly 3x the raw compute), and a CPU-only GCP VM with 22 fast cores and no GPU at all.

The L4 wasn’t faster than the RTX 5000. And the CPU-only machine? It was 2-3x faster than both GPUs.

That felt wrong. This article is about why it isn’t.

The mental model that breaks

When we think about GPUs, we imagine a pipeline: data goes in, math happens in parallel, results come out. More cores, more TFLOPS → faster training. This model works at scale. It completely breaks down for small workloads.

A modern GPU like the L4 has 7,424 CUDA cores and 240 Tensor cores. Our MNIST model — a 784→128→10 dense network with ~101K parameters — uses a tiny fraction of them. It’s like renting a football stadium to host a dinner for four. The stadium isn’t slower than a restaurant, but the overhead of opening the gates, turning on the lights, and warming up the kitchen dominates the actual meal.

Where the time actually goes

Training a neural network on a GPU involves more than matrix multiplication. Every training step is a sequence of operations, and for small models, the actual math is the smallest part:

1. Kernel launch overhead

Every GPU operation — a matrix multiply, an activation function, a loss computation — is a kernel that the CPU must schedule and launch on the GPU. Each launch has a fixed overhead of roughly 5–15 microseconds. For a large matrix multiply that takes milliseconds, this is negligible. For our tiny 32×784 times 784×128 multiply that completes in microseconds, the launch overhead can exceed the computation itself.

Our forward pass alone involves: a matrix multiply, a bias add, a ReLU, another matrix multiply, another bias add, and a softmax. That’s at least 6 kernel launches before we even start backpropagation.

2. Memory transfer latency

Data must travel from CPU memory to GPU memory (and gradients must travel back). This transfer has a fixed latency — the time it takes to set up the DMA transfer, cross the PCIe bus, and signal completion. For large batches, this latency is amortized across many samples. For batch size 32 with 784-dimensional inputs, we’re transferring about 100 KB per batch. The PCIe bus can move 32 GB/s, so the raw transfer is ~3 microseconds — but the setup overhead is 10–20x that.

3. Python and framework overhead

Keras/TensorFlow add their own layer of indirection. Each operation passes through Python, the TensorFlow runtime, XLA compilation (on first run), memory allocation, and synchronization. For large operations, this overhead is invisible. For small operations, it is the bottleneck.

4. GPU clock speeds and latency

Here’s a counterintuitive detail: workstation GPUs like the RTX 5000 often have higher base clock speeds than datacenter GPUs. The RTX 5000 boosts to ~1,815 MHz; the L4 runs at ~2,040 MHz (similar in this case, but many datacenter GPUs clock lower for thermal/power reasons). When a kernel is so small that it runs on a handful of cores, the clock speed of those cores matters more than the total core count.

The misleading utilization number

If you run nvidia-smi while training, you might see GPU utilization at 90–100%. That looks like the GPU is fully busy — so why isn’t the faster GPU faster?

Because nvidia-smi utilization doesn’t measure what you think it does. It reports the percentage of time at least one kernel is running on the GPU — not how many cores are active. If tiny kernels are being launched back-to-back with no idle gaps between them, it shows ~100% utilization even though the vast majority of cores are idle at any given moment.

It’s like measuring how busy a stadium is by checking “was at least one person inside during each minute?” — it’ll report 100% occupancy even if only 4 people are having dinner on the field.

What actually matters is compute throughput — how much of the GPU’s theoretical FLOPS are being delivered. Think of GPU cores as lanes on a highway:

RTX 5000 (3,072 cores):
[##...............................] ~5% compute throughput
 ↑ actual work

L4 (7,424 cores):
[#................................................................] ~2% compute throughput
 ↑ same work, even more idle cores

Both GPUs finish the actual math in roughly the same time — the matrix multiplications are too small to distribute across more than a few dozen cores effectively. The L4 has more cores sitting idle, but idle cores don’t slow things down. What does differ is everything around the math: driver overhead, memory subsystem behavior, framework initialization, and thermal characteristics.

To see the real picture, you’d need to profile with NVIDIA Nsight Systems or at minimum use nvidia-smi dmon -s u to check SM (streaming multiprocessor) occupancy. For our MNIST model, that number would be very low on both GPUs — despite the headline utilization saying 100%.

The real numbers

We ran the same notebook on three machines to see how they compare:

MachineTypeSpecsCost
Local workstationGPUQuadro RTX 5000, 16 GB, ~11 TFLOPS
GCP g2-standard-4GPUNVIDIA L4, 24 GB, ~30 TFLOPS~$0.70/hr
GCP c3-highcpu-22CPU only22 vCPUs, no GPU~$0.77/hr

Here’s how each notebook experiment timed out:

ExperimentL4 GPUCPU-only (c3-highcpu-22)
Learning rate sweep (4 models × 3 epochs)1m 03s28s
Batch size sweep (bs=1, 32, 256 × 3 epochs)7m 03s2m 44s
Network depth sweep (4 architectures × 3 epochs)1m 16s31s

The CPU-only machine is 2-3x faster across every single experiment. Not marginally — decisively. Let’s look at why.

Walking through the notebook experiments

Learning rate sweep — 1m 03s (GPU) vs 28s (CPU)

for lr in [0.001, 0.01, 0.1, 1.0]:
    model = keras.Sequential([
        keras.layers.Dense(128, activation="relu", input_shape=(784,)),
        keras.layers.Dense(10, activation="softmax"),
    ])
    model.fit(X_train, y_train, epochs=3, batch_size=32, ...)

This trains 4 identical small models sequentially. Each step involves a 32×784 × 784×128 matrix multiply — about 6.4 million multiply-adds. The L4 can do 30 trillion operations per second, so the actual math takes roughly 0.2 microseconds. But each step also requires:

  • CPU schedules and launches ~12+ GPU kernels (forward: matmul, bias, ReLU, matmul, bias, softmax; backward: the same in reverse, plus gradient updates)
  • Each kernel launch costs ~5–15μs of overhead
  • Each batch transfer from CPU→GPU memory costs ~30–50μs of setup

That’s roughly 200μs of overhead per step for 0.2μs of useful math. The GPU spends 99.9% of its time waiting. With 60,000 training images at batch size 32, that’s 1,875 steps per epoch × 3 epochs × 4 models = 22,500 steps — each one paying this overhead tax.

On the CPU, there’s no transfer, no kernel launch, no synchronization. The CPU just does the matrix multiply directly in its own memory. The math takes longer per operation (CPUs have far fewer cores), but for matrices this small, the overhead savings dwarf the compute difference.

Batch size sweep — 7m 03s (GPU) vs 2m 44s (CPU)

for bs in [1, 32, 256]:
    model.fit(X_train, y_train, epochs=3, batch_size=bs, ...)

This is the most dramatic result, and the bs=1 configuration is the reason. Look at what batch size does to the step count:

Batch sizeMatrix multiplyOps per stepSteps per epochSteps total (3 epochs)
11×784 × 784×128~200K60,000180,000
3232×784 × 784×128~6.4M1,8755,625
256256×784 × 784×128~51M235705

With batch size 1, each step is a single vector-matrix multiply — 200,000 operations, which is trivial for any processor. But on the GPU, each of those 180,000 steps still pays the full kernel launch and memory transfer overhead. At ~200μs overhead per step, that’s 36 seconds of pure overhead — just for scheduling, before any math happens. And the “math” itself takes nanoseconds per step.

On the CPU, batch size 1 is also slow (many steps means many Python loop iterations), but there’s no GPU round-trip per step. The CPU just reads from its own cache and computes.

The bs=256 case is less extreme (only 705 steps), but the matrix multiplies at 51 million operations are still too small to saturate the GPU — the L4 would need matrices 100-1000x larger to keep its 7,424 cores busy.

Network depth sweep — 1m 16s (GPU) vs 31s (CPU)

arch_configs = {
    "128": [128],           # 101K params
    "256": [256],           # 201K params
    "128→64": [128, 64],    # 109K params
    "256→128→64": [256, 128, 64],  # 234K params
}

Deeper networks make the GPU overhead worse, not better. Each additional layer adds more kernel launches per step — the 256→128→64 network needs ~18+ kernel launches per step (3 layers × forward + backward, each with matmul, bias, activation). More kernels per step means more overhead per step, while the individual matrix multiplies remain tiny.

The largest architecture here (256→128→64) has 234K parameters — about 1 MB of weights. Modern GPUs are designed for models with millions to billions of parameters. Even the “big” model in this experiment occupies less than 0.005% of the L4’s 24 GB memory.

When would the GPU actually win?

For these experiments, the GPU would need fundamentally larger workloads:

ChangeWhy it helps
Scale to 784→1024→512→256→10 (~1.4M params)Larger matrices, more computation per step
Increase batch size to 512 or 1024Bigger matrix multiplies, better core utilization
Switch to convolutional layers on 28×28 imagesConvolutions are more compute-dense than dense layers
Use a pretrained model like ResNet-18 (~11M params)Enough parameters to actually occupy the GPU

The crossover point is roughly: if your model has fewer than ~500K parameters and your batch size is under 256, a fast CPU will beat a GPU. Not because the GPU is slow at math — but because the GPU never gets to do any math. It’s all overhead.

The three bottlenecks

It’s tempting to think of training speed as a single number determined by your hardware. In reality, there are three independent bottlenecks stacked on top of each other, and for small models, none of them are “how fast the GPU can multiply matrices.”

1. GPU overhead

Every training step pays a fixed tax: the CPU must schedule kernel launches (~5–15μs each), transfer data to GPU memory (~30–50μs setup), and synchronize after each operation. For our MNIST model, the actual matrix math takes ~0.2μs per step, but the overhead around it takes ~200μs. That’s a 1000:1 ratio of waiting to working.

This is why the CPU-only machine was 2-3x faster than the L4 GPU. On CPU, there’s no transfer, no kernel dispatch, no synchronization — the processor just reads from its own cache and computes. The math takes longer per operation (CPUs have far fewer parallel cores), but for matrices this small, eliminating the overhead more than compensates.

When this bottleneck matters: small models (<500K parameters), small batch sizes (<256), any workload where individual operations complete in microseconds.

2. Backend: TensorFlow vs JAX

Keras can run on different backends, and the choice has a surprising impact on overhead — even though the math is identical.

TensorFlow (the default backend) executes operations eagerly: each op — matrix multiply, bias add, ReLU, softmax — goes individually through Python → TF runtime → kernel launch. For our forward pass alone, that’s 6+ separate dispatches, each with its own overhead.

JAX takes a different approach. Its XLA compiler traces the entire computation graph and compiles it into a single fused kernel. Instead of 6 separate kernel launches, XLA can merge them into one. The first run is slower (compilation takes time), but subsequent steps are dramatically faster because most of the per-op overhead disappears.

For a model with 12+ kernels per step running 180,000 steps (the bs=1 case), the difference between “12 kernel launches per step” and “1 fused kernel per step” is enormous — potentially an order of magnitude less overhead.

When this bottleneck matters: whenever per-step overhead is the dominant cost — which is exactly our situation. Switching backends doesn’t change the math, but it changes how many times you pay the dispatch tax per step.

3. Python single-thread speed

The training loop itself — model.fit() — is single-threaded Python. Each of the 180,000 steps (at bs=1) involves Python iteration, framework dispatch, metric bookkeeping, and callback handling. None of this benefits from more CPU cores — only clock speed and instructions-per-cycle matter.

This is why upgrading from a c3-highcpu-22 (22 cores) to a c4-highcpu-4 (4 cores, slightly higher clock) wouldn’t help much: you’d gain ~5% from the clock bump, but the extra 18 cores were never being used anyway.

When this bottleneck matters: high step counts (small batch sizes), many sequential model.fit() calls (hyperparameter sweeps), and any training loop with per-step Python callbacks.

How the bottlenecks interact

For our MNIST notebook, all three bottlenecks compound:

Each training step:
  Python loop overhead          ~10μs  ← bottleneck #3
  + Framework dispatch (TF)     ~50μs  ← bottleneck #2
  + GPU kernel launch + sync   ~150μs  ← bottleneck #1
  + Actual matrix math           ~0.2μs ← the "work"
  ≈ 210μs total, 99.9% overhead

A faster GPU only helps the 0.2μs. Switching to JAX could cut the 50μs + 150μs significantly by fusing operations. But the ~10μs of Python overhead per step is a hard floor — short of rewriting the training loop in C++, nothing eliminates it.

The practical takeaway: for small models, the path to faster training goes through reducing overhead (fewer steps, fused operations, CPU execution) — not through bigger hardware.

Try it yourself: running the notebook on a cloud GPU

If you want to reproduce these experiments on a GPU — or scale them up to see where the faster hardware actually wins — you can spin up a GCP VM with an NVIDIA L4 in a few commands. The whole process takes about 2 minutes. You’ll need the gcloud CLI installed and a GCP project with billing enabled.

1. Create the VM

gcloud compute instances create marimo-gpu \
  --zone=us-central1-a \
  --machine-type=g2-standard-4 \
  --accelerator=type=nvidia-l4,count=1 \
  --image-family=pytorch-2-7-cu128-ubuntu-2204-nvidia-570 \
  --image-project=deeplearning-platform-release \
  --maintenance-policy=TERMINATE \
  --boot-disk-size=200GB

This creates a g2-standard-4 instance (4 vCPUs, 16 GB RAM, 1x NVIDIA L4) with a deep learning image that has CUDA, cuDNN, and Python pre-installed. It costs about $0.70/hr — remember to delete it when you’re done.

2. Copy the notebook and install dependencies

gcloud compute scp mnist-training.py marimo-gpu:~ --zone=us-central1-a

gcloud compute ssh marimo-gpu --zone=us-central1-a \
  --command='pip install marimo'

The deep learning image already has CUDA drivers, Python, and NumPy. Marimo will detect and offer to install any missing packages (like Keras, TensorFlow, matplotlib) when you open the notebook.

3. Start marimo

gcloud compute ssh marimo-gpu --zone=us-central1-a \
  --command='nohup marimo edit ~/mnist-training.py \
    --host 0.0.0.0 --port 2718 &>/tmp/marimo.log &'

4. Open it in your browser

You have two options:

Option A: SSH tunnel (no firewall changes needed)

gcloud compute ssh marimo-gpu --zone=us-central1-a \
  -- -L 2718:localhost:2718 -N

Then open http://localhost:2718 in your browser.

Option B: Direct access (expose the port, locked to your IP)

# Get your public IP
MY_IP=$(curl -s ifconfig.me)

# Create a firewall rule restricted to your IP
gcloud compute firewall-rules create allow-marimo \
  --allow=tcp:2718 \
  --source-ranges=$MY_IP/32 \
  --target-tags=marimo

# Tag the VM
gcloud compute instances add-tags marimo-gpu \
  --zone=us-central1-a --tags=marimo

Then open http://<VM_EXTERNAL_IP>:2718. You can find the external IP in the output from step 1, or run gcloud compute instances describe marimo-gpu --zone=us-central1-a --format='value(networkInterfaces[0].accessConfigs[0].natIP)'.

5. Clean up when done

gcloud compute instances delete marimo-gpu --zone=us-central1-a
gcloud compute firewall-rules delete allow-marimo  # if you created one

Don’t forget this step — the VM charges by the minute whether you’re using it or not.

Running experiments faster with parallelization

So bigger hardware doesn’t help — but these experiments are still slow. The learning rate sweep takes 28 seconds on CPU, and the batch size sweep takes nearly 3 minutes. When you’re iterating on hyperparameters, that adds up. Can we do better without upgrading hardware?

Yes — because the 5 learning rate configs are completely independent. They don’t share weights, gradients, or state. We can run them in parallel. There are two practical approaches, and they attack different bottlenecks from the list above.

Multiprocessing: run each config on its own CPU core

The simplest approach: spawn one process per config using Python’s multiprocessing.Pool. Each worker trains its own model independently, and the parent process collects results when everyone finishes.

import multiprocessing as mp

def train_one_config(args):
    lr, X_train, y_train = args
    model = keras.Sequential([
        keras.layers.Dense(128, activation="relu", input_shape=(784,)),
        keras.layers.Dense(10, activation="softmax"),
    ])
    model.compile(
        optimizer=keras.optimizers.SGD(learning_rate=lr),
        loss="sparse_categorical_crossentropy",
        metrics=["accuracy"],
    )
    history = model.fit(X_train, y_train, epochs=10, batch_size=32,
                        validation_split=0.2, verbose=0)
    return lr, history.history

with mp.Pool(5) as pool:
    results = pool.map(train_one_config,
                       [(lr, X_train, y_train) for lr in [0.001, 0.01, 0.1, 1.0, 10.0]])

With 5 workers on 5 cores, the entire sweep runs in roughly the time of one config — about a 4x speedup. The speedup isn’t a perfect 5x because of process creation overhead and OS scheduling, but it’s close.

This attacks bottleneck #3 (Python single-thread speed) by side-stepping it entirely — each config gets its own Python process. But each worker still pays the full TensorFlow dispatch overhead per step (bottleneck #2).

JAX vmap: train all models in one vectorized pass

A more elegant approach uses JAX’s vmap (vectorized map) to train all 5 models simultaneously in a single forward and backward pass. Instead of a Python loop over configs, we stack all 5 sets of parameters into batched arrays and let vmap apply the same update step across all of them at once.

import jax
import jax.numpy as jnp
from jax import vmap, jit, grad, random

def init_params(key):
    k1, k2 = random.split(key)
    w1 = random.normal(k1, (784, 128)) * jnp.sqrt(2.0 / 784)
    b1 = jnp.zeros(128)
    w2 = random.normal(k2, (128, 10)) * jnp.sqrt(2.0 / 128)
    b2 = jnp.zeros(10)
    return (w1, b1, w2, b2)

def forward(params, x):
    w1, b1, w2, b2 = params
    h = jnp.maximum(0, x @ w1 + b1)  # ReLU
    return h @ w2 + b2               # logits

def loss_fn(params, x, y):
    logits = forward(params, x)
    log_probs = logits - jnp.log(jnp.sum(jnp.exp(logits), axis=-1, keepdims=True))
    return -jnp.mean(log_probs[jnp.arange(y.shape[0]), y])

def sgd_step(params, x, y, lr):
    grads = grad(loss_fn)(params, x, y)
    return tuple(p - lr * g for p, g in zip(params, grads))

# vmap over params and lr, broadcast x and y
batched_step = jit(vmap(sgd_step, in_axes=(0, None, None, 0)))

# Stack 5 sets of parameters and learning rates
keys = random.split(random.PRNGKey(42), 5)
batched_params = tuple(
    jnp.stack([init_params(k)[i] for k in keys])
    for i in range(4)
)
lr_array = jnp.array([0.001, 0.01, 0.1, 1.0, 10.0])

# Each call trains all 5 models on one batch simultaneously
batched_params = batched_step(batched_params, x_batch, y_batch, lr_array)

The key insight: vmap maps the update function over the parameter axis (each model has its own weights) and the learning rate axis, while broadcasting the training batch across all models — they all see the same data. Under the hood, XLA compiles this into a single fused kernel: one matrix multiply handles all 5 forward passes, one handles all 5 backward passes. There’s no Python loop, no per-model kernel launch overhead, and the memory layout is optimized for the hardware.

This attacks all three bottlenecks at once. Bottleneck #1 (GPU/kernel overhead): XLA fuses operations across models, so the hardware gets one large workload instead of many tiny ones. Bottleneck #2 (framework dispatch): JAX’s JIT compilation eliminates per-op Python→runtime dispatch. Bottleneck #3 (Python single-thread): there’s one compiled step per batch for all 5 models, not five separate Python iterations.

The tradeoff: all models must have the same architecture (same parameter shapes), and you need to rewrite the training loop in JAX. Multiprocessing has no such restriction.

Timing comparison

Here’s how the four approaches compare on the learning rate sweep (5 configs × 10 epochs, 8-core CPU):

Wall-Clock Time — Learning Rate Sweep (5 configs × 10 epochs)
How it works
 

Multiprocessing gives an easy ~4x win with zero code changes to the training logic. JAX’s vmap goes further — by eliminating Python overhead and fusing operations, it delivers nearly 6x speedup over the Keras baseline.

For our MNIST model, this brings a 28-second sweep down to under 5 seconds — fast enough for interactive exploration. The same principles apply at larger scale: hyperparameter sweeps are embarrassingly parallel, and the choice between process-level parallelism (multiprocessing, distributed training) and model-level vectorization (vmap) depends on whether your configs share the same architecture.

What to take away

  1. TFLOPS are a ceiling, not a floor. Your workload has to be large enough to approach that ceiling. Our notebook used ~0.001% of the L4’s theoretical throughput.
  2. Small models are bottlenecked by overhead, not compute. Kernel launches, memory transfers, and framework overhead dominate. A CPU-only machine was 2-3x faster than an L4 GPU for every experiment in our notebook.
  3. nvidia-smi utilization is misleading. It shows time-with-any-kernel-running, not core occupancy. A GPU can report 100% utilization while delivering 2% of its theoretical throughput. Use Nsight profiling or nvidia-smi dmon -s u to see actual SM occupancy.
  4. Match your hardware to your workload. For small models and prototyping, a fast CPU is not just cheaper — it’s genuinely faster. Save GPUs for models with millions of parameters.
  5. When you do use a GPU, feed it enough work. Larger batch sizes, larger models, or mixed-precision training can all help saturate the hardware. If your matrix multiplies complete in microseconds, the GPU never gets a chance to show what it can do.

Our MNIST notebook is a teaching tool — it’s designed to make concepts clear, not to stress hardware. The fact that a 0.77/hrCPUmachinebeatsa0.77/hr CPU machine beats a 0.70/hr GPU machine on every experiment is the most practical lesson here: don’t reach for a GPU until your workload actually needs one.