Part 5 of 6 in the Distributed Training series
5.1 The Strategy Abstraction
As multiple approaches emerged (standard data parallelism, bucketed reduction, optimizer sharding, tensor parallelism), I introduced a strategy abstraction to separate training logic from distributed mechanics:
class DistributedStrategy:
def wrap_model(self, model: nn.Module) -> nn.Module:
raise NotImplementedError
def wrap_optimizer(self, optimizer_cls, model, **kwargs) -> Optimizer:
return optimizer_cls(model.parameters(), **kwargs)
def sync_gradients(self, model: nn.Module):
pass # Override if manual sync needed Each strategy implements these hooks differently. The DDPBucketedStrategy wraps the model in my custom bucketed gradient sync and requires an explicit sync call after backward:
class DDPBucketedStrategy(DistributedStrategy):
def wrap_model(self, model):
return DDPBucketed(model, bucket_size_mb=25.0)
def sync_gradients(self, model):
model.finish_gradient_synchronization() The ZeROStrategy combines standard DDP for gradient sync with sharded optimizer state:
class ZeROStrategy(DistributedStrategy):
def wrap_model(self, model):
return DDP(model, device_ids=[local_rank])
def wrap_optimizer(self, optimizer_cls, model, **kwargs):
return ShardedOptimizer(model.parameters(), optimizer_cls, **kwargs) This made it possible to switch between approaches without rewriting the training loop:
# Single GPU
python train.py --strategy single
# PyTorch DDP
torchrun --nproc_per_node=4 train.py --strategy ddp
# Custom bucketed DDP
torchrun --nproc_per_node=4 train.py --strategy ddp_bucketed
# ZeRO optimizer sharding
torchrun --nproc_per_node=4 train.py --strategy zero5.2 Simulating Multi-GPU with Gloo
All development was done with access to a single RTX 3060. To validate multi-process correctness without multi-GPU hardware, I used PyTorch's gloo backend, which supports CPU-based collective communication. The key insight: torchrun spawns multiple processes regardless of GPU count, and gloo allows those processes to communicate over shared memory.
def setup_distributed(backend="auto"):
rank = int(os.environ["RANK"])
world_size = int(os.environ["WORLD_SIZE"])
num_gpus = torch.cuda.device_count()
# Auto-select backend based on hardware
if backend == "auto":
backend = "nccl" if num_gpus >= world_size else "gloo"
if backend == "gloo":
# All processes share GPU 0 (simulation mode)
dist.init_process_group(backend="gloo")
torch.cuda.set_device(0)
device = torch.device("cuda:0")
else:
# Real multi-GPU: each process gets its own device
torch.cuda.set_device(local_rank)
device = torch.device(f"cuda:{local_rank}")
dist.init_process_group(backend="nccl", device_id=device)
return {"rank": rank, "world_size": world_size, "device": device} This gives us 4 processes that believe they're on separate GPUs, but all share the same physical device. The collective operations (all_reduce, broadcast, all_gather) work identically; only the transport layer differs. This means the distributed logic is fully exercised even without distributed hardware.
To run with simulated multi-GPU:
# Spawns 4 processes, all sharing GPU 0
torchrun --nproc_per_node=4 benchmark_strategies.py --quick5.3 The Benchmark Harness
To validate all implementations systematically, I built a benchmark harness that tests each strategy under identical conditions. The harness measures three properties:
Correctness: After N training steps, are weights identical across all ranks? For data-parallel strategies, all ranks should converge to the same weights. For tensor parallelism and FSDP, weights are intentionally sharded: different sums are correct behavior.
Memory: Peak GPU memory during training. This reveals the overhead of each approach.
Throughput: Steps per second and tokens per second. Not meaningful in simulation mode (all processes compete for one GPU), but the harness supports real multi-GPU benchmarking.
The core benchmark loop:
def run_benchmark(cfg, dist_info):
# Create model based on strategy
if cfg.strategy == "ddp_flat":
wrapper = DDPIndividualParameters(model)
needs_sync = True
elif cfg.strategy == "ddp_bucketed":
wrapper = DDPBucketed(model, bucket_size_mb=25.0)
needs_sync = True
elif cfg.strategy == "tensor_parallel":
model = TensorParallelTransformerLM(..., tp_group)
# Optional: wrap optimizer with ZeRO sharding
if cfg.use_sharded_optimizer:
opt = ShardedOptimizer(params, torch.optim.AdamW, ...)
# Training loop
for step in range(cfg.num_steps):
opt.zero_grad()
for _ in range(grad_accum):
loss = F.cross_entropy(model(x), y)
loss.backward()
if needs_sync:
wrapper.finish_gradient_synchronization()
opt.step()
# Verify weight synchronization
weight_sum = sum(p.sum() for p in model.parameters())
all_sums = [torch.zeros(1) for _ in range(world_size)]
dist.all_gather(all_sums, weight_sum)
synced = all(abs(all_sums[0] - s) 0.5 for s in all_sums) The strategies tested:
| Strategy | Description |
|---|---|
ddp_flat | Async per-parameter all-reduce with gradient hooks |
ddp_bucketed | Bucketed gradient reduction, reverse parameter order |
ddp_flat + ZeRO | Per-param sync with optimizer state sharding |
ddp_bucketed + ZeRO | Bucketed sync with optimizer state sharding |
tensor_parallel | Column/row parallel layers, attention head sharding |
pytorch_ddp | Baseline comparison |
pytorch_fsdp | Baseline comparison (parameter sharding) |