↼ Back to Thoughts

2. Data Parallelism as the First Scaling Primitive

Explicit gradient synchronization and the cost of communication

Part 2 of 6 in the Distributed Training series

2.1 Data Parallelism

The first distributed primitive I implemented was data parallelism.

Data parallelism was chosen because it is straightforward to implement for models that fit in GPU memory and allows linear scaling with worker count under moderate communication overhead.

Each process:

  1. holds a full copy of the model
  2. processes a different batch
  3. computes local gradients
  4. synchronizes gradients before the optimizer step
GPU 0GPU 1GPU 2Model(copy)Model(copy)Model(copy)Batch 0Batch 1Batch 2Grads 0Grads 1Grads 2All-Reduce← Average gradients

This keeps replicas consistent while scaling throughput with the number of workers, assuming communication overhead is manageable.

While frameworks like PyTorch DDP automate synchronization, manually implementing it allowed us to inspect behavior and experiment with batching multiple gradients.

2.2 Explicit Gradient Synchronization

After local backpropagation, gradients are reduced and averaged across processes:

import torch.distributed as dist

def train_step_distributed(model, x, y, optimizer, loss_fn, world_size):
    logits = model(x)
    loss = loss_fn(logits, y)
    loss.backward()

    for param in model.parameters():
        if param.grad is not None:
            dist.all_reduce(param.grad, op=dist.ReduceOp.SUM)
            param.grad.div_(world_size)

    optimizer.step()
    optimizer.zero_grad()
    return loss.item()

This implementation made two properties explicit:

  1. Correctness is purely a function of synchronization. If every gradient is reduced exactly once and averaged, replicas stay in sync.
  2. Communication cost is paid every step, independent of model size or batch composition.

Once synchronization was correct, behavior matched single-process training, with communication as the only additional cost.

2.3 Communication on the Critical Path

The initial implementation performed one all-reduce per parameter tensor:

for param in model.parameters():
    dist.all_reduce(param.grad)

For a transformer with hundreds of parameters, this produced hundreds of collective operations per step. The dominant cost was not bandwidth but latency. Each collective introduced a synchronization point, and those costs accumulated quickly.

At this point, computation was no longer the bottleneck. Communication pattern was.

Accordingly, the next step focuses on fusing gradients and reducing the number of collective operations per step.