↼ Back to Thoughts

4. Tensor Parallelism as the Fourth Scaling Primitive

Splitting individual layers across GPUs

Part 4 of 6 in the Distributed Training series

4.1 Overview: Why Tensor Parallelism?

Data parallelism replicates the entire model on each GPU. For very large models, even the parameters themselves don't fit on a single device. Tensor parallelism addresses this by splitting individual layers across GPUs.

Input XSingle Linear Layer (split across GPUs)GPU 0W[:, 0:n/3]GPU 1W[:, n/3:2n/3]GPU 2W[:, 2n/3:n]Y₀Y₁Y₂Gather or reduce partial outputs as needed

Unlike data parallelism where each GPU holds a complete model copy, tensor parallelism partitions the weight matrices themselves. This enables training models whose parameters exceed single-GPU memory.

4.2 Column Parallel Linear

The weight matrix is split along the output dimension. Each GPU computes a portion of the output:

class ColumnParallelLinear(nn.Module):
    """
    Linear split along output dim: Y = X @ [A1|A2|...|An]
    Each GPU computes Yi = X @ Ai
    """
    
    def __init__(self, in_features: int, out_features: int, tp: TensorParallelGroup,
                 bias: bool = True, gather_output: bool = True):
        super().__init__()
        assert out_features % tp.world_size == 0
        self.tp = tp
        self.gather_output = gather_output
        self.out_per_rank = out_features // tp.world_size
        
        self.weight = nn.Parameter(torch.empty(self.out_per_rank, in_features))
        self.bias = nn.Parameter(torch.empty(self.out_per_rank)) if bias else None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        x = copy_to_parallel(x, self.tp)  # Identity fwd, all-reduce bwd
        out = F.linear(x, self.weight, self.bias)
        return gather_from_parallel(out, self.tp, dim=-1) if self.gather_output else out

Notice that:

  • No intermediate communication until outputs are optionally gathered.
  • Enables forward computation overlap with other GPUs.

4.3 Row Parallel Linear

The weight matrix is split along the input dimension. Partial results are summed via all-reduce:

class RowParallelLinear(nn.Module):
    """
    Linear split along input dim: Y = sum([X1|X2|...|Xn] @ [A1;A2;...;An])
    Each GPU computes Yi = Xi @ Ai, then all-reduce.
    """
    
    def __init__(self, in_features: int, out_features: int, tp: TensorParallelGroup,
                 bias: bool = True, input_is_parallel: bool = False):
        super().__init__()
        assert in_features % tp.world_size == 0
        self.tp = tp
        self.input_is_parallel = input_is_parallel
        self.in_per_rank = in_features // tp.world_size
        
        self.weight = nn.Parameter(torch.empty(out_features, self.in_per_rank))
        self.bias = nn.Parameter(torch.empty(out_features)) if bias else None
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        if not self.input_is_parallel:
            x = scatter_to_parallel(x, self.tp, dim=-1)
        out = F.linear(x, self.weight)
        out = reduce_from_parallel(out, self.tp)  # All-reduce sum
        return out + self.bias if self.bias is not None else out

Notice that:

  • Communication is reduced to a single all-reduce at the layer output.
  • Very useful for MLP blocks in transformers.

4.4 Tensor Parallel Attention

Attention heads are split across GPUs. Each GPU processes num_heads / world_size heads:

class TensorParallelAttention(nn.Module):
    def __init__(self, d_model: int, num_heads: int, tp: TensorParallelGroup):
        super().__init__()
        assert num_heads % tp.world_size == 0
        self.tp = tp
        self.num_heads = num_heads // tp.world_size
        self.head_dim = d_model // num_heads
        self.scale = self.head_dim ** -0.5
        
        # QKV: column parallel (split heads)
        self.qkv = ColumnParallelLinear(d_model, 3 * d_model, tp, gather_output=False)
        # Output: row parallel (reduce partial results)
        self.out_proj = RowParallelLinear(d_model, d_model, tp, input_is_parallel=True)
    
    def forward(self, x: torch.Tensor) -> torch.Tensor:
        B, S, _ = x.shape
        
        qkv = self.qkv(x).view(B, S, 3, self.num_heads, self.head_dim)
        qkv = qkv.permute(2, 0, 3, 1, 4)
        q, k, v = qkv.unbind(0)
        
        attn = (q @ k.transpose(-2, -1)) * self.scale
        attn = F.softmax(attn, dim=-1)
        out = (attn @ v).transpose(1, 2).reshape(B, S, -1)
        
        return self.out_proj(out)

Notice that:

  • Column-parallel QKV + row-parallel output ensures intermediate activations remain partitioned.
  • Communication is minimized even for large-scale multi-head attention.

4.5 Communication Pattern

The key insight is that column-parallel and row-parallel layers can be composed without intermediate communication:

Column ParallelA₁A₂A₃Y₁Y₂Y₃no communicationRow ParallelB₁B₂B₃partial sumsAll-ReduceOutput

This pairing is why transformer MLPs typically use column-parallel first layer and row-parallel second layer. The intermediate tensor stays partitioned to avoid unnecessary communication costs.