ML in the browser: running and training models with Transformers.js and TensorFlow.js
Machine learning usually means Python, GPUs, and a server somewhere. But two JavaScript libraries let you skip all of that and run models directly in the browser tab your user already has open.
Transformers.js v4 loads pretrained Hugging Face models and runs inference using WebGPU or WASM — sentiment analysis, text generation, embeddings, all client-side. TensorFlow.js goes further: you can define a model architecture, train it on real data, and run predictions — all without leaving the browser.
This article walks through both. We’ll build two standalone demos: one that runs pretrained models with Transformers.js, and one that trains a convolutional neural network on MNIST with TensorFlow.js.
Part 1: Inference with Transformers.js v4
Transformers.js v4 shipped with a rewritten WebGPU runtime in C++, 10x faster build times, and support for ~15 new architectures including DeepSeek-v3 and Qwen3. The API stays simple: pick a task, point it at a model, and call it.
Setup
npm create vite@latest transformers-demo -- --template vanilla
cd transformers-demo
npm i @huggingface/transformers
Running a pipeline
The core abstraction is pipeline — you specify a task and a model, and it handles downloading, caching, and inference:
import { pipeline } from "@huggingface/transformers";
const classifier = await pipeline(
"sentiment-analysis",
"Xenova/distilbert-base-uncased-finetuned-sst-2-english"
);
const result = await classifier("I love how fast this is!");
// [{ label: "POSITIVE", score: 0.9998 }]
The first call downloads the model (~67MB for DistilBERT) and caches it in the browser’s Cache API. Subsequent loads are instant.
Multiple tasks
The same pattern works across tasks. Text generation:
const generator = await pipeline("text-generation", "Xenova/gpt2");
const output = await generator("The future of AI is", {
max_new_tokens: 50,
});
Summarization:
const summarizer = await pipeline(
"summarization",
"Xenova/distilbart-cnn-6-6"
);
const summary = await summarizer(longText, { max_new_tokens: 60 });
Feature extraction (embeddings):
const extractor = await pipeline(
"feature-extraction",
"Xenova/all-MiniLM-L6-v2"
);
const embeddings = await extractor("Some text to embed");
// Tensor of shape [1, tokens, 384]
WebGPU detection
Transformers.js v4 automatically uses WebGPU when available and falls back to WASM. You can check what’s available:
if (navigator.gpu) {
const adapter = await navigator.gpu.requestAdapter();
if (adapter) {
console.log("WebGPU available — hardware acceleration enabled");
}
}
In our demo app, we show a badge indicating whether the browser is using WebGPU or WASM, so the user knows what backend is running.
What this is good for
Client-side inference makes sense when:
- Privacy matters — data never leaves the device. Useful for text that users wouldn’t want sent to a server.
- Latency matters — no network round-trip. Once the model is cached, inference starts immediately.
- Cost matters — the user’s device does the compute. No GPU servers to provision.
- Offline matters — with cached models and WASM, everything works without a connection.
The tradeoff is model size. Browser-friendly models are typically quantized and distilled, so they won’t match a 70B parameter model running on a server. But for many tasks — classification, embeddings, short generation — the quality is more than sufficient.
Part 2: Training MNIST with TensorFlow.js
Running pretrained models is one thing. Training from scratch is another — and TensorFlow.js can do that too, entirely in the browser.
Our demo has an editable code editor where you define the model architecture in JavaScript, then train it on MNIST and test it by drawing digits. The full pipeline — data loading, model definition, training loop with live metrics, and interactive inference — runs client-side.
Setup
npm create vite@latest tfjs-mnist -- --template vanilla
cd tfjs-mnist
npm i @tensorflow/tfjs
Loading MNIST
MNIST is 70,000 grayscale 28×28 images of handwritten digits (0–9). The TensorFlow.js team hosts a browser-friendly version as a single PNG sprite sheet and a binary labels file:
import * as tf from "@tensorflow/tfjs";
const MNIST_IMAGES_URL =
"https://storage.googleapis.com/learnjs-data/model-builder/mnist_images.png";
const MNIST_LABELS_URL =
"https://storage.googleapis.com/learnjs-data/model-builder/mnist_labels_uint8";
The sprite sheet packs all 65,000 images into one PNG. We draw it to a canvas, read the pixel data, and extract each 28×28 image as a normalized float array:
const img = new Image();
img.crossOrigin = "";
img.src = MNIST_IMAGES_URL;
img.onload = () => {
const canvas = document.createElement("canvas");
canvas.width = img.width;
canvas.height = img.height;
const ctx = canvas.getContext("2d");
ctx.drawImage(img, 0, 0);
const imageData = ctx.getImageData(0, 0, canvas.width, canvas.height);
const dataset = new Float32Array(65000 * 784);
for (let i = 0; i < 65000; i++) {
for (let j = 0; j < 784; j++) {
dataset[i * 784 + j] = imageData.data[(i * 784 + j) * 4] / 255;
}
}
};
We split this into 55,000 training and 10,000 test images, reshape into [N, 28, 28, 1] tensors, and one-hot encode the labels.
The model
Instead of hardcoding an architecture, the demo has an editable code editor where you write your model definition. The code receives tf as a parameter and must return a compiled model. The default is a small CNN:
const model = tf.sequential();
model.add(tf.layers.conv2d({
inputShape: [28, 28, 1],
filters: 16,
kernelSize: 3,
activation: "relu",
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.conv2d({
filters: 32,
kernelSize: 3,
activation: "relu",
}));
model.add(tf.layers.maxPooling2d({ poolSize: 2 }));
model.add(tf.layers.flatten());
model.add(tf.layers.dense({ units: 64, activation: "relu" }));
model.add(tf.layers.dropout({ rate: 0.25 }));
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));
model.compile({
optimizer: tf.train.adam(0.01),
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
return model;
But you can replace it with anything — a simple dense network, a deeper CNN, different optimizers. The only constraints are that the input shape must be [28, 28, 1] and the output must be 10 classes with softmax. Try replacing the CNN with a pure dense network to see how much accuracy you lose:
const model = tf.sequential();
model.add(tf.layers.flatten({ inputShape: [28, 28, 1] }));
model.add(tf.layers.dense({ units: 128, activation: "relu" }));
model.add(tf.layers.dense({ units: 10, activation: "softmax" }));
model.compile({
optimizer: tf.train.adam(0.01),
loss: "categoricalCrossentropy",
metrics: ["accuracy"],
});
return model;
Under the hood, we pass the user’s code to new Function("tf", code) and call it with the TensorFlow.js module. If the code throws or doesn’t return a valid model, the error is shown inline.
Training
TensorFlow.js model.fit works like Keras. The onEpochEnd callback gives us live metrics:
await model.fit(trainXs, trainLabels, {
epochs: 5,
batchSize: 128,
validationData: [testXs, testLabels],
callbacks: {
onEpochEnd: (epoch, logs) => {
console.log(
`Epoch ${epoch + 1}: loss=${logs.loss.toFixed(4)} ` +
`acc=${logs.acc.toFixed(4)} val_acc=${logs.val_acc.toFixed(4)}`
);
},
},
});
Typical output after 5 epochs:
Epoch 1: loss=0.3012 acc=0.9102 val_acc=0.9701
Epoch 2: loss=0.0892 acc=0.9734 val_acc=0.9812
Epoch 3: loss=0.0623 acc=0.9812 val_acc=0.9851
Epoch 4: loss=0.0498 acc=0.9848 val_acc=0.9872
Epoch 5: loss=0.0412 acc=0.9875 val_acc=0.9889
98.9% validation accuracy in 5 epochs, trained entirely in the browser.
Interactive prediction
Once trained, we let the user draw a digit on a canvas and classify it in real time:
function predict() {
const tensor = tf.tidy(() => {
// Downscale the 200x200 drawing canvas to 28x28
const small = document.createElement("canvas");
small.width = 28;
small.height = 28;
small.getContext("2d").drawImage(drawCanvas, 0, 0, 28, 28);
return model.predict(
tf.browser
.fromPixels(small.getContext("2d").getImageData(0, 0, 28, 28), 1)
.toFloat()
.div(255)
.reshape([1, 28, 28, 1])
);
});
const probs = tensor.dataSync();
const predicted = probs.indexOf(Math.max(...probs));
tensor.dispose();
return { predicted, probs };
}
tf.tidy automatically cleans up intermediate tensors — important in a browser where you can’t rely on garbage collection for GPU memory. The drawing canvas is downscaled from 200×200 to 28×28 to match the training data resolution.
In our demo, prediction runs on every mouse-up, so the model classifies as you draw. A probability bar chart shows the confidence for each digit.
Memory management
TensorFlow.js tensors live outside JavaScript’s garbage collector. If you don’t dispose them, you’ll leak GPU/WASM memory. Three patterns to know:
tf.tidy(() => { ... })— disposes all tensors created inside, except the return valuetensor.dispose()— manual cleanup for tensors you need to keep aroundtf.memory()— returns current tensor count and bytes, useful for debugging leaks
In the training code, we dispose the training and test tensors after model.fit completes:
trainXs.dispose();
trainLabels.dispose();
testXs.dispose();
testLabels.dispose();
Running the demos
Both demos are Vite apps. To try them:
# Transformers.js inference demo
cd transformers-js-demo
npm install && npm run dev
# TensorFlow.js MNIST training demo
cd tfjs-mnist-demo
npm install && npm run dev
The Transformers.js demo downloads models on first use (~67MB for DistilBERT, smaller for others). The TensorFlow.js demo downloads MNIST data (~13MB) and trains from scratch each time.
When to use which
| Transformers.js | TensorFlow.js | |
|---|---|---|
| Primary use | Run pretrained models | Train and run models |
| Model source | Hugging Face Hub (ONNX) | Define in code or load saved |
| Backend | WebGPU / WASM | WebGL / WebGPU / WASM / Node |
| Best for | NLP tasks, embeddings, classification | Custom models, fine-tuning, on-device training |
| Model size | Depends on HF model (10MB–2GB) | You control it |
They’re not mutually exclusive. You might use Transformers.js for NLP tasks where a pretrained model already exists, and TensorFlow.js when you need to train something custom or work with architectures not available on Hugging Face.
Both libraries prove the same point: the browser is a legitimate ML runtime. Not for everything — but for more than most people expect.
Stay up to date
Get notified when I publish new deep dives.