diff --git a/transformer_engine/pytorch/module/base.py b/transformer_engine/pytorch/module/base.py index 6524678fd..69c8d86ef 100644 --- a/transformer_engine/pytorch/module/base.py +++ b/transformer_engine/pytorch/module/base.py @@ -1077,6 +1077,13 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_meta["num_gemms"] = num_gemms self.init_fp8_meta_tensors(self.fp8_meta["recipe"]) + # Force the transpose cache to be kept whenever the recipe is MXFP8 / MXFP4, + # regardless of whether we are currently inside an fp8_autocast region or not. + # reset_parameters() would disable columnwise_usage for params constructed inside + # `fp8_model_init` / `quantized_model_init`, leaving `_columnwise_data=None`). + if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4(): + self.keep_fp8_weight_transpose_cache = True + if fp8_enabled: # Set FP8 and other FP8 metadata self.fp8_meta["num_gemms"] = num_gemms @@ -1092,8 +1099,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None: self.fp8_initialized = True self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe() - if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4(): - self.keep_fp8_weight_transpose_cache = True _current_recipe = self.fp8_meta["recipe"] if _original_recipe is not None and not ( diff --git a/transformer_engine/pytorch/tensor/mxfp8_tensor.py b/transformer_engine/pytorch/tensor/mxfp8_tensor.py index bd3d93e9f..387bc75fd 100644 --- a/transformer_engine/pytorch/tensor/mxfp8_tensor.py +++ b/transformer_engine/pytorch/tensor/mxfp8_tensor.py @@ -471,7 +471,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv] split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE] # Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4 - padding_multiples = [128, 4] + # NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`), + # so applying the padding here would produce a per-shard scale-inv whose dim-0 + # does not match the destination scale-inv allocated for the FSDP2 local shard. + padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1] for scale_inv, scale_split_size, pad_multiple in zip( scale_invs, split_sizes_for_scale, padding_multiples ): @@ -487,7 +490,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None): ) scale_inv_out = list(scale_inv_out) if scale_inv_out is not None else None # Pad scale_inv_out to be a multiple of pad_multiple - if scale_inv_out is not None: + if scale_inv_out is not None and pad_multiple > 1: for idx, split_scale_inv_out in enumerate(scale_inv_out): current_shape = split_scale_inv_out.shape pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple