Skip to content
Open
Show file tree
Hide file tree
Changes from all commits
Commits
File filter

Filter by extension

Filter by extension

Conversations
Failed to load comments.
Loading
Jump to
Jump to file
Failed to load files.
Loading
Diff view
Diff view
9 changes: 7 additions & 2 deletions transformer_engine/pytorch/module/base.py
Original file line number Diff line number Diff line change
Expand Up @@ -1077,6 +1077,13 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_meta["num_gemms"] = num_gemms
self.init_fp8_meta_tensors(self.fp8_meta["recipe"])

# Force the transpose cache to be kept whenever the recipe is MXFP8 / MXFP4,
# regardless of whether we are currently inside an fp8_autocast region or not.
# reset_parameters() would disable columnwise_usage for params constructed inside
# `fp8_model_init` / `quantized_model_init`, leaving `_columnwise_data=None`).
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

if fp8_enabled:
# Set FP8 and other FP8 metadata
self.fp8_meta["num_gemms"] = num_gemms
Expand All @@ -1092,8 +1099,6 @@ def init_fp8_metadata(self, num_gemms: int = 1) -> None:
self.fp8_initialized = True

self.fp8_meta["recipe"] = FP8GlobalStateManager.get_fp8_recipe()
if self.fp8_meta["recipe"].mxfp8() or self.fp8_meta["recipe"].mxfp4():
self.keep_fp8_weight_transpose_cache = True

_current_recipe = self.fp8_meta["recipe"]
if _original_recipe is not None and not (
Expand Down
7 changes: 5 additions & 2 deletions transformer_engine/pytorch/tensor/mxfp8_tensor.py
Original file line number Diff line number Diff line change
Expand Up @@ -471,7 +471,10 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
scale_invs = [tensor._rowwise_scale_inv, tensor._columnwise_scale_inv]
split_sizes_for_scale = [split_size, split_size // MXFP8_BLOCK_SCALING_SIZE]
# Padding requirements: rowwise dim0 should be divisble by 128, columnwise dim0 should be divisble by 4
padding_multiples = [128, 4]
# NOTE: ROCm/HIP backend uses an unpadded scale-inv layout (see `MXFP8Quantizer.make_empty`),
# so applying the padding here would produce a per-shard scale-inv whose dim-0
# does not match the destination scale-inv allocated for the FSDP2 local shard.
padding_multiples = [128, 4] if not IS_HIP_EXTENSION else [1, 1]
for scale_inv, scale_split_size, pad_multiple in zip(
scale_invs, split_sizes_for_scale, padding_multiples
):
Expand All @@ -487,7 +490,7 @@ def __torch_dispatch__(cls, func, types, args, kwargs=None):
)
scale_inv_out = list(scale_inv_out) if scale_inv_out is not None else None
# Pad scale_inv_out to be a multiple of pad_multiple
if scale_inv_out is not None:
if scale_inv_out is not None and pad_multiple > 1:
for idx, split_scale_inv_out in enumerate(scale_inv_out):
current_shape = split_scale_inv_out.shape
pad_dim0 = (pad_multiple - current_shape[0] % pad_multiple) % pad_multiple
Expand Down
Loading