diff --git a/examples/pytorch/comm_gemm_overlap/README.md b/examples/pytorch/comm_gemm_overlap/README.md index b7ecb2d069..779d8eca09 100644 --- a/examples/pytorch/comm_gemm_overlap/README.md +++ b/examples/pytorch/comm_gemm_overlap/README.md @@ -9,6 +9,81 @@ - Devices older than compute capability 9.0 require `UB_SKIPMC=1` in the environment in order to fall back on a less performant implementation based on CUDA Inter-Process Communication (IPC) handles. +## Enabling overlap in your own module + +The example follows the same setup sequence that user code should use: + +1. Set `CUDA_DEVICE_MAX_CONNECTIONS=1` before creating the layer. +2. Initialize `torch.distributed` and create the tensor-parallel process group. +3. Call `te.module.base.initialize_ub(...)` with the local activation shape and tensor-parallel + size before constructing TE layers with userbuffer overlap enabled. +4. Pass the tensor-parallel group, tensor-parallel size, and overlap flags to the TE layer. +5. Call `te.module.base.destroy_ub()` before shutting down the process group. + +Minimal setup sketch: + +```python +import os +import torch +import torch.distributed as dist +import transformer_engine.pytorch as te + +os.environ["CUDA_DEVICE_MAX_CONNECTIONS"] = "1" + +dist.init_process_group(backend="nccl") +tp_group = dist.group.WORLD +tp_size = dist.get_world_size(tp_group) + +num_heads = 16 +head_dim = 128 +seq_length = 2048 +micro_batch_size = 4 + +hidden_size = num_heads * head_dim +batched_size = seq_length * micro_batch_size + +te.module.base.initialize_ub( + [batched_size, hidden_size], + tp_size, + quantization_modes=[te.module.base.UserBufferQuantizationMode.NONE], + dtype=torch.bfloat16, + bootstrap_backend="nccl", +) + +layer = te.TransformerLayer( + hidden_size, + 4 * hidden_size, + num_heads, + tp_group=tp_group, + tp_size=tp_size, + sequence_parallel=True, + fuse_qkv_params=True, + ub_tp_comm_overlap=True, + ub_overlap_ag=True, + ub_overlap_rs=True, + ub_bulk_wgrad=True, + ub_bulk_dgrad=True, + seq_length=seq_length, +) + +# ... run forward/backward/optimizer steps ... + +te.module.base.destroy_ub() +``` + +`ub_tp_comm_overlap` is the top-level gate on `TransformerLayer`: when it is `False`, the +layer disables the individual userbuffer overlap paths even if the per-path flags are `True`. +For lower-level layers such as `Linear`, `LayerNormLinear`, `LayerNormMLP`, or +`MultiheadAttention`, enable the relevant per-path flags directly (for example +`ub_overlap_ag`, `ub_overlap_rs`, `ub_bulk_wgrad`, and `ub_bulk_dgrad`) and set the `ub_name` +where the layer requires one. + +When replacing modules in a Hugging Face model, run the userbuffer initialization once before +constructing the replacement TE modules. The replacement modules need the same tensor-parallel +group, tensor-parallel size, sequence-parallel setting, and overlap flags shown above; the +activation shape passed to `initialize_ub` should match the sequence length, micro-batch size, +and hidden size used by the replaced blocks. + ## Examples ### Single node, tensor-parallel LayerNormMLP: