Part 4 of 6 in the Distributed Training series
4.1 Overview: Why Tensor Parallelism?
Data parallelism replicates the entire model on each GPU. For very large models, even the parameters themselves don't fit on a single device. Tensor parallelism addresses this by splitting individual layers across GPUs.
Unlike data parallelism where each GPU holds a complete model copy, tensor parallelism partitions the weight matrices themselves. This enables training models whose parameters exceed single-GPU memory.
4.2 Column Parallel Linear
The weight matrix is split along the output dimension. Each GPU computes a portion of the output:
class ColumnParallelLinear(nn.Module):
"""
Linear split along output dim: Y = X @ [A1|A2|...|An]
Each GPU computes Yi = X @ Ai
"""
def __init__(self, in_features: int, out_features: int, tp: TensorParallelGroup,
bias: bool = True, gather_output: bool = True):
super().__init__()
assert out_features % tp.world_size == 0
self.tp = tp
self.gather_output = gather_output
self.out_per_rank = out_features // tp.world_size
self.weight = nn.Parameter(torch.empty(self.out_per_rank, in_features))
self.bias = nn.Parameter(torch.empty(self.out_per_rank)) if bias else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
x = copy_to_parallel(x, self.tp) # Identity fwd, all-reduce bwd
out = F.linear(x, self.weight, self.bias)
return gather_from_parallel(out, self.tp, dim=-1) if self.gather_output else out Notice that:
- No intermediate communication until outputs are optionally gathered.
- Enables forward computation overlap with other GPUs.
4.3 Row Parallel Linear
The weight matrix is split along the input dimension. Partial results are summed via all-reduce:
class RowParallelLinear(nn.Module):
"""
Linear split along input dim: Y = sum([X1|X2|...|Xn] @ [A1;A2;...;An])
Each GPU computes Yi = Xi @ Ai, then all-reduce.
"""
def __init__(self, in_features: int, out_features: int, tp: TensorParallelGroup,
bias: bool = True, input_is_parallel: bool = False):
super().__init__()
assert in_features % tp.world_size == 0
self.tp = tp
self.input_is_parallel = input_is_parallel
self.in_per_rank = in_features // tp.world_size
self.weight = nn.Parameter(torch.empty(out_features, self.in_per_rank))
self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
def forward(self, x: torch.Tensor) -> torch.Tensor:
if not self.input_is_parallel:
x = scatter_to_parallel(x, self.tp, dim=-1)
out = F.linear(x, self.weight)
out = reduce_from_parallel(out, self.tp) # All-reduce sum
return out + self.bias if self.bias is not None else out Notice that:
- Communication is reduced to a single all-reduce at the layer output.
- Very useful for MLP blocks in transformers.
4.4 Tensor Parallel Attention
Attention heads are split across GPUs. Each GPU processes num_heads / world_size heads:
class TensorParallelAttention(nn.Module):
def __init__(self, d_model: int, num_heads: int, tp: TensorParallelGroup):
super().__init__()
assert num_heads % tp.world_size == 0
self.tp = tp
self.num_heads = num_heads // tp.world_size
self.head_dim = d_model // num_heads
self.scale = self.head_dim ** -0.5
# QKV: column parallel (split heads)
self.qkv = ColumnParallelLinear(d_model, 3 * d_model, tp, gather_output=False)
# Output: row parallel (reduce partial results)
self.out_proj = RowParallelLinear(d_model, d_model, tp, input_is_parallel=True)
def forward(self, x: torch.Tensor) -> torch.Tensor:
B, S, _ = x.shape
qkv = self.qkv(x).view(B, S, 3, self.num_heads, self.head_dim)
qkv = qkv.permute(2, 0, 3, 1, 4)
q, k, v = qkv.unbind(0)
attn = (q @ k.transpose(-2, -1)) * self.scale
attn = F.softmax(attn, dim=-1)
out = (attn @ v).transpose(1, 2).reshape(B, S, -1)
return self.out_proj(out) Notice that:
- Column-parallel QKV + row-parallel output ensures intermediate activations remain partitioned.
- Communication is minimized even for large-scale multi-head attention.
4.5 Communication Pattern
The key insight is that column-parallel and row-parallel layers can be composed without intermediate communication:
This pairing is why transformer MLPs typically use column-parallel first layer and row-parallel second layer. The intermediate tensor stays partitioned to avoid unnecessary communication costs.