From 80ad4d9afa51db943a9335fcf2ddcb50177fa803 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Tue, 26 May 2026 15:08:32 +0200 Subject: [PATCH 1/3] [PyTorch Debug] Add scale_inv_std stat and skip NVFP4 layers in LogFp8TensorStats (#2801) - Register scale_inv_std (plus helper variance/numel/sum buffers using Welford reduction) for all FP8 recipes and NVFP4 in add_scale_inv_stats. Population variance keeps std=0 for delayed/current scaling where scale_inv is a single scalar. - Also wire scale_inv_min/max/std for NVFP4 (was previously only FP8 recipes). - LogFp8TensorStats.inspect_tensor now filters bare stats on NVFP4 layers with a warning instead of raising, so dual LogFp8TensorStats + LogNvfp4TensorStats configs work with overlapping (or catch-all) layer regexes. Recipe-prefixed FP8 stats (e.g. mxfp8_mse) are preserved for what-if comparisons. - Numerics test extended to validate scale_inv_min/max/std against torch.std(scale_inv, unbiased=False). Signed-off-by: Pawel Gadzinski --- tests/pytorch/debug/test_log.py | 99 +++++++++--- .../debug/features/log_fp8_tensor_stats.py | 114 ++++++++++---- .../debug/features/log_nvfp4_tensor_stats.py | 32 +++- .../debug/features/utils/stats_computation.py | 148 ++++++++++++++---- 4 files changed, 305 insertions(+), 88 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 055210f93a..7beed1a2e0 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -28,8 +28,8 @@ fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( - return_reason=True +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( + is_fp8_block_scaling_available(return_reason=True) ) nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) @@ -61,6 +61,7 @@ "underflows%", "scale_inv_min", "scale_inv_max", + "scale_inv_std", "mse", ] @@ -87,7 +88,9 @@ all_stats.append(f"{r}_{stat}{columnwise_postfix}") -all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows% +all_stats.append( + "fp8_delayed_scaling_overflows%" +) # only delayed-scaling supports overflows% @contextlib.contextmanager @@ -221,7 +224,9 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling(): pytest.skip(reason_for_no_fp8_block_scaling) - log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats)) + log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format( + stats=", ".join(bare_stats) + ) with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: recipe_state = RecipeState.create( @@ -248,25 +253,50 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): debug_api.step() dequantized_tensor = quantized_tensor.dequantize() + if hasattr(quantized_tensor, "_scale_inv"): + scale_inv_rowwise = quantized_tensor._scale_inv.float() + else: + scale_inv_rowwise = quantized_tensor._rowwise_scale_inv.float() output = read_log(log_dir) for line in output.splitlines(): if "underflows%" in line: underflows = float(line.split("value=")[1]) expected = ( - ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) + / tensor.numel() + * 100 ) assert underflows == pytest.approx(expected.cpu(), abs=1e-4) if "mse" in line: mse = float(line.split("value=")[1]) - expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") + expected = torch.nn.functional.mse_loss( + dequantized_tensor, tensor, reduction="mean" + ) assert mse == pytest.approx(expected.cpu(), abs=1e-4) if "overflows%" in line: overflows = float(line.split("value=")[1]) expected = ( - (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 + (abs(dequantized_tensor) > abs(tensor)).sum() + / dequantized_tensor.numel() + * 100 ) assert overflows == pytest.approx(expected.cpu(), abs=1e-4) + # Rowwise scale_inv stats only; logger formats with {:.4f} so abs<1e-4. + if "scale_inv_min" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx( + scale_inv_rowwise.min().cpu().item(), abs=1e-4 + ) + if "scale_inv_max" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + assert value == pytest.approx( + scale_inv_rowwise.max().cpu().item(), abs=1e-4 + ) + if "scale_inv_std" in line and "_columnwise" not in line: + value = float(line.split("value=")[1]) + expected = torch.std(scale_inv_rowwise, unbiased=False).cpu().item() + assert value == pytest.approx(expected, abs=1e-4) LOG_HIGH_PRECISION_CONFIG = """ @@ -328,7 +358,9 @@ def test_log_stats_numerics(feature_dirs, tensor_name): output = read_log(log_dir) max_over_orientations = tensor_name in ["activation", "weight"] - max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else "" + max_over_orientations_suffix = ( + "_max_over_orientations" if max_over_orientations else "" + ) # Track which stats were found to ensure all are present found_dims_1 = False @@ -336,8 +368,13 @@ def test_log_stats_numerics(feature_dirs, tensor_name): found_dynamic_range = False for line in output.splitlines(): - if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line: - max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) + if ( + f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" + in line + ): + max_blockwise_dynamic_range_block_size_4_dims_1 = float( + line.split("value=")[1] + ) if max_over_orientations: # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) expected = math.log2(A) - math.log2(B) @@ -349,9 +386,12 @@ def test_log_stats_numerics(feature_dirs, tensor_name): ) found_dims_1 = True elif ( - f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line + f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" + in line ): - max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) + max_blockwise_dynamic_range_block_size_4_dims_2 = float( + line.split("value=")[1] + ) # For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows expected = math.log2(A) - math.log2(B) assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( @@ -403,7 +443,8 @@ def test_log_every_3_or_5_layers(layer, configs_dir, feature_dirs): with open( os.path.join( - temp_dir, "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log" + temp_dir, + "nvdlfw_inspect_statistics_logs/nvdlfw_inspect_globalrank-0.log", ), "r", ) as f: @@ -539,7 +580,9 @@ def test_log_grouped_gemm(feature_dirs): log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats)) with debug_session(log_all_stats_config, feature_dirs) as log_dir: - model = te.GroupedLinear(3, 128, 128, name="linear1", params_dtype=torch.bfloat16) + model = te.GroupedLinear( + 3, 128, 128, name="linear1", params_dtype=torch.bfloat16 + ) inp = torch.randn((1, 128, 128), dtype=torch.bfloat16).cuda() m_splits = [64, 32, 32] with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): @@ -572,7 +615,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 1: dims=1, max_over_orientations=False (rowwise only) # Rowwise blocks have uniform values -> dynamic_range should be 0 - stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False) + stat_config = BlockwiseDynamicRangeStat( + block_size=4, dims=1, max_over_orientations=False + ) result = compute_max_blockwise_dynamic_range(tensor, stat_config) assert result.item() == pytest.approx( 0.0, abs=1e-4 @@ -580,7 +625,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise) # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) - stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + stat_config = BlockwiseDynamicRangeStat( + block_size=4, dims=1, max_over_orientations=True + ) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(B) assert result.item() == pytest.approx(expected, abs=1e-4), ( @@ -590,7 +637,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 3: dims=2, block_size=4 (4x4 tiles) # 2D blocks span multiple rows -> always have mixed values - stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False) + stat_config = BlockwiseDynamicRangeStat( + block_size=4, dims=2, max_over_orientations=False + ) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(B) assert result.item() == pytest.approx(expected, abs=1e-4), ( @@ -601,7 +650,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 4: Different block size # With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon] # So max=A, min=epsilon (not B anymore) - stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True) + stat_config = BlockwiseDynamicRangeStat( + block_size=8, dims=1, max_over_orientations=True + ) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B assert result.item() == pytest.approx( @@ -610,7 +661,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 5: Tensor with all uniform values -> dynamic_range should be 0 uniform_tensor = torch.ones(64, 64).cuda() * 42.0 - stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) + stat_config = BlockwiseDynamicRangeStat( + block_size=4, dims=1, max_over_orientations=True + ) result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config) assert result.item() == pytest.approx( 0.0, abs=1e-4 @@ -629,7 +682,9 @@ def test_compute_max_blockwise_dynamic_range_direct(): ).cuda() # Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100) - stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False) + stat_config = BlockwiseDynamicRangeStat( + block_size=2, dims=2, max_over_orientations=False + ) result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config) # Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct @@ -715,7 +770,9 @@ def test_dump_tensors_sanity(feature_dirs): ), f"Expected QuantizedTensor, got {type(data['quantized'])}" # Verify tensor shapes and values match - assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" + assert ( + data["high_precision"].shape == tensor.shape + ), "high_precision shape mismatch" assert torch.equal( data["high_precision"], tensor ), "high_precision tensor values do not match original tensor" diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index d26f9ef7f6..fa0770a46e 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -10,19 +10,26 @@ import torch import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.debug_features.log_tensor_stats import ( + LogTensorStats as BaseLogTensorStats, +) from nvdlfw_inspect.registry import Registry, api_method import transformer_engine_torch as tex from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter +from transformer_engine.debug.features.utils import ( + get_reduction_params, + next_enabled_iter, +) from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.float8_tensor import ( Float8Quantizer, Float8CurrentScalingQuantizer, ) from transformer_engine.pytorch.tensor.mxfp8_tensor import MXFP8Quantizer -from transformer_engine.pytorch.tensor.float8_blockwise_tensor import Float8BlockQuantizer +from transformer_engine.pytorch.tensor.float8_blockwise_tensor import ( + Float8BlockQuantizer, +) try: from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer @@ -33,7 +40,12 @@ NVFP4Quantizer = None -ALL_RECIPE_NAMES = ["fp8_delayed_scaling", "fp8_current_scaling", "mxfp8", "fp8_block_scaling"] +ALL_RECIPE_NAMES = [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "mxfp8", + "fp8_block_scaling", +] def _get_recipe_name(quantizer: Optional[Quantizer]): @@ -57,7 +69,10 @@ def _get_new_quantizer(recipe_name, fp8_dtype): return Float8BlockQuantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) if recipe_name == "fp8_current_scaling": return Float8CurrentScalingQuantizer( - fp8_dtype=fp8_dtype, device=torch.device("cuda"), rowwise=True, columnwise=True + fp8_dtype=fp8_dtype, + device=torch.device("cuda"), + rowwise=True, + columnwise=True, ) if recipe_name == "mxfp8": return MXFP8Quantizer(fp8_dtype=fp8_dtype, rowwise=True, columnwise=True) @@ -119,10 +134,13 @@ class LogFp8TensorStats(BaseLogTensorStats): - overflows% - percentage of elements of tensor that were clipped to the max/min value of the FP8 range - supported only for fp8_delayed_scaling, - scale_inv_min - minimum of the inverse of the scaling factors, - scale_inv_max - maximum of the inverse of the scaling factors, + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss (degenerate to 0 for + fp8_delayed_scaling / fp8_current_scaling since those use a single scalar scale). - mse - mean squared error of the quantized tensor and the original tensor = sum((quantized_tensor - original_tensor)**2) / num_elements, When collecting stats for the weight tensor with FP8 model parameters enabled, - only "scale_inv_min" and "scale_inv_max" are available. + only "scale_inv_min", "scale_inv_max" and "scale_inv_std" are available. All other statistics require access to the high precision tensor. tensors/tensors_struct: List[str] @@ -169,7 +187,9 @@ def check_if_stat_is_supported( columnwise = stat.endswith("_columnwise") if columnwise: stat = stat[: -len("_columnwise")] - recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) + recipe_from_stat, _ = self.get_recipe_from_stat( + stat, default_recipe=current_recipe + ) stat_without_recipe = stat.replace(recipe_from_stat + "_", "") need_high_precision_tensor_stats = ["underflows%", "overflows%", "mse"] @@ -189,34 +209,44 @@ def check_if_stat_is_supported( ) if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: - raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") - - # Block any NVFP4 stats in LogFp8TensorStats (FP8-specific logic won't work) - # But allow recipe-prefixed FP8 stats like "mxfp8_underflows%" even with NVFP4 quantizer - if recipe_from_stat == "nvfp4": raise ValueError( - f"[NVTORCH INSPECT ERROR] Cannot compute NVFP4 stats '{stat}' in LogFp8TensorStats." - " FP8-specific statistics do not work with NVFP4. Use LogNvfp4TensorStats for" - " NVFP4-specific stats, or use FP8 recipe-prefixed stats (e.g.," - " 'mxfp8_underflows%', 'fp8_block_scaling_mse') for what-if FP8 comparisons." + f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}" ) - if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: + # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). + assert recipe_from_stat != "nvfp4" + + if ( + recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] + and columnwise + ): raise ValueError( f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" " fp8_delayed_scaling and fp8_current_scaling." ) - if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%": + if ( + recipe_from_stat == "fp8_delayed_scaling" + and stat_without_recipe == "overflows%" + ): return True - if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9: + if ( + recipe_from_stat in ["fp8_block_scaling"] + and torch.cuda.get_device_capability()[0] < 9 + ): raise ValueError(f"Stat {stat} needs Hopper or later GPU.") if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: raise ValueError(f"Stat {stat} needs Blackwell or later GPU.") - supported_stats = ["underflows%", "scale_inv_min", "scale_inv_max", "mse"] + supported_stats = [ + "underflows%", + "scale_inv_min", + "scale_inv_max", + "scale_inv_std", + "mse", + ] if stat_without_recipe not in supported_stats: raise ValueError( f"Stat {stat} contains an unsupported stat name: {stat_without_recipe}" @@ -252,9 +282,14 @@ def update_aux_dict( Needs to clean after usage, because it possibly change the usage of the quantized tensor. """ fp8_dtype = tex.DType.kFloat8E4M3 - if recipe_name in ["fp8_delayed_scaling", "fp8_current_scaling", "fp8_block_scaling"]: + if recipe_name in [ + "fp8_delayed_scaling", + "fp8_current_scaling", + "fp8_block_scaling", + ]: assert isinstance( - quantizer, (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer) + quantizer, + (Float8Quantizer, Float8CurrentScalingQuantizer, Float8BlockQuantizer), ) fp8_dtype = quantizer.dtype @@ -280,7 +315,8 @@ def update_aux_dict( finally: if isinstance(quantized_tensor, QuantizedTensor): quantized_tensor.update_usage( - rowwise_usage=old_rowwise_usage, columnwise_usage=old_columnwise_usage + rowwise_usage=old_rowwise_usage, + columnwise_usage=old_columnwise_usage, ) @api_method @@ -338,6 +374,27 @@ def inspect_tensor( recipe_name = _get_recipe_name(quantizer) + # If the layer uses NVFP4, drop bare stats (which would target the NVFP4 + # recipe that LogFp8TensorStats can't handle) but keep stats explicitly + # prefixed with an FP8 recipe (e.g. "mxfp8_mse") for what-if FP8 comparison. + if _nvfp4_available and isinstance(quantizer, NVFP4Quantizer): + kept_stats, dropped_stats = [], [] + for stat in config["stats"]: + if any(r in stat for r in ALL_RECIPE_NAMES): + kept_stats.append(stat) + else: + dropped_stats.append(stat) + if dropped_stats: + warnings.warn( + f"[LogFp8TensorStats] Skipping stats {dropped_stats} for layer " + f"'{layer_name}', tensor '{tensor_name}': layer uses NVFP4. Use " + "LogNvfp4TensorStats for NVFP4 stats, or prefix stats with an FP8 " + "recipe name (e.g. 'mxfp8_mse') for what-if FP8 comparisons." + ) + if not kept_stats: + return + config = {**config, "stats": kept_stats} + for stat in config["stats"]: self.check_if_stat_is_supported( stat, recipe_name, high_precision_tensor_provided=tensor is not None @@ -347,7 +404,9 @@ def inspect_tensor( end_step = config.get("end_step", None) start_end_list = config.get("start_end_list", None) if start_end_list is not None: - start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + start_end_list = tuple( + tuple(int(x) for x in interval) for interval in start_end_list + ) options = ( start_step, @@ -356,8 +415,8 @@ def inspect_tensor( "fp8", ) - skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group, tp_size + skip_reduction, reduction_group, reduce_within_microbatch = ( + get_reduction_params(tensor_name, tp_group, tp_size) ) STATS_BUFFERS.try_add_buffer( @@ -370,7 +429,8 @@ def inspect_tensor( ) recipes_in_stats = [ - self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"] + self.get_recipe_from_stat(stat, default_recipe=recipe_name) + for stat in config["stats"] ] with self.update_aux_dict( diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 8a76f4edcf..38c1578f9a 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -11,14 +11,21 @@ import torch import nvdlfw_inspect.api as debug_api -from nvdlfw_inspect.debug_features.log_tensor_stats import LogTensorStats as BaseLogTensorStats +from nvdlfw_inspect.debug_features.log_tensor_stats import ( + LogTensorStats as BaseLogTensorStats, +) from nvdlfw_inspect.registry import Registry, api_method from transformer_engine.debug.features.utils.stats_buffer import STATS_BUFFERS from transformer_engine.pytorch.tensor import Quantizer, QuantizedTensor from transformer_engine.pytorch.tensor.nvfp4_tensor import NVFP4Quantizer -from transformer_engine.debug.features.utils import get_reduction_params, next_enabled_iter -from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import NVFP4TensorStorage +from transformer_engine.debug.features.utils import ( + get_reduction_params, + next_enabled_iter, +) +from transformer_engine.pytorch.tensor.storage.nvfp4_tensor_storage import ( + NVFP4TensorStorage, +) @Registry.register_feature(namespace="transformer_engine") @@ -45,6 +52,10 @@ class LogNvfp4TensorStats(BaseLogTensorStats): List of statistics to collect. Available stats: - underflows% - percentage of non-zero elements clipped to 0 (from packed FP4 data) - mse - mean squared error = sum((quantized_tensor - original_tensor)**2) / num_elements + - scale_inv_min - minimum of the inverse of the scaling factors + - scale_inv_max - maximum of the inverse of the scaling factors + - scale_inv_std - population standard deviation of the inverse of the scaling factors; + useful for spotting clipping that min/max alone can miss tensors/tensors_struct: List[str] list of tensors to log @@ -85,13 +96,18 @@ class LogNvfp4TensorStats(BaseLogTensorStats): def check_if_stat_is_supported(self, stat: str): """Returns True if stat is supported, raises ValueError otherwise.""" + bare = stat[: -len("_columnwise")] if stat.endswith("_columnwise") else stat supported_stats = [ "underflows%", "mse", + "scale_inv_min", + "scale_inv_max", + "scale_inv_std", ] - if stat not in supported_stats: + if bare not in supported_stats: raise ValueError( f"Stat {stat} is not supported for NVFP4. Supported stats: {supported_stats}" + " (any of these may take an optional '_columnwise' suffix)" ) return True @@ -190,7 +206,9 @@ def inspect_tensor( end_step = config.get("end_step", None) start_end_list = config.get("start_end_list", None) if start_end_list is not None: - start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) + start_end_list = tuple( + tuple(int(x) for x in interval) for interval in start_end_list + ) options = ( start_step, @@ -199,8 +217,8 @@ def inspect_tensor( "nvfp4", ) - skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( - tensor_name, tp_group, tp_size + skip_reduction, reduction_group, reduce_within_microbatch = ( + get_reduction_params(tensor_name, tp_group, tp_size) ) # Add nvfp4_ prefix to all stats for internal use diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index b0002ffee6..98625e4a5d 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -16,7 +16,9 @@ class BlockwiseDynamicRangeStat( - namedtuple("BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"]) + namedtuple( + "BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"] + ) ): """Named tuple representing a blockwise dynamic range statistic configuration.""" @@ -99,7 +101,9 @@ def _compute_for_one_orientation(tensor): .reshape(-1, block_size, block_size) ) per_block_amax = tensor.amax(dim=(1, 2)) - per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin( + dim=(1, 2) + ) # Identify blocks that contain any non-zero element nonzero_blocks = per_block_amax != 0 @@ -115,7 +119,9 @@ def _compute_for_one_orientation(tensor): if max_over_orientations: return max( _compute_for_one_orientation(tensor_2d), # Rowwise orientation - _compute_for_one_orientation(tensor_2d.transpose(-2, -1)), # Columnwise orientation + _compute_for_one_orientation( + tensor_2d.transpose(-2, -1) + ), # Columnwise orientation ) return _compute_for_one_orientation(tensor_2d) @@ -125,7 +131,9 @@ def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) means = sums / numels - var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels) + var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum( + numels + ) return var @@ -207,15 +215,26 @@ def _get(buffers, stat_name): } STATS = { - "min": (lambda x, aux_dict: torch.min(x), lambda buffers: min(_get(buffers, "min"))), - "max": (lambda x, aux_dict: torch.max(x), lambda buffers: max(_get(buffers, "max"))), - "sum": (lambda x, aux_dict: torch.sum(x), lambda buffers: sum(_get(buffers, "sum"))), + "min": ( + lambda x, aux_dict: torch.min(x), + lambda buffers: min(_get(buffers, "min")), + ), + "max": ( + lambda x, aux_dict: torch.max(x), + lambda buffers: max(_get(buffers, "max")), + ), + "sum": ( + lambda x, aux_dict: torch.sum(x), + lambda buffers: sum(_get(buffers, "sum")), + ), "mean": ( lambda x, aux_dict: torch.mean(x), lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")), ), "numel": ( - lambda x, aux_dict: x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel(), + lambda x, aux_dict: ( + x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel() + ), lambda buffers: sum(_get(buffers, "numel")), ), "l1_norm": ( @@ -236,7 +255,10 @@ def _get(buffers, stat_name): _get(buffers, "variance"), _get(buffers, "numel"), _get(buffers, "sum") ), ), - "cur_amax": (lambda x, aux_dict: x.abs().max(), lambda buffers: max(_get(buffers, "cur_amax"))), + "cur_amax": ( + lambda x, aux_dict: x.abs().max(), + lambda buffers: max(_get(buffers, "cur_amax")), + ), "dynamic_range_top": ( lambda x, aux_dict: _compute_dynamic_range_top(x), lambda buffers: max(_get(buffers, "dynamic_range_top")), @@ -252,7 +274,8 @@ def _get(buffers, stat_name): ), ), "dynamic_range": ( - lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), + lambda x, aux_dict: _compute_dynamic_range_top(x) + - _compute_dynamic_range_bottom(x), lambda buffers: max(_get(buffers, "dynamic_range_top")) - min(_get(buffers, "dynamic_range_bottom")), ), @@ -280,7 +303,9 @@ def _get(buffers, stat_name): lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]) / x.numel() * 100, - lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")), + lambda buffers: 100 + * sum(_get(buffers, "overflows_num")) + / sum(_get(buffers, "numel")), ), } @@ -290,7 +315,9 @@ def _get(buffers, stat_name): def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: """Count the number of non-zero elements in the fp8 data.""" fp8_data = fp8_data.view(dtype=torch.uint8) - zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) + zero_vals = torch.tensor( + [0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8 + ) return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() @@ -300,7 +327,9 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): # Stat names stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}" - stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" + stat_pct = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" + ) stats_to_num[stat_num] = len(stats_to_num) stats_to_num[stat_pct] = len(stats_to_num) @@ -335,11 +364,13 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): def add_scale_inv_stats(recipe_name: str, columnwise: bool = False): - """Register *both* scale-inv min and max stats for a given recipe. + """Register scale-inv min/max/std stats for a given recipe. - This replaces the earlier separate helpers and avoids duplicated boilerplate. + The std uses Welford's algorithm to combine partial variances across + microbatches/ranks, so helper buffers for variance/numel/sum are also + registered. Population variance (unbiased=False) is used so single-element + scale_inv tensors (delayed/current scaling) yield std=0 rather than NaN. """ - # Determine which attribute holds the scale-inverse tensor. def get_scale_inv(quantized_tensor, columnwise): if hasattr(quantized_tensor, "_scale_inv"): @@ -348,31 +379,76 @@ def get_scale_inv(quantized_tensor, columnwise): return getattr(quantized_tensor, "_columnwise_scale_inv") return getattr(quantized_tensor, "_rowwise_scale_inv") + def _prefix(): + return f"{recipe_name}{'_' if recipe_name != '' else ''}" + columnwise_suffix = "_columnwise" if columnwise else "" - # Prepare stat names. - stat_name_min = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_min{columnwise_suffix}" - ) - stat_name_max = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}scale_inv_max{columnwise_suffix}" - ) + stat_name_min = f"{_prefix()}scale_inv_min{columnwise_suffix}" + stat_name_max = f"{_prefix()}scale_inv_max{columnwise_suffix}" + stat_name_std = f"{_prefix()}scale_inv_std{columnwise_suffix}" + stat_name_var = f"{_prefix()}scale_inv_variance{columnwise_suffix}" + stat_name_numel = f"{_prefix()}scale_inv_numel{columnwise_suffix}" + stat_name_sum = f"{_prefix()}scale_inv_sum{columnwise_suffix}" # Assign indices in `stats_to_num` (order matters — keep insertion order deterministic). - stats_to_num[stat_name_min] = len(stats_to_num) - stats_to_num[stat_name_max] = len(stats_to_num) + for name in ( + stat_name_min, + stat_name_max, + stat_name_std, + stat_name_var, + stat_name_numel, + stat_name_sum, + ): + stats_to_num[name] = len(stats_to_num) # Capture the attribute name inside lambdas via default args to avoid late binding. STATS[stat_name_min] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(), + lambda x, aux_dict, _col=columnwise: get_scale_inv( + aux_dict[recipe_name], _col + ).min(), lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), ) STATS[stat_name_max] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), + lambda x, aux_dict, _col=columnwise: get_scale_inv( + aux_dict[recipe_name], _col + ).max(), lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), ) + STATS[stat_name_var] = ( + lambda x, aux_dict, _col=columnwise: torch.var( + get_scale_inv(aux_dict[recipe_name], _col).float(), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_variance( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss) + ), + ) + STATS[stat_name_numel] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv( + aux_dict[recipe_name], _col + ).numel(), + lambda buffers, _sn=stat_name_numel: sum(_get(buffers, _sn)), + ) + STATS[stat_name_sum] = ( + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col) + .float() + .sum(), + lambda buffers, _ss=stat_name_sum: sum(_get(buffers, _ss)), + ) + STATS[stat_name_std] = ( + lambda x, aux_dict, _col=columnwise: torch.std( + get_scale_inv(aux_dict[recipe_name], _col).float(), unbiased=False + ), + lambda buffers, _sv=stat_name_var, _sn=stat_name_numel, _ss=stat_name_sum: compute_std( + _get(buffers, _sv), _get(buffers, _sn), _get(buffers, _ss) + ), + ) DEPENDENCIES[stat_name_min] = {stat_name_min} DEPENDENCIES[stat_name_max] = {stat_name_max} + DEPENDENCIES[stat_name_numel] = {stat_name_numel} + DEPENDENCIES[stat_name_sum] = {stat_name_sum} + DEPENDENCIES[stat_name_var] = {stat_name_var, stat_name_numel, stat_name_sum} + DEPENDENCIES[stat_name_std] = {stat_name_var, stat_name_numel, stat_name_sum} def add_mse_stats(recipe_name: str, columnwise: bool = False): @@ -380,20 +456,22 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): columnwise_suffix = "_columnwise" if columnwise else "" stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}" - stat_err = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" - ) + stat_err = f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" stats_to_num[stat_mse] = len(stats_to_num) stats_to_num[stat_err] = len(stats_to_num) STATS[stat_mse] = ( - lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"), + lambda x, aux_dict: F.mse_loss( + x, aux_dict[recipe_name].dequantize(), reduction="mean" + ), lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)) / sum(_get(buffers, "numel")), ) STATS[stat_err] = ( - lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"), + lambda x, aux_dict: F.mse_loss( + x, aux_dict[recipe_name].dequantize(), reduction="sum" + ), lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)), ) @@ -425,7 +503,9 @@ def add_max_blockwise_dynamic_range_stats( DEPENDENCIES[stat_key] = {stat_key} STATS[stat_key] = ( - lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range(x, _stat_key), + lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range( + x, _stat_key + ), lambda buffers, _stat_key=stat_key: max(_get(buffers, _stat_key)), ) @@ -505,3 +585,5 @@ def add_nvfp4_underflows_stats(): # Register NVFP4 stats add_nvfp4_underflows_stats() add_mse_stats("nvfp4") # Reuse existing MSE function +for _columnwise in [True, False]: + add_scale_inv_stats("nvfp4", _columnwise) From 441a793b35b5ae18100a781abfb2ec28f390caef Mon Sep 17 00:00:00 2001 From: "pre-commit-ci[bot]" <66853113+pre-commit-ci[bot]@users.noreply.github.com> Date: Tue, 26 May 2026 13:13:17 +0000 Subject: [PATCH 2/3] [pre-commit.ci] auto fixes from pre-commit.com hooks for more information, see https://pre-commit.ci --- tests/pytorch/debug/test_log.py | 84 +++++-------------- .../debug/features/log_fp8_tensor_stats.py | 34 ++------ .../debug/features/log_nvfp4_tensor_stats.py | 8 +- .../debug/features/utils/stats_computation.py | 63 ++++---------- 4 files changed, 52 insertions(+), 137 deletions(-) diff --git a/tests/pytorch/debug/test_log.py b/tests/pytorch/debug/test_log.py index 7beed1a2e0..14a087da5a 100644 --- a/tests/pytorch/debug/test_log.py +++ b/tests/pytorch/debug/test_log.py @@ -28,8 +28,8 @@ fp8_available, reason_for_no_fp8 = is_fp8_available(return_reason=True) mxfp8_available, reason_for_no_mxfp8 = is_mxfp8_available(return_reason=True) -fp8_block_scaling_available, reason_for_no_fp8_block_scaling = ( - is_fp8_block_scaling_available(return_reason=True) +fp8_block_scaling_available, reason_for_no_fp8_block_scaling = is_fp8_block_scaling_available( + return_reason=True ) nvfp4_available, reason_for_no_nvfp4 = is_nvfp4_available(return_reason=True) @@ -88,9 +88,7 @@ all_stats.append(f"{r}_{stat}{columnwise_postfix}") -all_stats.append( - "fp8_delayed_scaling_overflows%" -) # only delayed-scaling supports overflows% +all_stats.append("fp8_delayed_scaling_overflows%") # only delayed-scaling supports overflows% @contextlib.contextmanager @@ -224,9 +222,7 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if not fp8_block_scaling_available and fp8_recipe == recipe.Float8BlockScaling(): pytest.skip(reason_for_no_fp8_block_scaling) - log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format( - stats=", ".join(bare_stats) - ) + log_only_bare_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(bare_stats)) with debug_session(log_only_bare_stats_config, feature_dirs) as log_dir: recipe_state = RecipeState.create( @@ -263,36 +259,26 @@ def test_log_quantized_stats_numerics(fp8_recipe, feature_dirs): if "underflows%" in line: underflows = float(line.split("value=")[1]) expected = ( - ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) - / tensor.numel() - * 100 + ((dequantized_tensor == 0).sum() - (tensor == 0).sum()) / tensor.numel() * 100 ) assert underflows == pytest.approx(expected.cpu(), abs=1e-4) if "mse" in line: mse = float(line.split("value=")[1]) - expected = torch.nn.functional.mse_loss( - dequantized_tensor, tensor, reduction="mean" - ) + expected = torch.nn.functional.mse_loss(dequantized_tensor, tensor, reduction="mean") assert mse == pytest.approx(expected.cpu(), abs=1e-4) if "overflows%" in line: overflows = float(line.split("value=")[1]) expected = ( - (abs(dequantized_tensor) > abs(tensor)).sum() - / dequantized_tensor.numel() - * 100 + (abs(dequantized_tensor) > abs(tensor)).sum() / dequantized_tensor.numel() * 100 ) assert overflows == pytest.approx(expected.cpu(), abs=1e-4) # Rowwise scale_inv stats only; logger formats with {:.4f} so abs<1e-4. if "scale_inv_min" in line and "_columnwise" not in line: value = float(line.split("value=")[1]) - assert value == pytest.approx( - scale_inv_rowwise.min().cpu().item(), abs=1e-4 - ) + assert value == pytest.approx(scale_inv_rowwise.min().cpu().item(), abs=1e-4) if "scale_inv_max" in line and "_columnwise" not in line: value = float(line.split("value=")[1]) - assert value == pytest.approx( - scale_inv_rowwise.max().cpu().item(), abs=1e-4 - ) + assert value == pytest.approx(scale_inv_rowwise.max().cpu().item(), abs=1e-4) if "scale_inv_std" in line and "_columnwise" not in line: value = float(line.split("value=")[1]) expected = torch.std(scale_inv_rowwise, unbiased=False).cpu().item() @@ -358,9 +344,7 @@ def test_log_stats_numerics(feature_dirs, tensor_name): output = read_log(log_dir) max_over_orientations = tensor_name in ["activation", "weight"] - max_over_orientations_suffix = ( - "_max_over_orientations" if max_over_orientations else "" - ) + max_over_orientations_suffix = "_max_over_orientations" if max_over_orientations else "" # Track which stats were found to ensure all are present found_dims_1 = False @@ -368,13 +352,8 @@ def test_log_stats_numerics(feature_dirs, tensor_name): found_dynamic_range = False for line in output.splitlines(): - if ( - f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" - in line - ): - max_blockwise_dynamic_range_block_size_4_dims_1 = float( - line.split("value=")[1] - ) + if f"max_blockwise_dynamic_range_block_size_4_dims_1{max_over_orientations_suffix}" in line: + max_blockwise_dynamic_range_block_size_4_dims_1 = float(line.split("value=")[1]) if max_over_orientations: # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) expected = math.log2(A) - math.log2(B) @@ -386,12 +365,9 @@ def test_log_stats_numerics(feature_dirs, tensor_name): ) found_dims_1 = True elif ( - f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" - in line + f"max_blockwise_dynamic_range_block_size_4_dims_2{max_over_orientations_suffix}" in line ): - max_blockwise_dynamic_range_block_size_4_dims_2 = float( - line.split("value=")[1] - ) + max_blockwise_dynamic_range_block_size_4_dims_2 = float(line.split("value=")[1]) # For 2D blocks (4x4 tiles), blocks always contain mixed values from different rows expected = math.log2(A) - math.log2(B) assert max_blockwise_dynamic_range_block_size_4_dims_2 == pytest.approx( @@ -580,9 +556,7 @@ def test_log_grouped_gemm(feature_dirs): log_all_stats_config = LOG_QUANTIZED_CONFIG_BASE.format(stats=", ".join(all_stats)) with debug_session(log_all_stats_config, feature_dirs) as log_dir: - model = te.GroupedLinear( - 3, 128, 128, name="linear1", params_dtype=torch.bfloat16 - ) + model = te.GroupedLinear(3, 128, 128, name="linear1", params_dtype=torch.bfloat16) inp = torch.randn((1, 128, 128), dtype=torch.bfloat16).cuda() m_splits = [64, 32, 32] with te.fp8_autocast(fp8_recipe=recipe.DelayedScaling()): @@ -615,9 +589,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 1: dims=1, max_over_orientations=False (rowwise only) # Rowwise blocks have uniform values -> dynamic_range should be 0 - stat_config = BlockwiseDynamicRangeStat( - block_size=4, dims=1, max_over_orientations=False - ) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=False) result = compute_max_blockwise_dynamic_range(tensor, stat_config) assert result.item() == pytest.approx( 0.0, abs=1e-4 @@ -625,9 +597,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 2: dims=1, max_over_orientations=True (max of rowwise and columnwise) # Columnwise blocks have mixed values [A, B, B, B] -> dynamic_range = log2(A/B) - stat_config = BlockwiseDynamicRangeStat( - block_size=4, dims=1, max_over_orientations=True - ) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(B) assert result.item() == pytest.approx(expected, abs=1e-4), ( @@ -637,9 +607,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 3: dims=2, block_size=4 (4x4 tiles) # 2D blocks span multiple rows -> always have mixed values - stat_config = BlockwiseDynamicRangeStat( - block_size=4, dims=2, max_over_orientations=False - ) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=2, max_over_orientations=False) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(B) assert result.item() == pytest.approx(expected, abs=1e-4), ( @@ -650,9 +618,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 4: Different block size # With block_size=8, columnwise blocks contain [A, B, B, B, epsilon, epsilon, epsilon, epsilon] # So max=A, min=epsilon (not B anymore) - stat_config = BlockwiseDynamicRangeStat( - block_size=8, dims=1, max_over_orientations=True - ) + stat_config = BlockwiseDynamicRangeStat(block_size=8, dims=1, max_over_orientations=True) result = compute_max_blockwise_dynamic_range(tensor, stat_config) expected = math.log2(A) - math.log2(epsilon) # min is epsilon, not B assert result.item() == pytest.approx( @@ -661,9 +627,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): # Test 5: Tensor with all uniform values -> dynamic_range should be 0 uniform_tensor = torch.ones(64, 64).cuda() * 42.0 - stat_config = BlockwiseDynamicRangeStat( - block_size=4, dims=1, max_over_orientations=True - ) + stat_config = BlockwiseDynamicRangeStat(block_size=4, dims=1, max_over_orientations=True) result = compute_max_blockwise_dynamic_range(uniform_tensor, stat_config) assert result.item() == pytest.approx( 0.0, abs=1e-4 @@ -682,9 +646,7 @@ def test_compute_max_blockwise_dynamic_range_direct(): ).cuda() # Compute on 2D tensor: 4 blocks of 2x2, max range is log2(1000/100) - stat_config = BlockwiseDynamicRangeStat( - block_size=2, dims=2, max_over_orientations=False - ) + stat_config = BlockwiseDynamicRangeStat(block_size=2, dims=2, max_over_orientations=False) result_2d = compute_max_blockwise_dynamic_range(tensor_2d, stat_config) # Reshape to 3D [2, 2, 4] and compute - should give same result if flattening is correct @@ -770,9 +732,7 @@ def test_dump_tensors_sanity(feature_dirs): ), f"Expected QuantizedTensor, got {type(data['quantized'])}" # Verify tensor shapes and values match - assert ( - data["high_precision"].shape == tensor.shape - ), "high_precision shape mismatch" + assert data["high_precision"].shape == tensor.shape, "high_precision shape mismatch" assert torch.equal( data["high_precision"], tensor ), "high_precision tensor values do not match original tensor" diff --git a/transformer_engine/debug/features/log_fp8_tensor_stats.py b/transformer_engine/debug/features/log_fp8_tensor_stats.py index fa0770a46e..f453b2a36a 100644 --- a/transformer_engine/debug/features/log_fp8_tensor_stats.py +++ b/transformer_engine/debug/features/log_fp8_tensor_stats.py @@ -187,9 +187,7 @@ def check_if_stat_is_supported( columnwise = stat.endswith("_columnwise") if columnwise: stat = stat[: -len("_columnwise")] - recipe_from_stat, _ = self.get_recipe_from_stat( - stat, default_recipe=current_recipe - ) + recipe_from_stat, _ = self.get_recipe_from_stat(stat, default_recipe=current_recipe) stat_without_recipe = stat.replace(recipe_from_stat + "_", "") need_high_precision_tensor_stats = ["underflows%", "overflows%", "mse"] @@ -209,32 +207,21 @@ def check_if_stat_is_supported( ) if recipe_from_stat != "" and recipe_from_stat not in ALL_RECIPE_NAMES: - raise ValueError( - f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}" - ) + raise ValueError(f"Stat {stat} contains an unsupported recipe name: {recipe_from_stat}") # NVFP4-resolved stats are filtered out before this point in inspect_tensor(). assert recipe_from_stat != "nvfp4" - if ( - recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] - and columnwise - ): + if recipe_from_stat in ["fp8_delayed_scaling", "fp8_current_scaling"] and columnwise: raise ValueError( f"Stat {stat} is not supported. Columnwise tensor statistics are not supported for" " fp8_delayed_scaling and fp8_current_scaling." ) - if ( - recipe_from_stat == "fp8_delayed_scaling" - and stat_without_recipe == "overflows%" - ): + if recipe_from_stat == "fp8_delayed_scaling" and stat_without_recipe == "overflows%": return True - if ( - recipe_from_stat in ["fp8_block_scaling"] - and torch.cuda.get_device_capability()[0] < 9 - ): + if recipe_from_stat in ["fp8_block_scaling"] and torch.cuda.get_device_capability()[0] < 9: raise ValueError(f"Stat {stat} needs Hopper or later GPU.") if recipe_from_stat == "mxfp8" and torch.cuda.get_device_capability()[0] < 10: @@ -404,9 +391,7 @@ def inspect_tensor( end_step = config.get("end_step", None) start_end_list = config.get("start_end_list", None) if start_end_list is not None: - start_end_list = tuple( - tuple(int(x) for x in interval) for interval in start_end_list - ) + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) options = ( start_step, @@ -415,8 +400,8 @@ def inspect_tensor( "fp8", ) - skip_reduction, reduction_group, reduce_within_microbatch = ( - get_reduction_params(tensor_name, tp_group, tp_size) + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group, tp_size ) STATS_BUFFERS.try_add_buffer( @@ -429,8 +414,7 @@ def inspect_tensor( ) recipes_in_stats = [ - self.get_recipe_from_stat(stat, default_recipe=recipe_name) - for stat in config["stats"] + self.get_recipe_from_stat(stat, default_recipe=recipe_name) for stat in config["stats"] ] with self.update_aux_dict( diff --git a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py index 38c1578f9a..848dfa8ab7 100644 --- a/transformer_engine/debug/features/log_nvfp4_tensor_stats.py +++ b/transformer_engine/debug/features/log_nvfp4_tensor_stats.py @@ -206,9 +206,7 @@ def inspect_tensor( end_step = config.get("end_step", None) start_end_list = config.get("start_end_list", None) if start_end_list is not None: - start_end_list = tuple( - tuple(int(x) for x in interval) for interval in start_end_list - ) + start_end_list = tuple(tuple(int(x) for x in interval) for interval in start_end_list) options = ( start_step, @@ -217,8 +215,8 @@ def inspect_tensor( "nvfp4", ) - skip_reduction, reduction_group, reduce_within_microbatch = ( - get_reduction_params(tensor_name, tp_group, tp_size) + skip_reduction, reduction_group, reduce_within_microbatch = get_reduction_params( + tensor_name, tp_group, tp_size ) # Add nvfp4_ prefix to all stats for internal use diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 98625e4a5d..80c3041f41 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -16,9 +16,7 @@ class BlockwiseDynamicRangeStat( - namedtuple( - "BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"] - ) + namedtuple("BlockwiseDynamicRangeStat", ["block_size", "dims", "max_over_orientations"]) ): """Named tuple representing a blockwise dynamic range statistic configuration.""" @@ -101,9 +99,7 @@ def _compute_for_one_orientation(tensor): .reshape(-1, block_size, block_size) ) per_block_amax = tensor.amax(dim=(1, 2)) - per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin( - dim=(1, 2) - ) + per_block_amin = tensor.masked_fill(tensor == 0, float("inf")).amin(dim=(1, 2)) # Identify blocks that contain any non-zero element nonzero_blocks = per_block_amax != 0 @@ -119,9 +115,7 @@ def _compute_for_one_orientation(tensor): if max_over_orientations: return max( _compute_for_one_orientation(tensor_2d), # Rowwise orientation - _compute_for_one_orientation( - tensor_2d.transpose(-2, -1) - ), # Columnwise orientation + _compute_for_one_orientation(tensor_2d.transpose(-2, -1)), # Columnwise orientation ) return _compute_for_one_orientation(tensor_2d) @@ -131,9 +125,7 @@ def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) means = sums / numels - var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum( - numels - ) + var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels) return var @@ -232,9 +224,7 @@ def _get(buffers, stat_name): lambda buffers: sum(_get(buffers, "sum")) / sum(_get(buffers, "numel")), ), "numel": ( - lambda x, aux_dict: ( - x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel() - ), + lambda x, aux_dict: (x.numel() if hasattr(x, "numel") else x.get_data_tensors()[0].numel()), lambda buffers: sum(_get(buffers, "numel")), ), "l1_norm": ( @@ -274,8 +264,7 @@ def _get(buffers, stat_name): ), ), "dynamic_range": ( - lambda x, aux_dict: _compute_dynamic_range_top(x) - - _compute_dynamic_range_bottom(x), + lambda x, aux_dict: _compute_dynamic_range_top(x) - _compute_dynamic_range_bottom(x), lambda buffers: max(_get(buffers, "dynamic_range_top")) - min(_get(buffers, "dynamic_range_bottom")), ), @@ -303,9 +292,7 @@ def _get(buffers, stat_name): lambda x, aux_dict: compute_fp8_delayed_scaling_overflows_num(x, aux_dict[""]) / x.numel() * 100, - lambda buffers: 100 - * sum(_get(buffers, "overflows_num")) - / sum(_get(buffers, "numel")), + lambda buffers: 100 * sum(_get(buffers, "overflows_num")) / sum(_get(buffers, "numel")), ), } @@ -315,9 +302,7 @@ def _get(buffers, stat_name): def count_nonzero_fp8(fp8_data: torch.Tensor) -> torch.Tensor: """Count the number of non-zero elements in the fp8 data.""" fp8_data = fp8_data.view(dtype=torch.uint8) - zero_vals = torch.tensor( - [0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8 - ) + zero_vals = torch.tensor([0, FP8_NEGATIVE_ZERO], device=fp8_data.device, dtype=torch.uint8) return fp8_data.numel() - torch.isin(fp8_data, zero_vals).sum() @@ -327,9 +312,7 @@ def add_underflows_stats(recipe_name: str, columnwise: bool = False): # Stat names stat_num = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows_num{columnwise_suffix}" - stat_pct = ( - f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" - ) + stat_pct = f"{recipe_name}{'_' if recipe_name != '' else ''}underflows%{columnwise_suffix}" stats_to_num[stat_num] = len(stats_to_num) stats_to_num[stat_pct] = len(stats_to_num) @@ -403,15 +386,11 @@ def _prefix(): # Capture the attribute name inside lambdas via default args to avoid late binding. STATS[stat_name_min] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv( - aux_dict[recipe_name], _col - ).min(), + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).min(), lambda buffers, _sn=stat_name_min: min(_get(buffers, _sn)), ) STATS[stat_name_max] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv( - aux_dict[recipe_name], _col - ).max(), + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).max(), lambda buffers, _sn=stat_name_max: max(_get(buffers, _sn)), ) STATS[stat_name_var] = ( @@ -423,9 +402,7 @@ def _prefix(): ), ) STATS[stat_name_numel] = ( - lambda x, aux_dict, _col=columnwise: get_scale_inv( - aux_dict[recipe_name], _col - ).numel(), + lambda x, aux_dict, _col=columnwise: get_scale_inv(aux_dict[recipe_name], _col).numel(), lambda buffers, _sn=stat_name_numel: sum(_get(buffers, _sn)), ) STATS[stat_name_sum] = ( @@ -456,22 +433,20 @@ def add_mse_stats(recipe_name: str, columnwise: bool = False): columnwise_suffix = "_columnwise" if columnwise else "" stat_mse = f"{recipe_name}{'_' if recipe_name != '' else ''}mse{columnwise_suffix}" - stat_err = f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" + stat_err = ( + f"{recipe_name}{'_' if recipe_name != '' else ''}total_square_error{columnwise_suffix}" + ) stats_to_num[stat_mse] = len(stats_to_num) stats_to_num[stat_err] = len(stats_to_num) STATS[stat_mse] = ( - lambda x, aux_dict: F.mse_loss( - x, aux_dict[recipe_name].dequantize(), reduction="mean" - ), + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="mean"), lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)) / sum(_get(buffers, "numel")), ) STATS[stat_err] = ( - lambda x, aux_dict: F.mse_loss( - x, aux_dict[recipe_name].dequantize(), reduction="sum" - ), + lambda x, aux_dict: F.mse_loss(x, aux_dict[recipe_name].dequantize(), reduction="sum"), lambda buffers, _sn_err=stat_err: torch.sum(_get(buffers, _sn_err)), ) @@ -503,9 +478,7 @@ def add_max_blockwise_dynamic_range_stats( DEPENDENCIES[stat_key] = {stat_key} STATS[stat_key] = ( - lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range( - x, _stat_key - ), + lambda x, aux_dict, _stat_key=stat_key: compute_max_blockwise_dynamic_range(x, _stat_key), lambda buffers, _stat_key=stat_key: max(_get(buffers, _stat_key)), ) From 4149e8150271d8808589eccd23591c226e42b5f8 Mon Sep 17 00:00:00 2001 From: Pawel Gadzinski Date: Fri, 29 May 2026 12:13:35 +0200 Subject: [PATCH 3/3] fix sign in parallel-axis Welford variance combination Parallel-group variance is Sigma n_i*(var_i + (mean_i - mean)^2) / N - the between-group term must be added, not subtracted. Single-group buffers hide the bug (mean_i = mean_global so the term is 0); it surfaces with scale_inv_std reduced across microbatches/ranks, where negative variance flows into sqrt() and yields NaN. Signed-off-by: Pawel Gadzinski --- transformer_engine/debug/features/utils/stats_computation.py | 2 +- 1 file changed, 1 insertion(+), 1 deletion(-) diff --git a/transformer_engine/debug/features/utils/stats_computation.py b/transformer_engine/debug/features/utils/stats_computation.py index 80c3041f41..aabb7d6959 100644 --- a/transformer_engine/debug/features/utils/stats_computation.py +++ b/transformer_engine/debug/features/utils/stats_computation.py @@ -125,7 +125,7 @@ def compute_variance(variances, numels, sums): """Welford algorithm is used for numerically stable distributed variance computation.""" mean = torch.sum(sums) / torch.sum(numels) means = sums / numels - var = torch.sum(numels * (variances - torch.pow((means - mean), 2))) / torch.sum(numels) + var = torch.sum(numels * (variances + torch.pow((means - mean), 2))) / torch.sum(numels) return var